""" RAPO++ Text-to-Video Prompt Optimization Demo This demo showcases Stage 1 (RAPO): Retrieval-Augmented Prompt Optimization It demonstrates how simple prompts can be enriched with contextually relevant modifiers retrieved from a knowledge graph for better text-to-video generation. """ # CRITICAL: Import spaces FIRST before any CUDA-related packages import spaces import gradio as gr import torch from sentence_transformers import SentenceTransformer from torch.nn.functional import cosine_similarity import networkx as nx import json import os import random from huggingface_hub import snapshot_download, hf_hub_download # ============================================================================= # Model and Data Setup (runs once at startup) # ============================================================================= print("=" * 60) print("Setting up RAPO++ Demo...") print("=" * 60) # Create necessary directories os.makedirs("./ckpt", exist_ok=True) os.makedirs("./relation_graph/graph_data", exist_ok=True) # Download SentenceTransformer model for embeddings SENTENCE_TRANSFORMER_PATH = "./ckpt/all-MiniLM-L6-v2" if not os.path.exists(SENTENCE_TRANSFORMER_PATH): print("Downloading SentenceTransformer model...") snapshot_download( repo_id="sentence-transformers/all-MiniLM-L6-v2", local_dir=SENTENCE_TRANSFORMER_PATH, local_dir_use_symlinks=False ) print("✓ SentenceTransformer downloaded") else: print("✓ SentenceTransformer already cached") # Load SentenceTransformer model print("Loading SentenceTransformer model...") embedding_model = SentenceTransformer(SENTENCE_TRANSFORMER_PATH) print("✓ Model loaded") # ============================================================================= # Simple Demo Graph Creation (since full graph data requires large download) # ============================================================================= def create_demo_graph(): """Create a simplified demo graph with common T2V generation concepts""" # Create sample place-verb and place-scene graphs G_place_verb = nx.Graph() G_place_scene = nx.Graph() # Define places (central nodes) places = [ "forest", "beach", "city street", "mountain", "room", "park", "studio", "kitchen", "bridge", "parking lot", "desert", "lake" ] # Define verbs/actions for each place place_verbs = { "forest": ["walking through", "hiking in", "exploring", "camping in", "running through"], "beach": ["walking on", "swimming at", "surfing at", "relaxing on", "playing on"], "city street": ["walking down", "driving through", "running along", "biking through"], "mountain": ["climbing", "hiking up", "descending", "exploring", "camping on"], "room": ["sitting in", "working in", "relaxing in", "reading in", "sleeping in"], "park": ["walking in", "playing in", "jogging through", "sitting in", "picnicking in"], "studio": ["working in", "dancing in", "recording in", "practicing in"], "kitchen": ["cooking in", "preparing food in", "baking in", "cleaning"], "bridge": ["walking across", "driving across", "standing on", "running across"], "parking lot": ["standing in", "walking through", "driving in", "parking in"], "desert": ["walking through", "driving through", "camping in", "exploring"], "lake": ["swimming in", "boating on", "fishing at", "relaxing by"] } # Define scenarios/atmospheres for each place place_scenes = { "forest": ["dense trees", "peaceful atmosphere", "natural setting", "quiet surroundings"], "beach": ["ocean waves", "sunny day", "sandy shore", "coastal view"], "city street": ["busy traffic", "urban environment", "city lights", "crowded sidewalk"], "mountain": ["scenic view", "high altitude", "rocky terrain", "mountain peak"], "room": ["indoor setting", "comfortable space", "quiet environment", "cozy atmosphere"], "park": ["green grass", "open space", "trees around", "peaceful setting"], "studio": ["professional lighting", "indoor space", "creative environment"], "kitchen": ["modern appliances", "cooking area", "indoor setting", "bright lighting"], "bridge": ["elevated view", "water below", "connecting path", "architectural structure"], "parking lot": ["outdoor area", "vehicles around", "paved surface", "open space"], "desert": ["sandy terrain", "hot climate", "barren landscape", "vast expanse"], "lake": ["calm water", "natural scenery", "peaceful setting", "reflection on water"] } # Build graphs for place in places: # Add place-verb connections for verb in place_verbs.get(place, []): G_place_verb.add_edge(place, verb) # Add place-scene connections for scene in place_scenes.get(place, []): G_place_scene.add_edge(place, scene) # Create embeddings for all places place_embeddings = embedding_model.encode(places) # Create lookup dictionaries place_to_idx = {place: idx for idx, place in enumerate(places)} idx_to_place = {idx: place for place, idx in place_to_idx.items()} return G_place_verb, G_place_scene, place_embeddings, place_to_idx, idx_to_place # Initialize demo graph print("Creating demo knowledge graph...") G_place_verb, G_place_scene, place_embeddings, place_to_idx, idx_to_place = create_demo_graph() print("✓ Demo graph created") print("=" * 60) print("✓ Setup complete!") print("=" * 60) # ============================================================================= # Core RAPO Functions # ============================================================================= @spaces.GPU def retrieve_and_augment_prompt(prompt: str, place_num: int = 2, modifier_num: int = 5) -> tuple: """ Main RAPO function: Retrieves relevant modifiers from the graph and augments the prompt. Args: prompt: Input text-to-video generation prompt place_num: Number of top places to retrieve modifier_num: Number of modifiers to sample per place Returns: Tuple of (augmented_prompt, retrieved_info, places_found) """ # Encode input prompt prompt_embedding = embedding_model.encode(prompt) # Compute similarity with all places similarities = cosine_similarity( torch.tensor(prompt_embedding).unsqueeze(0), torch.tensor(place_embeddings) ) # Get top-K most similar places top_indices = torch.topk(similarities, min(place_num, len(place_to_idx))).indices # Retrieve modifiers from graph retrieved_verbs = [] retrieved_scenes = [] places_found = [] for idx in top_indices.numpy().tolist(): place = idx_to_place[idx] places_found.append(place) # Get verb neighbors verb_neighbors = list(G_place_verb.neighbors(place)) verb_samples = random.sample(verb_neighbors, min(modifier_num, len(verb_neighbors))) retrieved_verbs.extend(verb_samples) # Get scene neighbors scene_neighbors = list(G_place_scene.neighbors(place)) scene_samples = random.sample(scene_neighbors, min(modifier_num, len(scene_neighbors))) retrieved_scenes.extend(scene_samples) # Remove duplicates while preserving order retrieved_verbs = list(dict.fromkeys(retrieved_verbs)) retrieved_scenes = list(dict.fromkeys(retrieved_scenes)) # Create augmented prompt (simple version - just add contextual details) augmented_parts = [prompt.strip()] # Add most relevant modifiers if retrieved_verbs: augmented_parts.append(f"The scene shows {retrieved_verbs[0]}") if retrieved_scenes: augmented_parts.append(f"with {retrieved_scenes[0]}") augmented_prompt = ", ".join(augmented_parts) + "." # Format retrieved info for display retrieved_info = { "Places": places_found, "Actions": retrieved_verbs[:5], "Atmosphere": retrieved_scenes[:5] } return augmented_prompt, retrieved_info, places_found # ============================================================================= # Gradio Interface # ============================================================================= def process_prompt(prompt, place_num, modifier_num): """Process prompt and return results for Gradio""" if not prompt.strip(): return "Please enter a prompt.", {}, [] try: augmented_prompt, retrieved_info, places = retrieve_and_augment_prompt( prompt, place_num, modifier_num ) # Format retrieved info for display info_text = "**Retrieved Modifiers:**\n\n" info_text += f"**📍 Top Places:** {', '.join(places)}\n\n" info_text += f"**🎬 Actions:** {', '.join(retrieved_info['Actions'])}\n\n" info_text += f"**🌅 Atmosphere:** {', '.join(retrieved_info['Atmosphere'])}\n\n" return augmented_prompt, info_text except Exception as e: return f"Error: {str(e)}", "" # Create Gradio interface with gr.Blocks( theme=gr.themes.Soft( primary_hue="purple", secondary_hue="blue" ), title="RAPO++ Text-to-Video Prompt Optimization" ) as demo: gr.Markdown(""" # 🎬 RAPO++ Text-to-Video Prompt Optimization This demo showcases **Stage 1 (RAPO)**: Retrieval-Augmented Prompt Optimization using knowledge graphs. **How it works:** 1. Enter a simple text-to-video prompt 2. The system retrieves contextually relevant modifiers from a knowledge graph 3. Your prompt is enhanced with specific actions and atmospheric details 4. Use the optimized prompt for better T2V generation results! **Example prompts to try:** - "A person walking" - "A car driving" - "Someone cooking" - "A group of people talking" Based on the paper: [RAPO++ (arXiv:2510.20206)](https://arxiv.org/abs/2510.20206) """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Input") input_prompt = gr.Textbox( label="Original Prompt", placeholder="Enter your text-to-video prompt (e.g., 'A person walking')", lines=3 ) with gr.Accordion("Advanced Settings", open=False): place_num = gr.Slider( minimum=1, maximum=5, value=2, step=1, label="Number of Places to Retrieve", info="How many related places to search in the knowledge graph" ) modifier_num = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Modifiers per Place", info="How many modifiers to sample from each place" ) process_btn = gr.Button("✨ Optimize Prompt", variant="primary", size="lg") with gr.Column(scale=1): gr.Markdown("### Results") output_prompt = gr.Textbox( label="Optimized Prompt", lines=5, show_copy_button=True ) retrieved_info = gr.Markdown( label="Retrieved Information" ) # Example prompts gr.Examples( examples=[ ["A person walking", 2, 5], ["A car driving at night", 2, 5], ["Someone cooking in a kitchen", 2, 5], ["A group of people talking", 2, 5], ["A bird flying", 2, 5], ["Someone sitting and reading", 2, 5], ], inputs=[input_prompt, place_num, modifier_num], outputs=[output_prompt, retrieved_info], fn=process_prompt, cache_examples=False ) gr.Markdown(""" --- ### About RAPO++ RAPO++ is a three-stage framework for text-to-video generation prompt optimization: - **Stage 1 (RAPO)**: Retrieval-Augmented Prompt Optimization using relation graphs *(demonstrated here)* - **Stage 2 (SSPO)**: Self-Supervised Prompt Optimization with test-time iterative refinement - **Stage 3**: LLM fine-tuning on collected feedback data The system is model-agnostic and works with various T2V models (Wan2.1, Open-Sora-Plan, HunyuanVideo, etc.). **Papers:** - [RAPO (CVPR 2025)](https://arxiv.org/abs/2502.07516): The Devil is in the Prompts: Retrieval-Augmented Prompt Optimization for Text-to-Video Generation - [RAPO++ (arXiv:2510.20206)](https://arxiv.org/abs/2510.20206): Cross-Stage Prompt Optimization for Text-to-Video Generation via Data Alignment and Test-Time Scaling **Project Page:** [https://whynothaha.github.io/RAPO_plus_github/](https://whynothaha.github.io/RAPO_plus_github/) **GitHub:** [https://github.com/Vchitect/RAPO](https://github.com/Vchitect/RAPO) """) # Event handlers process_btn.click( fn=process_prompt, inputs=[input_prompt, place_num, modifier_num], outputs=[output_prompt, retrieved_info] ) input_prompt.submit( fn=process_prompt, inputs=[input_prompt, place_num, modifier_num], outputs=[output_prompt, retrieved_info] ) # Launch the app if __name__ == "__main__": demo.launch()