euIaxs22 commited on
Commit
4e8356d
·
verified ·
1 Parent(s): 98528dd

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +735 -0
inference.py ADDED
@@ -0,0 +1,735 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ from diffusers.utils import logging
6
+ from typing import Optional, List, Union
7
+ import yaml
8
+
9
+ import imageio
10
+ import json
11
+ import numpy as np
12
+ import torch
13
+ from safetensors import safe_open
14
+ from PIL import Image
15
+ import torchvision.transforms.functional as TVF
16
+ from transformers import (
17
+ T5EncoderModel,
18
+ T5Tokenizer,
19
+ AutoModelForCausalLM,
20
+ AutoProcessor,
21
+ AutoTokenizer,
22
+ )
23
+ from huggingface_hub import hf_hub_download
24
+ from dataclasses import dataclass, field
25
+
26
+ from ltx_video.models.autoencoders.causal_video_autoencoder import (
27
+ CausalVideoAutoencoder,
28
+ )
29
+ from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
30
+ from ltx_video.models.transformers.transformer3d import Transformer3DModel
31
+ from ltx_video.pipelines.pipeline_ltx_video import (
32
+ ConditioningItem,
33
+ LTXVideoPipeline,
34
+ LTXMultiScalePipeline,
35
+ )
36
+ from ltx_video.schedulers.rf import RectifiedFlowScheduler
37
+ from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
38
+ from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
39
+ import ltx_video.pipelines.crf_compressor as crf_compressor
40
+
41
+ logger = logging.get_logger("LTX-Video")
42
+
43
+
44
+ def get_total_gpu_memory():
45
+ if torch.cuda.is_available():
46
+ total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
47
+ return total_memory
48
+ return 0
49
+
50
+
51
+ def get_device():
52
+ if torch.cuda.is_available():
53
+ return "cuda"
54
+ elif torch.backends.mps.is_available():
55
+ return "mps"
56
+ return "cpu"
57
+
58
+
59
+ def load_image_to_tensor_with_resize_and_crop(
60
+ image_input: Union[str, Image.Image],
61
+ target_height: int = 512,
62
+ target_width: int = 768,
63
+ just_crop: bool = False,
64
+ ) -> torch.Tensor:
65
+ """Load and process an image into a tensor.
66
+
67
+ Args:
68
+ image_input: Either a file path (str) or a PIL Image object
69
+ target_height: Desired height of output tensor
70
+ target_width: Desired width of output tensor
71
+ just_crop: If True, only crop the image to the target size without resizing
72
+ """
73
+ if isinstance(image_input, str):
74
+ image = Image.open(image_input).convert("RGB")
75
+ elif isinstance(image_input, Image.Image):
76
+ image = image_input
77
+ else:
78
+ raise ValueError("image_input must be either a file path or a PIL Image object")
79
+
80
+ input_width, input_height = image.size
81
+ aspect_ratio_target = target_width / target_height
82
+ aspect_ratio_frame = input_width / input_height
83
+ if aspect_ratio_frame > aspect_ratio_target:
84
+ new_width = int(input_height * aspect_ratio_target)
85
+ new_height = input_height
86
+ x_start = (input_width - new_width) // 2
87
+ y_start = 0
88
+ else:
89
+ new_width = input_width
90
+ new_height = int(input_width / aspect_ratio_target)
91
+ x_start = 0
92
+ y_start = (input_height - new_height) // 2
93
+
94
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
95
+ if not just_crop:
96
+ image = image.resize((target_width, target_height))
97
+
98
+ frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W), [0,1]
99
+ frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=3, sigma=1.0)
100
+ frame_tensor_hwc = frame_tensor.permute(1, 2, 0) # (C, H, W) -> (H, W, C)
101
+ frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc)
102
+ frame_tensor = frame_tensor_hwc.permute(2, 0, 1) * 255.0 # (H, W, C) -> (C, H, W)
103
+ frame_tensor = (frame_tensor / 127.5) - 1.0
104
+ # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
105
+ return frame_tensor.unsqueeze(0).unsqueeze(2)
106
+
107
+
108
+ def calculate_padding(
109
+ source_height: int, source_width: int, target_height: int, target_width: int
110
+ ) -> tuple[int, int, int, int]:
111
+
112
+ # Calculate total padding needed
113
+ pad_height = target_height - source_height
114
+ pad_width = target_width - source_width
115
+
116
+ # Calculate padding for each side
117
+ pad_top = pad_height // 2
118
+ pad_bottom = pad_height - pad_top # Handles odd padding
119
+ pad_left = pad_width // 2
120
+ pad_right = pad_width - pad_left # Handles odd padding
121
+
122
+ # Return padded tensor
123
+ # Padding format is (left, right, top, bottom)
124
+ padding = (pad_left, pad_right, pad_top, pad_bottom)
125
+ return padding
126
+
127
+
128
+ def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
129
+ # Remove non-letters and convert to lowercase
130
+ clean_text = "".join(
131
+ char.lower() for char in text if char.isalpha() or char.isspace()
132
+ )
133
+
134
+ # Split into words
135
+ words = clean_text.split()
136
+
137
+ # Build result string keeping track of length
138
+ result = []
139
+ current_length = 0
140
+
141
+ for word in words:
142
+ # Add word length plus 1 for underscore (except for first word)
143
+ new_length = current_length + len(word)
144
+
145
+ if new_length <= max_len:
146
+ result.append(word)
147
+ current_length += len(word)
148
+ else:
149
+ break
150
+
151
+ return "-".join(result)
152
+
153
+
154
+ # Generate output video name
155
+ def get_unique_filename(
156
+ base: str,
157
+ ext: str,
158
+ prompt: str,
159
+ seed: int,
160
+ resolution: tuple[int, int, int],
161
+ dir: Path,
162
+ endswith=None,
163
+ index_range=1000,
164
+ ) -> Path:
165
+ base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
166
+ for i in range(index_range):
167
+ filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
168
+ if not os.path.exists(filename):
169
+ return filename
170
+ raise FileExistsError(
171
+ f"Could not find a unique filename after {index_range} attempts."
172
+ )
173
+
174
+
175
+ def seed_everething(seed: int):
176
+ random.seed(seed)
177
+ np.random.seed(seed)
178
+ torch.manual_seed(seed)
179
+ if torch.cuda.is_available():
180
+ torch.cuda.manual_seed(seed)
181
+ if torch.backends.mps.is_available():
182
+ torch.mps.manual_seed(seed)
183
+
184
+
185
+ def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel:
186
+ if precision == "float8_e4m3fn":
187
+ try:
188
+ from q8_kernels.integration.patch_transformer import (
189
+ patch_diffusers_transformer as patch_transformer_for_q8_kernels,
190
+ )
191
+
192
+ transformer = Transformer3DModel.from_pretrained(
193
+ ckpt_path, dtype=torch.float8_e4m3fn
194
+ )
195
+ patch_transformer_for_q8_kernels(transformer)
196
+ return transformer
197
+ except ImportError:
198
+ raise ValueError(
199
+ "Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from https://github.com/Lightricks/LTXVideo-Q8-Kernels"
200
+ )
201
+ elif precision == "bfloat16":
202
+ return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16)
203
+ else:
204
+ return Transformer3DModel.from_pretrained(ckpt_path)
205
+
206
+
207
+ def create_ltx_video_pipeline(
208
+ ckpt_path: str,
209
+ precision: str,
210
+ text_encoder_model_name_or_path: str,
211
+ sampler: Optional[str] = None,
212
+ device: Optional[str] = None,
213
+ enhance_prompt: bool = False,
214
+ prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
215
+ prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
216
+ ) -> LTXVideoPipeline:
217
+ ckpt_path = Path(ckpt_path)
218
+ assert os.path.exists(
219
+ ckpt_path
220
+ ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
221
+
222
+ with safe_open(ckpt_path, framework="pt") as f:
223
+ metadata = f.metadata()
224
+ config_str = metadata.get("config")
225
+ configs = json.loads(config_str)
226
+ allowed_inference_steps = configs.get("allowed_inference_steps", None)
227
+
228
+ vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
229
+ transformer = create_transformer(ckpt_path, precision)
230
+
231
+ # Use constructor if sampler is specified, otherwise use from_pretrained
232
+ if sampler == "from_checkpoint" or not sampler:
233
+ scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
234
+ else:
235
+ scheduler = RectifiedFlowScheduler(
236
+ sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
237
+ )
238
+
239
+ text_encoder = T5EncoderModel.from_pretrained(
240
+ text_encoder_model_name_or_path, subfolder="text_encoder"
241
+ )
242
+ patchifier = SymmetricPatchifier(patch_size=1)
243
+ tokenizer = T5Tokenizer.from_pretrained(
244
+ text_encoder_model_name_or_path, subfolder="tokenizer"
245
+ )
246
+
247
+ transformer = transformer.to(device)
248
+ vae = vae.to(device)
249
+ text_encoder = text_encoder.to(device)
250
+
251
+ if enhance_prompt:
252
+ prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
253
+ prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
254
+ )
255
+ prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
256
+ prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
257
+ )
258
+ prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
259
+ prompt_enhancer_llm_model_name_or_path,
260
+ torch_dtype="bfloat16",
261
+ )
262
+ prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
263
+ prompt_enhancer_llm_model_name_or_path,
264
+ )
265
+ else:
266
+ prompt_enhancer_image_caption_model = None
267
+ prompt_enhancer_image_caption_processor = None
268
+ prompt_enhancer_llm_model = None
269
+ prompt_enhancer_llm_tokenizer = None
270
+
271
+ vae = vae.to(torch.bfloat16)
272
+ text_encoder = text_encoder.to(torch.bfloat16)
273
+
274
+ # Use submodels for the pipeline
275
+ submodel_dict = {
276
+ "transformer": transformer,
277
+ "patchifier": patchifier,
278
+ "text_encoder": text_encoder,
279
+ "tokenizer": tokenizer,
280
+ "scheduler": scheduler,
281
+ "vae": vae,
282
+ "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
283
+ "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
284
+ "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
285
+ "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
286
+ "allowed_inference_steps": allowed_inference_steps,
287
+ }
288
+
289
+ pipeline = LTXVideoPipeline(**submodel_dict)
290
+ pipeline = pipeline.to(device)
291
+ return pipeline
292
+
293
+
294
+ def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
295
+ latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
296
+ latent_upsampler.to(device)
297
+ latent_upsampler.eval()
298
+ return latent_upsampler
299
+
300
+
301
+ def load_pipeline_config(pipeline_config: str):
302
+ current_file = Path(__file__)
303
+
304
+ path = None
305
+ if os.path.isfile(current_file.parent / pipeline_config):
306
+ path = current_file.parent / pipeline_config
307
+ elif os.path.isfile(pipeline_config):
308
+ path = pipeline_config
309
+ else:
310
+ raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
311
+
312
+ with open(path, "r") as f:
313
+ return yaml.safe_load(f)
314
+
315
+
316
+ @dataclass
317
+ class InferenceConfig:
318
+ prompt: str = field(metadata={"help": "Prompt for the generation"})
319
+
320
+ output_path: str = field(
321
+ default_factory=lambda: Path(
322
+ f"outputs/{datetime.today().strftime('%Y-%m-%d')}"
323
+ ),
324
+ metadata={"help": "Path to the folder to save the output video"},
325
+ )
326
+
327
+ # Pipeline settings
328
+ pipeline_config: str = field(
329
+ default="configs/ltxv-13b-0.9.7-dev.yaml",
330
+ metadata={"help": "Path to the pipeline config file"},
331
+ )
332
+ seed: int = field(
333
+ default=171198, metadata={"help": "Random seed for the inference"}
334
+ )
335
+ height: int = field(
336
+ default=704, metadata={"help": "Height of the output video frames"}
337
+ )
338
+ width: int = field(
339
+ default=1216, metadata={"help": "Width of the output video frames"}
340
+ )
341
+ num_frames: int = field(
342
+ default=121,
343
+ metadata={"help": "Number of frames to generate in the output video"},
344
+ )
345
+ frame_rate: int = field(
346
+ default=30, metadata={"help": "Frame rate for the output video"}
347
+ )
348
+ offload_to_cpu: bool = field(
349
+ default=False, metadata={"help": "Offloading unnecessary computations to CPU."}
350
+ )
351
+ negative_prompt: str = field(
352
+ default="worst quality, inconsistent motion, blurry, jittery, distorted",
353
+ metadata={"help": "Negative prompt for undesired features"},
354
+ )
355
+
356
+ # Video-to-video arguments
357
+ input_media_path: Optional[str] = field(
358
+ default=None,
359
+ metadata={
360
+ "help": "Path to the input video (or image) to be modified using the video-to-video pipeline"
361
+ },
362
+ )
363
+
364
+ # Conditioning
365
+ image_cond_noise_scale: float = field(
366
+ default=0.15,
367
+ metadata={"help": "Amount of noise to add to the conditioned image"},
368
+ )
369
+ conditioning_media_paths: Optional[List[str]] = field(
370
+ default=None,
371
+ metadata={
372
+ "help": "List of paths to conditioning media (images or videos). Each path will be used as a conditioning item."
373
+ },
374
+ )
375
+ conditioning_strengths: Optional[List[float]] = field(
376
+ default=None,
377
+ metadata={
378
+ "help": "List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items."
379
+ },
380
+ )
381
+ conditioning_start_frames: Optional[List[int]] = field(
382
+ default=None,
383
+ metadata={
384
+ "help": "List of frame indices where each conditioning item should be applied. Must match the number of conditioning items."
385
+ },
386
+ )
387
+
388
+
389
+ def infer(config: InferenceConfig):
390
+ pipeline_config = load_pipeline_config(config.pipeline_config)
391
+
392
+ ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
393
+ if not os.path.isfile(ltxv_model_name_or_path):
394
+ ltxv_model_path = hf_hub_download(
395
+ repo_id="Lightricks/LTX-Video",
396
+ filename=ltxv_model_name_or_path,
397
+ repo_type="model",
398
+ )
399
+ else:
400
+ ltxv_model_path = ltxv_model_name_or_path
401
+
402
+ spatial_upscaler_model_name_or_path = pipeline_config.get(
403
+ "spatial_upscaler_model_path"
404
+ )
405
+ if spatial_upscaler_model_name_or_path and not os.path.isfile(
406
+ spatial_upscaler_model_name_or_path
407
+ ):
408
+ spatial_upscaler_model_path = hf_hub_download(
409
+ repo_id="Lightricks/LTX-Video",
410
+ filename=spatial_upscaler_model_name_or_path,
411
+ repo_type="model",
412
+ )
413
+ else:
414
+ spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
415
+
416
+ conditioning_media_paths = config.conditioning_media_paths
417
+ conditioning_strengths = config.conditioning_strengths
418
+ conditioning_start_frames = config.conditioning_start_frames
419
+
420
+ # Validate conditioning arguments
421
+ if conditioning_media_paths:
422
+ # Use default strengths of 1.0
423
+ if not conditioning_strengths:
424
+ conditioning_strengths = [1.0] * len(conditioning_media_paths)
425
+ if not conditioning_start_frames:
426
+ raise ValueError(
427
+ "If `conditioning_media_paths` is provided, "
428
+ "`conditioning_start_frames` must also be provided"
429
+ )
430
+ if len(conditioning_media_paths) != len(conditioning_strengths) or len(
431
+ conditioning_media_paths
432
+ ) != len(conditioning_start_frames):
433
+ raise ValueError(
434
+ "`conditioning_media_paths`, `conditioning_strengths`, "
435
+ "and `conditioning_start_frames` must have the same length"
436
+ )
437
+ if any(s < 0 or s > 1 for s in conditioning_strengths):
438
+ raise ValueError("All conditioning strengths must be between 0 and 1")
439
+ if any(f < 0 or f >= config.num_frames for f in conditioning_start_frames):
440
+ raise ValueError(
441
+ f"All conditioning start frames must be between 0 and {config.num_frames-1}"
442
+ )
443
+
444
+ seed_everething(config.seed)
445
+ if config.offload_to_cpu and not torch.cuda.is_available():
446
+ logger.warning(
447
+ "offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
448
+ )
449
+ offload_to_cpu = False
450
+ else:
451
+ offload_to_cpu = config.offload_to_cpu and get_total_gpu_memory() < 30
452
+
453
+ output_dir = (
454
+ Path(config.output_path)
455
+ if config.output_path
456
+ else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
457
+ )
458
+ output_dir.mkdir(parents=True, exist_ok=True)
459
+
460
+ # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
461
+ height_padded = ((config.height - 1) // 32 + 1) * 32
462
+ width_padded = ((config.width - 1) // 32 + 1) * 32
463
+ num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1
464
+
465
+ padding = calculate_padding(
466
+ config.height, config.width, height_padded, width_padded
467
+ )
468
+
469
+ logger.warning(
470
+ f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
471
+ )
472
+
473
+ device = get_device()
474
+
475
+ prompt_enhancement_words_threshold = pipeline_config[
476
+ "prompt_enhancement_words_threshold"
477
+ ]
478
+
479
+ prompt_word_count = len(config.prompt.split())
480
+ enhance_prompt = (
481
+ prompt_enhancement_words_threshold > 0
482
+ and prompt_word_count < prompt_enhancement_words_threshold
483
+ )
484
+
485
+ if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
486
+ logger.info(
487
+ f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
488
+ )
489
+
490
+ precision = pipeline_config["precision"]
491
+ text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
492
+ sampler = pipeline_config.get("sampler", None)
493
+ prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
494
+ "prompt_enhancer_image_caption_model_name_or_path"
495
+ ]
496
+ prompt_enhancer_llm_model_name_or_path = pipeline_config[
497
+ "prompt_enhancer_llm_model_name_or_path"
498
+ ]
499
+
500
+ pipeline = create_ltx_video_pipeline(
501
+ ckpt_path=ltxv_model_path,
502
+ precision=precision,
503
+ text_encoder_model_name_or_path=text_encoder_model_name_or_path,
504
+ sampler=sampler,
505
+ device=device,
506
+ enhance_prompt=enhance_prompt,
507
+ prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
508
+ prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
509
+ )
510
+
511
+ if pipeline_config.get("pipeline_type", None) == "multi-scale":
512
+ if not spatial_upscaler_model_path:
513
+ raise ValueError(
514
+ "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
515
+ )
516
+ latent_upsampler = create_latent_upsampler(
517
+ spatial_upscaler_model_path, pipeline.device
518
+ )
519
+ pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
520
+
521
+ media_item = None
522
+ if config.input_media_path:
523
+ media_item = load_media_file(
524
+ media_path=config.input_media_path,
525
+ height=config.height,
526
+ width=config.width,
527
+ max_frames=num_frames_padded,
528
+ padding=padding,
529
+ )
530
+
531
+ conditioning_items = (
532
+ prepare_conditioning(
533
+ conditioning_media_paths=conditioning_media_paths,
534
+ conditioning_strengths=conditioning_strengths,
535
+ conditioning_start_frames=conditioning_start_frames,
536
+ height=config.height,
537
+ width=config.width,
538
+ num_frames=config.num_frames,
539
+ padding=padding,
540
+ pipeline=pipeline,
541
+ )
542
+ if conditioning_media_paths
543
+ else None
544
+ )
545
+
546
+ stg_mode = pipeline_config.get("stg_mode", "attention_values")
547
+ del pipeline_config["stg_mode"]
548
+ if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
549
+ skip_layer_strategy = SkipLayerStrategy.AttentionValues
550
+ elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
551
+ skip_layer_strategy = SkipLayerStrategy.AttentionSkip
552
+ elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
553
+ skip_layer_strategy = SkipLayerStrategy.Residual
554
+ elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
555
+ skip_layer_strategy = SkipLayerStrategy.TransformerBlock
556
+ else:
557
+ raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
558
+
559
+ # Prepare input for the pipeline
560
+ sample = {
561
+ "prompt": config.prompt,
562
+ "prompt_attention_mask": None,
563
+ "negative_prompt": config.negative_prompt,
564
+ "negative_prompt_attention_mask": None,
565
+ }
566
+
567
+ generator = torch.Generator(device=device).manual_seed(config.seed)
568
+
569
+ images = pipeline(
570
+ **pipeline_config,
571
+ skip_layer_strategy=skip_layer_strategy,
572
+ generator=generator,
573
+ output_type="pt",
574
+ callback_on_step_end=None,
575
+ height=height_padded,
576
+ width=width_padded,
577
+ num_frames=num_frames_padded,
578
+ frame_rate=config.frame_rate,
579
+ **sample,
580
+ media_items=media_item,
581
+ conditioning_items=conditioning_items,
582
+ is_video=True,
583
+ vae_per_channel_normalize=True,
584
+ image_cond_noise_scale=config.image_cond_noise_scale,
585
+ mixed_precision=(precision == "mixed_precision"),
586
+ offload_to_cpu=offload_to_cpu,
587
+ device=device,
588
+ enhance_prompt=enhance_prompt,
589
+ ).images
590
+
591
+ # Crop the padded images to the desired resolution and number of frames
592
+ (pad_left, pad_right, pad_top, pad_bottom) = padding
593
+ pad_bottom = -pad_bottom
594
+ pad_right = -pad_right
595
+ if pad_bottom == 0:
596
+ pad_bottom = images.shape[3]
597
+ if pad_right == 0:
598
+ pad_right = images.shape[4]
599
+ images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right]
600
+
601
+ for i in range(images.shape[0]):
602
+ # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
603
+ video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
604
+ # Unnormalizing images to [0, 255] range
605
+ video_np = (video_np * 255).astype(np.uint8)
606
+ fps = config.frame_rate
607
+ height, width = video_np.shape[1:3]
608
+ # In case a single image is generated
609
+ if video_np.shape[0] == 1:
610
+ output_filename = get_unique_filename(
611
+ f"image_output_{i}",
612
+ ".png",
613
+ prompt=config.prompt,
614
+ seed=config.seed,
615
+ resolution=(height, width, config.num_frames),
616
+ dir=output_dir,
617
+ )
618
+ imageio.imwrite(output_filename, video_np[0])
619
+ else:
620
+ output_filename = get_unique_filename(
621
+ f"video_output_{i}",
622
+ ".mp4",
623
+ prompt=config.prompt,
624
+ seed=config.seed,
625
+ resolution=(height, width, config.num_frames),
626
+ dir=output_dir,
627
+ )
628
+
629
+ # Write video
630
+ with imageio.get_writer(output_filename, fps=fps) as video:
631
+ for frame in video_np:
632
+ video.append_data(frame)
633
+
634
+ logger.warning(f"Output saved to {output_filename}")
635
+
636
+
637
+ def prepare_conditioning(
638
+ conditioning_media_paths: List[str],
639
+ conditioning_strengths: List[float],
640
+ conditioning_start_frames: List[int],
641
+ height: int,
642
+ width: int,
643
+ num_frames: int,
644
+ padding: tuple[int, int, int, int],
645
+ pipeline: LTXVideoPipeline,
646
+ ) -> Optional[List[ConditioningItem]]:
647
+ """Prepare conditioning items based on input media paths and their parameters.
648
+
649
+ Args:
650
+ conditioning_media_paths: List of paths to conditioning media (images or videos)
651
+ conditioning_strengths: List of conditioning strengths for each media item
652
+ conditioning_start_frames: List of frame indices where each item should be applied
653
+ height: Height of the output frames
654
+ width: Width of the output frames
655
+ num_frames: Number of frames in the output video
656
+ padding: Padding to apply to the frames
657
+ pipeline: LTXVideoPipeline object used for condition video trimming
658
+
659
+ Returns:
660
+ A list of ConditioningItem objects.
661
+ """
662
+ conditioning_items = []
663
+ for path, strength, start_frame in zip(
664
+ conditioning_media_paths, conditioning_strengths, conditioning_start_frames
665
+ ):
666
+ num_input_frames = orig_num_input_frames = get_media_num_frames(path)
667
+ if hasattr(pipeline, "trim_conditioning_sequence") and callable(
668
+ getattr(pipeline, "trim_conditioning_sequence")
669
+ ):
670
+ num_input_frames = pipeline.trim_conditioning_sequence(
671
+ start_frame, orig_num_input_frames, num_frames
672
+ )
673
+ if num_input_frames < orig_num_input_frames:
674
+ logger.warning(
675
+ f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
676
+ )
677
+
678
+ media_tensor = load_media_file(
679
+ media_path=path,
680
+ height=height,
681
+ width=width,
682
+ max_frames=num_input_frames,
683
+ padding=padding,
684
+ just_crop=True,
685
+ )
686
+ conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
687
+ return conditioning_items
688
+
689
+
690
+ def get_media_num_frames(media_path: str) -> int:
691
+ is_video = any(
692
+ media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
693
+ )
694
+ num_frames = 1
695
+ if is_video:
696
+ reader = imageio.get_reader(media_path)
697
+ num_frames = reader.count_frames()
698
+ reader.close()
699
+ return num_frames
700
+
701
+
702
+ def load_media_file(
703
+ media_path: str,
704
+ height: int,
705
+ width: int,
706
+ max_frames: int,
707
+ padding: tuple[int, int, int, int],
708
+ just_crop: bool = False,
709
+ ) -> torch.Tensor:
710
+ is_video = any(
711
+ media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
712
+ )
713
+ if is_video:
714
+ reader = imageio.get_reader(media_path)
715
+ num_input_frames = min(reader.count_frames(), max_frames)
716
+
717
+ # Read and preprocess the relevant frames from the video file.
718
+ frames = []
719
+ for i in range(num_input_frames):
720
+ frame = Image.fromarray(reader.get_data(i))
721
+ frame_tensor = load_image_to_tensor_with_resize_and_crop(
722
+ frame, height, width, just_crop=just_crop
723
+ )
724
+ frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
725
+ frames.append(frame_tensor)
726
+ reader.close()
727
+
728
+ # Stack frames along the temporal dimension
729
+ media_tensor = torch.cat(frames, dim=2)
730
+ else: # Input image
731
+ media_tensor = load_image_to_tensor_with_resize_and_crop(
732
+ media_path, height, width, just_crop=just_crop
733
+ )
734
+ media_tensor = torch.nn.functional.pad(media_tensor, padding)
735
+ return media_tensor