# -*- coding: utf-8 -*- """ ROBOTSMALI — Sous-titrage Bambara """ import os import shlex import subprocess import tempfile import traceback import random import textwrap from pathlib import Path import numpy as np import torch import soundfile as sf import librosa from huggingface_hub import snapshot_download from nemo.collections import asr as nemo_asr import gradio as gr # ---------------------------- # CONFIG # ---------------------------- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" random.seed(1234) np.random.seed(1234) torch.manual_seed(1234) MODELS = { "Soloni V1 (RNNT)": ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"), "Soloni V0 (RNNT)": ("RobotsMali/soloni-114m-tdt-ctc-v0", "rnnt"), "Soloba V1 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"), "Soloba V0 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v0", "ctc"), "QuartzNet V1 (CTC-char)": ("RobotsMali/stt-bm-quartznet15x5-v1", "ctc_char"), "QuartzNet V0 (CTC-char)": ("RobotsMali/stt-bm-quartznet15x5-v0", "ctc_char"), } _cache = {} # ---------------------------- # UTIL: run_cmd, ffprobe_duration # ---------------------------- def run_cmd(cmd): """Execute a shell command and raise on non-zero exit.""" print("RUN:", cmd) res = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) if res.returncode != 0: raise RuntimeError(f"Commande échouée [{cmd}]\nOutput:\n{res.stdout}") return res.stdout def ffprobe_duration(path): cmd = f'ffprobe -v error -select_streams v:0 -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {shlex.quote(path)}' out = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) if out.returncode != 0: print("ffprobe erreur:", out.stderr) return None try: return float(out.stdout.strip()) except: return None # ---------------------------- # LOAD MODEL (robust) # ---------------------------- def load_model(name): """Charge le modèle NeMo correct selon type (rnnt / ctc / ctc_char).""" if name in _cache: return _cache[name] repo, mode = MODELS[name] print(f"[LOAD] snapshot_download {repo} ...") folder = snapshot_download(repo, local_dir_use_symlinks=False) nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None) if not nemo_file: raise FileNotFoundError(f"Aucun .nemo trouvé pour {name} dans {folder}") print(f"[LOAD] .nemo trouvé: {nemo_file}; mode={mode}") # Sélection de la classe NeMo selon le mode if mode == "rnnt": model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.restore_from(nemo_file) elif mode == "ctc_char": # QuartzNet (char) : pas de tokenizer BPE dans cfg -> utiliser EncDecCTCModel model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file) else: # mode == "ctc" (BPE) try: model = nemo_asr.models.EncDecCTCModelBPE.restore_from(nemo_file) except Exception as e: # fallback sur EncDecCTCModel si BPE absent (prudence) print(f"[WARN] EncDecCTCModelBPE failed ({e}), fallback EncDecCTCModel") model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file) model.to(DEVICE).eval() _cache[name] = model print(f"[OK] Modèle {name} chargé sur {DEVICE}") return model # ---------------------------- # AUDIO EXTRACTION & CLEANING # ---------------------------- def extract_audio(video_path, out_wav): """Extract mono 16k WAV using ffmpeg.""" cmd = f'ffmpeg -hide_banner -loglevel error -y -i {shlex.quote(video_path)} -vn -ac 1 -ar 16000 -f wav {shlex.quote(out_wav)}' run_cmd(cmd) def clean_audio(wav_path, target_sr=16000): """Load audio, ensure mono, resample to target_sr, normalize, write cleaned wav.""" audio, sr = sf.read(wav_path) if audio.ndim == 2: audio = audio.mean(axis=1) if sr != target_sr: audio = librosa.resample(audio.astype(float), orig_sr=sr, target_sr=target_sr) sr = target_sr max_val = np.max(np.abs(audio)) if audio.size > 0 else 0.0 if max_val > 1e-6: audio = audio / max_val * 0.9 clean_path = str(Path(wav_path).with_name(Path(wav_path).stem + "_clean.wav")) sf.write(clean_path, audio, sr) return clean_path, audio, sr # ---------------------------- # TRANSCRIPTION # ---------------------------- def transcribe(model, wav_path): """Robuste: essaie model.transcribe et nettoie la sortie.""" if not hasattr(model, "transcribe"): raise RuntimeError("Le modèle ne supporte pas model.transcribe()") out = model.transcribe([wav_path]) # Différentes formes de sortie possibles if isinstance(out, list): if len(out) == 0: return "" first = out[0] if isinstance(first, str): return first.strip() if hasattr(first, "text"): return first.text.strip() return str(first).strip() if hasattr(out, "text"): return out.text.strip() return str(out).strip() # ---------------------------- # UTILITAIRES sous-titres / packing # ---------------------------- def keep_bambara(words): res = [] for w in words: wl = w.lower() if any(c in wl for c in ["ɛ","ɔ","ŋ"]) or sum(1 for c in wl if c in "aeiou") >= 2: res.append(w) return res MAX_CHARS = 45; MIN_DUR = 0.3; MAX_DUR = 3.2; MAX_WORDS = 8 def wrap2(txt): parts = textwrap.wrap(txt, MAX_CHARS) if len(parts) <= 1: return txt mid = len(txt) // 2 left = txt.rfind(" ", 0, mid) right = txt.find(" ", mid) cut = left if (mid - left) <= ((right - mid) if right != -1 else 1e9) else right l1 = txt[:cut].strip(); l2 = txt[cut:].strip() return l1 + "\n" + l2 if l2 else l1 def pack(spans, total): tmp = [] for s, e, t in spans: s = max(0, min(s, total)); e = max(0, min(e, total)) if e <= s or not t.strip(): continue tmp.append((s, e, t.strip())) merged = [] for seg in tmp: if not merged: merged.append(seg); continue ps, pe, pt = merged[-1]; s, e, t = seg if (e - s) < MIN_DUR or (s - pe) < 0.1: merged[-1] = (ps, max(pe, e), (pt + " " + t).strip()) else: merged.append(seg) out = []; last_end = 0 for s, e, t in merged: dur = e - s; words = t.split() blocks = [" ".join(words[i:i+MAX_WORDS]) for i in range(0, len(words), MAX_WORDS)] step = dur / max(1, len(blocks)) base = s for b in blocks: st = base; en = min(base + step, e); base = en if en <= st: en = min(st + 0.05, total) txt = wrap2(b) if st < last_end: st = last_end + 1e-3; en = max(en, st + 0.05) out.append((st, en, txt)); last_end = en return out # ---------------------------- # VAD ALIGN (fallback alignment) # ---------------------------- def align_vad(text, audio, sr, total_dur, top_db=28): words = keep_bambara(text.split()) total = total_dur if audio is None or len(audio) == 0 or not words: return pack([(0, total, " ".join(words[:MAX_WORDS]))], total) iv = librosa.effects.split(audio, top_db=top_db) if len(iv) == 0: return pack([(0, total, " ".join(words[:MAX_WORDS]))], total) spans = [] L = sum(e - s for s, e in iv) idx = 0 for s, e in iv: seg = e - s; segt = seg / sr k = max(1, int(round(len(words) * (seg / L)))) chunk = words[idx:idx+k]; idx += k if not chunk: continue lines = [chunk[i:i+MAX_WORDS] for i in range(0, len(chunk), MAX_WORDS)] step = max(MIN_DUR, min(MAX_DUR, segt / max(1, len(lines)))) base = s / sr for j, ln in enumerate(lines): st = base + j * step; en = base + (j + 1) * step spans.append((st, en, " ".join(ln))) return pack(spans, total) # ---------------------------- # Écriture SRT + Burn (réencode) # ---------------------------- def burn(video_path, subs, output_path=None): if output_path is None: output_path = "RobotsMali_Subtitled.mp4" tmp_fd, tmp_srt = tempfile.mkstemp(suffix=".srt") os.close(tmp_fd) def sec_to_srt(t): h = int(t // 3600); m = int((t % 3600) // 60); s = int(t % 60); ms = int((t - int(t)) * 1000) return f"{h:02}:{m:02}:{s:02},{ms:03}" with open(tmp_srt, "w", encoding="utf-8") as f: for i, (start, end, text) in enumerate(subs, 1): f.write(f"{i}\n{sec_to_srt(start)} --> {sec_to_srt(end)}\n{text}\n\n") # On réencode (libx264) car on applique subtitles filter vf = f"subtitles={shlex.quote(tmp_srt)}:force_style='Fontsize=22,PrimaryColour=&HFFFFFF&,OutlineColour=&H000000&'" cmd = f'ffmpeg -hide_banner -loglevel error -y -i {shlex.quote(video_path)} -vf {shlex.quote(vf)} -c:v libx264 -preset fast -crf 23 -c:a aac -b:a 192k {shlex.quote(output_path)}' try: run_cmd(cmd) finally: if os.path.exists(tmp_srt): os.remove(tmp_srt) return output_path # ---------------------------- # PIPELINE PRINCIPAL (V41) # ---------------------------- def pipeline(video_input, model_name): """ video_input : chemin ou dict Gradio (tmp_path) model_name : clé dans MODELS """ try: # support Gradio dict (tmp_path) if isinstance(video_input, dict) and "tmp_path" in video_input: video_path = video_input["tmp_path"] else: video_path = video_input duration = ffprobe_duration(video_path) if duration is None: raise RuntimeError("Impossible d'obtenir la durée de la vidéo via ffprobe") # fichiers temporaires tmp_fd, tmp_wav = tempfile.mkstemp(suffix=".wav") os.close(tmp_fd) # extraction + nettoyage extract_audio(video_path, tmp_wav) clean_wav, audio, sr = clean_audio(tmp_wav) # charger modèle model = load_model(model_name) text = transcribe(model, clean_wav) mode = MODELS[model_name][1] # segmentation / alignement subs = None if mode == "rnnt": # RNNT : tentative de segmentation via logits + ctc_segmentation si dispo try: from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text words = keep_bambara(text.split()) if not words: return ("⚠️ Aucun sous-titre utilisable (texte vide après filtrage)", None) x = torch.tensor(audio).float().unsqueeze(0).to(DEVICE) ln = torch.tensor([x.shape[1]]).to(DEVICE) with torch.no_grad(): logits = model(input_signal=x, input_signal_length=ln)[0] # heuristique mapping frames -> seconds time_per_frame = duration / max(1, logits.shape[1]) # build char list try: raw = model.tokenizer.vocab vocab = list(raw.keys()) if isinstance(raw, dict) else list(raw) except Exception: vocab = None cfg = CtcSegmentationParameters() if vocab: cfg.char_list = vocab gt = prepare_text(cfg, words)[0] try: timing, _, _ = ctc_segmentation(cfg, logits.detach().cpu().numpy()[0], gt) spans = [(timing[i] * time_per_frame, timing[i+1] * time_per_frame, words[i]) for i in range(len(words) - 1)] subs = pack(spans, duration) except AssertionError: print("[WARN] Audio shorter than text -> fallback to VAD alignment") subs = align_vad(text, audio, sr, duration) except Exception as e: print(f"[WARN] ctc_segmentation not available or failed ({e}) -> fallback VAD") subs = align_vad(text, audio, sr, duration) elif mode == "ctc_char": # QuartzNet : pas de tokenizer BPE, on procède avec VAD (ou on peut essayer timestamps si model le permet) # On essaie d'obtenir timestamps via model.transcribe() si disponible (mais souvent non) try: subs = align_vad(text, audio, sr, duration) except Exception as e: print(f"[WARN] QuartzNet alignment failed: {e}") subs = align_vad(text, audio, sr, duration) else: # ctc (BPE) # Pour les modèles CTC-BPE, VAD reste une option raisonnable si segmentation manque try: subs = align_vad(text, audio, sr, duration) except Exception as e: print(f"[WARN] CTC alignment failed: {e}") subs = align_vad(text, audio, sr, duration) if not subs: return ("⚠️ Aucun sous-titre utilisable (sub list vide)", None) out_video = burn(video_path, subs) return ("✅ Terminé avec succès", out_video) except Exception as e: traceback.print_exc() return (f"❌ Erreur — {str(e)}", None) # ---------------------------- # INTERFACE GRADIO (optionnel) # ---------------------------- with gr.Blocks(title="RobotsMali - Sous-titrage") as demo: gr.Markdown(" RobotsMali — Sous-titrage") v = gr.Video(label="Vidéo à sous-titrer") m = gr.Dropdown(list(MODELS.keys()), value="Soloba V1 (CTC)", label="Modèle ASR") b = gr.Button("▶️ Générer") s = gr.Markdown() o = gr.Video(label="Vidéo sous-titrée") b.click(pipeline, [v, m], [s, o]) # Pour exécuter l'interface : # demo.launch(share=True, debug=False) demo.launch(share=True, debug=False)