Mahmoud Elsamadony
Update app (sync and Async requests)
c7dcb6e
from __future__ import annotations
import os
import tempfile
from typing import Dict, List, Optional, Any
import gradio as gr
import torch
from dotenv import load_dotenv
from faster_whisper import WhisperModel
from huggingface_hub import snapshot_download
# Reduce NeMo import verbosity on environments lacking certain torch.distributed symbols
os.environ.setdefault("NEMO_LOG_LEVEL", "ERROR")
# Import NeMo for NVIDIA Sortformer diarization
try:
from nemo.collections.asr.models import SortformerEncLabelModel
NEMO_AVAILABLE = True
except ImportError:
NEMO_AVAILABLE = False
SortformerEncLabelModel = None
load_dotenv()
# ---------------------------------------------------------------------------
# Configuration via environment variables (override inside HF Space settings)
# ---------------------------------------------------------------------------
# Whisper model: use same model names as Django app (tiny, base, small, medium, large-v3)
# faster-whisper will download these automatically from Hugging Face on first run
WHISPER_MODEL_SIZE = os.environ.get("WHISPER_MODEL_SIZE", "large-v3")
# Prefer GPU on Hugging Face Spaces if available, but allow override via env
def _default_device() -> str:
try:
return "cuda" if torch.cuda.is_available() else "cpu"
except Exception:
return "cpu"
WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE") or _default_device()
# Choose a sensible default compute type based on device (can be overridden by env)
# - GPU: float16 is fastest and fits T4 for small/medium; use int8_float16 to save VRAM for large-v3
# - CPU: int8_float32 works well
WHISPER_COMPUTE_TYPE = os.environ.get("WHISPER_COMPUTE_TYPE") or (
"float16" if WHISPER_DEVICE == "cuda" else "int8_float32"
)
# Diarization: NVIDIA NeMo Sortformer model
DIARIZATION_MODEL_NAME = os.environ.get(
"DIARIZATION_MODEL_NAME", "nvidia/diar_streaming_sortformer_4spk-v2"
)
# Diarization streaming configuration (using "high latency" preset for better accuracy)
# See https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2 for other configs
CHUNK_SIZE = int(os.environ.get("DIAR_CHUNK_SIZE", 124)) # high latency: 124 frames
RIGHT_CONTEXT = int(os.environ.get("DIAR_RIGHT_CONTEXT", 1))
FIFO_SIZE = int(os.environ.get("DIAR_FIFO_SIZE", 124))
UPDATE_PERIOD = int(os.environ.get("DIAR_UPDATE_PERIOD", 124))
SPEAKER_CACHE_SIZE = int(os.environ.get("DIAR_CACHE_SIZE", 188))
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
# Preload prompts/parameters
default_language = os.environ.get("DEFAULT_LANGUAGE", "ar")
initial_prompt = os.environ.get(
"INITIAL_PROMPT",
"مكالمة خدمة عملاء باللهجة السعودية. كلمات شائعة: أبشر، يعطيك العافية، رقم الطلب، الشحنة، الدفع، الفاتورة، التسليم.",
)
beam_size_default = int(os.environ.get("WHISPER_BEAM_SIZE", 5))
best_of_default = int(os.environ.get("WHISPER_BEST_OF", 5))
expected_speakers_default = int(os.environ.get("EXPECTED_SPEAKERS", 2))
# ---------------------------------------------------------------------------
# Lazy singletons for the heavy models
# ---------------------------------------------------------------------------
_whisper_model: Optional[WhisperModel] = None
_diarization_model: Optional[Any] = None
def _ensure_snapshot(repo_id: str, local_dir: str, allow_patterns: Optional[List[str]] = None) -> str:
"""Download model snapshot locally if missing."""
if os.path.isdir(local_dir) and any(os.scandir(local_dir)):
return local_dir
os.makedirs(local_dir, exist_ok=True)
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
local_dir_use_symlinks=False,
token=HF_TOKEN,
allow_patterns=allow_patterns,
)
return local_dir
def _load_whisper_model() -> WhisperModel:
"""Load Faster Whisper model lazily (singleton) - same approach as Django app"""
global _whisper_model
if _whisper_model is None:
print(f"Loading Faster Whisper model: {WHISPER_MODEL_SIZE} on {WHISPER_DEVICE} with compute_type={WHISPER_COMPUTE_TYPE}")
# Load model by name - faster-whisper downloads automatically from HuggingFace
# This is the same approach used in the Django app
_whisper_model = WhisperModel(
WHISPER_MODEL_SIZE, # Model name: tiny, base, small, medium, large-v3
device=WHISPER_DEVICE,
compute_type=WHISPER_COMPUTE_TYPE,
)
return _whisper_model
def _load_diarization_model() -> Optional[Any]:
"""Load NVIDIA NeMo Sortformer diarization model lazily (singleton)"""
global _diarization_model
if _diarization_model is None:
if not NEMO_AVAILABLE:
raise gr.Error(
"NeMo is not installed. Please install it with: pip install nemo_toolkit[asr]"
)
print(f"Loading NVIDIA Sortformer diarization model: {DIARIZATION_MODEL_NAME}...")
try:
# Load model directly from Hugging Face
_diarization_model = SortformerEncLabelModel.from_pretrained(
DIARIZATION_MODEL_NAME
)
# Switch to evaluation mode
_diarization_model.eval()
# Move to GPU if available on Spaces
if torch.cuda.is_available():
try:
_diarization_model.to("cuda")
print("[DEBUG] Moved Sortformer model to CUDA device")
except Exception:
# Fallback for modules exposing .cuda()
try:
_diarization_model.cuda()
print("[DEBUG] Moved Sortformer model to CUDA via .cuda()")
except Exception as _e:
print(f"[WARN] Could not move Sortformer model to GPU: {_e}")
# Configure streaming parameters (high latency preset for better accuracy)
# See: https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2#setting-up-streaming-configuration
_diarization_model.sortformer_modules.chunk_len = CHUNK_SIZE
_diarization_model.sortformer_modules.chunk_right_context = RIGHT_CONTEXT
_diarization_model.sortformer_modules.fifo_len = FIFO_SIZE
_diarization_model.sortformer_modules.spkcache_update_period = UPDATE_PERIOD
_diarization_model.sortformer_modules.spkcache_len = SPEAKER_CACHE_SIZE
_diarization_model.sortformer_modules._check_streaming_parameters()
print("Sortformer model loaded successfully!")
except Exception as e:
raise gr.Error(
f"Failed to load NVIDIA Sortformer diarization model.\n\n"
f"Error details: {e}\n\n"
"Solutions:\n"
" - Make sure NeMo is properly installed: pip install nemo_toolkit[asr]\n"
" - Check that you have internet access to download from Hugging Face.\n"
" - The model will be downloaded automatically on first use (~700MB)."
)
if _diarization_model is None:
raise gr.Error(
"Diarization model returned None after loading."
)
return _diarization_model
# ---------------------------------------------------------------------------
# Core inference function
# ---------------------------------------------------------------------------
def transcribe(
audio_path: str,
language: str,
enable_diarization: bool,
expected_speakers: int,
beam_size: int,
best_of: int,
) -> Dict:
if audio_path is None:
raise gr.Error("Please upload an audio file.")
model = _load_whisper_model()
# Decide whether to pass the initial prompt to the model.
# Users reported the large model sometimes echoes the prompt in output.
# Only pass initial_prompt for non-large models by default.
use_initial_prompt = True
try:
# treat any model name that starts with 'large' as large models
if isinstance(WHISPER_MODEL_SIZE, str) and WHISPER_MODEL_SIZE.startswith("large"):
use_initial_prompt = False
except Exception:
use_initial_prompt = True
initial_prompt_param = initial_prompt if use_initial_prompt else None
# Transcription parameters matching Django app configuration
try:
segments, info = model.transcribe(
audio_path,
language=language if language else None,
beam_size=beam_size,
best_of=best_of,
temperature=[0.0, 0.2, 0.4, 0.6], # Matching Django app
vad_filter=True,
vad_parameters=dict(
min_silence_duration_ms=300, # Split sooner on short pauses
speech_pad_ms=120
),
condition_on_previous_text=False, # KEY: stop cross-segment repetition
initial_prompt=initial_prompt_param,
compression_ratio_threshold=2.4,
log_prob_threshold=-1.0,
no_speech_threshold=0.45,
word_timestamps=True,
)
except Exception as e:
msg = str(e)
# Provide a clearer message for common decode failures (e.g., HTML instead of audio)
if "Invalid data found when processing input" in msg or "av.error" in msg.lower():
raise gr.Error(
"Could not decode the audio input.\n\n"
"Likely causes:\n"
" - The provided URL is not a direct audio download (e.g., a Google Drive 'view' page).\n"
" - The file is corrupted or zero-length.\n\n"
"Fixes:\n"
" - Use a direct download URL (for Drive: https://drive.google.com/uc?export=download&id=<FILE_ID>).\n"
" - Or first POST /upload, then pass the returned file pointer to the API."
)
raise gr.Error(f"Transcription failed: {msg}")
# Convert generator to list so we can iterate multiple times
segments_list = list(segments)
transcript_text = "".join(segment.text for segment in segments_list).strip()
# Defensive: if the model echoed the initial prompt at the start of the
# transcript, remove it. This handles cases where a prompt sneaks into
# the returned text (observed for some model sizes).
if initial_prompt and transcript_text.startswith(initial_prompt):
# remove the prompt and any leading punctuation/whitespace after it
transcript_text = transcript_text[len(initial_prompt) :].lstrip("\n\r\t .،,؛:-")
segments_payload: List[Dict] = []
for segment in segments_list:
segments_payload.append(
{
"start": segment.start,
"end": segment.end,
"text": segment.text,
"words": [
{
"start": word.start,
"end": word.end,
"word": word.word,
"probability": word.probability,
}
for word in (segment.words or [])
],
}
)
print(f"[DEBUG] Created {len(segments_payload)} transcript segments")
response: Dict[str, object] = {
"text": transcript_text,
"language": info.language,
"language_probability": getattr(info, "language_probability", None),
"duration": info.duration,
"segments": segments_payload,
}
if enable_diarization:
diar_model = _load_diarization_model()
# Run diarization using NeMo Sortformer
# Returns list of lists: [[start_sec, end_sec, speaker_id], ...]
try:
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=1)
print(f"[DEBUG] Diarization raw output type: {type(predicted_segments)}")
print(f"[DEBUG] Diarization raw output: {predicted_segments}")
except Exception as e:
raise gr.Error(f"Diarization failed: {e}")
# Parse NeMo output format
speaker_turns: List[Dict[str, float]] = []
if predicted_segments and len(predicted_segments) > 0:
# Get the first (and only) result for single file input
diar_result = predicted_segments[0]
print(f"[DEBUG] Diar result type: {type(diar_result)}, content: {diar_result}")
# NeMo diarize() returns a list where each element is a string formatted as:
# "start_time end_time speaker_label" (space-separated)
if isinstance(diar_result, str):
# Single string with newline-separated segments
lines = diar_result.strip().split('\n')
for line in lines:
parts = line.strip().split()
if len(parts) >= 3:
start_time = float(parts[0])
end_time = float(parts[1])
speaker_label = parts[2]
speaker_turns.append({
"start": start_time,
"end": end_time,
"speaker": speaker_label,
})
elif hasattr(diar_result, '__iter__') and not isinstance(diar_result, (str, bytes)):
# List/array of segments
for item in diar_result:
# Each item might be a string "start end speaker" or a list/array [start, end, speaker]
if isinstance(item, str):
parts = item.strip().split()
if len(parts) >= 3:
start_time = float(parts[0])
end_time = float(parts[1])
speaker_label = parts[2]
speaker_turns.append({
"start": start_time,
"end": end_time,
"speaker": speaker_label,
})
elif hasattr(item, '__getitem__'):
# Array-like: [start, end, speaker]
if hasattr(item, 'tolist'):
item = item.tolist()
if len(item) >= 3:
start_time = float(item[0])
end_time = float(item[1])
# Speaker could be numeric or string
speaker_label = str(item[2]) if not isinstance(item[2], str) else item[2]
# Ensure proper speaker format
if not speaker_label.startswith("SPEAKER_"):
speaker_label = f"SPEAKER_{speaker_label}"
speaker_turns.append({
"start": start_time,
"end": end_time,
"speaker": speaker_label,
})
print(f"[DEBUG] Parsed {len(speaker_turns)} speaker turns")
if speaker_turns:
print(f"[DEBUG] First few turns: {speaker_turns[:5]}")
# Sort speaker turns by start time
speaker_turns.sort(key=lambda x: x["start"])
# Consolidate speakers if we detected more than expected
unique_speakers = set(turn["speaker"] for turn in speaker_turns)
print(f"[DEBUG] Detected {len(unique_speakers)} unique speakers: {unique_speakers}")
if expected_speakers > 0 and len(unique_speakers) > expected_speakers:
print(f"[DEBUG] Consolidating from {len(unique_speakers)} to {expected_speakers} speakers")
# Create a mapping to merge speakers
# Strategy: Merge speakers by order of first appearance, keeping the most active ones
speaker_stats = {}
for turn in speaker_turns:
spk = turn["speaker"]
if spk not in speaker_stats:
speaker_stats[spk] = {"first_appear": turn["start"], "duration": 0, "count": 0}
speaker_stats[spk]["duration"] += turn["end"] - turn["start"]
speaker_stats[spk]["count"] += 1
# Sort speakers by total speaking duration (most active first)
sorted_speakers = sorted(speaker_stats.items(), key=lambda x: x[1]["duration"], reverse=True)
print(f"[DEBUG] Speaker activity: {[(s, round(stats['duration'], 1)) for s, stats in sorted_speakers]}")
# Keep the top N most active speakers, map others to them
kept_speakers = [s[0] for s in sorted_speakers[:expected_speakers]]
speaker_mapping = {}
for spk, stats in sorted_speakers:
if spk in kept_speakers:
speaker_mapping[spk] = spk
else:
# Map this speaker to the closest kept speaker by first appearance time
closest_kept = min(kept_speakers,
key=lambda k: abs(speaker_stats[k]["first_appear"] - stats["first_appear"]))
speaker_mapping[spk] = closest_kept
print(f"[DEBUG] Mapping {spk} -> {closest_kept}")
# Apply the mapping
for turn in speaker_turns:
turn["speaker"] = speaker_mapping[turn["speaker"]]
# Merge consecutive turns from the same speaker
merged_turns = []
for turn in speaker_turns:
if merged_turns and merged_turns[-1]["speaker"] == turn["speaker"] and \
turn["start"] - merged_turns[-1]["end"] < 1.0: # Less than 1 second gap
# Extend the previous turn
merged_turns[-1]["end"] = turn["end"]
else:
merged_turns.append(turn.copy())
speaker_turns = merged_turns
print(f"[DEBUG] After consolidation: {len(speaker_turns)} speaker turns")
# Assign speakers to each transcript segment
for segment in response["segments"]:
seg_start = segment["start"]
seg_end = segment["end"]
segment["speaker"] = None
# Find the speaker turn that has the most overlap with this segment
best_overlap = 0
best_speaker = None
for speaker_info in speaker_turns:
spk_start = speaker_info["start"]
spk_end = speaker_info["end"]
# Calculate overlap between segment and speaker turn
overlap_start = max(seg_start, spk_start)
overlap_end = min(seg_end, spk_end)
overlap_duration = max(0, overlap_end - overlap_start)
if overlap_duration > best_overlap:
best_overlap = overlap_duration
best_speaker = speaker_info["speaker"]
segment["speaker"] = best_speaker
response["speakers"] = speaker_turns
# Generate a conversation-style transcript with speaker labels
print(f"[DEBUG] Generating conversation from {len(response['segments'])} segments")
conversation_lines = []
current_speaker = None
current_text = []
for idx, segment in enumerate(response["segments"]):
speaker = segment.get("speaker", "UNKNOWN")
text = segment.get("text", "").strip()
if idx < 3: # Debug first few segments
print(f"[DEBUG] Segment {idx}: speaker={speaker}, text='{text[:50]}...'")
if speaker != current_speaker:
# Speaker changed - save previous speaker's text
if current_speaker and current_text:
line = f"{current_speaker}: {' '.join(current_text)}"
conversation_lines.append(line)
if len(conversation_lines) <= 3:
print(f"[DEBUG] Added line: {line[:100]}...")
# Start new speaker's text
current_speaker = speaker
current_text = [text] if text else []
else:
# Same speaker - accumulate text
if text:
current_text.append(text)
# Add the last speaker's text
if current_speaker and current_text:
line = f"{current_speaker}: {' '.join(current_text)}"
conversation_lines.append(line)
print(f"[DEBUG] Added final line: {line[:100]}...")
print(f"[DEBUG] Generated {len(conversation_lines)} conversation lines")
response["conversation"] = "\n\n".join(conversation_lines)
# Use conversation format for display when diarization is enabled
display_text = response["conversation"] if conversation_lines else transcript_text
else:
display_text = transcript_text
# Gradio expects two outputs (transcript textbox, json). Return both.
return display_text, response
# ---------------------------------------------------------------------------
# Gradio UI definition
# ---------------------------------------------------------------------------
def build_interface() -> gr.Blocks:
with gr.Blocks(title="VTT with Diarization (faster-whisper + NVIDIA NeMo)") as demo:
gr.Markdown(
"""
# Voice-to-Text with Optional Diarization
Powered by **faster-whisper** and **NVIDIA NeMo Sortformer** (runs locally on this Space).
"""
)
gr.Markdown(
f"Running on device: `{WHISPER_DEVICE}` with compute type: `{WHISPER_COMPUTE_TYPE}`"
)
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Upload audio (mp3, wav, m4a, ...)")
options = gr.Column()
with options:
language_input = gr.Dropdown(
label="Language",
choices=["", "ar", "en", "fr", "de", "es", "ru", "zh"],
value=default_language,
info="Leave blank for auto-detect.",
)
diarization_toggle = gr.Checkbox(
label="Enable Speaker Diarization",
value=False,
info="Uses NVIDIA Sortformer model (max 4 speakers, downloads ~700MB on first use).",
)
expected_speakers_slider = gr.Slider(
label="Expected Number of Speakers",
minimum=0,
maximum=4,
step=1,
value=expected_speakers_default,
info="Set to 0 for automatic detection, or specify 2-4 to consolidate speakers.",
)
beam_slider = gr.Slider(
label="Beam Size",
minimum=1,
maximum=10,
step=1,
value=beam_size_default,
)
best_of_slider = gr.Slider(
label="Best Of",
minimum=1,
maximum=10,
step=1,
value=best_of_default,
)
transcript_output = gr.Textbox(label="Transcript", lines=8)
json_output = gr.JSON(label="Detailed Segments & Speakers")
run_button = gr.Button("Transcribe")
run_button.click(
fn=transcribe,
inputs=[
audio_input,
language_input,
diarization_toggle,
expected_speakers_slider,
beam_slider,
best_of_slider,
],
outputs=[transcript_output, json_output],
api_name="predict",
)
gr.Markdown(
f"""
## Tips
- **Whisper model**: `{WHISPER_MODEL_SIZE}` (first run downloads model automatically)
- **Diarization model**: NVIDIA `{DIARIZATION_MODEL_NAME}` (streaming, max 4 speakers)
- Diarization downloads ~700MB on first use (cached afterward)
- Change `WHISPER_MODEL_SIZE` in Space Variables to `medium` or `large-v3` for higher accuracy
- Optimized for Arabic customer service calls with specialized initial prompt
- Streaming configuration: High latency preset (10s latency, better accuracy)
"""
)
# Use a queue to serialize work on GPU and avoid OOM on Spaces free/shared GPUs
# In Gradio 4.x, concurrency_count parameter is removed; control concurrency via Spaces hardware or routes
demo.queue(max_size=16)
return demo
demo = build_interface()
if __name__ == "__main__":
demo.launch()