import shutil import gradio as gr import spaces import yt_dlp import os import tempfile import re import subprocess import socket import time import atexit import torch from transformers import AutoModel, AutoProcessor PROXY_URL = None _tunnel_proc = None def _write_temp_key_and_kh(key_str, kh_line): key_clean = key_str.replace("\r\n", "\n").replace("\r", "\n") if not key_clean.endswith("\n"): key_clean += "\n" keyf = tempfile.NamedTemporaryFile("w", delete=False) keyf.write(key_clean) keyf.flush() os.chmod(keyf.name, 0o600) keyf.close() khf = tempfile.NamedTemporaryFile("w", delete=False) khf.write(kh_line.strip() + "\n") khf.flush() khf.close() return keyf.name, khf.name def _validate_private_key(path): if not shutil.which("ssh-keygen"): return True try: subprocess.check_output(["ssh-keygen", "-y", "-f", path], stderr=subprocess.STDOUT) return True except subprocess.CalledProcessError: return False def _ensure_local_socks_tunnel(): global PROXY_URL, _tunnel_proc if PROXY_URL: return srv = os.getenv("SSH_SERVER") port = os.getenv("SSH_PORT", "22") key = os.getenv("SSH_PRIVATE_KEY") hk = os.getenv("SSH_HOSTKEY") if not (srv and key and hk and shutil.which("ssh")): return key_path, kh_path = _write_temp_key_and_kh(key, hk) if not _validate_private_key(key_path): return cmd = [ "ssh","-NT","-p", port,"-i", key_path, "-D","127.0.0.1:1080", "-o","IdentitiesOnly=yes", "-o","ExitOnForwardFailure=yes", "-o","BatchMode=yes", "-o","StrictHostKeyChecking=yes", "-o", f"UserKnownHostsFile={kh_path}", "-o","GlobalKnownHostsFile=/dev/null", "-o","ServerAliveInterval=30","-o","ServerAliveCountMax=3", srv, ] with open("/tmp/ssh_tunnel.log", "w") as lf: _tunnel_proc = subprocess.Popen(cmd, stdout=lf, stderr=lf) for _ in range(40): if _tunnel_proc.poll() is not None: return try: socket.create_connection(("127.0.0.1", 1080), 0.5).close() PROXY_URL = "socks5h://127.0.0.1:1080" break except OSError: time.sleep(0.25) atexit.register(lambda: _tunnel_proc and _tunnel_proc.terminate()) _ensure_local_socks_tunnel() MODEL_ID = "nvidia/music-flamingo-hf" HERO_IMAGE_URL = "https://musicflamingo.github.io/logo-no-bg.png" HERO_TITLE = "Music Flamingo: Scaling Music Understanding in Audio Language Models" HERO_SUBTITLE = "Upload a song and ask anything β€” including captions, lyrics, genre, key, chords, or complex questions. Music Flamingo gives detailed answers." HERO_AUTHORS = """

Authors: Sreyan Ghosh1,2*, Arushi Goel1*, Lasha Koroshinadze2**, Sang-gil Lee1, Zhifeng Kong1, Joao Felipe Santos1,
Ramani Duraiswami2, Dinesh Manocha2, Wei Ping1, Mohammad Shoeybi1, Bryan Catanzaro1

1NVIDIA, CA, USA | 2University of Maryland, College Park, USA

*Equally contributed and led the project. Names randomly ordered. **Significant technical contribution.

Correspondence: sreyang@umd.edu, arushig@nvidia.com

""" HERO_BADGES = """
arXiv Demo page Github Stars Checkpoints Dataset
""" APP_CSS = """ :root { --font-sans: ui-sans-serif, system-ui, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji"; --font-mono: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; --app-font: var(--font-sans); } body { font-family: var(--app-font); } .gradio-container { font-family: var(--app-font); max-width: 80rem !important; /* Tailwind max-w-7xl (1280px) */ width: 100%; margin-inline: auto; /* mx-auto */ padding-inline: 1rem; /* px-4 */ padding-bottom: 64px; } .hero { display: flex; flex-direction: column; align-items: center; gap: 12px; padding: 24px 24px 32px; text-align: center; } .hero__logo { width: 112px; height: 112px; border-radius: 50%; box-shadow: 0 12px 40px rgba(0, 0, 0, 0.15); } .hero__title { font-size: clamp(2.4rem, 5.4vw, 3.2rem); font-weight: 700; line-height: 1.5; letter-spacing: -0.01em; background: linear-gradient(120deg, #ff6bd6 0%, #af66ff 35%, #4e9cff 100%); -webkit-background-clip: text; background-clip: text; color: transparent; } .hero__subtitle { max-width: none; font-size: 1.08rem; opacity: 0.8; } .tab-nav { border-radius: 18px; border: 1px solid var(--border-color-primary); padding: 6px; margin: 0 18px 12px; } .tab-nav button { border-radius: 12px !important; } .tab-nav button[aria-selected="true"] { box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1); } .panel-row { gap: 24px !important; align-items: stretch; flex-wrap: wrap; } .glass-card { border: 1px solid var(--border-color-primary); border-radius: 26px; padding: 28px; box-shadow: 0 8px 25px rgba(0, 0, 0, 0.1); display: flex; flex-direction: column; gap: 18px; } /* Glass card content styling */ .glass-card .gradio-input, .glass-card .gradio-output { /* Let Gradio handle default styling */ } .glass-card label { font-weight: 600; letter-spacing: 0.01em; } /* Text input styling */ .glass-card textarea { border-radius: 18px !important; } .glass-card textarea:focus { box-shadow: 0 0 0 3px rgba(0, 123, 255, 0.25) !important; } /* Audio component fix */ .glass-card [data-testid="Audio"] .wrap { /* Let Gradio handle default styling */ } /* YouTube embed styling */ .glass-card [data-testid="HTML"] { margin: 12px 0; } /* Load button styling */ .glass-card button[variant="secondary"] { border-radius: 12px !important; font-weight: 500 !important; } /* Action button styling */ .accent-button { background: linear-gradient(120deg, #ff6bd6 0%, #8f5bff 45%, #4e9cff 100%) !important; border-radius: 14px !important; box-shadow: 0 6px 20px rgba(0, 0, 0, 0.15); color: #ffffff !important; font-weight: 600 !important; letter-spacing: 0.01em; padding: 0.85rem 1.5rem !important; transition: transform 0.18s ease, box-shadow 0.18s ease; } .accent-button:hover { transform: translateY(-2px); box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2); } .accent-button:active { transform: translateY(0px); box-shadow: 0 4px 15px rgba(0, 0, 0, 0.15); } .footer-note { text-align: center; opacity: 0.6; margin-top: 28px; font-size: 0.95rem; } """ EXAMPLE_YOUTUBE_PROMPTS = [ [ "https://youtu.be/ko70cExuzZM", "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.", ], [ "https://youtu.be/iywaBOMvYLI", "Generate a structured lyric sheet from the input music.", ], [ "https://youtu.be/_mTRvJ9fugM", "Which line directly precedes the chorus?", ], ] processor = AutoProcessor.from_pretrained(MODEL_ID) model = AutoModel.from_pretrained(MODEL_ID, device_map="cpu", torch_dtype=torch.float32).eval() _youtube_cache = {} def clear_youtube_cache(): """Clear the YouTube audio cache and delete cached files.""" import shutil for url, (file_path, title) in _youtube_cache.items(): try: if os.path.exists(file_path): temp_dir = os.path.dirname(file_path) shutil.rmtree(temp_dir) except Exception: pass _youtube_cache.clear() def truncate_title(title, max_length=50): """Truncate long titles with ellipsis to prevent UI wrapping.""" if len(title) <= max_length: return title return title[: max_length - 3] + "..." def extract_youtube_id(url): """Extract YouTube video ID from various YouTube URL formats.""" patterns = [ r"(?:https?://)?(?:www\.)?youtube\.com/watch\?v=([^&=%\?]{11})", r"(?:https?://)?(?:www\.)?youtu\.be/([^&=%\?]{11})", r"(?:https?://)?(?:www\.)?youtube\.com/embed/([^&=%\?]{11})", r"(?:https?://)?(?:www\.)?youtube-nocookie\.com/embed/([^&=%\?]{11})", r"(?:https?://)?(?:www\.)?youtube\.com/v/([^&=%\?]{11})", ] for pattern in patterns: match = re.search(pattern, url) if match: return match.group(1) return None def generate_youtube_embed(url, title="YouTube Video"): """Generate YouTube embed HTML from URL.""" video_id = extract_youtube_id(url) if not video_id: return "" embed_html = f"""
""" return embed_html def download_youtube_audio(url, force_reload=False): """Download audio from YouTube URL and return the file path.""" try: youtube_regex = re.compile(r"(https?://)?(www\.)?(youtube|youtu|youtube-nocookie)\.(com|be)/" r"(watch\?v=|embed/|v/|.+\?v=)?([^&=%\?]{11})") if not youtube_regex.match(url): return None, "❌ Invalid YouTube URL format" if not force_reload and url in _youtube_cache: cached_path, cached_title = _youtube_cache[url] if os.path.exists(cached_path): return cached_path, f"βœ… Using cached: {truncate_title(cached_title)}" if force_reload and url in _youtube_cache: old_path, _ = _youtube_cache[url] try: if os.path.exists(old_path): import shutil temp_dir = os.path.dirname(old_path) shutil.rmtree(temp_dir) except Exception: pass del _youtube_cache[url] temp_dir = tempfile.mkdtemp() ydl_opts = { "format": "bestaudio/best", "outtmpl": os.path.join(temp_dir, "%(title)s.%(ext)s"), "postprocessors": [ { "key": "FFmpegExtractAudio", "preferredcodec": "mp3", "preferredquality": "128", } ], "noplaylist": True, } if PROXY_URL: ydl_opts["proxy"] = PROXY_URL with yt_dlp.YoutubeDL(ydl_opts) as ydl: info = ydl.extract_info(url, download=False) title = info.get("title", "Unknown") ydl.download([url]) for file in os.listdir(temp_dir): if file.endswith(".mp3"): file_path = os.path.join(temp_dir, file) _youtube_cache[url] = (file_path, title) return file_path, f"βœ… Downloaded: {truncate_title(title)}" return None, "❌ Failed to download audio file" except Exception as e: return None, f"❌ Download error: {str(e)}" @spaces.GPU def infer(audio_path, youtube_url, prompt_text): try: device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) final_audio_path = None status_message = "" if audio_path: final_audio_path = audio_path status_message = "βœ… Using audio file" elif youtube_url.strip(): final_audio_path, status_message = download_youtube_audio(youtube_url.strip()) if not final_audio_path: return status_message else: return "❌ Please either upload an audio file or provide a YouTube URL." conversations = [ [ { "role": "user", "content": [ {"type": "text", "text": prompt_text or ""}, {"type": "audio", "path": final_audio_path}, ], } ] ] # NOTE: If `conversations` includes audio, apply_chat_template() decodes via load_audio() # to MONO float32 at 16 kHz by default. We omit `sampling_rate`, so the 16k default is used. # Processor assumes mono 1-D audio; stereo would require code changes. No audio β‡’ no effect here. batch = processor.apply_chat_template( conversations, tokenize=True, add_generation_prompt=True, return_dict=True, ).to(model.device) gen_ids = model.generate(**batch, max_new_tokens=4096, repetition_penalty=1.2) inp_len = batch["input_ids"].shape[1] new_tokens = gen_ids[:, inp_len:] texts = processor.batch_decode(new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False) result = texts[0] if texts else "" return f"{status_message}\n\n{result}" except Exception as e: return f"❌ Error: {str(e)}" def load_youtube_audio(youtube_url): """Load YouTube audio into the Audio component and generate video embed.""" if not youtube_url.strip(): return None, "❌ Please enter a YouTube URL", "" embed_html = generate_youtube_embed(youtube_url.strip()) audio_path, message = download_youtube_audio(youtube_url.strip(), force_reload=True) if audio_path: return audio_path, message, embed_html else: return None, message, embed_html with gr.Blocks(css=APP_CSS, theme=gr.themes.Soft(primary_hue="purple", secondary_hue="fuchsia")) as demo: gr.HTML( f"""

{HERO_TITLE}

{HERO_SUBTITLE}

{HERO_AUTHORS} {HERO_BADGES}
""" ) with gr.Tabs(elem_classes="tab-nav"): with gr.Row(elem_classes="panel-row"): with gr.Column(elem_classes=["glass-card"]): gr.Markdown("### 🎡 Audio Input") audio_in = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Upload Audio File", show_label=True, ) gr.Markdown("**OR**") youtube_url = gr.Textbox(label="YouTube URL", placeholder="https://www.youtube.com/watch?v=...", info="Paste any YouTube URL - we'll extract high-quality audio automatically") load_btn = gr.Button("πŸ”„ Load Audio", variant="secondary", size="sm") status_text = gr.Textbox(label="Status", interactive=False, visible=False) youtube_embed = gr.HTML(label="Video Preview", visible=False) prompt_in = gr.Textbox( label="Prompt", value="Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.", placeholder="Ask a question about the audio…", lines=6, ) gr.Examples( examples=EXAMPLE_YOUTUBE_PROMPTS, inputs=[youtube_url, prompt_in], label="🎡 Example Prompts", ) btn = gr.Button("Generate Answer", elem_classes="accent-button") with gr.Column(elem_classes=["glass-card"]): out = gr.Textbox( label="Model Response", lines=25, placeholder="Model answers will appear here with audio-informed insights…", ) load_btn.click(lambda: [None, "πŸ”„ Loading audio...", gr.update(visible=True)], outputs=[audio_in, status_text, status_text]).then( fn=load_youtube_audio, inputs=[youtube_url], outputs=[audio_in, status_text, youtube_embed] ).then(lambda: gr.update(visible=True), outputs=[youtube_embed]) btn.click(fn=infer, inputs=[audio_in, youtube_url, prompt_in], outputs=out) gr.HTML( """ """ ) if __name__ == "__main__": demo.launch(share=True)