from __future__ import annotations import os import tempfile from typing import Dict, List, Optional, Any import gradio as gr import torch from dotenv import load_dotenv from faster_whisper import WhisperModel from huggingface_hub import snapshot_download # Reduce NeMo import verbosity on environments lacking certain torch.distributed symbols os.environ.setdefault("NEMO_LOG_LEVEL", "ERROR") # Import NeMo for NVIDIA Sortformer diarization try: from nemo.collections.asr.models import SortformerEncLabelModel NEMO_AVAILABLE = True except ImportError: NEMO_AVAILABLE = False SortformerEncLabelModel = None load_dotenv() # --------------------------------------------------------------------------- # Configuration via environment variables (override inside HF Space settings) # --------------------------------------------------------------------------- # Whisper model: use same model names as Django app (tiny, base, small, medium, large-v3) # faster-whisper will download these automatically from Hugging Face on first run WHISPER_MODEL_SIZE = os.environ.get("WHISPER_MODEL_SIZE", "large-v3") # Prefer GPU on Hugging Face Spaces if available, but allow override via env def _default_device() -> str: try: return "cuda" if torch.cuda.is_available() else "cpu" except Exception: return "cpu" WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE") or _default_device() # Choose a sensible default compute type based on device (can be overridden by env) # - GPU: float16 is fastest and fits T4 for small/medium; use int8_float16 to save VRAM for large-v3 # - CPU: int8_float32 works well WHISPER_COMPUTE_TYPE = os.environ.get("WHISPER_COMPUTE_TYPE") or ( "float16" if WHISPER_DEVICE == "cuda" else "int8_float32" ) # Diarization: NVIDIA NeMo Sortformer model DIARIZATION_MODEL_NAME = os.environ.get( "DIARIZATION_MODEL_NAME", "nvidia/diar_streaming_sortformer_4spk-v2" ) # Diarization streaming configuration (using "high latency" preset for better accuracy) # See https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2 for other configs CHUNK_SIZE = int(os.environ.get("DIAR_CHUNK_SIZE", 124)) # high latency: 124 frames RIGHT_CONTEXT = int(os.environ.get("DIAR_RIGHT_CONTEXT", 1)) FIFO_SIZE = int(os.environ.get("DIAR_FIFO_SIZE", 124)) UPDATE_PERIOD = int(os.environ.get("DIAR_UPDATE_PERIOD", 124)) SPEAKER_CACHE_SIZE = int(os.environ.get("DIAR_CACHE_SIZE", 188)) HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") # Preload prompts/parameters default_language = os.environ.get("DEFAULT_LANGUAGE", "ar") initial_prompt = os.environ.get( "INITIAL_PROMPT", "مكالمة خدمة عملاء باللهجة السعودية. كلمات شائعة: أبشر، يعطيك العافية، رقم الطلب، الشحنة، الدفع، الفاتورة، التسليم.", ) beam_size_default = int(os.environ.get("WHISPER_BEAM_SIZE", 5)) best_of_default = int(os.environ.get("WHISPER_BEST_OF", 5)) expected_speakers_default = int(os.environ.get("EXPECTED_SPEAKERS", 2)) # --------------------------------------------------------------------------- # Lazy singletons for the heavy models # --------------------------------------------------------------------------- _whisper_model: Optional[WhisperModel] = None _diarization_model: Optional[Any] = None def _ensure_snapshot(repo_id: str, local_dir: str, allow_patterns: Optional[List[str]] = None) -> str: """Download model snapshot locally if missing.""" if os.path.isdir(local_dir) and any(os.scandir(local_dir)): return local_dir os.makedirs(local_dir, exist_ok=True) snapshot_download( repo_id=repo_id, local_dir=local_dir, local_dir_use_symlinks=False, token=HF_TOKEN, allow_patterns=allow_patterns, ) return local_dir def _load_whisper_model() -> WhisperModel: """Load Faster Whisper model lazily (singleton) - same approach as Django app""" global _whisper_model if _whisper_model is None: print(f"Loading Faster Whisper model: {WHISPER_MODEL_SIZE} on {WHISPER_DEVICE} with compute_type={WHISPER_COMPUTE_TYPE}") # Load model by name - faster-whisper downloads automatically from HuggingFace # This is the same approach used in the Django app _whisper_model = WhisperModel( WHISPER_MODEL_SIZE, # Model name: tiny, base, small, medium, large-v3 device=WHISPER_DEVICE, compute_type=WHISPER_COMPUTE_TYPE, ) return _whisper_model def _load_diarization_model() -> Optional[Any]: """Load NVIDIA NeMo Sortformer diarization model lazily (singleton)""" global _diarization_model if _diarization_model is None: if not NEMO_AVAILABLE: raise gr.Error( "NeMo is not installed. Please install it with: pip install nemo_toolkit[asr]" ) print(f"Loading NVIDIA Sortformer diarization model: {DIARIZATION_MODEL_NAME}...") try: # Load model directly from Hugging Face _diarization_model = SortformerEncLabelModel.from_pretrained( DIARIZATION_MODEL_NAME ) # Switch to evaluation mode _diarization_model.eval() # Move to GPU if available on Spaces if torch.cuda.is_available(): try: _diarization_model.to("cuda") print("[DEBUG] Moved Sortformer model to CUDA device") except Exception: # Fallback for modules exposing .cuda() try: _diarization_model.cuda() print("[DEBUG] Moved Sortformer model to CUDA via .cuda()") except Exception as _e: print(f"[WARN] Could not move Sortformer model to GPU: {_e}") # Configure streaming parameters (high latency preset for better accuracy) # See: https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2#setting-up-streaming-configuration _diarization_model.sortformer_modules.chunk_len = CHUNK_SIZE _diarization_model.sortformer_modules.chunk_right_context = RIGHT_CONTEXT _diarization_model.sortformer_modules.fifo_len = FIFO_SIZE _diarization_model.sortformer_modules.spkcache_update_period = UPDATE_PERIOD _diarization_model.sortformer_modules.spkcache_len = SPEAKER_CACHE_SIZE _diarization_model.sortformer_modules._check_streaming_parameters() print("Sortformer model loaded successfully!") except Exception as e: raise gr.Error( f"Failed to load NVIDIA Sortformer diarization model.\n\n" f"Error details: {e}\n\n" "Solutions:\n" " - Make sure NeMo is properly installed: pip install nemo_toolkit[asr]\n" " - Check that you have internet access to download from Hugging Face.\n" " - The model will be downloaded automatically on first use (~700MB)." ) if _diarization_model is None: raise gr.Error( "Diarization model returned None after loading." ) return _diarization_model # --------------------------------------------------------------------------- # Core inference function # --------------------------------------------------------------------------- def transcribe( audio_path: str, language: str, enable_diarization: bool, expected_speakers: int, beam_size: int, best_of: int, ) -> Dict: if audio_path is None: raise gr.Error("Please upload an audio file.") model = _load_whisper_model() # Decide whether to pass the initial prompt to the model. # Users reported the large model sometimes echoes the prompt in output. # Only pass initial_prompt for non-large models by default. use_initial_prompt = True try: # treat any model name that starts with 'large' as large models if isinstance(WHISPER_MODEL_SIZE, str) and WHISPER_MODEL_SIZE.startswith("large"): use_initial_prompt = False except Exception: use_initial_prompt = True initial_prompt_param = initial_prompt if use_initial_prompt else None # Transcription parameters matching Django app configuration try: segments, info = model.transcribe( audio_path, language=language if language else None, beam_size=beam_size, best_of=best_of, temperature=[0.0, 0.2, 0.4, 0.6], # Matching Django app vad_filter=True, vad_parameters=dict( min_silence_duration_ms=300, # Split sooner on short pauses speech_pad_ms=120 ), condition_on_previous_text=False, # KEY: stop cross-segment repetition initial_prompt=initial_prompt_param, compression_ratio_threshold=2.4, log_prob_threshold=-1.0, no_speech_threshold=0.45, word_timestamps=True, ) except Exception as e: msg = str(e) # Provide a clearer message for common decode failures (e.g., HTML instead of audio) if "Invalid data found when processing input" in msg or "av.error" in msg.lower(): raise gr.Error( "Could not decode the audio input.\n\n" "Likely causes:\n" " - The provided URL is not a direct audio download (e.g., a Google Drive 'view' page).\n" " - The file is corrupted or zero-length.\n\n" "Fixes:\n" " - Use a direct download URL (for Drive: https://drive.google.com/uc?export=download&id=).\n" " - Or first POST /upload, then pass the returned file pointer to the API." ) raise gr.Error(f"Transcription failed: {msg}") # Convert generator to list so we can iterate multiple times segments_list = list(segments) transcript_text = "".join(segment.text for segment in segments_list).strip() # Defensive: if the model echoed the initial prompt at the start of the # transcript, remove it. This handles cases where a prompt sneaks into # the returned text (observed for some model sizes). if initial_prompt and transcript_text.startswith(initial_prompt): # remove the prompt and any leading punctuation/whitespace after it transcript_text = transcript_text[len(initial_prompt) :].lstrip("\n\r\t .،,؛:-") segments_payload: List[Dict] = [] for segment in segments_list: segments_payload.append( { "start": segment.start, "end": segment.end, "text": segment.text, "words": [ { "start": word.start, "end": word.end, "word": word.word, "probability": word.probability, } for word in (segment.words or []) ], } ) print(f"[DEBUG] Created {len(segments_payload)} transcript segments") response: Dict[str, object] = { "text": transcript_text, "language": info.language, "language_probability": getattr(info, "language_probability", None), "duration": info.duration, "segments": segments_payload, } if enable_diarization: diar_model = _load_diarization_model() # Run diarization using NeMo Sortformer # Returns list of lists: [[start_sec, end_sec, speaker_id], ...] try: predicted_segments = diar_model.diarize(audio=audio_path, batch_size=1) print(f"[DEBUG] Diarization raw output type: {type(predicted_segments)}") print(f"[DEBUG] Diarization raw output: {predicted_segments}") except Exception as e: raise gr.Error(f"Diarization failed: {e}") # Parse NeMo output format speaker_turns: List[Dict[str, float]] = [] if predicted_segments and len(predicted_segments) > 0: # Get the first (and only) result for single file input diar_result = predicted_segments[0] print(f"[DEBUG] Diar result type: {type(diar_result)}, content: {diar_result}") # NeMo diarize() returns a list where each element is a string formatted as: # "start_time end_time speaker_label" (space-separated) if isinstance(diar_result, str): # Single string with newline-separated segments lines = diar_result.strip().split('\n') for line in lines: parts = line.strip().split() if len(parts) >= 3: start_time = float(parts[0]) end_time = float(parts[1]) speaker_label = parts[2] speaker_turns.append({ "start": start_time, "end": end_time, "speaker": speaker_label, }) elif hasattr(diar_result, '__iter__') and not isinstance(diar_result, (str, bytes)): # List/array of segments for item in diar_result: # Each item might be a string "start end speaker" or a list/array [start, end, speaker] if isinstance(item, str): parts = item.strip().split() if len(parts) >= 3: start_time = float(parts[0]) end_time = float(parts[1]) speaker_label = parts[2] speaker_turns.append({ "start": start_time, "end": end_time, "speaker": speaker_label, }) elif hasattr(item, '__getitem__'): # Array-like: [start, end, speaker] if hasattr(item, 'tolist'): item = item.tolist() if len(item) >= 3: start_time = float(item[0]) end_time = float(item[1]) # Speaker could be numeric or string speaker_label = str(item[2]) if not isinstance(item[2], str) else item[2] # Ensure proper speaker format if not speaker_label.startswith("SPEAKER_"): speaker_label = f"SPEAKER_{speaker_label}" speaker_turns.append({ "start": start_time, "end": end_time, "speaker": speaker_label, }) print(f"[DEBUG] Parsed {len(speaker_turns)} speaker turns") if speaker_turns: print(f"[DEBUG] First few turns: {speaker_turns[:5]}") # Sort speaker turns by start time speaker_turns.sort(key=lambda x: x["start"]) # Consolidate speakers if we detected more than expected unique_speakers = set(turn["speaker"] for turn in speaker_turns) print(f"[DEBUG] Detected {len(unique_speakers)} unique speakers: {unique_speakers}") if expected_speakers > 0 and len(unique_speakers) > expected_speakers: print(f"[DEBUG] Consolidating from {len(unique_speakers)} to {expected_speakers} speakers") # Create a mapping to merge speakers # Strategy: Merge speakers by order of first appearance, keeping the most active ones speaker_stats = {} for turn in speaker_turns: spk = turn["speaker"] if spk not in speaker_stats: speaker_stats[spk] = {"first_appear": turn["start"], "duration": 0, "count": 0} speaker_stats[spk]["duration"] += turn["end"] - turn["start"] speaker_stats[spk]["count"] += 1 # Sort speakers by total speaking duration (most active first) sorted_speakers = sorted(speaker_stats.items(), key=lambda x: x[1]["duration"], reverse=True) print(f"[DEBUG] Speaker activity: {[(s, round(stats['duration'], 1)) for s, stats in sorted_speakers]}") # Keep the top N most active speakers, map others to them kept_speakers = [s[0] for s in sorted_speakers[:expected_speakers]] speaker_mapping = {} for spk, stats in sorted_speakers: if spk in kept_speakers: speaker_mapping[spk] = spk else: # Map this speaker to the closest kept speaker by first appearance time closest_kept = min(kept_speakers, key=lambda k: abs(speaker_stats[k]["first_appear"] - stats["first_appear"])) speaker_mapping[spk] = closest_kept print(f"[DEBUG] Mapping {spk} -> {closest_kept}") # Apply the mapping for turn in speaker_turns: turn["speaker"] = speaker_mapping[turn["speaker"]] # Merge consecutive turns from the same speaker merged_turns = [] for turn in speaker_turns: if merged_turns and merged_turns[-1]["speaker"] == turn["speaker"] and \ turn["start"] - merged_turns[-1]["end"] < 1.0: # Less than 1 second gap # Extend the previous turn merged_turns[-1]["end"] = turn["end"] else: merged_turns.append(turn.copy()) speaker_turns = merged_turns print(f"[DEBUG] After consolidation: {len(speaker_turns)} speaker turns") # Assign speakers to each transcript segment for segment in response["segments"]: seg_start = segment["start"] seg_end = segment["end"] segment["speaker"] = None # Find the speaker turn that has the most overlap with this segment best_overlap = 0 best_speaker = None for speaker_info in speaker_turns: spk_start = speaker_info["start"] spk_end = speaker_info["end"] # Calculate overlap between segment and speaker turn overlap_start = max(seg_start, spk_start) overlap_end = min(seg_end, spk_end) overlap_duration = max(0, overlap_end - overlap_start) if overlap_duration > best_overlap: best_overlap = overlap_duration best_speaker = speaker_info["speaker"] segment["speaker"] = best_speaker response["speakers"] = speaker_turns # Generate a conversation-style transcript with speaker labels print(f"[DEBUG] Generating conversation from {len(response['segments'])} segments") conversation_lines = [] current_speaker = None current_text = [] for idx, segment in enumerate(response["segments"]): speaker = segment.get("speaker", "UNKNOWN") text = segment.get("text", "").strip() if idx < 3: # Debug first few segments print(f"[DEBUG] Segment {idx}: speaker={speaker}, text='{text[:50]}...'") if speaker != current_speaker: # Speaker changed - save previous speaker's text if current_speaker and current_text: line = f"{current_speaker}: {' '.join(current_text)}" conversation_lines.append(line) if len(conversation_lines) <= 3: print(f"[DEBUG] Added line: {line[:100]}...") # Start new speaker's text current_speaker = speaker current_text = [text] if text else [] else: # Same speaker - accumulate text if text: current_text.append(text) # Add the last speaker's text if current_speaker and current_text: line = f"{current_speaker}: {' '.join(current_text)}" conversation_lines.append(line) print(f"[DEBUG] Added final line: {line[:100]}...") print(f"[DEBUG] Generated {len(conversation_lines)} conversation lines") response["conversation"] = "\n\n".join(conversation_lines) # Use conversation format for display when diarization is enabled display_text = response["conversation"] if conversation_lines else transcript_text else: display_text = transcript_text # Gradio expects two outputs (transcript textbox, json). Return both. return display_text, response # --------------------------------------------------------------------------- # Gradio UI definition # --------------------------------------------------------------------------- def build_interface() -> gr.Blocks: with gr.Blocks(title="VTT with Diarization (faster-whisper + NVIDIA NeMo)") as demo: gr.Markdown( """ # Voice-to-Text with Optional Diarization Powered by **faster-whisper** and **NVIDIA NeMo Sortformer** (runs locally on this Space). """ ) gr.Markdown( f"Running on device: `{WHISPER_DEVICE}` with compute type: `{WHISPER_COMPUTE_TYPE}`" ) with gr.Row(): audio_input = gr.Audio(type="filepath", label="Upload audio (mp3, wav, m4a, ...)") options = gr.Column() with options: language_input = gr.Dropdown( label="Language", choices=["", "ar", "en", "fr", "de", "es", "ru", "zh"], value=default_language, info="Leave blank for auto-detect.", ) diarization_toggle = gr.Checkbox( label="Enable Speaker Diarization", value=False, info="Uses NVIDIA Sortformer model (max 4 speakers, downloads ~700MB on first use).", ) expected_speakers_slider = gr.Slider( label="Expected Number of Speakers", minimum=0, maximum=4, step=1, value=expected_speakers_default, info="Set to 0 for automatic detection, or specify 2-4 to consolidate speakers.", ) beam_slider = gr.Slider( label="Beam Size", minimum=1, maximum=10, step=1, value=beam_size_default, ) best_of_slider = gr.Slider( label="Best Of", minimum=1, maximum=10, step=1, value=best_of_default, ) transcript_output = gr.Textbox(label="Transcript", lines=8) json_output = gr.JSON(label="Detailed Segments & Speakers") run_button = gr.Button("Transcribe") run_button.click( fn=transcribe, inputs=[ audio_input, language_input, diarization_toggle, expected_speakers_slider, beam_slider, best_of_slider, ], outputs=[transcript_output, json_output], api_name="predict", ) gr.Markdown( f""" ## Tips - **Whisper model**: `{WHISPER_MODEL_SIZE}` (first run downloads model automatically) - **Diarization model**: NVIDIA `{DIARIZATION_MODEL_NAME}` (streaming, max 4 speakers) - Diarization downloads ~700MB on first use (cached afterward) - Change `WHISPER_MODEL_SIZE` in Space Variables to `medium` or `large-v3` for higher accuracy - Optimized for Arabic customer service calls with specialized initial prompt - Streaming configuration: High latency preset (10s latency, better accuracy) """ ) # Use a queue to serialize work on GPU and avoid OOM on Spaces free/shared GPUs # In Gradio 4.x, concurrency_count parameter is removed; control concurrency via Spaces hardware or routes demo.queue(max_size=16) return demo demo = build_interface() if __name__ == "__main__": demo.launch()