import gradio as gr
import torch
import uuid
from transformers import AutoProcessor, AutoModelForImageTextToText
import spaces
from molmo_utils import process_vision_info
from typing import Iterable
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
colors.orange_red = colors.Color(
name="orange_red",
c50="#FFF0E5",
c100="#FFE0CC",
c200="#FFC299",
c300="#FFA366",
c400="#FF8533",
c500="#FF4500",
c600="#E63E00",
c700="#CC3700",
c800="#B33000",
c900="#992900",
c950="#802200",
)
class OrangeRedTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.orange_red,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_secondary_text_color="black",
button_secondary_text_color_hover="white",
button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
slider_color="*secondary_500",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
button_large_padding="11px",
color_accent_soft="*primary_100",
block_label_background_fill="*primary_200",
)
orange_red_theme = OrangeRedTheme()
css = """
#main-title h1 {font-size: 2.4em !important;}
/* RadioAnimated Styles */
.ra-wrap{ width: fit-content; }
.ra-inner{
position: relative; display: inline-flex; align-items: center; gap: 0; padding: 6px;
background: var(--neutral-200); border-radius: 9999px; overflow: hidden;
}
.ra-input{ display: none; }
.ra-label{
position: relative; z-index: 2; padding: 8px 16px;
font-family: inherit; font-size: 14px; font-weight: 600;
color: var(--neutral-500); cursor: pointer; transition: color 0.2s; white-space: nowrap;
}
.ra-highlight{
position: absolute; z-index: 1; top: 6px; left: 6px;
height: calc(100% - 12px); border-radius: 9999px;
background: white; box-shadow: 0 2px 4px rgba(0,0,0,0.1);
transition: transform 0.2s, width 0.2s;
}
.ra-input:checked + .ra-label{ color: black; }
/* Dark mode adjustments for Radio */
.dark .ra-inner { background: var(--neutral-800); }
.dark .ra-label { color: var(--neutral-400); }
.dark .ra-highlight { background: var(--neutral-600); }
.dark .ra-input:checked + .ra-label { color: white; }
#gpu-duration-container {
padding: 10px;
border-radius: 8px;
background: var(--background-fill-secondary);
border: 1px solid var(--border-color-primary);
margin-top: 10px;
}
"""
class RadioAnimated(gr.HTML):
def __init__(self, choices, value=None, **kwargs):
if not choices or len(choices) < 2:
raise ValueError("RadioAnimated requires at least 2 choices.")
if value is None:
value = choices[0]
uid = uuid.uuid4().hex[:8]
group_name = f"ra-{uid}"
inputs_html = "\n".join(
f"""
"""
for i, c in enumerate(choices)
)
html_template = f"""
"""
js_on_load = r"""
(() => {
const wrap = element.querySelector('.ra-wrap');
const inner = element.querySelector('.ra-inner');
const highlight = element.querySelector('.ra-highlight');
const inputs = Array.from(element.querySelectorAll('.ra-input'));
if (!inputs.length) return;
const choices = inputs.map(i => i.value);
function setHighlightByIndex(idx) {
const n = choices.length;
const pct = 100 / n;
highlight.style.width = `calc(${pct}% - 6px)`;
highlight.style.transform = `translateX(${idx * 100}%)`;
}
function setCheckedByValue(val, shouldTrigger=false) {
const idx = Math.max(0, choices.indexOf(val));
inputs.forEach((inp, i) => { inp.checked = (i === idx); });
setHighlightByIndex(idx);
props.value = choices[idx];
if (shouldTrigger) trigger('change', props.value);
}
setCheckedByValue(props.value ?? choices[0], false);
inputs.forEach((inp) => {
inp.addEventListener('change', () => {
setCheckedByValue(inp.value, true);
});
});
})();
"""
super().__init__(
value=value,
html_template=html_template,
js_on_load=js_on_load,
**kwargs
)
def apply_gpu_duration(val: str):
return int(val)
MODEL_ID = "allenai/SAGE-MM-Qwen3-VL-4B-SFT_RL"
print(f"Loading {MODEL_ID}...")
processor = AutoProcessor.from_pretrained(
MODEL_ID,
attn_implementation="kernels-community/flash-attn2",
trust_remote_code=True,
dtype="auto",
device_map="auto"
)
model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
trust_remote_code=True,
dtype="auto",
device_map="auto"
)
print("Model loaded successfully.")
def calc_timeout_video(user_text: str, video_path: str, max_new_tokens: int, gpu_timeout: int):
"""Calculate GPU timeout duration for video processing."""
try:
return int(gpu_timeout)
except:
return 90
@spaces.GPU(duration=calc_timeout_video)
def process_video(user_text, video_path, max_new_tokens, gpu_timeout: int = 120):
if not video_path:
return "Please upload a video."
if not user_text.strip():
user_text = "Describe this video in detail."
messages = [
{
"role": "user",
"content": [
dict(type="text", text=user_text),
dict(type="video", video=video_path),
],
}
]
try:
_, videos, video_kwargs = process_vision_info(messages)
videos, video_metadatas = zip(*videos)
videos, video_metadatas = list(videos), list(video_metadatas)
except Exception as e:
return f"Error processing video frames: {e}"
# Apply chat template
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Prepare inputs
inputs = processor(
videos=videos,
video_metadata=video_metadatas,
text=text,
padding=True,
return_tensors="pt",
**video_kwargs,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Generate
with torch.inference_mode():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens
)
generated_tokens = generated_ids[0, inputs['input_ids'].size(1):]
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text
with gr.Blocks() as demo:
gr.Markdown("# **SAGE-MM-Video-Reasoning**", elem_id="main-title")
gr.Markdown("Upload a video to get a detailed explanation or ask specific questions using [SAGE-MM-Qwen3-VL](https://huggingface.co/allenai/SAGE-MM-Qwen3-VL-4B-SFT_RL).")
with gr.Row():
with gr.Column():
vid_input = gr.Video(label="Input Video", format="mp4", height=350)
vid_prompt = gr.Textbox(
label="Prompt",
value="Describe this video in detail.",
placeholder="Type your question here..."
)
with gr.Accordion("Advanced Settings", open=False):
max_tokens_slider = gr.Slider(
minimum=128,
maximum=4096,
value=1024,
step=128,
label="Max New Tokens",
info="Controls the length of the generated text."
)
vid_btn = gr.Button("Analyze Video", variant="primary")
with gr.Column():
vid_text_out = gr.Textbox(label="Model Response", interactive=True, lines=18)
with gr.Row(elem_id="gpu-duration-container"):
with gr.Column():
gr.Markdown("**GPU Duration (seconds)**")
radioanimated_gpu_duration = RadioAnimated(
choices=["90", "120", "180", "240", "300"],
value="90",
elem_id="radioanimated_gpu_duration"
)
gpu_duration_state = gr.Number(value=90, visible=False)
gr.Markdown("*Note: Higher GPU duration allows for longer processing but consumes more GPU quota.*")
gr.Examples(
examples=[
["example-videos/1.mp4"],
["example-videos/2.mp4"],
["example-videos/3.mp4"],
["example-videos/4.mp4"],
["example-videos/5.mp4"],
],
inputs=[vid_input],
label="Video Examples"
)
radioanimated_gpu_duration.change(
fn=apply_gpu_duration,
inputs=radioanimated_gpu_duration,
outputs=[gpu_duration_state],
api_visibility="private"
)
vid_btn.click(
fn=process_video,
inputs=[vid_prompt, vid_input, max_tokens_slider, gpu_duration_state],
outputs=[vid_text_out]
)
if __name__ == "__main__":
demo.launch(theme=orange_red_theme, css=css, mcp_server=True, ssr_mode=False)