Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,381 Bytes
ee81688 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 |
"""
RAPO++ Text-to-Video Prompt Optimization Demo
This demo showcases Stage 1 (RAPO): Retrieval-Augmented Prompt Optimization
It demonstrates how simple prompts can be enriched with contextually relevant modifiers
retrieved from a knowledge graph for better text-to-video generation.
"""
# CRITICAL: Import spaces FIRST before any CUDA-related packages
import spaces
import gradio as gr
import torch
from sentence_transformers import SentenceTransformer
from torch.nn.functional import cosine_similarity
import networkx as nx
import json
import os
import random
from huggingface_hub import snapshot_download, hf_hub_download
# =============================================================================
# Model and Data Setup (runs once at startup)
# =============================================================================
print("=" * 60)
print("Setting up RAPO++ Demo...")
print("=" * 60)
# Create necessary directories
os.makedirs("./ckpt", exist_ok=True)
os.makedirs("./relation_graph/graph_data", exist_ok=True)
# Download SentenceTransformer model for embeddings
SENTENCE_TRANSFORMER_PATH = "./ckpt/all-MiniLM-L6-v2"
if not os.path.exists(SENTENCE_TRANSFORMER_PATH):
print("Downloading SentenceTransformer model...")
snapshot_download(
repo_id="sentence-transformers/all-MiniLM-L6-v2",
local_dir=SENTENCE_TRANSFORMER_PATH,
local_dir_use_symlinks=False
)
print("β SentenceTransformer downloaded")
else:
print("β SentenceTransformer already cached")
# Load SentenceTransformer model
print("Loading SentenceTransformer model...")
embedding_model = SentenceTransformer(SENTENCE_TRANSFORMER_PATH)
print("β Model loaded")
# =============================================================================
# Simple Demo Graph Creation (since full graph data requires large download)
# =============================================================================
def create_demo_graph():
"""Create a simplified demo graph with common T2V generation concepts"""
# Create sample place-verb and place-scene graphs
G_place_verb = nx.Graph()
G_place_scene = nx.Graph()
# Define places (central nodes)
places = [
"forest", "beach", "city street", "mountain", "room", "park",
"studio", "kitchen", "bridge", "parking lot", "desert", "lake"
]
# Define verbs/actions for each place
place_verbs = {
"forest": ["walking through", "hiking in", "exploring", "camping in", "running through"],
"beach": ["walking on", "swimming at", "surfing at", "relaxing on", "playing on"],
"city street": ["walking down", "driving through", "running along", "biking through"],
"mountain": ["climbing", "hiking up", "descending", "exploring", "camping on"],
"room": ["sitting in", "working in", "relaxing in", "reading in", "sleeping in"],
"park": ["walking in", "playing in", "jogging through", "sitting in", "picnicking in"],
"studio": ["working in", "dancing in", "recording in", "practicing in"],
"kitchen": ["cooking in", "preparing food in", "baking in", "cleaning"],
"bridge": ["walking across", "driving across", "standing on", "running across"],
"parking lot": ["standing in", "walking through", "driving in", "parking in"],
"desert": ["walking through", "driving through", "camping in", "exploring"],
"lake": ["swimming in", "boating on", "fishing at", "relaxing by"]
}
# Define scenarios/atmospheres for each place
place_scenes = {
"forest": ["dense trees", "peaceful atmosphere", "natural setting", "quiet surroundings"],
"beach": ["ocean waves", "sunny day", "sandy shore", "coastal view"],
"city street": ["busy traffic", "urban environment", "city lights", "crowded sidewalk"],
"mountain": ["scenic view", "high altitude", "rocky terrain", "mountain peak"],
"room": ["indoor setting", "comfortable space", "quiet environment", "cozy atmosphere"],
"park": ["green grass", "open space", "trees around", "peaceful setting"],
"studio": ["professional lighting", "indoor space", "creative environment"],
"kitchen": ["modern appliances", "cooking area", "indoor setting", "bright lighting"],
"bridge": ["elevated view", "water below", "connecting path", "architectural structure"],
"parking lot": ["outdoor area", "vehicles around", "paved surface", "open space"],
"desert": ["sandy terrain", "hot climate", "barren landscape", "vast expanse"],
"lake": ["calm water", "natural scenery", "peaceful setting", "reflection on water"]
}
# Build graphs
for place in places:
# Add place-verb connections
for verb in place_verbs.get(place, []):
G_place_verb.add_edge(place, verb)
# Add place-scene connections
for scene in place_scenes.get(place, []):
G_place_scene.add_edge(place, scene)
# Create embeddings for all places
place_embeddings = embedding_model.encode(places)
# Create lookup dictionaries
place_to_idx = {place: idx for idx, place in enumerate(places)}
idx_to_place = {idx: place for place, idx in place_to_idx.items()}
return G_place_verb, G_place_scene, place_embeddings, place_to_idx, idx_to_place
# Initialize demo graph
print("Creating demo knowledge graph...")
G_place_verb, G_place_scene, place_embeddings, place_to_idx, idx_to_place = create_demo_graph()
print("β Demo graph created")
print("=" * 60)
print("β Setup complete!")
print("=" * 60)
# =============================================================================
# Core RAPO Functions
# =============================================================================
@spaces.GPU
def retrieve_and_augment_prompt(prompt: str, place_num: int = 2, modifier_num: int = 5) -> tuple:
"""
Main RAPO function: Retrieves relevant modifiers from the graph and augments the prompt.
Args:
prompt: Input text-to-video generation prompt
place_num: Number of top places to retrieve
modifier_num: Number of modifiers to sample per place
Returns:
Tuple of (augmented_prompt, retrieved_info, places_found)
"""
# Encode input prompt
prompt_embedding = embedding_model.encode(prompt)
# Compute similarity with all places
similarities = cosine_similarity(
torch.tensor(prompt_embedding).unsqueeze(0),
torch.tensor(place_embeddings)
)
# Get top-K most similar places
top_indices = torch.topk(similarities, min(place_num, len(place_to_idx))).indices
# Retrieve modifiers from graph
retrieved_verbs = []
retrieved_scenes = []
places_found = []
for idx in top_indices.numpy().tolist():
place = idx_to_place[idx]
places_found.append(place)
# Get verb neighbors
verb_neighbors = list(G_place_verb.neighbors(place))
verb_samples = random.sample(verb_neighbors, min(modifier_num, len(verb_neighbors)))
retrieved_verbs.extend(verb_samples)
# Get scene neighbors
scene_neighbors = list(G_place_scene.neighbors(place))
scene_samples = random.sample(scene_neighbors, min(modifier_num, len(scene_neighbors)))
retrieved_scenes.extend(scene_samples)
# Remove duplicates while preserving order
retrieved_verbs = list(dict.fromkeys(retrieved_verbs))
retrieved_scenes = list(dict.fromkeys(retrieved_scenes))
# Create augmented prompt (simple version - just add contextual details)
augmented_parts = [prompt.strip()]
# Add most relevant modifiers
if retrieved_verbs:
augmented_parts.append(f"The scene shows {retrieved_verbs[0]}")
if retrieved_scenes:
augmented_parts.append(f"with {retrieved_scenes[0]}")
augmented_prompt = ", ".join(augmented_parts) + "."
# Format retrieved info for display
retrieved_info = {
"Places": places_found,
"Actions": retrieved_verbs[:5],
"Atmosphere": retrieved_scenes[:5]
}
return augmented_prompt, retrieved_info, places_found
# =============================================================================
# Gradio Interface
# =============================================================================
def process_prompt(prompt, place_num, modifier_num):
"""Process prompt and return results for Gradio"""
if not prompt.strip():
return "Please enter a prompt.", {}, []
try:
augmented_prompt, retrieved_info, places = retrieve_and_augment_prompt(
prompt, place_num, modifier_num
)
# Format retrieved info for display
info_text = "**Retrieved Modifiers:**\n\n"
info_text += f"**π Top Places:** {', '.join(places)}\n\n"
info_text += f"**π¬ Actions:** {', '.join(retrieved_info['Actions'])}\n\n"
info_text += f"**π
Atmosphere:** {', '.join(retrieved_info['Atmosphere'])}\n\n"
return augmented_prompt, info_text
except Exception as e:
return f"Error: {str(e)}", ""
# Create Gradio interface
with gr.Blocks(
theme=gr.themes.Soft(
primary_hue="purple",
secondary_hue="blue"
),
title="RAPO++ Text-to-Video Prompt Optimization"
) as demo:
gr.Markdown("""
# π¬ RAPO++ Text-to-Video Prompt Optimization
This demo showcases **Stage 1 (RAPO)**: Retrieval-Augmented Prompt Optimization using knowledge graphs.
**How it works:**
1. Enter a simple text-to-video prompt
2. The system retrieves contextually relevant modifiers from a knowledge graph
3. Your prompt is enhanced with specific actions and atmospheric details
4. Use the optimized prompt for better T2V generation results!
**Example prompts to try:**
- "A person walking"
- "A car driving"
- "Someone cooking"
- "A group of people talking"
Based on the paper: [RAPO++ (arXiv:2510.20206)](https://arxiv.org/abs/2510.20206)
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Input")
input_prompt = gr.Textbox(
label="Original Prompt",
placeholder="Enter your text-to-video prompt (e.g., 'A person walking')",
lines=3
)
with gr.Accordion("Advanced Settings", open=False):
place_num = gr.Slider(
minimum=1,
maximum=5,
value=2,
step=1,
label="Number of Places to Retrieve",
info="How many related places to search in the knowledge graph"
)
modifier_num = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Modifiers per Place",
info="How many modifiers to sample from each place"
)
process_btn = gr.Button("β¨ Optimize Prompt", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### Results")
output_prompt = gr.Textbox(
label="Optimized Prompt",
lines=5,
show_copy_button=True
)
retrieved_info = gr.Markdown(
label="Retrieved Information"
)
# Example prompts
gr.Examples(
examples=[
["A person walking", 2, 5],
["A car driving at night", 2, 5],
["Someone cooking in a kitchen", 2, 5],
["A group of people talking", 2, 5],
["A bird flying", 2, 5],
["Someone sitting and reading", 2, 5],
],
inputs=[input_prompt, place_num, modifier_num],
outputs=[output_prompt, retrieved_info],
fn=process_prompt,
cache_examples=False
)
gr.Markdown("""
---
### About RAPO++
RAPO++ is a three-stage framework for text-to-video generation prompt optimization:
- **Stage 1 (RAPO)**: Retrieval-Augmented Prompt Optimization using relation graphs *(demonstrated here)*
- **Stage 2 (SSPO)**: Self-Supervised Prompt Optimization with test-time iterative refinement
- **Stage 3**: LLM fine-tuning on collected feedback data
The system is model-agnostic and works with various T2V models (Wan2.1, Open-Sora-Plan, HunyuanVideo, etc.).
**Papers:**
- [RAPO (CVPR 2025)](https://arxiv.org/abs/2502.07516): The Devil is in the Prompts: Retrieval-Augmented Prompt Optimization for Text-to-Video Generation
- [RAPO++ (arXiv:2510.20206)](https://arxiv.org/abs/2510.20206): Cross-Stage Prompt Optimization for Text-to-Video Generation via Data Alignment and Test-Time Scaling
**Project Page:** [https://whynothaha.github.io/RAPO_plus_github/](https://whynothaha.github.io/RAPO_plus_github/)
**GitHub:** [https://github.com/Vchitect/RAPO](https://github.com/Vchitect/RAPO)
""")
# Event handlers
process_btn.click(
fn=process_prompt,
inputs=[input_prompt, place_num, modifier_num],
outputs=[output_prompt, retrieved_info]
)
input_prompt.submit(
fn=process_prompt,
inputs=[input_prompt, place_num, modifier_num],
outputs=[output_prompt, retrieved_info]
)
# Launch the app
if __name__ == "__main__":
demo.launch()
|