import shutil
import gradio as gr
import spaces
import yt_dlp
import os
import tempfile
import re
import subprocess
import socket
import time
import atexit
import torch
from transformers import AutoModel, AutoProcessor
PROXY_URL = None
_tunnel_proc = None
def _write_temp_key_and_kh(key_str, kh_line):
key_clean = key_str.replace("\r\n", "\n").replace("\r", "\n")
if not key_clean.endswith("\n"):
key_clean += "\n"
keyf = tempfile.NamedTemporaryFile("w", delete=False)
keyf.write(key_clean)
keyf.flush()
os.chmod(keyf.name, 0o600)
keyf.close()
khf = tempfile.NamedTemporaryFile("w", delete=False)
khf.write(kh_line.strip() + "\n")
khf.flush()
khf.close()
return keyf.name, khf.name
def _validate_private_key(path):
if not shutil.which("ssh-keygen"):
return True
try:
subprocess.check_output(["ssh-keygen", "-y", "-f", path], stderr=subprocess.STDOUT)
return True
except subprocess.CalledProcessError:
return False
def _ensure_local_socks_tunnel():
global PROXY_URL, _tunnel_proc
if PROXY_URL:
return
srv = os.getenv("SSH_SERVER")
port = os.getenv("SSH_PORT", "22")
key = os.getenv("SSH_PRIVATE_KEY")
hk = os.getenv("SSH_HOSTKEY")
if not (srv and key and hk and shutil.which("ssh")):
return
key_path, kh_path = _write_temp_key_and_kh(key, hk)
if not _validate_private_key(key_path):
return
cmd = [
"ssh","-NT","-p", port,"-i", key_path,
"-D","127.0.0.1:1080",
"-o","IdentitiesOnly=yes",
"-o","ExitOnForwardFailure=yes",
"-o","BatchMode=yes",
"-o","StrictHostKeyChecking=yes",
"-o", f"UserKnownHostsFile={kh_path}",
"-o","GlobalKnownHostsFile=/dev/null",
"-o","ServerAliveInterval=30","-o","ServerAliveCountMax=3",
srv,
]
with open("/tmp/ssh_tunnel.log", "w") as lf:
_tunnel_proc = subprocess.Popen(cmd, stdout=lf, stderr=lf)
for _ in range(40):
if _tunnel_proc.poll() is not None:
return
try:
socket.create_connection(("127.0.0.1", 1080), 0.5).close()
PROXY_URL = "socks5h://127.0.0.1:1080"
break
except OSError:
time.sleep(0.25)
atexit.register(lambda: _tunnel_proc and _tunnel_proc.terminate())
_ensure_local_socks_tunnel()
MODEL_ID = "nvidia/music-flamingo-hf"
HERO_IMAGE_URL = "https://musicflamingo.github.io/logo-no-bg.png"
HERO_TITLE = "Music Flamingo: Scaling Music Understanding in Audio Language Models"
HERO_SUBTITLE = "Upload a song and ask anything β including captions, lyrics, genre, key, chords, or complex questions. Music Flamingo gives detailed answers."
HERO_AUTHORS = """
Authors: Sreyan Ghosh1,2*, Arushi Goel1*, Lasha Koroshinadze2**, Sang-gil Lee1, Zhifeng Kong1, Joao Felipe Santos1,
Ramani Duraiswami2, Dinesh Manocha2, Wei Ping1, Mohammad Shoeybi1, Bryan Catanzaro1
1NVIDIA, CA, USA | 2University of Maryland, College Park, USA
*Equally contributed and led the project. Names randomly ordered. **Significant technical contribution.
Correspondence: sreyang@umd.edu, arushig@nvidia.com
"""
HERO_BADGES = """
"""
APP_CSS = """
:root {
--font-sans: ui-sans-serif, system-ui, sans-serif,
"Apple Color Emoji", "Segoe UI Emoji",
"Segoe UI Symbol", "Noto Color Emoji";
--font-mono: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas,
"Liberation Mono", "Courier New", monospace;
--app-font: var(--font-sans);
}
body {
font-family: var(--app-font);
}
.gradio-container {
font-family: var(--app-font);
max-width: 80rem !important; /* Tailwind max-w-7xl (1280px) */
width: 100%;
margin-inline: auto; /* mx-auto */
padding-inline: 1rem; /* px-4 */
padding-bottom: 64px;
}
.hero {
display: flex;
flex-direction: column;
align-items: center;
gap: 12px;
padding: 24px 24px 32px;
text-align: center;
}
.hero__logo {
width: 112px;
height: 112px;
border-radius: 50%;
box-shadow: 0 12px 40px rgba(0, 0, 0, 0.15);
}
.hero__title {
font-size: clamp(2.4rem, 5.4vw, 3.2rem);
font-weight: 700;
line-height: 1.5;
letter-spacing: -0.01em;
background: linear-gradient(120deg, #ff6bd6 0%, #af66ff 35%, #4e9cff 100%);
-webkit-background-clip: text;
background-clip: text;
color: transparent;
}
.hero__subtitle {
max-width: none;
font-size: 1.08rem;
opacity: 0.8;
}
.tab-nav {
border-radius: 18px;
border: 1px solid var(--border-color-primary);
padding: 6px;
margin: 0 18px 12px;
}
.tab-nav button {
border-radius: 12px !important;
}
.tab-nav button[aria-selected="true"] {
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
}
.panel-row {
gap: 24px !important;
align-items: stretch;
flex-wrap: wrap;
}
.glass-card {
border: 1px solid var(--border-color-primary);
border-radius: 26px;
padding: 28px;
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.1);
display: flex;
flex-direction: column;
gap: 18px;
}
/* Glass card content styling */
.glass-card .gradio-input,
.glass-card .gradio-output {
/* Let Gradio handle default styling */
}
.glass-card label {
font-weight: 600;
letter-spacing: 0.01em;
}
/* Text input styling */
.glass-card textarea {
border-radius: 18px !important;
}
.glass-card textarea:focus {
box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.25) !important;
}
/* Audio component fix */
.glass-card [data-testid="Audio"] .wrap {
/* Let Gradio handle default styling */
}
/* YouTube embed styling */
.glass-card [data-testid="HTML"] {
margin: 12px 0;
}
/* Load button styling */
.glass-card button[variant="secondary"] {
border-radius: 12px !important;
font-weight: 500 !important;
}
/* Action button styling */
.accent-button {
background: linear-gradient(120deg, #ff6bd6 0%, #8f5bff 45%, #4e9cff 100%) !important;
border-radius: 14px !important;
box-shadow: 0 6px 20px rgba(0, 0, 0, 0.15);
color: #ffffff !important;
font-weight: 600 !important;
letter-spacing: 0.01em;
padding: 0.85rem 1.5rem !important;
transition: transform 0.18s ease, box-shadow 0.18s ease;
}
.accent-button:hover {
transform: translateY(-2px);
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2);
}
.accent-button:active {
transform: translateY(0px);
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.15);
}
.footer-note {
text-align: center;
opacity: 0.6;
margin-top: 28px;
font-size: 0.95rem;
}
"""
EXAMPLE_YOUTUBE_PROMPTS = [
[
"https://youtu.be/ko70cExuzZM",
"Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.",
],
[
"https://youtu.be/iywaBOMvYLI",
"Generate a structured lyric sheet from the input music.",
],
[
"https://youtu.be/_mTRvJ9fugM",
"Which line directly precedes the chorus?",
],
]
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModel.from_pretrained(MODEL_ID, device_map="cpu", torch_dtype=torch.float32).eval()
_youtube_cache = {}
def clear_youtube_cache():
"""Clear the YouTube audio cache and delete cached files."""
import shutil
for url, (file_path, title) in _youtube_cache.items():
try:
if os.path.exists(file_path):
temp_dir = os.path.dirname(file_path)
shutil.rmtree(temp_dir)
except Exception:
pass
_youtube_cache.clear()
def truncate_title(title, max_length=50):
"""Truncate long titles with ellipsis to prevent UI wrapping."""
if len(title) <= max_length:
return title
return title[: max_length - 3] + "..."
def extract_youtube_id(url):
"""Extract YouTube video ID from various YouTube URL formats."""
patterns = [
r"(?:https?://)?(?:www\.)?youtube\.com/watch\?v=([^&=%\?]{11})",
r"(?:https?://)?(?:www\.)?youtu\.be/([^&=%\?]{11})",
r"(?:https?://)?(?:www\.)?youtube\.com/embed/([^&=%\?]{11})",
r"(?:https?://)?(?:www\.)?youtube-nocookie\.com/embed/([^&=%\?]{11})",
r"(?:https?://)?(?:www\.)?youtube\.com/v/([^&=%\?]{11})",
]
for pattern in patterns:
match = re.search(pattern, url)
if match:
return match.group(1)
return None
def generate_youtube_embed(url, title="YouTube Video"):
"""Generate YouTube embed HTML from URL."""
video_id = extract_youtube_id(url)
if not video_id:
return ""
embed_html = f"""
"""
return embed_html
def download_youtube_audio(url, force_reload=False):
"""Download audio from YouTube URL and return the file path."""
try:
youtube_regex = re.compile(r"(https?://)?(www\.)?(youtube|youtu|youtube-nocookie)\.(com|be)/" r"(watch\?v=|embed/|v/|.+\?v=)?([^&=%\?]{11})")
if not youtube_regex.match(url):
return None, "β Invalid YouTube URL format"
if not force_reload and url in _youtube_cache:
cached_path, cached_title = _youtube_cache[url]
if os.path.exists(cached_path):
return cached_path, f"β
Using cached: {truncate_title(cached_title)}"
if force_reload and url in _youtube_cache:
old_path, _ = _youtube_cache[url]
try:
if os.path.exists(old_path):
import shutil
temp_dir = os.path.dirname(old_path)
shutil.rmtree(temp_dir)
except Exception:
pass
del _youtube_cache[url]
temp_dir = tempfile.mkdtemp()
ydl_opts = {
"format": "bestaudio/best",
"outtmpl": os.path.join(temp_dir, "%(title)s.%(ext)s"),
"postprocessors": [
{
"key": "FFmpegExtractAudio",
"preferredcodec": "mp3",
"preferredquality": "128",
}
],
"noplaylist": True,
}
if PROXY_URL:
ydl_opts["proxy"] = PROXY_URL
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info = ydl.extract_info(url, download=False)
title = info.get("title", "Unknown")
ydl.download([url])
for file in os.listdir(temp_dir):
if file.endswith(".mp3"):
file_path = os.path.join(temp_dir, file)
_youtube_cache[url] = (file_path, title)
return file_path, f"β
Downloaded: {truncate_title(title)}"
return None, "β Failed to download audio file"
except Exception as e:
return None, f"β Download error: {str(e)}"
@spaces.GPU
def infer(audio_path, youtube_url, prompt_text):
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
final_audio_path = None
status_message = ""
if audio_path:
final_audio_path = audio_path
status_message = "β
Using audio file"
elif youtube_url.strip():
final_audio_path, status_message = download_youtube_audio(youtube_url.strip())
if not final_audio_path:
return status_message
else:
return "β Please either upload an audio file or provide a YouTube URL."
conversations = [
[
{
"role": "user",
"content": [
{"type": "text", "text": prompt_text or ""},
{"type": "audio", "path": final_audio_path},
],
}
]
]
# NOTE: If `conversations` includes audio, apply_chat_template() decodes via load_audio()
# to MONO float32 at 16 kHz by default. We omit `sampling_rate`, so the 16k default is used.
# Processor assumes mono 1-D audio; stereo would require code changes. No audio β no effect here.
batch = processor.apply_chat_template(
conversations,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
).to(model.device)
gen_ids = model.generate(**batch, max_new_tokens=4096, repetition_penalty=1.2)
inp_len = batch["input_ids"].shape[1]
new_tokens = gen_ids[:, inp_len:]
texts = processor.batch_decode(new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
result = texts[0] if texts else ""
return f"{status_message}\n\n{result}"
except Exception as e:
return f"β Error: {str(e)}"
def load_youtube_audio(youtube_url):
"""Load YouTube audio into the Audio component and generate video embed."""
if not youtube_url.strip():
return None, "β Please enter a YouTube URL", ""
embed_html = generate_youtube_embed(youtube_url.strip())
audio_path, message = download_youtube_audio(youtube_url.strip(), force_reload=True)
if audio_path:
return audio_path, message, embed_html
else:
return None, message, embed_html
with gr.Blocks(css=APP_CSS, theme=gr.themes.Soft(primary_hue="purple", secondary_hue="fuchsia")) as demo:
gr.HTML(
f"""
{HERO_TITLE}
{HERO_SUBTITLE}
{HERO_AUTHORS}
{HERO_BADGES}
"""
)
with gr.Tabs(elem_classes="tab-nav"):
with gr.Row(elem_classes="panel-row"):
with gr.Column(elem_classes=["glass-card"]):
gr.Markdown("### π΅ Audio Input")
audio_in = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Upload Audio File",
show_label=True,
)
gr.Markdown("**OR**")
youtube_url = gr.Textbox(label="YouTube URL", placeholder="https://www.youtube.com/watch?v=...", info="Paste any YouTube URL - we'll extract high-quality audio automatically")
load_btn = gr.Button("π Load Audio", variant="secondary", size="sm")
status_text = gr.Textbox(label="Status", interactive=False, visible=False)
youtube_embed = gr.HTML(label="Video Preview", visible=False)
prompt_in = gr.Textbox(
label="Prompt",
value="Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.",
placeholder="Ask a question about the audioβ¦",
lines=6,
)
gr.Examples(
examples=EXAMPLE_YOUTUBE_PROMPTS,
inputs=[youtube_url, prompt_in],
label="π΅ Example Prompts",
)
btn = gr.Button("Generate Answer", elem_classes="accent-button")
with gr.Column(elem_classes=["glass-card"]):
out = gr.Textbox(
label="Model Response",
lines=25,
placeholder="Model answers will appear here with audio-informed insightsβ¦",
)
load_btn.click(lambda: [None, "π Loading audio...", gr.update(visible=True)], outputs=[audio_in, status_text, status_text]).then(
fn=load_youtube_audio, inputs=[youtube_url], outputs=[audio_in, status_text, youtube_embed]
).then(lambda: gr.update(visible=True), outputs=[youtube_embed])
btn.click(fn=infer, inputs=[audio_in, youtube_url, prompt_in], outputs=out)
gr.HTML(
"""
"""
)
if __name__ == "__main__":
demo.launch(share=True)