|
|
from __future__ import annotations |
|
|
import os |
|
|
import tempfile |
|
|
from typing import Dict, List, Optional, Any |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
from dotenv import load_dotenv |
|
|
from faster_whisper import WhisperModel |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
os.environ.setdefault("NEMO_LOG_LEVEL", "ERROR") |
|
|
|
|
|
|
|
|
try: |
|
|
from nemo.collections.asr.models import SortformerEncLabelModel |
|
|
NEMO_AVAILABLE = True |
|
|
except ImportError: |
|
|
NEMO_AVAILABLE = False |
|
|
SortformerEncLabelModel = None |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WHISPER_MODEL_SIZE = os.environ.get("WHISPER_MODEL_SIZE", "large-v3") |
|
|
|
|
|
|
|
|
def _default_device() -> str: |
|
|
try: |
|
|
return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
except Exception: |
|
|
return "cpu" |
|
|
|
|
|
WHISPER_DEVICE = os.environ.get("WHISPER_DEVICE") or _default_device() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
WHISPER_COMPUTE_TYPE = os.environ.get("WHISPER_COMPUTE_TYPE") or ( |
|
|
"float16" if WHISPER_DEVICE == "cuda" else "int8_float32" |
|
|
) |
|
|
|
|
|
|
|
|
DIARIZATION_MODEL_NAME = os.environ.get( |
|
|
"DIARIZATION_MODEL_NAME", "nvidia/diar_streaming_sortformer_4spk-v2" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
CHUNK_SIZE = int(os.environ.get("DIAR_CHUNK_SIZE", 124)) |
|
|
RIGHT_CONTEXT = int(os.environ.get("DIAR_RIGHT_CONTEXT", 1)) |
|
|
FIFO_SIZE = int(os.environ.get("DIAR_FIFO_SIZE", 124)) |
|
|
UPDATE_PERIOD = int(os.environ.get("DIAR_UPDATE_PERIOD", 124)) |
|
|
SPEAKER_CACHE_SIZE = int(os.environ.get("DIAR_CACHE_SIZE", 188)) |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") |
|
|
|
|
|
|
|
|
default_language = os.environ.get("DEFAULT_LANGUAGE", "ar") |
|
|
initial_prompt = os.environ.get( |
|
|
"INITIAL_PROMPT", |
|
|
"مكالمة خدمة عملاء باللهجة السعودية. كلمات شائعة: أبشر، يعطيك العافية، رقم الطلب، الشحنة، الدفع، الفاتورة، التسليم.", |
|
|
) |
|
|
|
|
|
beam_size_default = int(os.environ.get("WHISPER_BEAM_SIZE", 5)) |
|
|
best_of_default = int(os.environ.get("WHISPER_BEST_OF", 5)) |
|
|
expected_speakers_default = int(os.environ.get("EXPECTED_SPEAKERS", 2)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_whisper_model: Optional[WhisperModel] = None |
|
|
_diarization_model: Optional[Any] = None |
|
|
|
|
|
|
|
|
def _ensure_snapshot(repo_id: str, local_dir: str, allow_patterns: Optional[List[str]] = None) -> str: |
|
|
"""Download model snapshot locally if missing.""" |
|
|
if os.path.isdir(local_dir) and any(os.scandir(local_dir)): |
|
|
return local_dir |
|
|
|
|
|
os.makedirs(local_dir, exist_ok=True) |
|
|
snapshot_download( |
|
|
repo_id=repo_id, |
|
|
local_dir=local_dir, |
|
|
local_dir_use_symlinks=False, |
|
|
token=HF_TOKEN, |
|
|
allow_patterns=allow_patterns, |
|
|
) |
|
|
return local_dir |
|
|
|
|
|
|
|
|
def _load_whisper_model() -> WhisperModel: |
|
|
"""Load Faster Whisper model lazily (singleton) - same approach as Django app""" |
|
|
global _whisper_model |
|
|
if _whisper_model is None: |
|
|
print(f"Loading Faster Whisper model: {WHISPER_MODEL_SIZE} on {WHISPER_DEVICE} with compute_type={WHISPER_COMPUTE_TYPE}") |
|
|
|
|
|
|
|
|
|
|
|
_whisper_model = WhisperModel( |
|
|
WHISPER_MODEL_SIZE, |
|
|
device=WHISPER_DEVICE, |
|
|
compute_type=WHISPER_COMPUTE_TYPE, |
|
|
) |
|
|
return _whisper_model |
|
|
|
|
|
|
|
|
def _load_diarization_model() -> Optional[Any]: |
|
|
"""Load NVIDIA NeMo Sortformer diarization model lazily (singleton)""" |
|
|
global _diarization_model |
|
|
if _diarization_model is None: |
|
|
if not NEMO_AVAILABLE: |
|
|
raise gr.Error( |
|
|
"NeMo is not installed. Please install it with: pip install nemo_toolkit[asr]" |
|
|
) |
|
|
|
|
|
print(f"Loading NVIDIA Sortformer diarization model: {DIARIZATION_MODEL_NAME}...") |
|
|
|
|
|
try: |
|
|
|
|
|
_diarization_model = SortformerEncLabelModel.from_pretrained( |
|
|
DIARIZATION_MODEL_NAME |
|
|
) |
|
|
|
|
|
|
|
|
_diarization_model.eval() |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
try: |
|
|
_diarization_model.to("cuda") |
|
|
print("[DEBUG] Moved Sortformer model to CUDA device") |
|
|
except Exception: |
|
|
|
|
|
try: |
|
|
_diarization_model.cuda() |
|
|
print("[DEBUG] Moved Sortformer model to CUDA via .cuda()") |
|
|
except Exception as _e: |
|
|
print(f"[WARN] Could not move Sortformer model to GPU: {_e}") |
|
|
|
|
|
|
|
|
|
|
|
_diarization_model.sortformer_modules.chunk_len = CHUNK_SIZE |
|
|
_diarization_model.sortformer_modules.chunk_right_context = RIGHT_CONTEXT |
|
|
_diarization_model.sortformer_modules.fifo_len = FIFO_SIZE |
|
|
_diarization_model.sortformer_modules.spkcache_update_period = UPDATE_PERIOD |
|
|
_diarization_model.sortformer_modules.spkcache_len = SPEAKER_CACHE_SIZE |
|
|
_diarization_model.sortformer_modules._check_streaming_parameters() |
|
|
|
|
|
print("Sortformer model loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
raise gr.Error( |
|
|
f"Failed to load NVIDIA Sortformer diarization model.\n\n" |
|
|
f"Error details: {e}\n\n" |
|
|
"Solutions:\n" |
|
|
" - Make sure NeMo is properly installed: pip install nemo_toolkit[asr]\n" |
|
|
" - Check that you have internet access to download from Hugging Face.\n" |
|
|
" - The model will be downloaded automatically on first use (~700MB)." |
|
|
) |
|
|
|
|
|
if _diarization_model is None: |
|
|
raise gr.Error( |
|
|
"Diarization model returned None after loading." |
|
|
) |
|
|
|
|
|
return _diarization_model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transcribe( |
|
|
audio_path: str, |
|
|
language: str, |
|
|
enable_diarization: bool, |
|
|
expected_speakers: int, |
|
|
beam_size: int, |
|
|
best_of: int, |
|
|
) -> Dict: |
|
|
if audio_path is None: |
|
|
raise gr.Error("Please upload an audio file.") |
|
|
|
|
|
model = _load_whisper_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_initial_prompt = True |
|
|
try: |
|
|
|
|
|
if isinstance(WHISPER_MODEL_SIZE, str) and WHISPER_MODEL_SIZE.startswith("large"): |
|
|
use_initial_prompt = False |
|
|
except Exception: |
|
|
use_initial_prompt = True |
|
|
|
|
|
initial_prompt_param = initial_prompt if use_initial_prompt else None |
|
|
|
|
|
|
|
|
try: |
|
|
segments, info = model.transcribe( |
|
|
audio_path, |
|
|
language=language if language else None, |
|
|
beam_size=beam_size, |
|
|
best_of=best_of, |
|
|
temperature=[0.0, 0.2, 0.4, 0.6], |
|
|
vad_filter=True, |
|
|
vad_parameters=dict( |
|
|
min_silence_duration_ms=300, |
|
|
speech_pad_ms=120 |
|
|
), |
|
|
condition_on_previous_text=False, |
|
|
initial_prompt=initial_prompt_param, |
|
|
compression_ratio_threshold=2.4, |
|
|
log_prob_threshold=-1.0, |
|
|
no_speech_threshold=0.45, |
|
|
word_timestamps=True, |
|
|
) |
|
|
except Exception as e: |
|
|
msg = str(e) |
|
|
|
|
|
if "Invalid data found when processing input" in msg or "av.error" in msg.lower(): |
|
|
raise gr.Error( |
|
|
"Could not decode the audio input.\n\n" |
|
|
"Likely causes:\n" |
|
|
" - The provided URL is not a direct audio download (e.g., a Google Drive 'view' page).\n" |
|
|
" - The file is corrupted or zero-length.\n\n" |
|
|
"Fixes:\n" |
|
|
" - Use a direct download URL (for Drive: https://drive.google.com/uc?export=download&id=<FILE_ID>).\n" |
|
|
" - Or first POST /upload, then pass the returned file pointer to the API." |
|
|
) |
|
|
raise gr.Error(f"Transcription failed: {msg}") |
|
|
|
|
|
|
|
|
segments_list = list(segments) |
|
|
|
|
|
transcript_text = "".join(segment.text for segment in segments_list).strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if initial_prompt and transcript_text.startswith(initial_prompt): |
|
|
|
|
|
transcript_text = transcript_text[len(initial_prompt) :].lstrip("\n\r\t .،,؛:-") |
|
|
|
|
|
segments_payload: List[Dict] = [] |
|
|
for segment in segments_list: |
|
|
segments_payload.append( |
|
|
{ |
|
|
"start": segment.start, |
|
|
"end": segment.end, |
|
|
"text": segment.text, |
|
|
"words": [ |
|
|
{ |
|
|
"start": word.start, |
|
|
"end": word.end, |
|
|
"word": word.word, |
|
|
"probability": word.probability, |
|
|
} |
|
|
for word in (segment.words or []) |
|
|
], |
|
|
} |
|
|
) |
|
|
|
|
|
print(f"[DEBUG] Created {len(segments_payload)} transcript segments") |
|
|
|
|
|
response: Dict[str, object] = { |
|
|
"text": transcript_text, |
|
|
"language": info.language, |
|
|
"language_probability": getattr(info, "language_probability", None), |
|
|
"duration": info.duration, |
|
|
"segments": segments_payload, |
|
|
} |
|
|
|
|
|
if enable_diarization: |
|
|
diar_model = _load_diarization_model() |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
predicted_segments = diar_model.diarize(audio=audio_path, batch_size=1) |
|
|
print(f"[DEBUG] Diarization raw output type: {type(predicted_segments)}") |
|
|
print(f"[DEBUG] Diarization raw output: {predicted_segments}") |
|
|
except Exception as e: |
|
|
raise gr.Error(f"Diarization failed: {e}") |
|
|
|
|
|
|
|
|
speaker_turns: List[Dict[str, float]] = [] |
|
|
|
|
|
if predicted_segments and len(predicted_segments) > 0: |
|
|
|
|
|
diar_result = predicted_segments[0] |
|
|
print(f"[DEBUG] Diar result type: {type(diar_result)}, content: {diar_result}") |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(diar_result, str): |
|
|
|
|
|
lines = diar_result.strip().split('\n') |
|
|
for line in lines: |
|
|
parts = line.strip().split() |
|
|
if len(parts) >= 3: |
|
|
start_time = float(parts[0]) |
|
|
end_time = float(parts[1]) |
|
|
speaker_label = parts[2] |
|
|
speaker_turns.append({ |
|
|
"start": start_time, |
|
|
"end": end_time, |
|
|
"speaker": speaker_label, |
|
|
}) |
|
|
elif hasattr(diar_result, '__iter__') and not isinstance(diar_result, (str, bytes)): |
|
|
|
|
|
for item in diar_result: |
|
|
|
|
|
if isinstance(item, str): |
|
|
parts = item.strip().split() |
|
|
if len(parts) >= 3: |
|
|
start_time = float(parts[0]) |
|
|
end_time = float(parts[1]) |
|
|
speaker_label = parts[2] |
|
|
speaker_turns.append({ |
|
|
"start": start_time, |
|
|
"end": end_time, |
|
|
"speaker": speaker_label, |
|
|
}) |
|
|
elif hasattr(item, '__getitem__'): |
|
|
|
|
|
if hasattr(item, 'tolist'): |
|
|
item = item.tolist() |
|
|
if len(item) >= 3: |
|
|
start_time = float(item[0]) |
|
|
end_time = float(item[1]) |
|
|
|
|
|
speaker_label = str(item[2]) if not isinstance(item[2], str) else item[2] |
|
|
|
|
|
if not speaker_label.startswith("SPEAKER_"): |
|
|
speaker_label = f"SPEAKER_{speaker_label}" |
|
|
speaker_turns.append({ |
|
|
"start": start_time, |
|
|
"end": end_time, |
|
|
"speaker": speaker_label, |
|
|
}) |
|
|
|
|
|
print(f"[DEBUG] Parsed {len(speaker_turns)} speaker turns") |
|
|
if speaker_turns: |
|
|
print(f"[DEBUG] First few turns: {speaker_turns[:5]}") |
|
|
|
|
|
|
|
|
speaker_turns.sort(key=lambda x: x["start"]) |
|
|
|
|
|
|
|
|
unique_speakers = set(turn["speaker"] for turn in speaker_turns) |
|
|
print(f"[DEBUG] Detected {len(unique_speakers)} unique speakers: {unique_speakers}") |
|
|
|
|
|
if expected_speakers > 0 and len(unique_speakers) > expected_speakers: |
|
|
print(f"[DEBUG] Consolidating from {len(unique_speakers)} to {expected_speakers} speakers") |
|
|
|
|
|
|
|
|
speaker_stats = {} |
|
|
for turn in speaker_turns: |
|
|
spk = turn["speaker"] |
|
|
if spk not in speaker_stats: |
|
|
speaker_stats[spk] = {"first_appear": turn["start"], "duration": 0, "count": 0} |
|
|
speaker_stats[spk]["duration"] += turn["end"] - turn["start"] |
|
|
speaker_stats[spk]["count"] += 1 |
|
|
|
|
|
|
|
|
sorted_speakers = sorted(speaker_stats.items(), key=lambda x: x[1]["duration"], reverse=True) |
|
|
print(f"[DEBUG] Speaker activity: {[(s, round(stats['duration'], 1)) for s, stats in sorted_speakers]}") |
|
|
|
|
|
|
|
|
kept_speakers = [s[0] for s in sorted_speakers[:expected_speakers]] |
|
|
speaker_mapping = {} |
|
|
|
|
|
for spk, stats in sorted_speakers: |
|
|
if spk in kept_speakers: |
|
|
speaker_mapping[spk] = spk |
|
|
else: |
|
|
|
|
|
closest_kept = min(kept_speakers, |
|
|
key=lambda k: abs(speaker_stats[k]["first_appear"] - stats["first_appear"])) |
|
|
speaker_mapping[spk] = closest_kept |
|
|
print(f"[DEBUG] Mapping {spk} -> {closest_kept}") |
|
|
|
|
|
|
|
|
for turn in speaker_turns: |
|
|
turn["speaker"] = speaker_mapping[turn["speaker"]] |
|
|
|
|
|
|
|
|
merged_turns = [] |
|
|
for turn in speaker_turns: |
|
|
if merged_turns and merged_turns[-1]["speaker"] == turn["speaker"] and \ |
|
|
turn["start"] - merged_turns[-1]["end"] < 1.0: |
|
|
|
|
|
merged_turns[-1]["end"] = turn["end"] |
|
|
else: |
|
|
merged_turns.append(turn.copy()) |
|
|
|
|
|
speaker_turns = merged_turns |
|
|
print(f"[DEBUG] After consolidation: {len(speaker_turns)} speaker turns") |
|
|
|
|
|
|
|
|
for segment in response["segments"]: |
|
|
seg_start = segment["start"] |
|
|
seg_end = segment["end"] |
|
|
segment["speaker"] = None |
|
|
|
|
|
|
|
|
best_overlap = 0 |
|
|
best_speaker = None |
|
|
|
|
|
for speaker_info in speaker_turns: |
|
|
spk_start = speaker_info["start"] |
|
|
spk_end = speaker_info["end"] |
|
|
|
|
|
|
|
|
overlap_start = max(seg_start, spk_start) |
|
|
overlap_end = min(seg_end, spk_end) |
|
|
overlap_duration = max(0, overlap_end - overlap_start) |
|
|
|
|
|
if overlap_duration > best_overlap: |
|
|
best_overlap = overlap_duration |
|
|
best_speaker = speaker_info["speaker"] |
|
|
|
|
|
segment["speaker"] = best_speaker |
|
|
|
|
|
response["speakers"] = speaker_turns |
|
|
|
|
|
|
|
|
print(f"[DEBUG] Generating conversation from {len(response['segments'])} segments") |
|
|
conversation_lines = [] |
|
|
current_speaker = None |
|
|
current_text = [] |
|
|
|
|
|
for idx, segment in enumerate(response["segments"]): |
|
|
speaker = segment.get("speaker", "UNKNOWN") |
|
|
text = segment.get("text", "").strip() |
|
|
|
|
|
if idx < 3: |
|
|
print(f"[DEBUG] Segment {idx}: speaker={speaker}, text='{text[:50]}...'") |
|
|
|
|
|
if speaker != current_speaker: |
|
|
|
|
|
if current_speaker and current_text: |
|
|
line = f"{current_speaker}: {' '.join(current_text)}" |
|
|
conversation_lines.append(line) |
|
|
if len(conversation_lines) <= 3: |
|
|
print(f"[DEBUG] Added line: {line[:100]}...") |
|
|
|
|
|
current_speaker = speaker |
|
|
current_text = [text] if text else [] |
|
|
else: |
|
|
|
|
|
if text: |
|
|
current_text.append(text) |
|
|
|
|
|
|
|
|
if current_speaker and current_text: |
|
|
line = f"{current_speaker}: {' '.join(current_text)}" |
|
|
conversation_lines.append(line) |
|
|
print(f"[DEBUG] Added final line: {line[:100]}...") |
|
|
|
|
|
print(f"[DEBUG] Generated {len(conversation_lines)} conversation lines") |
|
|
response["conversation"] = "\n\n".join(conversation_lines) |
|
|
|
|
|
|
|
|
display_text = response["conversation"] if conversation_lines else transcript_text |
|
|
else: |
|
|
display_text = transcript_text |
|
|
|
|
|
|
|
|
return display_text, response |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_interface() -> gr.Blocks: |
|
|
with gr.Blocks(title="VTT with Diarization (faster-whisper + NVIDIA NeMo)") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# Voice-to-Text with Optional Diarization |
|
|
Powered by **faster-whisper** and **NVIDIA NeMo Sortformer** (runs locally on this Space). |
|
|
""" |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
f"Running on device: `{WHISPER_DEVICE}` with compute type: `{WHISPER_COMPUTE_TYPE}`" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
audio_input = gr.Audio(type="filepath", label="Upload audio (mp3, wav, m4a, ...)") |
|
|
options = gr.Column() |
|
|
|
|
|
with options: |
|
|
language_input = gr.Dropdown( |
|
|
label="Language", |
|
|
choices=["", "ar", "en", "fr", "de", "es", "ru", "zh"], |
|
|
value=default_language, |
|
|
info="Leave blank for auto-detect.", |
|
|
) |
|
|
diarization_toggle = gr.Checkbox( |
|
|
label="Enable Speaker Diarization", |
|
|
value=False, |
|
|
info="Uses NVIDIA Sortformer model (max 4 speakers, downloads ~700MB on first use).", |
|
|
) |
|
|
expected_speakers_slider = gr.Slider( |
|
|
label="Expected Number of Speakers", |
|
|
minimum=0, |
|
|
maximum=4, |
|
|
step=1, |
|
|
value=expected_speakers_default, |
|
|
info="Set to 0 for automatic detection, or specify 2-4 to consolidate speakers.", |
|
|
) |
|
|
beam_slider = gr.Slider( |
|
|
label="Beam Size", |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
step=1, |
|
|
value=beam_size_default, |
|
|
) |
|
|
best_of_slider = gr.Slider( |
|
|
label="Best Of", |
|
|
minimum=1, |
|
|
maximum=10, |
|
|
step=1, |
|
|
value=best_of_default, |
|
|
) |
|
|
|
|
|
transcript_output = gr.Textbox(label="Transcript", lines=8) |
|
|
json_output = gr.JSON(label="Detailed Segments & Speakers") |
|
|
|
|
|
run_button = gr.Button("Transcribe") |
|
|
|
|
|
run_button.click( |
|
|
fn=transcribe, |
|
|
inputs=[ |
|
|
audio_input, |
|
|
language_input, |
|
|
diarization_toggle, |
|
|
expected_speakers_slider, |
|
|
beam_slider, |
|
|
best_of_slider, |
|
|
], |
|
|
outputs=[transcript_output, json_output], |
|
|
api_name="predict", |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
f""" |
|
|
## Tips |
|
|
- **Whisper model**: `{WHISPER_MODEL_SIZE}` (first run downloads model automatically) |
|
|
- **Diarization model**: NVIDIA `{DIARIZATION_MODEL_NAME}` (streaming, max 4 speakers) |
|
|
- Diarization downloads ~700MB on first use (cached afterward) |
|
|
- Change `WHISPER_MODEL_SIZE` in Space Variables to `medium` or `large-v3` for higher accuracy |
|
|
- Optimized for Arabic customer service calls with specialized initial prompt |
|
|
- Streaming configuration: High latency preset (10s latency, better accuracy) |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
demo.queue(max_size=16) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
demo = build_interface() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |