import gradio as gr import torch import os import sys import subprocess import tempfile import numpy as np import spaces from PIL import Image # Define paths REPO_PATH = "LongCat-Video" CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video") # Clone the repository if it doesn't exist if not os.path.exists(REPO_PATH): print(f"Cloning LongCat-Video repository to '{REPO_PATH}'...") try: subprocess.run( ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH], check=True, capture_output=True ) print("Repository cloned successfully.") except subprocess.CalledProcessError as e: print(f"Error cloning repository: {e.stderr.decode()}") sys.exit(1) # Add the cloned repository to the Python path to allow imports sys.path.insert(0, os.path.abspath(REPO_PATH)) # Now that the repo is in the path, we can import its modules from huggingface_hub import snapshot_download from longcat_video.pipeline_longcat_video import LongCatVideoPipeline from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel from longcat_video.context_parallel import context_parallel_util from transformers import AutoTokenizer, UMT5EncoderModel from diffusers.utils import export_to_video # Download model weights from Hugging Face Hub if they don't exist if not os.path.exists(CHECKPOINT_DIR): print(f"Downloading model weights to '{CHECKPOINT_DIR}'...") try: snapshot_download( repo_id="meituan-longcat/LongCat-Video", local_dir=CHECKPOINT_DIR, local_dir_use_symlinks=False, # Use False for better Windows compatibility ignore_patterns=["*.md", "*.gitattributes", "assets/*"] # ignore non-essential files ) print("Model weights downloaded successfully.") except Exception as e: print(f"Error downloading model weights: {e}") sys.exit(1) # Global placeholder for the pipeline and device configuration pipe = None device = "cuda" if torch.cuda.is_available() else "cpu" torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32 print("--- Initializing Models (loaded once at startup) ---") try: # Context parallel is not used in this single-instance demo, but the model requires the config. cp_split_hw = context_parallel_util.get_optimal_split(1) print("Loading tokenizer and text encoder...") tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype) text_encoder = UMT5EncoderModel.from_pretrained(CHECKPOINT_DIR, subfolder="text_encoder", torch_dtype=torch_dtype) print("Loading VAE and Scheduler...") vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype) scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype) print("Loading DiT model...") dit = LongCatVideoTransformer3DModel.from_pretrained(CHECKPOINT_DIR, enable_flashattn3=False, enable_flashattn2=False, enable_xformers=True, subfolder="dit", cp_split_hw=cp_split_hw, torch_dtype=torch_dtype) print("Creating LongCatVideoPipeline...") pipe = LongCatVideoPipeline( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, scheduler=scheduler, dit=dit, ) pipe.to(device) print("Loading LoRA weights for optional modes...") cfg_step_lora_path = os.path.join(CHECKPOINT_DIR, 'lora/cfg_step_lora.safetensors') pipe.dit.load_lora(cfg_step_lora_path, 'cfg_step_lora') refinement_lora_path = os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors') pipe.dit.load_lora(refinement_lora_path, 'refinement_lora') print("--- Models loaded successfully and are ready for inference. ---") except Exception as e: print("--- FATAL ERROR: Failed to load models. ---") print(f"Details: {e}") # The app will still run, but generation will fail with an error message. pipe = None # --- 3. Generation Logic --- def torch_gc(): """Helper function to clean up GPU memory.""" if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() def check_duration( mode, prompt, neg_prompt, image, height, width, resolution, seed, use_distill, use_refine, progress ): if use_distill and resolution=="480p": return 180 elif resolution=="720p": return 360 else: return 900 @spaces.GPU(duration=check_duration) def generate_video( mode, prompt, neg_prompt, image, height, width, resolution, seed, use_distill, use_refine, progress=gr.Progress(track_tqdm=True) ): """ Universal video generation function. """ if pipe is None: raise gr.Error("Models failed to load. Please check the console output for errors and restart the app.") generator = torch.Generator(device=device).manual_seed(int(seed)) # --- Stage 1: Base Generation (Standard or Distill) --- progress(0, desc="Starting Stage 1: Base Generation") num_frames = 93 # Default from demo scripts is_distill = use_distill or use_refine # Refinement requires a distilled video as input if is_distill: pipe.dit.enable_loras(['cfg_step_lora']) num_inference_steps = 16 guidance_scale = 1.0 current_neg_prompt = "" else: num_inference_steps = 50 guidance_scale = 4.0 current_neg_prompt = neg_prompt if mode == "t2v": output = pipe.generate_t2v( prompt=prompt, negative_prompt=current_neg_prompt, height=height, width=width, num_frames=num_frames, num_inference_steps=num_inference_steps, use_distill=is_distill, guidance_scale=guidance_scale, generator=generator, )[0] elif mode == "i2v": pil_image = Image.fromarray(image) output = pipe.generate_i2v( image=pil_image, prompt=prompt, negative_prompt=current_neg_prompt, resolution=resolution, num_frames=num_frames, num_inference_steps=num_inference_steps, use_distill=is_distill, guidance_scale=guidance_scale, generator=generator, )[0] if is_distill: pipe.dit.disable_all_loras() torch_gc() # --- Stage 2: Refinement (Optional) --- if use_refine: progress(0.5, desc="Starting Stage 2: Refinement") pipe.dit.enable_loras(['refinement_lora']) pipe.dit.enable_bsa() stage1_video_pil = [(frame * 255).astype(np.uint8) for frame in output] stage1_video_pil = [Image.fromarray(img) for img in stage1_video_pil] refine_image = Image.fromarray(image) if mode == 'i2v' else None output = pipe.generate_refine( image=refine_image, prompt=prompt, stage1_video=stage1_video_pil, num_cond_frames=1 if mode == 'i2v' else 0, num_inference_steps=50, generator=generator, )[0] pipe.dit.disable_all_loras() pipe.dit.disable_bsa() torch_gc() # --- Post-processing and Output --- progress(1.0, desc="Exporting video") with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video_file: fps = 30 if use_refine else 15 export_to_video(output, temp_video_file.name, fps=fps) return temp_video_file.name # --- 4. Gradio UI Definition --- css = ''' .fillable{max-width: 960px !important} ''' with gr.Blocks(css=css) as demo: gr.Markdown("# 🎬 LongCat-Video") gr.Markdown('''13.6B parameter dense video-generation model by Meituan — [[Model](https://huggingface.co/meituan-longcat/LongCat-Video)]''') with gr.Tabs() as tabs: with gr.TabItem("Image-to-Video", id=1): mode_i2v = gr.State("i2v") with gr.Row(): with gr.Column(scale=2): image_i2v = gr.Image(type="numpy", label="Input Image") prompt_i2v = gr.Textbox(label="Prompt", lines=4, placeholder="The cat in the image wags its tail and blinks.") with gr.Accordion(label="Advanced Options", open=False): neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles, watermark") resolution_i2v = gr.Dropdown(label="Resolution", choices=["480p", "720p"], value="480p") seed_i2v = gr.Number(label="Seed", value=42, precision=0) distill_i2v = gr.Checkbox(label="Use Distill Mode", value=True, info="Faster, lower quality base generation.") refine_i2v = gr.Checkbox(label="Use Refine Mode", value=False, info="Higher quality & resolution, but slower. Uses Distill mode for its first stage.") i2v_button = gr.Button("Generate 6s video", variant="primary") with gr.Column(scale=3): video_output_i2v = gr.Video(label="Generated Video", interactive=False) with gr.TabItem("Text-to-Video", id=0): mode_t2v = gr.State("t2v") with gr.Row(): with gr.Column(scale=2): prompt_t2v = gr.Textbox(label="Prompt", lines=4, placeholder="A cinematic shot of a Corgi walking on the beach.") with gr.Accordion(label="Advanced Options", open=False): neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles") with gr.Row(): height_t2v = gr.Slider(label="Height", minimum=256, maximum=1024, value=480, step=64) width_t2v = gr.Slider(label="Width", minimum=256, maximum=1024, value=832, step=64) with gr.Row(): seed_t2v = gr.Number(label="Seed", value=42, precision=0) distill_t2v = gr.Checkbox(label="Use Distill Mode", value=True, info="Faster, lower quality base generation.") refine_t2v = gr.Checkbox(label="Use Refine Mode", value=False, info="Higher quality & resolution, but slower. Uses Distill mode for its first stage.") t2v_button = gr.Button("Generate Video", variant="primary") with gr.Column(scale=3): video_output_t2v = gr.Video(label="Generated 6s video", interactive=False) # --- Event Handlers --- t2v_inputs = [ mode_t2v, prompt_t2v, neg_prompt_t2v, gr.State(None), # Placeholder for image height_t2v, width_t2v, gr.State(None), # Placeholder for resolution seed_t2v, distill_t2v, refine_t2v ] t2v_button.click(fn=generate_video, inputs=t2v_inputs, outputs=video_output_t2v) i2v_inputs = [ mode_i2v, prompt_i2v, neg_prompt_i2v, image_i2v, gr.State(None), gr.State(None), # Placeholders for height/width resolution_i2v, seed_i2v, distill_i2v, refine_i2v ] i2v_button.click(fn=generate_video, inputs=i2v_inputs, outputs=video_output_i2v) # --- 5. Launch the App --- if __name__ == "__main__": demo.launch()