Spaces:
Paused
Paused
| import os | |
| import random | |
| from datetime import datetime | |
| from pathlib import Path | |
| from diffusers.utils import logging | |
| from typing import Optional, List, Union | |
| import yaml | |
| import imageio | |
| import json | |
| import numpy as np | |
| import torch | |
| from safetensors import safe_open | |
| from PIL import Image | |
| import torchvision.transforms.functional as TVF | |
| from transformers import ( | |
| T5EncoderModel, | |
| T5Tokenizer, | |
| AutoModelForCausalLM, | |
| AutoProcessor, | |
| AutoTokenizer, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| from dataclasses import dataclass, field | |
| from ltx_video.models.autoencoders.causal_video_autoencoder import ( | |
| CausalVideoAutoencoder, | |
| ) | |
| from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier | |
| from ltx_video.models.transformers.transformer3d import Transformer3DModel | |
| from ltx_video.pipelines.pipeline_ltx_video import ( | |
| ConditioningItem, | |
| LTXVideoPipeline, | |
| LTXMultiScalePipeline, | |
| ) | |
| from ltx_video.schedulers.rf import RectifiedFlowScheduler | |
| from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy | |
| from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler | |
| import ltx_video.pipelines.crf_compressor as crf_compressor | |
| logger = logging.get_logger("LTX-Video") | |
| def get_total_gpu_memory(): | |
| if torch.cuda.is_available(): | |
| total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| return total_memory | |
| return 0 | |
| def get_device(): | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def load_image_to_tensor_with_resize_and_crop( | |
| image_input: Union[str, Image.Image], | |
| target_height: int = 512, | |
| target_width: int = 768, | |
| just_crop: bool = False, | |
| ) -> torch.Tensor: | |
| """Load and process an image into a tensor. | |
| Args: | |
| image_input: Either a file path (str) or a PIL Image object | |
| target_height: Desired height of output tensor | |
| target_width: Desired width of output tensor | |
| just_crop: If True, only crop the image to the target size without resizing | |
| """ | |
| if isinstance(image_input, str): | |
| image = Image.open(image_input).convert("RGB") | |
| elif isinstance(image_input, Image.Image): | |
| image = image_input | |
| else: | |
| raise ValueError("image_input must be either a file path or a PIL Image object") | |
| input_width, input_height = image.size | |
| aspect_ratio_target = target_width / target_height | |
| aspect_ratio_frame = input_width / input_height | |
| if aspect_ratio_frame > aspect_ratio_target: | |
| new_width = int(input_height * aspect_ratio_target) | |
| new_height = input_height | |
| x_start = (input_width - new_width) // 2 | |
| y_start = 0 | |
| else: | |
| new_width = input_width | |
| new_height = int(input_width / aspect_ratio_target) | |
| x_start = 0 | |
| y_start = (input_height - new_height) // 2 | |
| image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) | |
| if not just_crop: | |
| image = image.resize((target_width, target_height)) | |
| frame_tensor = TVF.to_tensor(image) # PIL -> tensor (C, H, W), [0,1] | |
| frame_tensor = TVF.gaussian_blur(frame_tensor, kernel_size=3, sigma=1.0) | |
| frame_tensor_hwc = frame_tensor.permute(1, 2, 0) # (C, H, W) -> (H, W, C) | |
| frame_tensor_hwc = crf_compressor.compress(frame_tensor_hwc) | |
| frame_tensor = frame_tensor_hwc.permute(2, 0, 1) * 255.0 # (H, W, C) -> (C, H, W) | |
| frame_tensor = (frame_tensor / 127.5) - 1.0 | |
| # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width) | |
| return frame_tensor.unsqueeze(0).unsqueeze(2) | |
| def calculate_padding( | |
| source_height: int, source_width: int, target_height: int, target_width: int | |
| ) -> tuple[int, int, int, int]: | |
| # Calculate total padding needed | |
| pad_height = target_height - source_height | |
| pad_width = target_width - source_width | |
| # Calculate padding for each side | |
| pad_top = pad_height // 2 | |
| pad_bottom = pad_height - pad_top # Handles odd padding | |
| pad_left = pad_width // 2 | |
| pad_right = pad_width - pad_left # Handles odd padding | |
| # Return padded tensor | |
| # Padding format is (left, right, top, bottom) | |
| padding = (pad_left, pad_right, pad_top, pad_bottom) | |
| return padding | |
| def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: | |
| # Remove non-letters and convert to lowercase | |
| clean_text = "".join( | |
| char.lower() for char in text if char.isalpha() or char.isspace() | |
| ) | |
| # Split into words | |
| words = clean_text.split() | |
| # Build result string keeping track of length | |
| result = [] | |
| current_length = 0 | |
| for word in words: | |
| # Add word length plus 1 for underscore (except for first word) | |
| new_length = current_length + len(word) | |
| if new_length <= max_len: | |
| result.append(word) | |
| current_length += len(word) | |
| else: | |
| break | |
| return "-".join(result) | |
| # Generate output video name | |
| def get_unique_filename( | |
| base: str, | |
| ext: str, | |
| prompt: str, | |
| seed: int, | |
| resolution: tuple[int, int, int], | |
| dir: Path, | |
| endswith=None, | |
| index_range=1000, | |
| ) -> Path: | |
| base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" | |
| for i in range(index_range): | |
| filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}" | |
| if not os.path.exists(filename): | |
| return filename | |
| raise FileExistsError( | |
| f"Could not find a unique filename after {index_range} attempts." | |
| ) | |
| def seed_everething(seed: int): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| if torch.backends.mps.is_available(): | |
| torch.mps.manual_seed(seed) | |
| def create_transformer(ckpt_path: str, precision: str) -> Transformer3DModel: | |
| if precision == "float8_e4m3fn": | |
| try: | |
| from q8_kernels.integration.patch_transformer import ( | |
| patch_diffusers_transformer as patch_transformer_for_q8_kernels, | |
| ) | |
| transformer = Transformer3DModel.from_pretrained( | |
| ckpt_path, dtype=torch.float8_e4m3fn | |
| ) | |
| patch_transformer_for_q8_kernels(transformer) | |
| return transformer | |
| except ImportError: | |
| raise ValueError( | |
| "Q8-Kernels not found. To use FP8 checkpoint, please install Q8 kernels from https://github.com/Lightricks/LTXVideo-Q8-Kernels" | |
| ) | |
| elif precision == "bfloat16": | |
| return Transformer3DModel.from_pretrained(ckpt_path).to(torch.bfloat16) | |
| else: | |
| return Transformer3DModel.from_pretrained(ckpt_path) | |
| def create_ltx_video_pipeline( | |
| ckpt_path: str, | |
| precision: str, | |
| text_encoder_model_name_or_path: str, | |
| sampler: Optional[str] = None, | |
| device: Optional[str] = None, | |
| enhance_prompt: bool = False, | |
| prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None, | |
| prompt_enhancer_llm_model_name_or_path: Optional[str] = None, | |
| ) -> LTXVideoPipeline: | |
| ckpt_path = Path(ckpt_path) | |
| assert os.path.exists( | |
| ckpt_path | |
| ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist" | |
| with safe_open(ckpt_path, framework="pt") as f: | |
| metadata = f.metadata() | |
| config_str = metadata.get("config") | |
| configs = json.loads(config_str) | |
| allowed_inference_steps = configs.get("allowed_inference_steps", None) | |
| vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) | |
| transformer = create_transformer(ckpt_path, precision) | |
| # Use constructor if sampler is specified, otherwise use from_pretrained | |
| if sampler == "from_checkpoint" or not sampler: | |
| scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) | |
| else: | |
| scheduler = RectifiedFlowScheduler( | |
| sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic") | |
| ) | |
| text_encoder = T5EncoderModel.from_pretrained( | |
| text_encoder_model_name_or_path, subfolder="text_encoder" | |
| ) | |
| patchifier = SymmetricPatchifier(patch_size=1) | |
| tokenizer = T5Tokenizer.from_pretrained( | |
| text_encoder_model_name_or_path, subfolder="tokenizer" | |
| ) | |
| transformer = transformer.to(device) | |
| vae = vae.to(device) | |
| text_encoder = text_encoder.to(device) | |
| if enhance_prompt: | |
| prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( | |
| prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True | |
| ) | |
| prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( | |
| prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True | |
| ) | |
| prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained( | |
| prompt_enhancer_llm_model_name_or_path, | |
| torch_dtype="bfloat16", | |
| ) | |
| prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained( | |
| prompt_enhancer_llm_model_name_or_path, | |
| ) | |
| else: | |
| prompt_enhancer_image_caption_model = None | |
| prompt_enhancer_image_caption_processor = None | |
| prompt_enhancer_llm_model = None | |
| prompt_enhancer_llm_tokenizer = None | |
| vae = vae.to(torch.bfloat16) | |
| text_encoder = text_encoder.to(torch.bfloat16) | |
| # Use submodels for the pipeline | |
| submodel_dict = { | |
| "transformer": transformer, | |
| "patchifier": patchifier, | |
| "text_encoder": text_encoder, | |
| "tokenizer": tokenizer, | |
| "scheduler": scheduler, | |
| "vae": vae, | |
| "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model, | |
| "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor, | |
| "prompt_enhancer_llm_model": prompt_enhancer_llm_model, | |
| "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer, | |
| "allowed_inference_steps": allowed_inference_steps, | |
| } | |
| pipeline = LTXVideoPipeline(**submodel_dict) | |
| pipeline = pipeline.to(device) | |
| return pipeline | |
| def create_latent_upsampler(latent_upsampler_model_path: str, device: str): | |
| latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path) | |
| latent_upsampler.to(device) | |
| latent_upsampler.eval() | |
| return latent_upsampler | |
| def load_pipeline_config(pipeline_config: str): | |
| current_file = Path(__file__) | |
| path = None | |
| if os.path.isfile(current_file.parent / pipeline_config): | |
| path = current_file.parent / pipeline_config | |
| elif os.path.isfile(pipeline_config): | |
| path = pipeline_config | |
| else: | |
| raise ValueError(f"Pipeline config file {pipeline_config} does not exist") | |
| with open(path, "r") as f: | |
| return yaml.safe_load(f) | |
| class InferenceConfig: | |
| prompt: str = field(metadata={"help": "Prompt for the generation"}) | |
| output_path: str = field( | |
| default_factory=lambda: Path( | |
| f"outputs/{datetime.today().strftime('%Y-%m-%d')}" | |
| ), | |
| metadata={"help": "Path to the folder to save the output video"}, | |
| ) | |
| # Pipeline settings | |
| pipeline_config: str = field( | |
| default="configs/ltxv-13b-0.9.7-dev.yaml", | |
| metadata={"help": "Path to the pipeline config file"}, | |
| ) | |
| seed: int = field( | |
| default=171198, metadata={"help": "Random seed for the inference"} | |
| ) | |
| height: int = field( | |
| default=704, metadata={"help": "Height of the output video frames"} | |
| ) | |
| width: int = field( | |
| default=1216, metadata={"help": "Width of the output video frames"} | |
| ) | |
| num_frames: int = field( | |
| default=121, | |
| metadata={"help": "Number of frames to generate in the output video"}, | |
| ) | |
| frame_rate: int = field( | |
| default=30, metadata={"help": "Frame rate for the output video"} | |
| ) | |
| offload_to_cpu: bool = field( | |
| default=False, metadata={"help": "Offloading unnecessary computations to CPU."} | |
| ) | |
| negative_prompt: str = field( | |
| default="worst quality, inconsistent motion, blurry, jittery, distorted", | |
| metadata={"help": "Negative prompt for undesired features"}, | |
| ) | |
| # Video-to-video arguments | |
| input_media_path: Optional[str] = field( | |
| default=None, | |
| metadata={ | |
| "help": "Path to the input video (or image) to be modified using the video-to-video pipeline" | |
| }, | |
| ) | |
| # Conditioning | |
| image_cond_noise_scale: float = field( | |
| default=0.15, | |
| metadata={"help": "Amount of noise to add to the conditioned image"}, | |
| ) | |
| conditioning_media_paths: Optional[List[str]] = field( | |
| default=None, | |
| metadata={ | |
| "help": "List of paths to conditioning media (images or videos). Each path will be used as a conditioning item." | |
| }, | |
| ) | |
| conditioning_strengths: Optional[List[float]] = field( | |
| default=None, | |
| metadata={ | |
| "help": "List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items." | |
| }, | |
| ) | |
| conditioning_start_frames: Optional[List[int]] = field( | |
| default=None, | |
| metadata={ | |
| "help": "List of frame indices where each conditioning item should be applied. Must match the number of conditioning items." | |
| }, | |
| ) | |
| def infer(config: InferenceConfig): | |
| pipeline_config = load_pipeline_config(config.pipeline_config) | |
| ltxv_model_name_or_path = pipeline_config["checkpoint_path"] | |
| if not os.path.isfile(ltxv_model_name_or_path): | |
| ltxv_model_path = hf_hub_download( | |
| repo_id="Lightricks/LTX-Video", | |
| filename=ltxv_model_name_or_path, | |
| repo_type="model", | |
| ) | |
| else: | |
| ltxv_model_path = ltxv_model_name_or_path | |
| spatial_upscaler_model_name_or_path = pipeline_config.get( | |
| "spatial_upscaler_model_path" | |
| ) | |
| if spatial_upscaler_model_name_or_path and not os.path.isfile( | |
| spatial_upscaler_model_name_or_path | |
| ): | |
| spatial_upscaler_model_path = hf_hub_download( | |
| repo_id="Lightricks/LTX-Video", | |
| filename=spatial_upscaler_model_name_or_path, | |
| repo_type="model", | |
| ) | |
| else: | |
| spatial_upscaler_model_path = spatial_upscaler_model_name_or_path | |
| conditioning_media_paths = config.conditioning_media_paths | |
| conditioning_strengths = config.conditioning_strengths | |
| conditioning_start_frames = config.conditioning_start_frames | |
| # Validate conditioning arguments | |
| if conditioning_media_paths: | |
| # Use default strengths of 1.0 | |
| if not conditioning_strengths: | |
| conditioning_strengths = [1.0] * len(conditioning_media_paths) | |
| if not conditioning_start_frames: | |
| raise ValueError( | |
| "If `conditioning_media_paths` is provided, " | |
| "`conditioning_start_frames` must also be provided" | |
| ) | |
| if len(conditioning_media_paths) != len(conditioning_strengths) or len( | |
| conditioning_media_paths | |
| ) != len(conditioning_start_frames): | |
| raise ValueError( | |
| "`conditioning_media_paths`, `conditioning_strengths`, " | |
| "and `conditioning_start_frames` must have the same length" | |
| ) | |
| if any(s < 0 or s > 1 for s in conditioning_strengths): | |
| raise ValueError("All conditioning strengths must be between 0 and 1") | |
| if any(f < 0 or f >= config.num_frames for f in conditioning_start_frames): | |
| raise ValueError( | |
| f"All conditioning start frames must be between 0 and {config.num_frames-1}" | |
| ) | |
| seed_everething(config.seed) | |
| if config.offload_to_cpu and not torch.cuda.is_available(): | |
| logger.warning( | |
| "offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU." | |
| ) | |
| offload_to_cpu = False | |
| else: | |
| offload_to_cpu = config.offload_to_cpu and get_total_gpu_memory() < 30 | |
| output_dir = ( | |
| Path(config.output_path) | |
| if config.output_path | |
| else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") | |
| ) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1) | |
| height_padded = ((config.height - 1) // 32 + 1) * 32 | |
| width_padded = ((config.width - 1) // 32 + 1) * 32 | |
| num_frames_padded = ((config.num_frames - 2) // 8 + 1) * 8 + 1 | |
| padding = calculate_padding( | |
| config.height, config.width, height_padded, width_padded | |
| ) | |
| logger.warning( | |
| f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}" | |
| ) | |
| device = get_device() | |
| prompt_enhancement_words_threshold = pipeline_config[ | |
| "prompt_enhancement_words_threshold" | |
| ] | |
| prompt_word_count = len(config.prompt.split()) | |
| enhance_prompt = ( | |
| prompt_enhancement_words_threshold > 0 | |
| and prompt_word_count < prompt_enhancement_words_threshold | |
| ) | |
| if prompt_enhancement_words_threshold > 0 and not enhance_prompt: | |
| logger.info( | |
| f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled." | |
| ) | |
| precision = pipeline_config["precision"] | |
| text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"] | |
| sampler = pipeline_config.get("sampler", None) | |
| prompt_enhancer_image_caption_model_name_or_path = pipeline_config[ | |
| "prompt_enhancer_image_caption_model_name_or_path" | |
| ] | |
| prompt_enhancer_llm_model_name_or_path = pipeline_config[ | |
| "prompt_enhancer_llm_model_name_or_path" | |
| ] | |
| pipeline = create_ltx_video_pipeline( | |
| ckpt_path=ltxv_model_path, | |
| precision=precision, | |
| text_encoder_model_name_or_path=text_encoder_model_name_or_path, | |
| sampler=sampler, | |
| device=device, | |
| enhance_prompt=enhance_prompt, | |
| prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path, | |
| prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path, | |
| ) | |
| if pipeline_config.get("pipeline_type", None) == "multi-scale": | |
| if not spatial_upscaler_model_path: | |
| raise ValueError( | |
| "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering" | |
| ) | |
| latent_upsampler = create_latent_upsampler( | |
| spatial_upscaler_model_path, pipeline.device | |
| ) | |
| pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler) | |
| media_item = None | |
| if config.input_media_path: | |
| media_item = load_media_file( | |
| media_path=config.input_media_path, | |
| height=config.height, | |
| width=config.width, | |
| max_frames=num_frames_padded, | |
| padding=padding, | |
| ) | |
| conditioning_items = ( | |
| prepare_conditioning( | |
| conditioning_media_paths=conditioning_media_paths, | |
| conditioning_strengths=conditioning_strengths, | |
| conditioning_start_frames=conditioning_start_frames, | |
| height=config.height, | |
| width=config.width, | |
| num_frames=config.num_frames, | |
| padding=padding, | |
| pipeline=pipeline, | |
| ) | |
| if conditioning_media_paths | |
| else None | |
| ) | |
| stg_mode = pipeline_config.get("stg_mode", "attention_values") | |
| del pipeline_config["stg_mode"] | |
| if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values": | |
| skip_layer_strategy = SkipLayerStrategy.AttentionValues | |
| elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip": | |
| skip_layer_strategy = SkipLayerStrategy.AttentionSkip | |
| elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual": | |
| skip_layer_strategy = SkipLayerStrategy.Residual | |
| elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block": | |
| skip_layer_strategy = SkipLayerStrategy.TransformerBlock | |
| else: | |
| raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}") | |
| # Prepare input for the pipeline | |
| sample = { | |
| "prompt": config.prompt, | |
| "prompt_attention_mask": None, | |
| "negative_prompt": config.negative_prompt, | |
| "negative_prompt_attention_mask": None, | |
| } | |
| generator = torch.Generator(device=device).manual_seed(config.seed) | |
| images = pipeline( | |
| **pipeline_config, | |
| skip_layer_strategy=skip_layer_strategy, | |
| generator=generator, | |
| output_type="pt", | |
| callback_on_step_end=None, | |
| height=height_padded, | |
| width=width_padded, | |
| num_frames=num_frames_padded, | |
| frame_rate=config.frame_rate, | |
| **sample, | |
| media_items=media_item, | |
| conditioning_items=conditioning_items, | |
| is_video=True, | |
| vae_per_channel_normalize=True, | |
| image_cond_noise_scale=config.image_cond_noise_scale, | |
| mixed_precision=(precision == "mixed_precision"), | |
| offload_to_cpu=offload_to_cpu, | |
| device=device, | |
| enhance_prompt=enhance_prompt, | |
| ).images | |
| # Crop the padded images to the desired resolution and number of frames | |
| (pad_left, pad_right, pad_top, pad_bottom) = padding | |
| pad_bottom = -pad_bottom | |
| pad_right = -pad_right | |
| if pad_bottom == 0: | |
| pad_bottom = images.shape[3] | |
| if pad_right == 0: | |
| pad_right = images.shape[4] | |
| images = images[:, :, : config.num_frames, pad_top:pad_bottom, pad_left:pad_right] | |
| for i in range(images.shape[0]): | |
| # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C | |
| video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy() | |
| # Unnormalizing images to [0, 255] range | |
| video_np = (video_np * 255).astype(np.uint8) | |
| fps = config.frame_rate | |
| height, width = video_np.shape[1:3] | |
| # In case a single image is generated | |
| if video_np.shape[0] == 1: | |
| output_filename = get_unique_filename( | |
| f"image_output_{i}", | |
| ".png", | |
| prompt=config.prompt, | |
| seed=config.seed, | |
| resolution=(height, width, config.num_frames), | |
| dir=output_dir, | |
| ) | |
| imageio.imwrite(output_filename, video_np[0]) | |
| else: | |
| output_filename = get_unique_filename( | |
| f"video_output_{i}", | |
| ".mp4", | |
| prompt=config.prompt, | |
| seed=config.seed, | |
| resolution=(height, width, config.num_frames), | |
| dir=output_dir, | |
| ) | |
| # Write video | |
| with imageio.get_writer(output_filename, fps=fps) as video: | |
| for frame in video_np: | |
| video.append_data(frame) | |
| logger.warning(f"Output saved to {output_filename}") | |
| def prepare_conditioning( | |
| conditioning_media_paths: List[str], | |
| conditioning_strengths: List[float], | |
| conditioning_start_frames: List[int], | |
| height: int, | |
| width: int, | |
| num_frames: int, | |
| padding: tuple[int, int, int, int], | |
| pipeline: LTXVideoPipeline, | |
| ) -> Optional[List[ConditioningItem]]: | |
| """Prepare conditioning items based on input media paths and their parameters. | |
| Args: | |
| conditioning_media_paths: List of paths to conditioning media (images or videos) | |
| conditioning_strengths: List of conditioning strengths for each media item | |
| conditioning_start_frames: List of frame indices where each item should be applied | |
| height: Height of the output frames | |
| width: Width of the output frames | |
| num_frames: Number of frames in the output video | |
| padding: Padding to apply to the frames | |
| pipeline: LTXVideoPipeline object used for condition video trimming | |
| Returns: | |
| A list of ConditioningItem objects. | |
| """ | |
| conditioning_items = [] | |
| for path, strength, start_frame in zip( | |
| conditioning_media_paths, conditioning_strengths, conditioning_start_frames | |
| ): | |
| num_input_frames = orig_num_input_frames = get_media_num_frames(path) | |
| if hasattr(pipeline, "trim_conditioning_sequence") and callable( | |
| getattr(pipeline, "trim_conditioning_sequence") | |
| ): | |
| num_input_frames = pipeline.trim_conditioning_sequence( | |
| start_frame, orig_num_input_frames, num_frames | |
| ) | |
| if num_input_frames < orig_num_input_frames: | |
| logger.warning( | |
| f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames." | |
| ) | |
| media_tensor = load_media_file( | |
| media_path=path, | |
| height=height, | |
| width=width, | |
| max_frames=num_input_frames, | |
| padding=padding, | |
| just_crop=True, | |
| ) | |
| conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength)) | |
| return conditioning_items | |
| def get_media_num_frames(media_path: str) -> int: | |
| is_video = any( | |
| media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"] | |
| ) | |
| num_frames = 1 | |
| if is_video: | |
| reader = imageio.get_reader(media_path) | |
| num_frames = reader.count_frames() | |
| reader.close() | |
| return num_frames | |
| def load_media_file( | |
| media_path: str, | |
| height: int, | |
| width: int, | |
| max_frames: int, | |
| padding: tuple[int, int, int, int], | |
| just_crop: bool = False, | |
| ) -> torch.Tensor: | |
| is_video = any( | |
| media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"] | |
| ) | |
| if is_video: | |
| reader = imageio.get_reader(media_path) | |
| num_input_frames = min(reader.count_frames(), max_frames) | |
| # Read and preprocess the relevant frames from the video file. | |
| frames = [] | |
| for i in range(num_input_frames): | |
| frame = Image.fromarray(reader.get_data(i)) | |
| frame_tensor = load_image_to_tensor_with_resize_and_crop( | |
| frame, height, width, just_crop=just_crop | |
| ) | |
| frame_tensor = torch.nn.functional.pad(frame_tensor, padding) | |
| frames.append(frame_tensor) | |
| reader.close() | |
| # Stack frames along the temporal dimension | |
| media_tensor = torch.cat(frames, dim=2) | |
| else: # Input image | |
| media_tensor = load_image_to_tensor_with_resize_and_crop( | |
| media_path, height, width, just_crop=just_crop | |
| ) | |
| media_tensor = torch.nn.functional.pad(media_tensor, padding) | |
| return media_tensor | |