Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| RAPO++ Text-to-Video Prompt Optimization Demo | |
| This demo showcases Stage 1 (RAPO): Retrieval-Augmented Prompt Optimization | |
| It demonstrates how simple prompts can be enriched with contextually relevant modifiers | |
| retrieved from a knowledge graph for better text-to-video generation. | |
| """ | |
| # CRITICAL: Import spaces FIRST before any CUDA-related packages | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| from torch.nn.functional import cosine_similarity | |
| import networkx as nx | |
| import json | |
| import os | |
| import random | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| # ============================================================================= | |
| # Model and Data Setup (runs once at startup) | |
| # ============================================================================= | |
| print("=" * 60) | |
| print("Setting up RAPO++ Demo...") | |
| print("=" * 60) | |
| # Create necessary directories | |
| os.makedirs("./ckpt", exist_ok=True) | |
| os.makedirs("./relation_graph/graph_data", exist_ok=True) | |
| # Download SentenceTransformer model for embeddings | |
| SENTENCE_TRANSFORMER_PATH = "./ckpt/all-MiniLM-L6-v2" | |
| if not os.path.exists(SENTENCE_TRANSFORMER_PATH): | |
| print("Downloading SentenceTransformer model...") | |
| snapshot_download( | |
| repo_id="sentence-transformers/all-MiniLM-L6-v2", | |
| local_dir=SENTENCE_TRANSFORMER_PATH, | |
| local_dir_use_symlinks=False | |
| ) | |
| print("β SentenceTransformer downloaded") | |
| else: | |
| print("β SentenceTransformer already cached") | |
| # Load SentenceTransformer model | |
| print("Loading SentenceTransformer model...") | |
| embedding_model = SentenceTransformer(SENTENCE_TRANSFORMER_PATH) | |
| print("β Model loaded") | |
| # ============================================================================= | |
| # Simple Demo Graph Creation (since full graph data requires large download) | |
| # ============================================================================= | |
| def create_demo_graph(): | |
| """Create a simplified demo graph with common T2V generation concepts""" | |
| # Create sample place-verb and place-scene graphs | |
| G_place_verb = nx.Graph() | |
| G_place_scene = nx.Graph() | |
| # Define places (central nodes) | |
| places = [ | |
| "forest", "beach", "city street", "mountain", "room", "park", | |
| "studio", "kitchen", "bridge", "parking lot", "desert", "lake" | |
| ] | |
| # Define verbs/actions for each place | |
| place_verbs = { | |
| "forest": ["walking through", "hiking in", "exploring", "camping in", "running through"], | |
| "beach": ["walking on", "swimming at", "surfing at", "relaxing on", "playing on"], | |
| "city street": ["walking down", "driving through", "running along", "biking through"], | |
| "mountain": ["climbing", "hiking up", "descending", "exploring", "camping on"], | |
| "room": ["sitting in", "working in", "relaxing in", "reading in", "sleeping in"], | |
| "park": ["walking in", "playing in", "jogging through", "sitting in", "picnicking in"], | |
| "studio": ["working in", "dancing in", "recording in", "practicing in"], | |
| "kitchen": ["cooking in", "preparing food in", "baking in", "cleaning"], | |
| "bridge": ["walking across", "driving across", "standing on", "running across"], | |
| "parking lot": ["standing in", "walking through", "driving in", "parking in"], | |
| "desert": ["walking through", "driving through", "camping in", "exploring"], | |
| "lake": ["swimming in", "boating on", "fishing at", "relaxing by"] | |
| } | |
| # Define scenarios/atmospheres for each place | |
| place_scenes = { | |
| "forest": ["dense trees", "peaceful atmosphere", "natural setting", "quiet surroundings"], | |
| "beach": ["ocean waves", "sunny day", "sandy shore", "coastal view"], | |
| "city street": ["busy traffic", "urban environment", "city lights", "crowded sidewalk"], | |
| "mountain": ["scenic view", "high altitude", "rocky terrain", "mountain peak"], | |
| "room": ["indoor setting", "comfortable space", "quiet environment", "cozy atmosphere"], | |
| "park": ["green grass", "open space", "trees around", "peaceful setting"], | |
| "studio": ["professional lighting", "indoor space", "creative environment"], | |
| "kitchen": ["modern appliances", "cooking area", "indoor setting", "bright lighting"], | |
| "bridge": ["elevated view", "water below", "connecting path", "architectural structure"], | |
| "parking lot": ["outdoor area", "vehicles around", "paved surface", "open space"], | |
| "desert": ["sandy terrain", "hot climate", "barren landscape", "vast expanse"], | |
| "lake": ["calm water", "natural scenery", "peaceful setting", "reflection on water"] | |
| } | |
| # Build graphs | |
| for place in places: | |
| # Add place-verb connections | |
| for verb in place_verbs.get(place, []): | |
| G_place_verb.add_edge(place, verb) | |
| # Add place-scene connections | |
| for scene in place_scenes.get(place, []): | |
| G_place_scene.add_edge(place, scene) | |
| # Create embeddings for all places | |
| place_embeddings = embedding_model.encode(places) | |
| # Create lookup dictionaries | |
| place_to_idx = {place: idx for idx, place in enumerate(places)} | |
| idx_to_place = {idx: place for place, idx in place_to_idx.items()} | |
| return G_place_verb, G_place_scene, place_embeddings, place_to_idx, idx_to_place | |
| # Initialize demo graph | |
| print("Creating demo knowledge graph...") | |
| G_place_verb, G_place_scene, place_embeddings, place_to_idx, idx_to_place = create_demo_graph() | |
| print("β Demo graph created") | |
| print("=" * 60) | |
| print("β Setup complete!") | |
| print("=" * 60) | |
| # ============================================================================= | |
| # Core RAPO Functions | |
| # ============================================================================= | |
| def retrieve_and_augment_prompt(prompt: str, place_num: int = 2, modifier_num: int = 5) -> tuple: | |
| """ | |
| Main RAPO function: Retrieves relevant modifiers from the graph and augments the prompt. | |
| Args: | |
| prompt: Input text-to-video generation prompt | |
| place_num: Number of top places to retrieve | |
| modifier_num: Number of modifiers to sample per place | |
| Returns: | |
| Tuple of (augmented_prompt, retrieved_info, places_found) | |
| """ | |
| # Encode input prompt | |
| prompt_embedding = embedding_model.encode(prompt) | |
| # Compute similarity with all places | |
| similarities = cosine_similarity( | |
| torch.tensor(prompt_embedding).unsqueeze(0), | |
| torch.tensor(place_embeddings) | |
| ) | |
| # Get top-K most similar places | |
| top_indices = torch.topk(similarities, min(place_num, len(place_to_idx))).indices | |
| # Retrieve modifiers from graph | |
| retrieved_verbs = [] | |
| retrieved_scenes = [] | |
| places_found = [] | |
| for idx in top_indices.numpy().tolist(): | |
| place = idx_to_place[idx] | |
| places_found.append(place) | |
| # Get verb neighbors | |
| verb_neighbors = list(G_place_verb.neighbors(place)) | |
| verb_samples = random.sample(verb_neighbors, min(modifier_num, len(verb_neighbors))) | |
| retrieved_verbs.extend(verb_samples) | |
| # Get scene neighbors | |
| scene_neighbors = list(G_place_scene.neighbors(place)) | |
| scene_samples = random.sample(scene_neighbors, min(modifier_num, len(scene_neighbors))) | |
| retrieved_scenes.extend(scene_samples) | |
| # Remove duplicates while preserving order | |
| retrieved_verbs = list(dict.fromkeys(retrieved_verbs)) | |
| retrieved_scenes = list(dict.fromkeys(retrieved_scenes)) | |
| # Create augmented prompt (simple version - just add contextual details) | |
| augmented_parts = [prompt.strip()] | |
| # Add most relevant modifiers | |
| if retrieved_verbs: | |
| augmented_parts.append(f"The scene shows {retrieved_verbs[0]}") | |
| if retrieved_scenes: | |
| augmented_parts.append(f"with {retrieved_scenes[0]}") | |
| augmented_prompt = ", ".join(augmented_parts) + "." | |
| # Format retrieved info for display | |
| retrieved_info = { | |
| "Places": places_found, | |
| "Actions": retrieved_verbs[:5], | |
| "Atmosphere": retrieved_scenes[:5] | |
| } | |
| return augmented_prompt, retrieved_info, places_found | |
| # ============================================================================= | |
| # Gradio Interface | |
| # ============================================================================= | |
| def process_prompt(prompt, place_num, modifier_num): | |
| """Process prompt and return results for Gradio""" | |
| if not prompt.strip(): | |
| return "Please enter a prompt.", {}, [] | |
| try: | |
| augmented_prompt, retrieved_info, places = retrieve_and_augment_prompt( | |
| prompt, place_num, modifier_num | |
| ) | |
| # Format retrieved info for display | |
| info_text = "**Retrieved Modifiers:**\n\n" | |
| info_text += f"**π Top Places:** {', '.join(places)}\n\n" | |
| info_text += f"**π¬ Actions:** {', '.join(retrieved_info['Actions'])}\n\n" | |
| info_text += f"**π Atmosphere:** {', '.join(retrieved_info['Atmosphere'])}\n\n" | |
| return augmented_prompt, info_text | |
| except Exception as e: | |
| return f"Error: {str(e)}", "" | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| theme=gr.themes.Soft( | |
| primary_hue="purple", | |
| secondary_hue="blue" | |
| ), | |
| title="RAPO++ Text-to-Video Prompt Optimization" | |
| ) as demo: | |
| gr.Markdown(""" | |
| # π¬ RAPO++ Text-to-Video Prompt Optimization | |
| This demo showcases **Stage 1 (RAPO)**: Retrieval-Augmented Prompt Optimization using knowledge graphs. | |
| **How it works:** | |
| 1. Enter a simple text-to-video prompt | |
| 2. The system retrieves contextually relevant modifiers from a knowledge graph | |
| 3. Your prompt is enhanced with specific actions and atmospheric details | |
| 4. Use the optimized prompt for better T2V generation results! | |
| **Example prompts to try:** | |
| - "A person walking" | |
| - "A car driving" | |
| - "Someone cooking" | |
| - "A group of people talking" | |
| Based on the paper: [RAPO++ (arXiv:2510.20206)](https://arxiv.org/abs/2510.20206) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input") | |
| input_prompt = gr.Textbox( | |
| label="Original Prompt", | |
| placeholder="Enter your text-to-video prompt (e.g., 'A person walking')", | |
| lines=3 | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| place_num = gr.Slider( | |
| minimum=1, | |
| maximum=5, | |
| value=2, | |
| step=1, | |
| label="Number of Places to Retrieve", | |
| info="How many related places to search in the knowledge graph" | |
| ) | |
| modifier_num = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Modifiers per Place", | |
| info="How many modifiers to sample from each place" | |
| ) | |
| process_btn = gr.Button("β¨ Optimize Prompt", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Results") | |
| output_prompt = gr.Textbox( | |
| label="Optimized Prompt", | |
| lines=5, | |
| show_copy_button=True | |
| ) | |
| retrieved_info = gr.Markdown( | |
| label="Retrieved Information" | |
| ) | |
| # Example prompts | |
| gr.Examples( | |
| examples=[ | |
| ["A person walking", 2, 5], | |
| ["A car driving at night", 2, 5], | |
| ["Someone cooking in a kitchen", 2, 5], | |
| ["A group of people talking", 2, 5], | |
| ["A bird flying", 2, 5], | |
| ["Someone sitting and reading", 2, 5], | |
| ], | |
| inputs=[input_prompt, place_num, modifier_num], | |
| outputs=[output_prompt, retrieved_info], | |
| fn=process_prompt, | |
| cache_examples=False | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### About RAPO++ | |
| RAPO++ is a three-stage framework for text-to-video generation prompt optimization: | |
| - **Stage 1 (RAPO)**: Retrieval-Augmented Prompt Optimization using relation graphs *(demonstrated here)* | |
| - **Stage 2 (SSPO)**: Self-Supervised Prompt Optimization with test-time iterative refinement | |
| - **Stage 3**: LLM fine-tuning on collected feedback data | |
| The system is model-agnostic and works with various T2V models (Wan2.1, Open-Sora-Plan, HunyuanVideo, etc.). | |
| **Papers:** | |
| - [RAPO (CVPR 2025)](https://arxiv.org/abs/2502.07516): The Devil is in the Prompts: Retrieval-Augmented Prompt Optimization for Text-to-Video Generation | |
| - [RAPO++ (arXiv:2510.20206)](https://arxiv.org/abs/2510.20206): Cross-Stage Prompt Optimization for Text-to-Video Generation via Data Alignment and Test-Time Scaling | |
| **Project Page:** [https://whynothaha.github.io/RAPO_plus_github/](https://whynothaha.github.io/RAPO_plus_github/) | |
| **GitHub:** [https://github.com/Vchitect/RAPO](https://github.com/Vchitect/RAPO) | |
| """) | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_prompt, | |
| inputs=[input_prompt, place_num, modifier_num], | |
| outputs=[output_prompt, retrieved_info] | |
| ) | |
| input_prompt.submit( | |
| fn=process_prompt, | |
| inputs=[input_prompt, place_num, modifier_num], | |
| outputs=[output_prompt, retrieved_info] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |