""" Gradio interface for WAN-VACE video generation """ import gradio as gr import torch # ----------------------------------------------------------------------------- # XPU shim for CPU‑only environments # # `diffusers` attempts to access `torch.xpu.empty_cache()` when cleaning up # device memory. On CPU‑only builds of PyTorch (or builds without Intel # extension support), the `xpu` attribute does not exist on the `torch` # module. Defining a dummy `torch.xpu` prevents AttributeError during # import. # ----------------------------------------------------------------------------- if not hasattr(torch, "xpu"): class _DummyXPU: @staticmethod def empty_cache(): return None @staticmethod def manual_seed(_seed: int): return None @staticmethod def is_available(): return False @staticmethod def device_count(): return 0 @staticmethod def current_device(): return 0 @staticmethod def set_device(_idx: int): return None torch.xpu = _DummyXPU() # type: ignore import time from typing import Optional # Import the simple planner from planning import plan_from_topic from config import UI_CONFIG, DEFAULT_PARAMS, SERVER_CONFIG from model_handler import model_handler from utils import cleanup_temp_files def load_model_interface(progress=gr.Progress()): """Interface function for loading the model""" def progress_callback(value, message): progress(value, desc=message) success, message = model_handler.load_model(progress_callback) if success: return ( gr.update(visible=False), # Hide load button gr.update(visible=True), # Show generation interface gr.update(value=message, visible=True), # Show success message gr.update(visible=False) # Hide error message ) else: return ( gr.update(visible=True), # Keep load button visible gr.update(visible=False), # Keep generation interface hidden gr.update(visible=False), # Hide success message gr.update(value=message, visible=True) # Show error message ) def generate_video_interface( prompt: str, negative_prompt: str, width: int, height: int, num_frames: int, num_inference_steps: int, guidance_scale: float, seed: Optional[int], progress=gr.Progress() ): """Interface function for video generation""" def progress_callback(value, message): progress(value, desc=message) # Plan the prompt: treat the user input as a high‑level concept and let the # agent craft a refined prompt and recommended negative prompt. If the user # supplies a negative prompt, it overrides the recommended negative prompt. plan = plan_from_topic(prompt) # Use the refined prompt from the plan effective_prompt = plan.prompt # If the user provided a negative prompt, use it; otherwise use the recommended one effective_negative = negative_prompt.strip() if negative_prompt and negative_prompt.strip() else plan.negative_prompt success, video_path, error_msg, gen_info = model_handler.generate_video( prompt=effective_prompt, negative_prompt=effective_negative, width=width, height=height, num_frames=num_frames, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, seed=seed, progress_callback=progress_callback ) if success: return ( gr.update(value=video_path, visible=True), # Video output gr.update(value=gen_info, visible=True), # Generation info gr.update(visible=False) # Hide error message ) else: return ( gr.update(value=None, visible=False), # Hide video output gr.update(visible=False), # Hide generation info gr.update(value=error_msg, visible=True) # Show error message ) def create_interface(): """Create the Gradio interface""" with gr.Blocks( title=UI_CONFIG["title"], theme=UI_CONFIG["theme"] ) as demo: # Header gr.Markdown(f"# {UI_CONFIG['title']}") gr.Markdown(UI_CONFIG["description"]) # Model loading section with gr.Row(): with gr.Column(): load_btn = gr.Button( "🚀 Load Video Generation Model", variant="primary", size="lg" ) load_success_msg = gr.Markdown(visible=False) load_error_msg = gr.Markdown(visible=False) # Main generation interface (initially hidden) with gr.Column(visible=False) as generation_interface: # Input section with gr.Row(): with gr.Column(scale=2): with gr.Group(): gr.Markdown("### 📝 Concept & Prompts") # The user supplies a high‑level concept or topic. The agent will # refine this into a detailed prompt automatically. prompt_input = gr.Textbox( label="Video Concept", placeholder="Describe the concept you want to generate, e.g. 'a pig in a winter forest'...", lines=3, value="a pig moving quickly in a beautiful winter scenery nature trees sunset tracking camera" ) # Optional negative prompt: overrides the agent's recommended negative prompt. negative_prompt_input = gr.Textbox( label="Negative Prompt (Optional)", placeholder="Things you don't want in the video; leave empty to use the agent's recommendation...", lines=2, value="" ) with gr.Column(scale=1): with gr.Group(): gr.Markdown("### ⚙️ Generation Parameters") with gr.Row(): width_slider = gr.Slider( label="Width", minimum=64, maximum=1920, step=8, value=DEFAULT_PARAMS["width"] ) height_slider = gr.Slider( label="Height", minimum=64, maximum=1080, step=8, value=DEFAULT_PARAMS["height"] ) num_frames_slider = gr.Slider( label="Number of Frames", minimum=1, maximum=200, step=1, value=DEFAULT_PARAMS["num_frames"] ) inference_steps_slider = gr.Slider( label="Inference Steps", minimum=1, maximum=100, step=1, value=DEFAULT_PARAMS["num_inference_steps"] ) guidance_scale_slider = gr.Slider( label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.1, value=DEFAULT_PARAMS["guidance_scale"] ) seed_input = gr.Number( label="Seed (Optional)", value=0, precision=0 ) # Generation button with gr.Row(): generate_btn = gr.Button( "🎬 Generate Video", variant="primary", size="lg" ) # Output section with gr.Row(): with gr.Column(): video_output = gr.Video( label="Generated Video", visible=False ) generation_info = gr.Markdown( label="Generation Information", visible=False ) generation_error = gr.Markdown( visible=False ) # Additional controls with gr.Row(): with gr.Column(): gr.Markdown(""" ### 💡 Tips: - Enter a short **concept** (e.g. “a busy city street at dawn”). The agent will expand it into a detailed prompt. - Adjust the **guidance scale**: higher values make the video adhere more closely to the refined prompt. - Increasing **inference steps** improves quality at the cost of generation time. - Use the optional **Negative Prompt** field only if you want to override the agent's recommended terms. - Keep width and height multiples of 8 for optimal performance. """) with gr.Column(): if torch.cuda.is_available(): gpu_info = f"🎮 GPU: {torch.cuda.get_device_name()}" else: gpu_info = "💻 Running on CPU" gr.Markdown(f""" ### 🖥️ System Information: {gpu_info} ### 📊 Model Information: - **Model:** WAN‑VACE 1.3B (Q4_0 Quantized) - **Text Encoder:** UMT5‑XXL - **Scheduler:** UniPC Multistep ### 🤖 Agent Details: - **Planning:** The agent automatically crafts a detailed prompt and a recommended negative prompt based on your concept. - **Override:** Supply your own negative prompt to override the recommendation if desired. """) # Event handlers load_btn.click( fn=load_model_interface, outputs=[ load_btn, generation_interface, load_success_msg, load_error_msg ] ) generate_btn.click( fn=generate_video_interface, inputs=[ prompt_input, negative_prompt_input, width_slider, height_slider, num_frames_slider, inference_steps_slider, guidance_scale_slider, seed_input ], outputs=[ video_output, generation_info, generation_error ] ) return demo def main(): """Main function to launch the application""" print(f"🚀 Starting {UI_CONFIG['title']}...") print(f"🔧 Server configuration: {SERVER_CONFIG['host']}:{SERVER_CONFIG['port']}") # Check GPU availability if torch.cuda.is_available(): print(f"🎮 GPU detected: {torch.cuda.get_device_name()}") print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB") else: print("💻 Running on CPU (GPU recommended for better performance)") # Create interface and enable the event queue to support multiple users. demo = create_interface() # Hugging Face Spaces expect `.queue()` to be called for handling request concurrency. # Limiting concurrency_count to 1 helps prevent excessive memory usage on CPU-only hardware. demo = demo.queue(concurrency_count=1) # Launch the interface. demo.launch( server_name=SERVER_CONFIG["host"], server_port=SERVER_CONFIG["port"], share=SERVER_CONFIG["share"], show_error=True, show_tips=True ) if __name__ == "__main__": main()