# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. """ Optimized MelodyFlow API for concurrent request handling on T4 GPU This version focuses on high-throughput API serving with batching """ import os import sys # Fix OpenMP threading issues - ensure they're set early and correctly os.environ['OMP_NUM_THREADS'] = '1' os.environ['MKL_NUM_THREADS'] = '1' os.environ['NUMEXPR_NUM_THREADS'] = '1' os.environ['OPENBLAS_NUM_THREADS'] = '1' # Additional protection against environment variable corruption def ensure_thread_env(): """Ensure threading environment variables stay set""" for key, value in [('OMP_NUM_THREADS', '1'), ('MKL_NUM_THREADS', '1'), ('NUMEXPR_NUM_THREADS', '1'), ('OPENBLAS_NUM_THREADS', '1')]: if os.environ.get(key) != value: os.environ[key] = value print(f"Reset {key} to {value}") # Call it immediately ensure_thread_env() import spaces import asyncio import threading import time import uuid import base64 import logging from concurrent.futures import ThreadPoolExecutor, Future from queue import Queue, Empty from tempfile import NamedTemporaryFile from pathlib import Path import typing as tp from dataclasses import dataclass import torch import gradio as gr from audiocraft.data.audio_utils import convert_audio from audiocraft.data.audio import audio_read, audio_write from audiocraft.models import MelodyFlow # Fix CSV field size limit for large audio data import csv csv.field_size_limit(1000000) # Increase field size limit # Configuration MODEL_PREFIX = "facebook/" BATCH_SIZE = 4 # Optimal for T4 GPU memory BATCH_TIMEOUT = 1.5 # Seconds to wait for batch formation MAX_QUEUE_SIZE = 100 MAX_CONCURRENT_BATCHES = 2 # Number of concurrent batch processors class FileCleaner: """Simple file cleaner for temporary audio files""" def __init__(self, file_lifetime: float = 3600): self.file_lifetime = file_lifetime self.files = [] def add(self, path: tp.Union[str, Path]): self._cleanup() self.files.append((time.time(), Path(path))) def _cleanup(self): now = time.time() for time_added, path in list(self.files): if now - time_added > self.file_lifetime: if path.exists(): path.unlink() self.files.pop(0) else: break # Global file cleaner file_cleaner = FileCleaner() @dataclass class GenerationRequest: """Represents a single generation request""" request_id: str text: str melody: tp.Optional[str] solver: str steps: int target_flowstep: float regularize: bool regularization_strength: float duration: float model: str future: Future created_at: float class OptimizedBatchProcessor: """Highly optimized batch processor for T4 GPU""" def __init__(self): self.model = None self.model_lock = threading.Lock() self.request_queue = Queue(maxsize=MAX_QUEUE_SIZE) self.current_batch = [] self.batch_start_time = None self.processing = False self.stop_event = threading.Event() self.executor = ThreadPoolExecutor(max_workers=MAX_CONCURRENT_BATCHES) def start(self): """Start the batch processing service""" self.thread = threading.Thread(target=self._batch_loop, daemon=True) self.thread.start() logging.info("Batch processor started") def stop(self): """Stop the batch processing service""" self.stop_event.set() self.executor.shutdown(wait=True) def submit_request(self, text: str, melody: tp.Optional[str], solver: str, steps: int, target_flowstep: float, regularize: bool, regularization_strength: float, duration: float, model: str) -> Future: """Submit a generation request and return a future""" request = GenerationRequest( request_id=str(uuid.uuid4()), text=text, melody=melody, solver=solver, steps=steps, target_flowstep=target_flowstep, regularize=regularize, regularization_strength=regularization_strength, duration=duration, model=model, future=Future(), created_at=time.time() ) print(f"📝 Submitting request {request.request_id} with text: '{text[:30]}...'") try: self.request_queue.put_nowait(request) print(f"✅ Request {request.request_id} queued successfully") return request.future except: # Queue is full print(f"❌ Queue full for request {request.request_id}") request.future.set_exception(Exception("Server is busy, please try again")) return request.future def _batch_loop(self): """Main batch processing loop""" while not self.stop_event.is_set(): try: # Try to get a request try: request = self.request_queue.get(timeout=0.1) self.current_batch.append(request) if self.batch_start_time is None: self.batch_start_time = time.time() except Empty: # No new requests, check if we should process current batch if self._should_process_batch(): self._submit_batch() continue # Check if we should process the batch if self._should_process_batch(): self._submit_batch() except Exception as e: logging.error(f"Error in batch loop: {e}") def _should_process_batch(self) -> bool: """Determine if current batch should be processed""" if not self.current_batch: return False batch_age = time.time() - (self.batch_start_time or time.time()) return (len(self.current_batch) >= BATCH_SIZE or batch_age >= BATCH_TIMEOUT) def _submit_batch(self): """Submit current batch for processing""" if not self.current_batch: return batch = self.current_batch.copy() self.current_batch = [] self.batch_start_time = None # Submit to thread pool self.executor.submit(self._process_batch, batch) @spaces.GPU(duration=60) # Longer duration for batch processing def _process_batch(self, batch: tp.List[GenerationRequest]): """Process a batch of requests on GPU""" try: # Ensure environment variables are still set before processing ensure_thread_env() logging.info(f"Processing batch of {len(batch)} requests") start_time = time.time() # Load model (assume all requests use same model for simplicity) model_version = batch[0].model self._load_model(model_version) # Separate generation vs editing requests gen_requests = [req for req in batch if req.melody is None] edit_requests = [req for req in batch if req.melody is not None] results = {} # Process generation requests in batch if gen_requests: gen_results = self._process_generation_batch(gen_requests) results.update(gen_results) # Process editing requests individually (due to melody constraints) if edit_requests: edit_results = self._process_editing_batch(edit_requests) results.update(edit_results) # Set results for all requests for request in batch: if request.request_id in results: result_data = results[request.request_id] print(f"🔄 Setting result for request {request.request_id}: {type(result_data)}") request.future.set_result(result_data) else: print(f"❌ No result found for request {request.request_id}") request.future.set_exception(Exception("Processing failed")) processing_time = time.time() - start_time logging.info(f"Batch processed in {processing_time:.2f}s") except Exception as e: logging.error(f"Batch processing error: {e}") for request in batch: request.future.set_exception(e) def _load_model(self, version: str): """Thread-safe model loading""" # Ensure environment variables are still set ensure_thread_env() with self.model_lock: if self.model is None or self.model.name != version: if self.model is not None: del self.model if torch.cuda.is_available(): torch.cuda.empty_cache() self.model = MelodyFlow.get_pretrained(version) logging.info(f"Model {version} loaded") def _process_generation_batch(self, requests: tp.List[GenerationRequest]) -> dict: """Process generation requests in batch""" if not requests: return {} # Use parameters from first request (assuming similar params for batch) params = requests[0] self.model.set_generation_params( solver=params.solver, steps=params.steps, duration=params.duration ) # Extract texts texts = [req.text for req in requests] # Generate outputs = self.model.generate(texts, progress=False, return_tokens=False) outputs = outputs.detach().cpu().float() # Create results results = {} for i, request in enumerate(requests): audio_base64 = self._audio_to_base64(outputs[i]) results[request.request_id] = { "audio": audio_base64, "format": "wav" } return results def _process_editing_batch(self, requests: tp.List[GenerationRequest]) -> dict: """Process editing requests individually""" results = {} for request in requests: try: self.model.set_editing_params( solver=request.solver, steps=request.steps, target_flowstep=request.target_flowstep, regularize=request.regularize, lambda_kl=request.regularization_strength ) # Process melody melody, sr = audio_read(request.melody) if melody.dim() == 2: melody = melody[None] if melody.shape[-1] > int(sr * self.model.duration): melody = melody[..., :int(sr * self.model.duration)] melody = convert_audio(melody, sr, 48000, 2) melody = self.model.encode_audio(melody.to(self.model.device)) # Edit output = self.model.edit( prompt_tokens=melody, descriptions=[request.text], src_descriptions=[""], progress=False, return_tokens=False ) output = output.detach().cpu().float()[0] audio_base64 = self._audio_to_base64(output) results[request.request_id] = { "audio": audio_base64, "format": "wav" } except Exception as e: logging.error(f"Error processing edit request {request.request_id}: {e}") # Will be handled by batch processor return results def _audio_to_base64(self, audio_tensor: torch.Tensor) -> str: """Convert audio tensor to base64 string""" with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: audio_write( file.name, audio_tensor, self.model.sample_rate, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False ) with open(file.name, 'rb') as f: audio_bytes = f.read() # Clean up temp file Path(file.name).unlink() return base64.b64encode(audio_bytes).decode('utf-8') # Global batch processor batch_processor = OptimizedBatchProcessor() def predict_concurrent(model: str, text: str, solver: str = "euler", steps: int = 50, target_flowstep: float = 0.0, regularize: bool = False, regularization_strength: float = 0.0, duration: float = 10.0, melody: tp.Optional[str] = None) -> dict: """ Non-blocking predict function optimized for concurrent requests """ # Adjust steps for melody editing if melody is not None: steps = steps // 2 if solver == "midpoint" else steps // 5 # Submit request to batch processor future = batch_processor.submit_request( text=text, melody=melody, solver=solver, steps=steps, target_flowstep=target_flowstep, regularize=regularize, regularization_strength=regularization_strength, duration=duration, model=model ) # Wait for result with timeout try: result = future.result(timeout=120) # 2 minute timeout # Add some debugging if isinstance(result, dict): print(f"✅ Received result with keys: {list(result.keys())}") if "audio" in result: audio_len = len(result["audio"]) if result["audio"] else 0 print(f"🎵 Audio data length: {audio_len} characters") # Return a summary instead of the full base64 for testing # This will help determine if the issue is with large data return { "status": "success", "message": f"Audio generated successfully ({audio_len} bytes)", "format": result.get("format", "wav"), "duration": duration, "text_prompt": text[:50] + "..." if len(text) > 50 else text, # Uncomment the line below to return full audio data: # "audio": result["audio"], "audio_preview": result["audio"][:100] + "..." if result["audio"] else "No audio data" } else: print("⚠️ No audio key in result") return {"status": "error", "message": "No audio generated"} else: print(f"⚠️ Unexpected result type: {type(result)}") return {"status": "error", "message": f"Unexpected result type: {type(result)}"} except TimeoutError: print("⏰ Request timeout") raise gr.Error("Request timeout - server is overloaded") except Exception as e: print(f"💥 Exception: {str(e)}") raise gr.Error(f"Generation failed: {str(e)}") def predict_concurrent_ui(model: str, text: str, solver: str = "euler", steps: int = 50, target_flowstep: float = 0.0, regularize: bool = False, regularization_strength: float = 0.0, duration: float = 10.0, melody: tp.Optional[str] = None) -> str: """ UI-optimized predict function that returns audio file path for Gradio Audio component """ # Adjust steps for melody editing if melody is not None: steps = steps // 2 if solver == "midpoint" else steps // 5 # Submit request to batch processor future = batch_processor.submit_request( text=text, melody=melody, solver=solver, steps=steps, target_flowstep=target_flowstep, regularize=regularize, regularization_strength=regularization_strength, duration=duration, model=model ) # Wait for result with timeout try: result = future.result(timeout=120) # 2 minute timeout # Convert base64 result to audio file for UI if isinstance(result, dict) and "audio" in result: print(f"✅ Received audio result, converting to file...") # Decode base64 and save to temporary file import base64 from tempfile import NamedTemporaryFile audio_data = base64.b64decode(result["audio"]) with NamedTemporaryFile(mode="wb", suffix=".wav", delete=False) as temp_file: temp_file.write(audio_data) temp_file_path = temp_file.name file_cleaner.add(temp_file_path) # Add to cleanup queue print(f"🎵 Audio saved to: {temp_file_path}") return temp_file_path else: raise gr.Error("No audio data received") except TimeoutError: print("⏰ Request timeout") raise gr.Error("Request timeout - server is overloaded") except Exception as e: print(f"💥 Exception: {str(e)}") raise gr.Error(f"Generation failed: {str(e)}") def predict_concurrent(model: str, text: str, solver: str = "euler", steps: int = 50, target_flowstep: float = 0.0, regularize: bool = False, regularization_strength: float = 0.0, duration: float = 10.0, melody: tp.Optional[str] = None) -> dict: """ API predict function that returns base64 audio data (for API endpoints) """ # Adjust steps for melody editing if melody is not None: steps = steps // 2 if solver == "midpoint" else steps // 5 # Submit request to batch processor future = batch_processor.submit_request( text=text, melody=melody, solver=solver, steps=steps, target_flowstep=target_flowstep, regularize=regularize, regularization_strength=regularization_strength, duration=duration, model=model ) # Wait for result with timeout try: result = future.result(timeout=120) # 2 minute timeout return result except TimeoutError: raise gr.Error("Request timeout - server is overloaded") except Exception as e: raise gr.Error(f"Generation failed: {str(e)}") def create_optimized_interface(): """Create Gradio interface optimized for concurrent usage""" with gr.Blocks(title="MelodyFlow - Concurrent API") as interface: gr.Markdown(""" # MelodyFlow - Optimized for Concurrent Requests This version is optimized for handling multiple concurrent requests efficiently. Requests are automatically batched for optimal GPU utilization. """) with gr.Row(): with gr.Column(): text = gr.Text(label="Text Description", placeholder="Describe the music you want to generate...") melody = gr.Audio(label="Reference Audio (optional)", type="filepath") with gr.Row(): solver = gr.Radio(["euler", "midpoint"], label="Solver", value="euler") steps = gr.Slider(1, 128, value=50, label="Steps") with gr.Row(): duration = gr.Slider(1, 30, value=10, label="Duration (s)") model = gr.Dropdown( [f"{MODEL_PREFIX}melodyflow-t24-30secs"], value=f"{MODEL_PREFIX}melodyflow-t24-30secs", label="Model" ) generate_btn = gr.Button("Generate", variant="primary") with gr.Column(): output = gr.Audio(label="Generated Audio") generate_btn.click( fn=predict_concurrent_ui, inputs=[model, text, solver, steps, gr.State(0.0), gr.State(False), gr.State(0.0), duration, melody], outputs=output, concurrency_limit=20 # Set concurrency limit on the event listener ) gr.Examples( fn=predict_concurrent_ui, examples=[ [f"{MODEL_PREFIX}melodyflow-t24-30secs", "80s electronic track with melodic synthesizers", "euler", 50, 0.0, False, 0.0, 10.0, None], [f"{MODEL_PREFIX}melodyflow-t24-30secs", "Cheerful country song with acoustic guitars", "euler", 50, 0.0, False, 0.0, 15.0, None] ], inputs=[model, text, solver, steps, gr.State(0.0), gr.State(False), gr.State(0.0), duration, melody], outputs=output, cache_examples=False # Disable caching to avoid CSV field size errors ) return interface if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") parser.add_argument("--port", type=int, default=7860, help="Port to bind to") parser.add_argument("--share", action="store_true", help="Create public link") args = parser.parse_args() # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) # Ensure environment variables one more time before starting ensure_thread_env() # Start batch processor batch_processor.start() # Create and launch interface interface = create_optimized_interface() try: interface.queue( max_size=200, # Large queue api_open=True ).launch( server_name=args.host, server_port=args.port, share=args.share, show_api=True, max_threads=40 # Configure worker threads in launch() ) finally: batch_processor.stop()