Spaces:
Runtime error
Runtime error
| import os | |
| import gc | |
| import time | |
| import random | |
| import torch | |
| import imageio | |
| import gradio as gr | |
| from diffusers.utils import load_image | |
| from skyreels_v2_infer import DiffusionForcingPipeline | |
| from skyreels_v2_infer.modules import download_model | |
| from skyreels_v2_infer.pipelines import PromptEnhancer, resizecrop | |
| def generate_diffusion_forced_video( | |
| prompt, | |
| model_id, | |
| resolution, | |
| num_frames, | |
| image=None, | |
| ar_step=0, | |
| causal_attention=False, | |
| causal_block_size=1, | |
| base_num_frames=97, | |
| overlap_history=None, | |
| addnoise_condition=0, | |
| guidance_scale=6.0, | |
| shift=8.0, | |
| inference_steps=30, | |
| use_usp=False, | |
| offload=True, | |
| fps=24, | |
| seed=None, | |
| prompt_enhancer=False, | |
| teacache=False, | |
| teacache_thresh=0.2, | |
| use_ret_steps=False | |
| ): | |
| model_id = download_model(model_id) | |
| if resolution == "540P": | |
| height, width = 544, 960 | |
| elif resolution == "720P": | |
| height, width = 720, 1280 | |
| else: | |
| raise ValueError(f"Invalid resolution: {resolution}") | |
| if seed is None: | |
| random.seed(time.time()) | |
| seed = int(random.randrange(4294967294)) | |
| if num_frames > base_num_frames and overlap_history is None: | |
| raise ValueError("Specify `overlap_history` for long video generation. Try 17 or 37.") | |
| if addnoise_condition > 60: | |
| print("Warning: Large `addnoise_condition` may reduce consistency. Recommended: 20.") | |
| if image is not None: | |
| image = load_image(image).convert("RGB") | |
| image_width, image_height = image.size | |
| if image_height > image_width: | |
| height, width = width, height | |
| image = resizecrop(image, height, width) | |
| negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" | |
| prompt_input = prompt | |
| if prompt_enhancer and image is None: | |
| enhancer = PromptEnhancer() | |
| prompt_input = enhancer(prompt_input) | |
| del enhancer | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| pipe = DiffusionForcingPipeline( | |
| model_id, | |
| dit_path=model_id, | |
| device=torch.device("cuda"), | |
| weight_dtype=torch.bfloat16, | |
| use_usp=use_usp, | |
| offload=offload, | |
| ) | |
| if causal_attention: | |
| pipe.transformer.set_ar_attention(causal_block_size) | |
| if teacache: | |
| if ar_step > 0: | |
| num_steps = ( | |
| inference_steps + (((base_num_frames - 1) // 4 + 1) // causal_block_size - 1) * ar_step | |
| ) | |
| else: | |
| num_steps = inference_steps | |
| pipe.transformer.initialize_teacache( | |
| enable_teacache=True, | |
| num_steps=num_steps, | |
| teacache_thresh=teacache_thresh, | |
| use_ret_steps=use_ret_steps, | |
| ckpt_dir=model_id, | |
| ) | |
| with torch.amp.autocast("cuda", dtype=pipe.transformer.dtype), torch.no_grad(): | |
| video_frames = pipe( | |
| prompt=prompt_input, | |
| negative_prompt=negative_prompt, | |
| image=image, | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| num_inference_steps=inference_steps, | |
| shift=shift, | |
| guidance_scale=guidance_scale, | |
| generator=torch.Generator(device="cuda").manual_seed(seed), | |
| overlap_history=overlap_history, | |
| addnoise_condition=addnoise_condition, | |
| base_num_frames=base_num_frames, | |
| ar_step=ar_step, | |
| causal_block_size=causal_block_size, | |
| fps=fps, | |
| )[0] | |
| os.makedirs("gradio_df_videos", exist_ok=True) | |
| timestamp = time.strftime("%Y%m%d_%H%M%S") | |
| output_path = f"gradio_df_videos/{prompt[:50].replace('/', '')}_{seed}_{timestamp}.mp4" | |
| imageio.mimwrite(output_path, video_frames, fps=fps, quality=8, output_params=["-loglevel", "error"]) | |
| return output_path | |
| # Gradio UI | |
| resolution_options = ["540P", "720P"] | |
| model_options = ["Skywork/SkyReels-V2-DF-1.3B-540P"] # Update if there are more | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| gr.Markdown("# SkyReels V2") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt") | |
| model_id = gr.Dropdown(choices=model_options, value=model_options[0], label="Model ID") | |
| resolution = gr.Radio(choices=resolution_options, value="540P", label="Resolution", interactive=False) | |
| num_frames = gr.Slider(minimum=16, maximum=200, value=97, step=1, label="Number of Frames") | |
| image = gr.Image(type="filepath", label="Input Image (optional)") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| ar_step = gr.Number(label="AR Step", value=0) | |
| causal_attention = gr.Checkbox(label="Causal Attention") | |
| causal_block_size = gr.Number(label="Causal Block Size", value=1) | |
| base_num_frames = gr.Number(label="Base Num Frames", value=97) | |
| overlap_history = gr.Number(label="Overlap History (set for long videos)", value=None) | |
| addnoise_condition = gr.Number(label="AddNoise Condition", value=0) | |
| guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=6.0, step=0.1, label="Guidance Scale") | |
| shift = gr.Slider(minimum=0.0, maximum=20.0, value=8.0, step=0.1, label="Shift") | |
| inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Inference Steps") | |
| use_usp = gr.Checkbox(label="Use USP") | |
| offload = gr.Checkbox(label="Offload", value=True, interactive=False) | |
| fps = gr.Slider(minimum=1, maximum=60, value=24, step=1, label="FPS") | |
| seed = gr.Number(label="Seed (optional)", precision=0) | |
| prompt_enhancer = gr.Checkbox(label="Prompt Enhancer") | |
| use_teacache = gr.Checkbox(label="Use TeaCache") | |
| teacache_thresh = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.01, label="TeaCache Threshold") | |
| use_ret_steps = gr.Checkbox(label="Use Retention Steps") | |
| submit_btn = gr.Button("Generate") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Generated Video") | |
| submit_btn.click( | |
| fn = generate_diffusion_forced_video, | |
| inputs = [ | |
| prompt, | |
| model_id, | |
| resolution, | |
| num_frames, | |
| image, | |
| ar_step, | |
| causal_attention, | |
| causal_block_size, | |
| base_num_frames, | |
| overlap_history, | |
| addnoise_condition, | |
| guidance_scale, | |
| shift, | |
| inference_steps, | |
| use_usp, | |
| offload, | |
| fps, | |
| seed, | |
| prompt_enhancer, | |
| use_teacache, | |
| teacache_thresh, | |
| use_ret_steps | |
| ], | |
| outputs = [ | |
| output_video | |
| ] | |
| ) | |
| demo.launch(show_error=True, show_api=False, share=False) | |