File size: 13,381 Bytes
ee81688
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
"""
RAPO++ Text-to-Video Prompt Optimization Demo

This demo showcases Stage 1 (RAPO): Retrieval-Augmented Prompt Optimization
It demonstrates how simple prompts can be enriched with contextually relevant modifiers
retrieved from a knowledge graph for better text-to-video generation.
"""

# CRITICAL: Import spaces FIRST before any CUDA-related packages
import spaces

import gradio as gr
import torch
from sentence_transformers import SentenceTransformer
from torch.nn.functional import cosine_similarity
import networkx as nx
import json
import os
import random
from huggingface_hub import snapshot_download, hf_hub_download

# =============================================================================
# Model and Data Setup (runs once at startup)
# =============================================================================

print("=" * 60)
print("Setting up RAPO++ Demo...")
print("=" * 60)

# Create necessary directories
os.makedirs("./ckpt", exist_ok=True)
os.makedirs("./relation_graph/graph_data", exist_ok=True)

# Download SentenceTransformer model for embeddings
SENTENCE_TRANSFORMER_PATH = "./ckpt/all-MiniLM-L6-v2"
if not os.path.exists(SENTENCE_TRANSFORMER_PATH):
    print("Downloading SentenceTransformer model...")
    snapshot_download(
        repo_id="sentence-transformers/all-MiniLM-L6-v2",
        local_dir=SENTENCE_TRANSFORMER_PATH,
        local_dir_use_symlinks=False
    )
    print("βœ“ SentenceTransformer downloaded")
else:
    print("βœ“ SentenceTransformer already cached")

# Load SentenceTransformer model
print("Loading SentenceTransformer model...")
embedding_model = SentenceTransformer(SENTENCE_TRANSFORMER_PATH)
print("βœ“ Model loaded")

# =============================================================================
# Simple Demo Graph Creation (since full graph data requires large download)
# =============================================================================

def create_demo_graph():
    """Create a simplified demo graph with common T2V generation concepts"""

    # Create sample place-verb and place-scene graphs
    G_place_verb = nx.Graph()
    G_place_scene = nx.Graph()

    # Define places (central nodes)
    places = [
        "forest", "beach", "city street", "mountain", "room", "park",
        "studio", "kitchen", "bridge", "parking lot", "desert", "lake"
    ]

    # Define verbs/actions for each place
    place_verbs = {
        "forest": ["walking through", "hiking in", "exploring", "camping in", "running through"],
        "beach": ["walking on", "swimming at", "surfing at", "relaxing on", "playing on"],
        "city street": ["walking down", "driving through", "running along", "biking through"],
        "mountain": ["climbing", "hiking up", "descending", "exploring", "camping on"],
        "room": ["sitting in", "working in", "relaxing in", "reading in", "sleeping in"],
        "park": ["walking in", "playing in", "jogging through", "sitting in", "picnicking in"],
        "studio": ["working in", "dancing in", "recording in", "practicing in"],
        "kitchen": ["cooking in", "preparing food in", "baking in", "cleaning"],
        "bridge": ["walking across", "driving across", "standing on", "running across"],
        "parking lot": ["standing in", "walking through", "driving in", "parking in"],
        "desert": ["walking through", "driving through", "camping in", "exploring"],
        "lake": ["swimming in", "boating on", "fishing at", "relaxing by"]
    }

    # Define scenarios/atmospheres for each place
    place_scenes = {
        "forest": ["dense trees", "peaceful atmosphere", "natural setting", "quiet surroundings"],
        "beach": ["ocean waves", "sunny day", "sandy shore", "coastal view"],
        "city street": ["busy traffic", "urban environment", "city lights", "crowded sidewalk"],
        "mountain": ["scenic view", "high altitude", "rocky terrain", "mountain peak"],
        "room": ["indoor setting", "comfortable space", "quiet environment", "cozy atmosphere"],
        "park": ["green grass", "open space", "trees around", "peaceful setting"],
        "studio": ["professional lighting", "indoor space", "creative environment"],
        "kitchen": ["modern appliances", "cooking area", "indoor setting", "bright lighting"],
        "bridge": ["elevated view", "water below", "connecting path", "architectural structure"],
        "parking lot": ["outdoor area", "vehicles around", "paved surface", "open space"],
        "desert": ["sandy terrain", "hot climate", "barren landscape", "vast expanse"],
        "lake": ["calm water", "natural scenery", "peaceful setting", "reflection on water"]
    }

    # Build graphs
    for place in places:
        # Add place-verb connections
        for verb in place_verbs.get(place, []):
            G_place_verb.add_edge(place, verb)

        # Add place-scene connections
        for scene in place_scenes.get(place, []):
            G_place_scene.add_edge(place, scene)

    # Create embeddings for all places
    place_embeddings = embedding_model.encode(places)

    # Create lookup dictionaries
    place_to_idx = {place: idx for idx, place in enumerate(places)}
    idx_to_place = {idx: place for place, idx in place_to_idx.items()}

    return G_place_verb, G_place_scene, place_embeddings, place_to_idx, idx_to_place

# Initialize demo graph
print("Creating demo knowledge graph...")
G_place_verb, G_place_scene, place_embeddings, place_to_idx, idx_to_place = create_demo_graph()
print("βœ“ Demo graph created")
print("=" * 60)
print("βœ“ Setup complete!")
print("=" * 60)

# =============================================================================
# Core RAPO Functions
# =============================================================================

@spaces.GPU
def retrieve_and_augment_prompt(prompt: str, place_num: int = 2, modifier_num: int = 5) -> tuple:
    """
    Main RAPO function: Retrieves relevant modifiers from the graph and augments the prompt.

    Args:
        prompt: Input text-to-video generation prompt
        place_num: Number of top places to retrieve
        modifier_num: Number of modifiers to sample per place

    Returns:
        Tuple of (augmented_prompt, retrieved_info, places_found)
    """
    # Encode input prompt
    prompt_embedding = embedding_model.encode(prompt)

    # Compute similarity with all places
    similarities = cosine_similarity(
        torch.tensor(prompt_embedding).unsqueeze(0),
        torch.tensor(place_embeddings)
    )

    # Get top-K most similar places
    top_indices = torch.topk(similarities, min(place_num, len(place_to_idx))).indices

    # Retrieve modifiers from graph
    retrieved_verbs = []
    retrieved_scenes = []
    places_found = []

    for idx in top_indices.numpy().tolist():
        place = idx_to_place[idx]
        places_found.append(place)

        # Get verb neighbors
        verb_neighbors = list(G_place_verb.neighbors(place))
        verb_samples = random.sample(verb_neighbors, min(modifier_num, len(verb_neighbors)))
        retrieved_verbs.extend(verb_samples)

        # Get scene neighbors
        scene_neighbors = list(G_place_scene.neighbors(place))
        scene_samples = random.sample(scene_neighbors, min(modifier_num, len(scene_neighbors)))
        retrieved_scenes.extend(scene_samples)

    # Remove duplicates while preserving order
    retrieved_verbs = list(dict.fromkeys(retrieved_verbs))
    retrieved_scenes = list(dict.fromkeys(retrieved_scenes))

    # Create augmented prompt (simple version - just add contextual details)
    augmented_parts = [prompt.strip()]

    # Add most relevant modifiers
    if retrieved_verbs:
        augmented_parts.append(f"The scene shows {retrieved_verbs[0]}")
    if retrieved_scenes:
        augmented_parts.append(f"with {retrieved_scenes[0]}")

    augmented_prompt = ", ".join(augmented_parts) + "."

    # Format retrieved info for display
    retrieved_info = {
        "Places": places_found,
        "Actions": retrieved_verbs[:5],
        "Atmosphere": retrieved_scenes[:5]
    }

    return augmented_prompt, retrieved_info, places_found

# =============================================================================
# Gradio Interface
# =============================================================================

def process_prompt(prompt, place_num, modifier_num):
    """Process prompt and return results for Gradio"""
    if not prompt.strip():
        return "Please enter a prompt.", {}, []

    try:
        augmented_prompt, retrieved_info, places = retrieve_and_augment_prompt(
            prompt, place_num, modifier_num
        )

        # Format retrieved info for display
        info_text = "**Retrieved Modifiers:**\n\n"
        info_text += f"**πŸ“ Top Places:** {', '.join(places)}\n\n"
        info_text += f"**🎬 Actions:** {', '.join(retrieved_info['Actions'])}\n\n"
        info_text += f"**πŸŒ… Atmosphere:** {', '.join(retrieved_info['Atmosphere'])}\n\n"

        return augmented_prompt, info_text
    except Exception as e:
        return f"Error: {str(e)}", ""

# Create Gradio interface
with gr.Blocks(
    theme=gr.themes.Soft(
        primary_hue="purple",
        secondary_hue="blue"
    ),
    title="RAPO++ Text-to-Video Prompt Optimization"
) as demo:

    gr.Markdown("""
    # 🎬 RAPO++ Text-to-Video Prompt Optimization

    This demo showcases **Stage 1 (RAPO)**: Retrieval-Augmented Prompt Optimization using knowledge graphs.

    **How it works:**
    1. Enter a simple text-to-video prompt
    2. The system retrieves contextually relevant modifiers from a knowledge graph
    3. Your prompt is enhanced with specific actions and atmospheric details
    4. Use the optimized prompt for better T2V generation results!

    **Example prompts to try:**
    - "A person walking"
    - "A car driving"
    - "Someone cooking"
    - "A group of people talking"

    Based on the paper: [RAPO++ (arXiv:2510.20206)](https://arxiv.org/abs/2510.20206)
    """)

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Input")

            input_prompt = gr.Textbox(
                label="Original Prompt",
                placeholder="Enter your text-to-video prompt (e.g., 'A person walking')",
                lines=3
            )

            with gr.Accordion("Advanced Settings", open=False):
                place_num = gr.Slider(
                    minimum=1,
                    maximum=5,
                    value=2,
                    step=1,
                    label="Number of Places to Retrieve",
                    info="How many related places to search in the knowledge graph"
                )

                modifier_num = gr.Slider(
                    minimum=1,
                    maximum=10,
                    value=5,
                    step=1,
                    label="Modifiers per Place",
                    info="How many modifiers to sample from each place"
                )

            process_btn = gr.Button("✨ Optimize Prompt", variant="primary", size="lg")

        with gr.Column(scale=1):
            gr.Markdown("### Results")

            output_prompt = gr.Textbox(
                label="Optimized Prompt",
                lines=5,
                show_copy_button=True
            )

            retrieved_info = gr.Markdown(
                label="Retrieved Information"
            )

    # Example prompts
    gr.Examples(
        examples=[
            ["A person walking", 2, 5],
            ["A car driving at night", 2, 5],
            ["Someone cooking in a kitchen", 2, 5],
            ["A group of people talking", 2, 5],
            ["A bird flying", 2, 5],
            ["Someone sitting and reading", 2, 5],
        ],
        inputs=[input_prompt, place_num, modifier_num],
        outputs=[output_prompt, retrieved_info],
        fn=process_prompt,
        cache_examples=False
    )

    gr.Markdown("""
    ---
    ### About RAPO++

    RAPO++ is a three-stage framework for text-to-video generation prompt optimization:

    - **Stage 1 (RAPO)**: Retrieval-Augmented Prompt Optimization using relation graphs *(demonstrated here)*
    - **Stage 2 (SSPO)**: Self-Supervised Prompt Optimization with test-time iterative refinement
    - **Stage 3**: LLM fine-tuning on collected feedback data

    The system is model-agnostic and works with various T2V models (Wan2.1, Open-Sora-Plan, HunyuanVideo, etc.).

    **Papers:**
    - [RAPO (CVPR 2025)](https://arxiv.org/abs/2502.07516): The Devil is in the Prompts: Retrieval-Augmented Prompt Optimization for Text-to-Video Generation
    - [RAPO++ (arXiv:2510.20206)](https://arxiv.org/abs/2510.20206): Cross-Stage Prompt Optimization for Text-to-Video Generation via Data Alignment and Test-Time Scaling

    **Project Page:** [https://whynothaha.github.io/RAPO_plus_github/](https://whynothaha.github.io/RAPO_plus_github/)

    **GitHub:** [https://github.com/Vchitect/RAPO](https://github.com/Vchitect/RAPO)
    """)

    # Event handlers
    process_btn.click(
        fn=process_prompt,
        inputs=[input_prompt, place_num, modifier_num],
        outputs=[output_prompt, retrieved_info]
    )

    input_prompt.submit(
        fn=process_prompt,
        inputs=[input_prompt, place_num, modifier_num],
        outputs=[output_prompt, retrieved_info]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()