SNIPED_rapo / app.py
jbilcke-hf's picture
Upload repository for paper 2510.20206
ee81688 verified
"""
RAPO++ Text-to-Video Prompt Optimization Demo
This demo showcases Stage 1 (RAPO): Retrieval-Augmented Prompt Optimization
It demonstrates how simple prompts can be enriched with contextually relevant modifiers
retrieved from a knowledge graph for better text-to-video generation.
"""
# CRITICAL: Import spaces FIRST before any CUDA-related packages
import spaces
import gradio as gr
import torch
from sentence_transformers import SentenceTransformer
from torch.nn.functional import cosine_similarity
import networkx as nx
import json
import os
import random
from huggingface_hub import snapshot_download, hf_hub_download
# =============================================================================
# Model and Data Setup (runs once at startup)
# =============================================================================
print("=" * 60)
print("Setting up RAPO++ Demo...")
print("=" * 60)
# Create necessary directories
os.makedirs("./ckpt", exist_ok=True)
os.makedirs("./relation_graph/graph_data", exist_ok=True)
# Download SentenceTransformer model for embeddings
SENTENCE_TRANSFORMER_PATH = "./ckpt/all-MiniLM-L6-v2"
if not os.path.exists(SENTENCE_TRANSFORMER_PATH):
print("Downloading SentenceTransformer model...")
snapshot_download(
repo_id="sentence-transformers/all-MiniLM-L6-v2",
local_dir=SENTENCE_TRANSFORMER_PATH,
local_dir_use_symlinks=False
)
print("βœ“ SentenceTransformer downloaded")
else:
print("βœ“ SentenceTransformer already cached")
# Load SentenceTransformer model
print("Loading SentenceTransformer model...")
embedding_model = SentenceTransformer(SENTENCE_TRANSFORMER_PATH)
print("βœ“ Model loaded")
# =============================================================================
# Simple Demo Graph Creation (since full graph data requires large download)
# =============================================================================
def create_demo_graph():
"""Create a simplified demo graph with common T2V generation concepts"""
# Create sample place-verb and place-scene graphs
G_place_verb = nx.Graph()
G_place_scene = nx.Graph()
# Define places (central nodes)
places = [
"forest", "beach", "city street", "mountain", "room", "park",
"studio", "kitchen", "bridge", "parking lot", "desert", "lake"
]
# Define verbs/actions for each place
place_verbs = {
"forest": ["walking through", "hiking in", "exploring", "camping in", "running through"],
"beach": ["walking on", "swimming at", "surfing at", "relaxing on", "playing on"],
"city street": ["walking down", "driving through", "running along", "biking through"],
"mountain": ["climbing", "hiking up", "descending", "exploring", "camping on"],
"room": ["sitting in", "working in", "relaxing in", "reading in", "sleeping in"],
"park": ["walking in", "playing in", "jogging through", "sitting in", "picnicking in"],
"studio": ["working in", "dancing in", "recording in", "practicing in"],
"kitchen": ["cooking in", "preparing food in", "baking in", "cleaning"],
"bridge": ["walking across", "driving across", "standing on", "running across"],
"parking lot": ["standing in", "walking through", "driving in", "parking in"],
"desert": ["walking through", "driving through", "camping in", "exploring"],
"lake": ["swimming in", "boating on", "fishing at", "relaxing by"]
}
# Define scenarios/atmospheres for each place
place_scenes = {
"forest": ["dense trees", "peaceful atmosphere", "natural setting", "quiet surroundings"],
"beach": ["ocean waves", "sunny day", "sandy shore", "coastal view"],
"city street": ["busy traffic", "urban environment", "city lights", "crowded sidewalk"],
"mountain": ["scenic view", "high altitude", "rocky terrain", "mountain peak"],
"room": ["indoor setting", "comfortable space", "quiet environment", "cozy atmosphere"],
"park": ["green grass", "open space", "trees around", "peaceful setting"],
"studio": ["professional lighting", "indoor space", "creative environment"],
"kitchen": ["modern appliances", "cooking area", "indoor setting", "bright lighting"],
"bridge": ["elevated view", "water below", "connecting path", "architectural structure"],
"parking lot": ["outdoor area", "vehicles around", "paved surface", "open space"],
"desert": ["sandy terrain", "hot climate", "barren landscape", "vast expanse"],
"lake": ["calm water", "natural scenery", "peaceful setting", "reflection on water"]
}
# Build graphs
for place in places:
# Add place-verb connections
for verb in place_verbs.get(place, []):
G_place_verb.add_edge(place, verb)
# Add place-scene connections
for scene in place_scenes.get(place, []):
G_place_scene.add_edge(place, scene)
# Create embeddings for all places
place_embeddings = embedding_model.encode(places)
# Create lookup dictionaries
place_to_idx = {place: idx for idx, place in enumerate(places)}
idx_to_place = {idx: place for place, idx in place_to_idx.items()}
return G_place_verb, G_place_scene, place_embeddings, place_to_idx, idx_to_place
# Initialize demo graph
print("Creating demo knowledge graph...")
G_place_verb, G_place_scene, place_embeddings, place_to_idx, idx_to_place = create_demo_graph()
print("βœ“ Demo graph created")
print("=" * 60)
print("βœ“ Setup complete!")
print("=" * 60)
# =============================================================================
# Core RAPO Functions
# =============================================================================
@spaces.GPU
def retrieve_and_augment_prompt(prompt: str, place_num: int = 2, modifier_num: int = 5) -> tuple:
"""
Main RAPO function: Retrieves relevant modifiers from the graph and augments the prompt.
Args:
prompt: Input text-to-video generation prompt
place_num: Number of top places to retrieve
modifier_num: Number of modifiers to sample per place
Returns:
Tuple of (augmented_prompt, retrieved_info, places_found)
"""
# Encode input prompt
prompt_embedding = embedding_model.encode(prompt)
# Compute similarity with all places
similarities = cosine_similarity(
torch.tensor(prompt_embedding).unsqueeze(0),
torch.tensor(place_embeddings)
)
# Get top-K most similar places
top_indices = torch.topk(similarities, min(place_num, len(place_to_idx))).indices
# Retrieve modifiers from graph
retrieved_verbs = []
retrieved_scenes = []
places_found = []
for idx in top_indices.numpy().tolist():
place = idx_to_place[idx]
places_found.append(place)
# Get verb neighbors
verb_neighbors = list(G_place_verb.neighbors(place))
verb_samples = random.sample(verb_neighbors, min(modifier_num, len(verb_neighbors)))
retrieved_verbs.extend(verb_samples)
# Get scene neighbors
scene_neighbors = list(G_place_scene.neighbors(place))
scene_samples = random.sample(scene_neighbors, min(modifier_num, len(scene_neighbors)))
retrieved_scenes.extend(scene_samples)
# Remove duplicates while preserving order
retrieved_verbs = list(dict.fromkeys(retrieved_verbs))
retrieved_scenes = list(dict.fromkeys(retrieved_scenes))
# Create augmented prompt (simple version - just add contextual details)
augmented_parts = [prompt.strip()]
# Add most relevant modifiers
if retrieved_verbs:
augmented_parts.append(f"The scene shows {retrieved_verbs[0]}")
if retrieved_scenes:
augmented_parts.append(f"with {retrieved_scenes[0]}")
augmented_prompt = ", ".join(augmented_parts) + "."
# Format retrieved info for display
retrieved_info = {
"Places": places_found,
"Actions": retrieved_verbs[:5],
"Atmosphere": retrieved_scenes[:5]
}
return augmented_prompt, retrieved_info, places_found
# =============================================================================
# Gradio Interface
# =============================================================================
def process_prompt(prompt, place_num, modifier_num):
"""Process prompt and return results for Gradio"""
if not prompt.strip():
return "Please enter a prompt.", {}, []
try:
augmented_prompt, retrieved_info, places = retrieve_and_augment_prompt(
prompt, place_num, modifier_num
)
# Format retrieved info for display
info_text = "**Retrieved Modifiers:**\n\n"
info_text += f"**πŸ“ Top Places:** {', '.join(places)}\n\n"
info_text += f"**🎬 Actions:** {', '.join(retrieved_info['Actions'])}\n\n"
info_text += f"**πŸŒ… Atmosphere:** {', '.join(retrieved_info['Atmosphere'])}\n\n"
return augmented_prompt, info_text
except Exception as e:
return f"Error: {str(e)}", ""
# Create Gradio interface
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="purple",
secondary_hue="blue"
),
title="RAPO++ Text-to-Video Prompt Optimization"
) as demo:
gr.Markdown("""
# 🎬 RAPO++ Text-to-Video Prompt Optimization
This demo showcases **Stage 1 (RAPO)**: Retrieval-Augmented Prompt Optimization using knowledge graphs.
**How it works:**
1. Enter a simple text-to-video prompt
2. The system retrieves contextually relevant modifiers from a knowledge graph
3. Your prompt is enhanced with specific actions and atmospheric details
4. Use the optimized prompt for better T2V generation results!
**Example prompts to try:**
- "A person walking"
- "A car driving"
- "Someone cooking"
- "A group of people talking"
Based on the paper: [RAPO++ (arXiv:2510.20206)](https://arxiv.org/abs/2510.20206)
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input")
input_prompt = gr.Textbox(
label="Original Prompt",
placeholder="Enter your text-to-video prompt (e.g., 'A person walking')",
lines=3
)
with gr.Accordion("Advanced Settings", open=False):
place_num = gr.Slider(
minimum=1,
maximum=5,
value=2,
step=1,
label="Number of Places to Retrieve",
info="How many related places to search in the knowledge graph"
)
modifier_num = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Modifiers per Place",
info="How many modifiers to sample from each place"
)
process_btn = gr.Button("✨ Optimize Prompt", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### Results")
output_prompt = gr.Textbox(
label="Optimized Prompt",
lines=5,
show_copy_button=True
)
retrieved_info = gr.Markdown(
label="Retrieved Information"
)
# Example prompts
gr.Examples(
examples=[
["A person walking", 2, 5],
["A car driving at night", 2, 5],
["Someone cooking in a kitchen", 2, 5],
["A group of people talking", 2, 5],
["A bird flying", 2, 5],
["Someone sitting and reading", 2, 5],
],
inputs=[input_prompt, place_num, modifier_num],
outputs=[output_prompt, retrieved_info],
fn=process_prompt,
cache_examples=False
)
gr.Markdown("""
---
### About RAPO++
RAPO++ is a three-stage framework for text-to-video generation prompt optimization:
- **Stage 1 (RAPO)**: Retrieval-Augmented Prompt Optimization using relation graphs *(demonstrated here)*
- **Stage 2 (SSPO)**: Self-Supervised Prompt Optimization with test-time iterative refinement
- **Stage 3**: LLM fine-tuning on collected feedback data
The system is model-agnostic and works with various T2V models (Wan2.1, Open-Sora-Plan, HunyuanVideo, etc.).
**Papers:**
- [RAPO (CVPR 2025)](https://arxiv.org/abs/2502.07516): The Devil is in the Prompts: Retrieval-Augmented Prompt Optimization for Text-to-Video Generation
- [RAPO++ (arXiv:2510.20206)](https://arxiv.org/abs/2510.20206): Cross-Stage Prompt Optimization for Text-to-Video Generation via Data Alignment and Test-Time Scaling
**Project Page:** [https://whynothaha.github.io/RAPO_plus_github/](https://whynothaha.github.io/RAPO_plus_github/)
**GitHub:** [https://github.com/Vchitect/RAPO](https://github.com/Vchitect/RAPO)
""")
# Event handlers
process_btn.click(
fn=process_prompt,
inputs=[input_prompt, place_num, modifier_num],
outputs=[output_prompt, retrieved_info]
)
input_prompt.submit(
fn=process_prompt,
inputs=[input_prompt, place_num, modifier_num],
outputs=[output_prompt, retrieved_info]
)
# Launch the app
if __name__ == "__main__":
demo.launch()