File size: 12,123 Bytes
ecda25f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6660ea
 
 
ecda25f
 
 
 
 
 
 
b68f017
ecda25f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efe2ba1
 
 
024771c
efe2ba1
 
 
ecda25f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82fc0ee
 
 
 
 
 
 
 
 
a57f68f
82fc0ee
96362e7
7b893b3
1576fbd
 
82fc0ee
 
 
6c880b4
ecda25f
 
 
 
 
 
 
 
 
214d9f8
ecda25f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbd966d
 
 
 
4f5e7b5
b37f148
ecda25f
 
ffec15c
 
 
 
 
 
353b507
 
 
 
ffec15c
6fd121e
ffec15c
 
7b893b3
ffec15c
 
 
ecda25f
 
 
 
 
 
353b507
 
 
 
 
 
 
 
 
 
ecda25f
 
96362e7
ecda25f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import gradio as gr
import torch
import os
import sys
import subprocess
import tempfile
import numpy as np
import spaces
from PIL import Image

# Define paths
REPO_PATH = "LongCat-Video"
CHECKPOINT_DIR = os.path.join(REPO_PATH, "weights", "LongCat-Video")

# Clone the repository if it doesn't exist
if not os.path.exists(REPO_PATH):
    print(f"Cloning LongCat-Video repository to '{REPO_PATH}'...")
    try:
        subprocess.run(
            ["git", "clone", "https://github.com/meituan-longcat/LongCat-Video.git", REPO_PATH],
            check=True,
            capture_output=True
        )
        print("Repository cloned successfully.")
    except subprocess.CalledProcessError as e:
        print(f"Error cloning repository: {e.stderr.decode()}")
        sys.exit(1)

# Add the cloned repository to the Python path to allow imports
sys.path.insert(0, os.path.abspath(REPO_PATH))

# Now that the repo is in the path, we can import its modules
from huggingface_hub import snapshot_download
from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
from longcat_video.modules.longcat_video_dit import LongCatVideoTransformer3DModel
from longcat_video.context_parallel import context_parallel_util
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers.utils import export_to_video

# Download model weights from Hugging Face Hub if they don't exist
if not os.path.exists(CHECKPOINT_DIR):
    print(f"Downloading model weights to '{CHECKPOINT_DIR}'...")
    try:
        snapshot_download(
            repo_id="meituan-longcat/LongCat-Video",
            local_dir=CHECKPOINT_DIR,
            local_dir_use_symlinks=False,  # Use False for better Windows compatibility
            ignore_patterns=["*.md", "*.gitattributes", "assets/*"] # ignore non-essential files
        )
        print("Model weights downloaded successfully.")
    except Exception as e:
        print(f"Error downloading model weights: {e}")
        sys.exit(1)

# Global placeholder for the pipeline and device configuration
pipe = None
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32

print("--- Initializing Models (loaded once at startup) ---")
try:
    # Context parallel is not used in this single-instance demo, but the model requires the config.
    cp_split_hw = context_parallel_util.get_optimal_split(1)
    
    print("Loading tokenizer and text encoder...")
    tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR, subfolder="tokenizer", torch_dtype=torch_dtype)
    text_encoder = UMT5EncoderModel.from_pretrained(CHECKPOINT_DIR, subfolder="text_encoder", torch_dtype=torch_dtype)
    
    print("Loading VAE and Scheduler...")
    vae = AutoencoderKLWan.from_pretrained(CHECKPOINT_DIR, subfolder="vae", torch_dtype=torch_dtype)
    scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(CHECKPOINT_DIR, subfolder="scheduler", torch_dtype=torch_dtype)
    
    print("Loading DiT model...")
    dit = LongCatVideoTransformer3DModel.from_pretrained(CHECKPOINT_DIR,
                                                         enable_flashattn3=False,
                                                         enable_flashattn2=False,
                                                         enable_xformers=True,
                                                         subfolder="dit",
                                                         cp_split_hw=cp_split_hw,
                                                         torch_dtype=torch_dtype)

    print("Creating LongCatVideoPipeline...")
    pipe = LongCatVideoPipeline(
        tokenizer=tokenizer,
        text_encoder=text_encoder,
        vae=vae,
        scheduler=scheduler,
        dit=dit,
    )
    pipe.to(device)
    
    print("Loading LoRA weights for optional modes...")
    cfg_step_lora_path = os.path.join(CHECKPOINT_DIR, 'lora/cfg_step_lora.safetensors')
    pipe.dit.load_lora(cfg_step_lora_path, 'cfg_step_lora')

    refinement_lora_path = os.path.join(CHECKPOINT_DIR, 'lora/refinement_lora.safetensors')
    pipe.dit.load_lora(refinement_lora_path, 'refinement_lora')

    print("--- Models loaded successfully and are ready for inference. ---")

except Exception as e:
    print("--- FATAL ERROR: Failed to load models. ---")
    print(f"Details: {e}")
    # The app will still run, but generation will fail with an error message.
    pipe = None


# --- 3. Generation Logic ---

def torch_gc():
    """Helper function to clean up GPU memory."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

def check_duration(
    mode,
    prompt, 
    neg_prompt, 
    image,
    height, width, resolution,
    seed,
    use_distill,
    use_refine,
    progress
):
    if use_distill and resolution=="480p":
        return 180
    elif resolution=="720p":
        return 360
    else:
        return 900

@spaces.GPU(duration=check_duration)
def generate_video(
    mode,
    prompt, 
    neg_prompt, 
    image,
    height, width, resolution,
    seed,
    use_distill,
    use_refine,
    progress=gr.Progress(track_tqdm=True)
):
    """
    Universal video generation function.
    """
    if pipe is None:
        raise gr.Error("Models failed to load. Please check the console output for errors and restart the app.")

    generator = torch.Generator(device=device).manual_seed(int(seed))
    
    # --- Stage 1: Base Generation (Standard or Distill) ---
    progress(0, desc="Starting Stage 1: Base Generation")
    
    num_frames = 93  # Default from demo scripts
    is_distill = use_distill or use_refine # Refinement requires a distilled video as input

    if is_distill:
        pipe.dit.enable_loras(['cfg_step_lora'])
        num_inference_steps = 16
        guidance_scale = 1.0
        current_neg_prompt = "" 
    else:
        num_inference_steps = 50
        guidance_scale = 4.0
        current_neg_prompt = neg_prompt

    if mode == "t2v":
        output = pipe.generate_t2v(
            prompt=prompt,
            negative_prompt=current_neg_prompt,
            height=height,
            width=width,
            num_frames=num_frames,
            num_inference_steps=num_inference_steps,
            use_distill=is_distill,
            guidance_scale=guidance_scale,
            generator=generator,
        )[0]
    elif mode == "i2v":
        pil_image = Image.fromarray(image)
        output = pipe.generate_i2v(
            image=pil_image,
            prompt=prompt,
            negative_prompt=current_neg_prompt,
            resolution=resolution,
            num_frames=num_frames,
            num_inference_steps=num_inference_steps,
            use_distill=is_distill,
            guidance_scale=guidance_scale,
            generator=generator,
        )[0]
    
    if is_distill:
        pipe.dit.disable_all_loras()
        
    torch_gc()

    # --- Stage 2: Refinement (Optional) ---
    if use_refine:
        progress(0.5, desc="Starting Stage 2: Refinement")
        
        pipe.dit.enable_loras(['refinement_lora'])
        pipe.dit.enable_bsa()

        stage1_video_pil = [(frame * 255).astype(np.uint8) for frame in output]
        stage1_video_pil = [Image.fromarray(img) for img in stage1_video_pil]

        refine_image = Image.fromarray(image) if mode == 'i2v' else None
        
        output = pipe.generate_refine(
            image=refine_image,
            prompt=prompt,
            stage1_video=stage1_video_pil,
            num_cond_frames=1 if mode == 'i2v' else 0,
            num_inference_steps=50,
            generator=generator,
        )[0]

        pipe.dit.disable_all_loras()
        pipe.dit.disable_bsa()
        torch_gc()

    # --- Post-processing and Output ---
    progress(1.0, desc="Exporting video")
    
    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video_file:
        fps = 30 if use_refine else 15
        export_to_video(output, temp_video_file.name, fps=fps)
        return temp_video_file.name

# --- 4. Gradio UI Definition ---
css = '''
.fillable{max-width: 960px !important}
'''
with gr.Blocks(css=css) as demo:
    gr.Markdown("# 🎬 LongCat-Video")
    gr.Markdown('''13.6B parameter dense video-generation model by Meituan — [[Model](https://huggingface.co/meituan-longcat/LongCat-Video)]''')

    with gr.Tabs() as tabs:
        with gr.TabItem("Image-to-Video", id=1):
            mode_i2v = gr.State("i2v")
            with gr.Row():
                with gr.Column(scale=2):
                    image_i2v = gr.Image(type="numpy", label="Input Image")
                    prompt_i2v = gr.Textbox(label="Prompt", lines=4, placeholder="The cat in the image wags its tail and blinks.")
                    
                    with gr.Accordion(label="Advanced Options", open=False):
                        neg_prompt_i2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles, watermark")
                        resolution_i2v = gr.Dropdown(label="Resolution", choices=["480p", "720p"], value="480p")
                        seed_i2v = gr.Number(label="Seed", value=42, precision=0)
                        distill_i2v = gr.Checkbox(label="Use Distill Mode", value=True, info="Faster, lower quality base generation.")
                        refine_i2v = gr.Checkbox(label="Use Refine Mode", value=False, info="Higher quality & resolution, but slower. Uses Distill mode for its first stage.")

                    i2v_button = gr.Button("Generate 6s video", variant="primary")
                with gr.Column(scale=3):
                    video_output_i2v = gr.Video(label="Generated Video", interactive=False)
                    
        with gr.TabItem("Text-to-Video", id=0):
            mode_t2v = gr.State("t2v")
            with gr.Row():
                with gr.Column(scale=2):
                    prompt_t2v = gr.Textbox(label="Prompt", lines=4, placeholder="A cinematic shot of a Corgi walking on the beach.")
                    
                    with gr.Accordion(label="Advanced Options", open=False):
                        neg_prompt_t2v = gr.Textbox(label="Negative Prompt", lines=2, value="ugly, blurry, low quality, static, subtitles")
                        with gr.Row():
                            height_t2v = gr.Slider(label="Height", minimum=256, maximum=1024, value=480, step=64)
                            width_t2v = gr.Slider(label="Width", minimum=256, maximum=1024, value=832, step=64)
                        with gr.Row():
                            seed_t2v = gr.Number(label="Seed", value=42, precision=0)
                            distill_t2v = gr.Checkbox(label="Use Distill Mode", value=True, info="Faster, lower quality base generation.")
                            refine_t2v = gr.Checkbox(label="Use Refine Mode", value=False, info="Higher quality & resolution, but slower. Uses Distill mode for its first stage.")
                        
                    t2v_button = gr.Button("Generate Video", variant="primary")
                with gr.Column(scale=3):
                    video_output_t2v = gr.Video(label="Generated 6s video", interactive=False)

    # --- Event Handlers ---
    t2v_inputs = [
        mode_t2v, prompt_t2v, neg_prompt_t2v, 
        gr.State(None), # Placeholder for image
        height_t2v, width_t2v, 
        gr.State(None), # Placeholder for resolution
        seed_t2v, distill_t2v, refine_t2v
    ]
    t2v_button.click(fn=generate_video, inputs=t2v_inputs, outputs=video_output_t2v)

    i2v_inputs = [
        mode_i2v, prompt_i2v, neg_prompt_i2v, image_i2v,
        gr.State(None), gr.State(None), # Placeholders for height/width
        resolution_i2v,
        seed_i2v, distill_i2v, refine_i2v
    ]
    i2v_button.click(fn=generate_video, inputs=i2v_inputs, outputs=video_output_i2v)

# --- 5. Launch the App ---
if __name__ == "__main__":
    demo.launch()