|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Optimized MelodyFlow API for concurrent request handling on T4 GPU |
|
|
This version focuses on high-throughput API serving with batching |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
|
|
|
|
|
|
os.environ['OMP_NUM_THREADS'] = '1' |
|
|
os.environ['MKL_NUM_THREADS'] = '1' |
|
|
os.environ['NUMEXPR_NUM_THREADS'] = '1' |
|
|
os.environ['OPENBLAS_NUM_THREADS'] = '1' |
|
|
|
|
|
|
|
|
def ensure_thread_env(): |
|
|
"""Ensure threading environment variables stay set""" |
|
|
for key, value in [('OMP_NUM_THREADS', '1'), ('MKL_NUM_THREADS', '1'), |
|
|
('NUMEXPR_NUM_THREADS', '1'), ('OPENBLAS_NUM_THREADS', '1')]: |
|
|
if os.environ.get(key) != value: |
|
|
os.environ[key] = value |
|
|
print(f"Reset {key} to {value}") |
|
|
|
|
|
|
|
|
ensure_thread_env() |
|
|
|
|
|
import spaces |
|
|
import asyncio |
|
|
import threading |
|
|
import time |
|
|
import uuid |
|
|
import base64 |
|
|
import logging |
|
|
from concurrent.futures import ThreadPoolExecutor, Future |
|
|
from queue import Queue, Empty |
|
|
from tempfile import NamedTemporaryFile |
|
|
from pathlib import Path |
|
|
import typing as tp |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
import gradio as gr |
|
|
from audiocraft.data.audio_utils import convert_audio |
|
|
from audiocraft.data.audio import audio_read, audio_write |
|
|
from audiocraft.models import MelodyFlow |
|
|
|
|
|
|
|
|
import csv |
|
|
csv.field_size_limit(1000000) |
|
|
|
|
|
|
|
|
|
|
|
MODEL_PREFIX = "facebook/" |
|
|
BATCH_SIZE = 4 |
|
|
BATCH_TIMEOUT = 1.5 |
|
|
MAX_QUEUE_SIZE = 100 |
|
|
MAX_CONCURRENT_BATCHES = 2 |
|
|
|
|
|
|
|
|
class FileCleaner: |
|
|
"""Simple file cleaner for temporary audio files""" |
|
|
def __init__(self, file_lifetime: float = 3600): |
|
|
self.file_lifetime = file_lifetime |
|
|
self.files = [] |
|
|
|
|
|
def add(self, path: tp.Union[str, Path]): |
|
|
self._cleanup() |
|
|
self.files.append((time.time(), Path(path))) |
|
|
|
|
|
def _cleanup(self): |
|
|
now = time.time() |
|
|
for time_added, path in list(self.files): |
|
|
if now - time_added > self.file_lifetime: |
|
|
if path.exists(): |
|
|
path.unlink() |
|
|
self.files.pop(0) |
|
|
else: |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
file_cleaner = FileCleaner() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GenerationRequest: |
|
|
"""Represents a single generation request""" |
|
|
request_id: str |
|
|
text: str |
|
|
melody: tp.Optional[str] |
|
|
solver: str |
|
|
steps: int |
|
|
target_flowstep: float |
|
|
regularize: bool |
|
|
regularization_strength: float |
|
|
duration: float |
|
|
model: str |
|
|
future: Future |
|
|
created_at: float |
|
|
|
|
|
|
|
|
class OptimizedBatchProcessor: |
|
|
"""Highly optimized batch processor for T4 GPU""" |
|
|
|
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.model_lock = threading.Lock() |
|
|
self.request_queue = Queue(maxsize=MAX_QUEUE_SIZE) |
|
|
self.current_batch = [] |
|
|
self.batch_start_time = None |
|
|
self.processing = False |
|
|
self.stop_event = threading.Event() |
|
|
self.executor = ThreadPoolExecutor(max_workers=MAX_CONCURRENT_BATCHES) |
|
|
|
|
|
def start(self): |
|
|
"""Start the batch processing service""" |
|
|
self.thread = threading.Thread(target=self._batch_loop, daemon=True) |
|
|
self.thread.start() |
|
|
logging.info("Batch processor started") |
|
|
|
|
|
def stop(self): |
|
|
"""Stop the batch processing service""" |
|
|
self.stop_event.set() |
|
|
self.executor.shutdown(wait=True) |
|
|
|
|
|
def submit_request(self, text: str, melody: tp.Optional[str], |
|
|
solver: str, steps: int, target_flowstep: float, |
|
|
regularize: bool, regularization_strength: float, |
|
|
duration: float, model: str) -> Future: |
|
|
"""Submit a generation request and return a future""" |
|
|
|
|
|
request = GenerationRequest( |
|
|
request_id=str(uuid.uuid4()), |
|
|
text=text, |
|
|
melody=melody, |
|
|
solver=solver, |
|
|
steps=steps, |
|
|
target_flowstep=target_flowstep, |
|
|
regularize=regularize, |
|
|
regularization_strength=regularization_strength, |
|
|
duration=duration, |
|
|
model=model, |
|
|
future=Future(), |
|
|
created_at=time.time() |
|
|
) |
|
|
|
|
|
print(f"📝 Submitting request {request.request_id} with text: '{text[:30]}...'") |
|
|
|
|
|
try: |
|
|
self.request_queue.put_nowait(request) |
|
|
print(f"✅ Request {request.request_id} queued successfully") |
|
|
return request.future |
|
|
except: |
|
|
|
|
|
print(f"❌ Queue full for request {request.request_id}") |
|
|
request.future.set_exception(Exception("Server is busy, please try again")) |
|
|
return request.future |
|
|
|
|
|
def _batch_loop(self): |
|
|
"""Main batch processing loop""" |
|
|
while not self.stop_event.is_set(): |
|
|
try: |
|
|
|
|
|
try: |
|
|
request = self.request_queue.get(timeout=0.1) |
|
|
self.current_batch.append(request) |
|
|
|
|
|
if self.batch_start_time is None: |
|
|
self.batch_start_time = time.time() |
|
|
|
|
|
except Empty: |
|
|
|
|
|
if self._should_process_batch(): |
|
|
self._submit_batch() |
|
|
continue |
|
|
|
|
|
|
|
|
if self._should_process_batch(): |
|
|
self._submit_batch() |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Error in batch loop: {e}") |
|
|
|
|
|
def _should_process_batch(self) -> bool: |
|
|
"""Determine if current batch should be processed""" |
|
|
if not self.current_batch: |
|
|
return False |
|
|
|
|
|
batch_age = time.time() - (self.batch_start_time or time.time()) |
|
|
return (len(self.current_batch) >= BATCH_SIZE or |
|
|
batch_age >= BATCH_TIMEOUT) |
|
|
|
|
|
def _submit_batch(self): |
|
|
"""Submit current batch for processing""" |
|
|
if not self.current_batch: |
|
|
return |
|
|
|
|
|
batch = self.current_batch.copy() |
|
|
self.current_batch = [] |
|
|
self.batch_start_time = None |
|
|
|
|
|
|
|
|
self.executor.submit(self._process_batch, batch) |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def _process_batch(self, batch: tp.List[GenerationRequest]): |
|
|
"""Process a batch of requests on GPU""" |
|
|
try: |
|
|
|
|
|
ensure_thread_env() |
|
|
|
|
|
logging.info(f"Processing batch of {len(batch)} requests") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
model_version = batch[0].model |
|
|
self._load_model(model_version) |
|
|
|
|
|
|
|
|
gen_requests = [req for req in batch if req.melody is None] |
|
|
edit_requests = [req for req in batch if req.melody is not None] |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
if gen_requests: |
|
|
gen_results = self._process_generation_batch(gen_requests) |
|
|
results.update(gen_results) |
|
|
|
|
|
|
|
|
if edit_requests: |
|
|
edit_results = self._process_editing_batch(edit_requests) |
|
|
results.update(edit_results) |
|
|
|
|
|
|
|
|
for request in batch: |
|
|
if request.request_id in results: |
|
|
result_data = results[request.request_id] |
|
|
print(f"🔄 Setting result for request {request.request_id}: {type(result_data)}") |
|
|
request.future.set_result(result_data) |
|
|
else: |
|
|
print(f"❌ No result found for request {request.request_id}") |
|
|
request.future.set_exception(Exception("Processing failed")) |
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
logging.info(f"Batch processed in {processing_time:.2f}s") |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Batch processing error: {e}") |
|
|
for request in batch: |
|
|
request.future.set_exception(e) |
|
|
|
|
|
def _load_model(self, version: str): |
|
|
"""Thread-safe model loading""" |
|
|
|
|
|
ensure_thread_env() |
|
|
|
|
|
with self.model_lock: |
|
|
if self.model is None or self.model.name != version: |
|
|
if self.model is not None: |
|
|
del self.model |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
self.model = MelodyFlow.get_pretrained(version) |
|
|
logging.info(f"Model {version} loaded") |
|
|
|
|
|
def _process_generation_batch(self, requests: tp.List[GenerationRequest]) -> dict: |
|
|
"""Process generation requests in batch""" |
|
|
if not requests: |
|
|
return {} |
|
|
|
|
|
|
|
|
params = requests[0] |
|
|
self.model.set_generation_params( |
|
|
solver=params.solver, |
|
|
steps=params.steps, |
|
|
duration=params.duration |
|
|
) |
|
|
|
|
|
|
|
|
texts = [req.text for req in requests] |
|
|
|
|
|
|
|
|
outputs = self.model.generate(texts, progress=False, return_tokens=False) |
|
|
outputs = outputs.detach().cpu().float() |
|
|
|
|
|
|
|
|
results = {} |
|
|
for i, request in enumerate(requests): |
|
|
audio_base64 = self._audio_to_base64(outputs[i]) |
|
|
results[request.request_id] = { |
|
|
"audio": audio_base64, |
|
|
"format": "wav" |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
def _process_editing_batch(self, requests: tp.List[GenerationRequest]) -> dict: |
|
|
"""Process editing requests individually""" |
|
|
results = {} |
|
|
|
|
|
for request in requests: |
|
|
try: |
|
|
self.model.set_editing_params( |
|
|
solver=request.solver, |
|
|
steps=request.steps, |
|
|
target_flowstep=request.target_flowstep, |
|
|
regularize=request.regularize, |
|
|
lambda_kl=request.regularization_strength |
|
|
) |
|
|
|
|
|
|
|
|
melody, sr = audio_read(request.melody) |
|
|
if melody.dim() == 2: |
|
|
melody = melody[None] |
|
|
if melody.shape[-1] > int(sr * self.model.duration): |
|
|
melody = melody[..., :int(sr * self.model.duration)] |
|
|
|
|
|
melody = convert_audio(melody, sr, 48000, 2) |
|
|
melody = self.model.encode_audio(melody.to(self.model.device)) |
|
|
|
|
|
|
|
|
output = self.model.edit( |
|
|
prompt_tokens=melody, |
|
|
descriptions=[request.text], |
|
|
src_descriptions=[""], |
|
|
progress=False, |
|
|
return_tokens=False |
|
|
) |
|
|
|
|
|
output = output.detach().cpu().float()[0] |
|
|
audio_base64 = self._audio_to_base64(output) |
|
|
|
|
|
results[request.request_id] = { |
|
|
"audio": audio_base64, |
|
|
"format": "wav" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Error processing edit request {request.request_id}: {e}") |
|
|
|
|
|
|
|
|
return results |
|
|
|
|
|
def _audio_to_base64(self, audio_tensor: torch.Tensor) -> str: |
|
|
"""Convert audio tensor to base64 string""" |
|
|
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: |
|
|
audio_write( |
|
|
file.name, audio_tensor, self.model.sample_rate, |
|
|
strategy="loudness", loudness_headroom_db=16, |
|
|
loudness_compressor=True, add_suffix=False |
|
|
) |
|
|
|
|
|
with open(file.name, 'rb') as f: |
|
|
audio_bytes = f.read() |
|
|
|
|
|
|
|
|
Path(file.name).unlink() |
|
|
|
|
|
return base64.b64encode(audio_bytes).decode('utf-8') |
|
|
|
|
|
|
|
|
|
|
|
batch_processor = OptimizedBatchProcessor() |
|
|
|
|
|
|
|
|
def predict_concurrent(model: str, text: str, solver: str = "euler", |
|
|
steps: int = 50, target_flowstep: float = 0.0, |
|
|
regularize: bool = False, regularization_strength: float = 0.0, |
|
|
duration: float = 10.0, melody: tp.Optional[str] = None) -> dict: |
|
|
""" |
|
|
Non-blocking predict function optimized for concurrent requests |
|
|
""" |
|
|
|
|
|
|
|
|
if melody is not None: |
|
|
steps = steps // 2 if solver == "midpoint" else steps // 5 |
|
|
|
|
|
|
|
|
future = batch_processor.submit_request( |
|
|
text=text, |
|
|
melody=melody, |
|
|
solver=solver, |
|
|
steps=steps, |
|
|
target_flowstep=target_flowstep, |
|
|
regularize=regularize, |
|
|
regularization_strength=regularization_strength, |
|
|
duration=duration, |
|
|
model=model |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
result = future.result(timeout=120) |
|
|
|
|
|
|
|
|
if isinstance(result, dict): |
|
|
print(f"✅ Received result with keys: {list(result.keys())}") |
|
|
if "audio" in result: |
|
|
audio_len = len(result["audio"]) if result["audio"] else 0 |
|
|
print(f"🎵 Audio data length: {audio_len} characters") |
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"message": f"Audio generated successfully ({audio_len} bytes)", |
|
|
"format": result.get("format", "wav"), |
|
|
"duration": duration, |
|
|
"text_prompt": text[:50] + "..." if len(text) > 50 else text, |
|
|
|
|
|
|
|
|
"audio_preview": result["audio"][:100] + "..." if result["audio"] else "No audio data" |
|
|
} |
|
|
else: |
|
|
print("⚠️ No audio key in result") |
|
|
return {"status": "error", "message": "No audio generated"} |
|
|
else: |
|
|
print(f"⚠️ Unexpected result type: {type(result)}") |
|
|
return {"status": "error", "message": f"Unexpected result type: {type(result)}"} |
|
|
|
|
|
except TimeoutError: |
|
|
print("⏰ Request timeout") |
|
|
raise gr.Error("Request timeout - server is overloaded") |
|
|
except Exception as e: |
|
|
print(f"💥 Exception: {str(e)}") |
|
|
raise gr.Error(f"Generation failed: {str(e)}") |
|
|
|
|
|
|
|
|
def predict_concurrent_ui(model: str, text: str, solver: str = "euler", |
|
|
steps: int = 50, target_flowstep: float = 0.0, |
|
|
regularize: bool = False, regularization_strength: float = 0.0, |
|
|
duration: float = 10.0, melody: tp.Optional[str] = None) -> str: |
|
|
""" |
|
|
UI-optimized predict function that returns audio file path for Gradio Audio component |
|
|
""" |
|
|
|
|
|
|
|
|
if melody is not None: |
|
|
steps = steps // 2 if solver == "midpoint" else steps // 5 |
|
|
|
|
|
|
|
|
future = batch_processor.submit_request( |
|
|
text=text, |
|
|
melody=melody, |
|
|
solver=solver, |
|
|
steps=steps, |
|
|
target_flowstep=target_flowstep, |
|
|
regularize=regularize, |
|
|
regularization_strength=regularization_strength, |
|
|
duration=duration, |
|
|
model=model |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
result = future.result(timeout=120) |
|
|
|
|
|
|
|
|
if isinstance(result, dict) and "audio" in result: |
|
|
print(f"✅ Received audio result, converting to file...") |
|
|
|
|
|
|
|
|
import base64 |
|
|
from tempfile import NamedTemporaryFile |
|
|
|
|
|
audio_data = base64.b64decode(result["audio"]) |
|
|
with NamedTemporaryFile(mode="wb", suffix=".wav", delete=False) as temp_file: |
|
|
temp_file.write(audio_data) |
|
|
temp_file_path = temp_file.name |
|
|
|
|
|
file_cleaner.add(temp_file_path) |
|
|
print(f"🎵 Audio saved to: {temp_file_path}") |
|
|
return temp_file_path |
|
|
else: |
|
|
raise gr.Error("No audio data received") |
|
|
|
|
|
except TimeoutError: |
|
|
print("⏰ Request timeout") |
|
|
raise gr.Error("Request timeout - server is overloaded") |
|
|
except Exception as e: |
|
|
print(f"💥 Exception: {str(e)}") |
|
|
raise gr.Error(f"Generation failed: {str(e)}") |
|
|
|
|
|
|
|
|
def predict_concurrent(model: str, text: str, solver: str = "euler", |
|
|
steps: int = 50, target_flowstep: float = 0.0, |
|
|
regularize: bool = False, regularization_strength: float = 0.0, |
|
|
duration: float = 10.0, melody: tp.Optional[str] = None) -> dict: |
|
|
""" |
|
|
API predict function that returns base64 audio data (for API endpoints) |
|
|
""" |
|
|
|
|
|
|
|
|
if melody is not None: |
|
|
steps = steps // 2 if solver == "midpoint" else steps // 5 |
|
|
|
|
|
|
|
|
future = batch_processor.submit_request( |
|
|
text=text, |
|
|
melody=melody, |
|
|
solver=solver, |
|
|
steps=steps, |
|
|
target_flowstep=target_flowstep, |
|
|
regularize=regularize, |
|
|
regularization_strength=regularization_strength, |
|
|
duration=duration, |
|
|
model=model |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
result = future.result(timeout=120) |
|
|
return result |
|
|
except TimeoutError: |
|
|
raise gr.Error("Request timeout - server is overloaded") |
|
|
except Exception as e: |
|
|
raise gr.Error(f"Generation failed: {str(e)}") |
|
|
|
|
|
|
|
|
def create_optimized_interface(): |
|
|
"""Create Gradio interface optimized for concurrent usage""" |
|
|
|
|
|
with gr.Blocks(title="MelodyFlow - Concurrent API") as interface: |
|
|
gr.Markdown(""" |
|
|
# MelodyFlow - Optimized for Concurrent Requests |
|
|
|
|
|
This version is optimized for handling multiple concurrent requests efficiently. |
|
|
Requests are automatically batched for optimal GPU utilization. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
text = gr.Text(label="Text Description", placeholder="Describe the music you want to generate...") |
|
|
melody = gr.Audio(label="Reference Audio (optional)", type="filepath") |
|
|
|
|
|
with gr.Row(): |
|
|
solver = gr.Radio(["euler", "midpoint"], label="Solver", value="euler") |
|
|
steps = gr.Slider(1, 128, value=50, label="Steps") |
|
|
|
|
|
with gr.Row(): |
|
|
duration = gr.Slider(1, 30, value=10, label="Duration (s)") |
|
|
model = gr.Dropdown( |
|
|
[f"{MODEL_PREFIX}melodyflow-t24-30secs"], |
|
|
value=f"{MODEL_PREFIX}melodyflow-t24-30secs", |
|
|
label="Model" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("Generate", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output = gr.Audio(label="Generated Audio") |
|
|
|
|
|
generate_btn.click( |
|
|
fn=predict_concurrent_ui, |
|
|
inputs=[model, text, solver, steps, gr.State(0.0), |
|
|
gr.State(False), gr.State(0.0), duration, melody], |
|
|
outputs=output, |
|
|
concurrency_limit=20 |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
fn=predict_concurrent_ui, |
|
|
examples=[ |
|
|
[f"{MODEL_PREFIX}melodyflow-t24-30secs", |
|
|
"80s electronic track with melodic synthesizers", |
|
|
"euler", 50, 0.0, False, 0.0, 10.0, None], |
|
|
[f"{MODEL_PREFIX}melodyflow-t24-30secs", |
|
|
"Cheerful country song with acoustic guitars", |
|
|
"euler", 50, 0.0, False, 0.0, 15.0, None] |
|
|
], |
|
|
inputs=[model, text, solver, steps, gr.State(0.0), |
|
|
gr.State(False), gr.State(0.0), duration, melody], |
|
|
outputs=output, |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
return interface |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") |
|
|
parser.add_argument("--port", type=int, default=7860, help="Port to bind to") |
|
|
parser.add_argument("--share", action="store_true", help="Create public link") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
|
|
|
|
|
|
ensure_thread_env() |
|
|
|
|
|
|
|
|
batch_processor.start() |
|
|
|
|
|
|
|
|
interface = create_optimized_interface() |
|
|
|
|
|
try: |
|
|
interface.queue( |
|
|
max_size=200, |
|
|
api_open=True |
|
|
).launch( |
|
|
server_name=args.host, |
|
|
server_port=args.port, |
|
|
share=args.share, |
|
|
show_api=True, |
|
|
max_threads=40 |
|
|
) |
|
|
finally: |
|
|
batch_processor.stop() |