binaryMao's picture
Update app.py
629996b verified
# -*- coding: utf-8 -*-
"""
ROBOTSMALI — Sous-titrage Bambara
"""
import os
import shlex
import subprocess
import tempfile
import traceback
import random
import textwrap
from pathlib import Path
import numpy as np
import torch
import soundfile as sf
import librosa
from huggingface_hub import snapshot_download
from nemo.collections import asr as nemo_asr
import gradio as gr
# ----------------------------
# CONFIG
# ----------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
random.seed(1234)
np.random.seed(1234)
torch.manual_seed(1234)
MODELS = {
"Soloni V1 (RNNT)": ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"),
"Soloni V0 (RNNT)": ("RobotsMali/soloni-114m-tdt-ctc-v0", "rnnt"),
"Soloba V1 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"),
"Soloba V0 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v0", "ctc"),
"QuartzNet V1 (CTC-char)": ("RobotsMali/stt-bm-quartznet15x5-v1", "ctc_char"),
"QuartzNet V0 (CTC-char)": ("RobotsMali/stt-bm-quartznet15x5-v0", "ctc_char"),
}
_cache = {}
# ----------------------------
# UTIL: run_cmd, ffprobe_duration
# ----------------------------
def run_cmd(cmd):
"""Execute a shell command and raise on non-zero exit."""
print("RUN:", cmd)
res = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
if res.returncode != 0:
raise RuntimeError(f"Commande échouée [{cmd}]\nOutput:\n{res.stdout}")
return res.stdout
def ffprobe_duration(path):
cmd = f'ffprobe -v error -select_streams v:0 -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {shlex.quote(path)}'
out = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if out.returncode != 0:
print("ffprobe erreur:", out.stderr)
return None
try:
return float(out.stdout.strip())
except:
return None
# ----------------------------
# LOAD MODEL (robust)
# ----------------------------
def load_model(name):
"""Charge le modèle NeMo correct selon type (rnnt / ctc / ctc_char)."""
if name in _cache:
return _cache[name]
repo, mode = MODELS[name]
print(f"[LOAD] snapshot_download {repo} ...")
folder = snapshot_download(repo, local_dir_use_symlinks=False)
nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None)
if not nemo_file:
raise FileNotFoundError(f"Aucun .nemo trouvé pour {name} dans {folder}")
print(f"[LOAD] .nemo trouvé: {nemo_file}; mode={mode}")
# Sélection de la classe NeMo selon le mode
if mode == "rnnt":
model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.restore_from(nemo_file)
elif mode == "ctc_char":
# QuartzNet (char) : pas de tokenizer BPE dans cfg -> utiliser EncDecCTCModel
model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file)
else: # mode == "ctc" (BPE)
try:
model = nemo_asr.models.EncDecCTCModelBPE.restore_from(nemo_file)
except Exception as e:
# fallback sur EncDecCTCModel si BPE absent (prudence)
print(f"[WARN] EncDecCTCModelBPE failed ({e}), fallback EncDecCTCModel")
model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file)
model.to(DEVICE).eval()
_cache[name] = model
print(f"[OK] Modèle {name} chargé sur {DEVICE}")
return model
# ----------------------------
# AUDIO EXTRACTION & CLEANING
# ----------------------------
def extract_audio(video_path, out_wav):
"""Extract mono 16k WAV using ffmpeg."""
cmd = f'ffmpeg -hide_banner -loglevel error -y -i {shlex.quote(video_path)} -vn -ac 1 -ar 16000 -f wav {shlex.quote(out_wav)}'
run_cmd(cmd)
def clean_audio(wav_path, target_sr=16000):
"""Load audio, ensure mono, resample to target_sr, normalize, write cleaned wav."""
audio, sr = sf.read(wav_path)
if audio.ndim == 2:
audio = audio.mean(axis=1)
if sr != target_sr:
audio = librosa.resample(audio.astype(float), orig_sr=sr, target_sr=target_sr)
sr = target_sr
max_val = np.max(np.abs(audio)) if audio.size > 0 else 0.0
if max_val > 1e-6:
audio = audio / max_val * 0.9
clean_path = str(Path(wav_path).with_name(Path(wav_path).stem + "_clean.wav"))
sf.write(clean_path, audio, sr)
return clean_path, audio, sr
# ----------------------------
# TRANSCRIPTION
# ----------------------------
def transcribe(model, wav_path):
"""Robuste: essaie model.transcribe et nettoie la sortie."""
if not hasattr(model, "transcribe"):
raise RuntimeError("Le modèle ne supporte pas model.transcribe()")
out = model.transcribe([wav_path])
# Différentes formes de sortie possibles
if isinstance(out, list):
if len(out) == 0:
return ""
first = out[0]
if isinstance(first, str):
return first.strip()
if hasattr(first, "text"):
return first.text.strip()
return str(first).strip()
if hasattr(out, "text"):
return out.text.strip()
return str(out).strip()
# ----------------------------
# UTILITAIRES sous-titres / packing
# ----------------------------
def keep_bambara(words):
res = []
for w in words:
wl = w.lower()
if any(c in wl for c in ["ɛ","ɔ","ŋ"]) or sum(1 for c in wl if c in "aeiou") >= 2:
res.append(w)
return res
MAX_CHARS = 45; MIN_DUR = 0.3; MAX_DUR = 3.2; MAX_WORDS = 8
def wrap2(txt):
parts = textwrap.wrap(txt, MAX_CHARS)
if len(parts) <= 1:
return txt
mid = len(txt) // 2
left = txt.rfind(" ", 0, mid)
right = txt.find(" ", mid)
cut = left if (mid - left) <= ((right - mid) if right != -1 else 1e9) else right
l1 = txt[:cut].strip(); l2 = txt[cut:].strip()
return l1 + "\n" + l2 if l2 else l1
def pack(spans, total):
tmp = []
for s, e, t in spans:
s = max(0, min(s, total)); e = max(0, min(e, total))
if e <= s or not t.strip(): continue
tmp.append((s, e, t.strip()))
merged = []
for seg in tmp:
if not merged:
merged.append(seg); continue
ps, pe, pt = merged[-1]; s, e, t = seg
if (e - s) < MIN_DUR or (s - pe) < 0.1:
merged[-1] = (ps, max(pe, e), (pt + " " + t).strip())
else:
merged.append(seg)
out = []; last_end = 0
for s, e, t in merged:
dur = e - s; words = t.split()
blocks = [" ".join(words[i:i+MAX_WORDS]) for i in range(0, len(words), MAX_WORDS)]
step = dur / max(1, len(blocks))
base = s
for b in blocks:
st = base; en = min(base + step, e); base = en
if en <= st: en = min(st + 0.05, total)
txt = wrap2(b)
if st < last_end:
st = last_end + 1e-3; en = max(en, st + 0.05)
out.append((st, en, txt)); last_end = en
return out
# ----------------------------
# VAD ALIGN (fallback alignment)
# ----------------------------
def align_vad(text, audio, sr, total_dur, top_db=28):
words = keep_bambara(text.split())
total = total_dur
if audio is None or len(audio) == 0 or not words:
return pack([(0, total, " ".join(words[:MAX_WORDS]))], total)
iv = librosa.effects.split(audio, top_db=top_db)
if len(iv) == 0:
return pack([(0, total, " ".join(words[:MAX_WORDS]))], total)
spans = []
L = sum(e - s for s, e in iv)
idx = 0
for s, e in iv:
seg = e - s; segt = seg / sr
k = max(1, int(round(len(words) * (seg / L))))
chunk = words[idx:idx+k]; idx += k
if not chunk: continue
lines = [chunk[i:i+MAX_WORDS] for i in range(0, len(chunk), MAX_WORDS)]
step = max(MIN_DUR, min(MAX_DUR, segt / max(1, len(lines))))
base = s / sr
for j, ln in enumerate(lines):
st = base + j * step; en = base + (j + 1) * step
spans.append((st, en, " ".join(ln)))
return pack(spans, total)
# ----------------------------
# Écriture SRT + Burn (réencode)
# ----------------------------
def burn(video_path, subs, output_path=None):
if output_path is None:
output_path = "RobotsMali_Subtitled.mp4"
tmp_fd, tmp_srt = tempfile.mkstemp(suffix=".srt")
os.close(tmp_fd)
def sec_to_srt(t):
h = int(t // 3600); m = int((t % 3600) // 60); s = int(t % 60); ms = int((t - int(t)) * 1000)
return f"{h:02}:{m:02}:{s:02},{ms:03}"
with open(tmp_srt, "w", encoding="utf-8") as f:
for i, (start, end, text) in enumerate(subs, 1):
f.write(f"{i}\n{sec_to_srt(start)} --> {sec_to_srt(end)}\n{text}\n\n")
# On réencode (libx264) car on applique subtitles filter
vf = f"subtitles={shlex.quote(tmp_srt)}:force_style='Fontsize=22,PrimaryColour=&HFFFFFF&,OutlineColour=&H000000&'"
cmd = f'ffmpeg -hide_banner -loglevel error -y -i {shlex.quote(video_path)} -vf {shlex.quote(vf)} -c:v libx264 -preset fast -crf 23 -c:a aac -b:a 192k {shlex.quote(output_path)}'
try:
run_cmd(cmd)
finally:
if os.path.exists(tmp_srt):
os.remove(tmp_srt)
return output_path
# ----------------------------
# PIPELINE PRINCIPAL (V41)
# ----------------------------
def pipeline(video_input, model_name):
"""
video_input : chemin ou dict Gradio (tmp_path)
model_name : clé dans MODELS
"""
try:
# support Gradio dict (tmp_path)
if isinstance(video_input, dict) and "tmp_path" in video_input:
video_path = video_input["tmp_path"]
else:
video_path = video_input
duration = ffprobe_duration(video_path)
if duration is None:
raise RuntimeError("Impossible d'obtenir la durée de la vidéo via ffprobe")
# fichiers temporaires
tmp_fd, tmp_wav = tempfile.mkstemp(suffix=".wav")
os.close(tmp_fd)
# extraction + nettoyage
extract_audio(video_path, tmp_wav)
clean_wav, audio, sr = clean_audio(tmp_wav)
# charger modèle
model = load_model(model_name)
text = transcribe(model, clean_wav)
mode = MODELS[model_name][1]
# segmentation / alignement
subs = None
if mode == "rnnt":
# RNNT : tentative de segmentation via logits + ctc_segmentation si dispo
try:
from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text
words = keep_bambara(text.split())
if not words:
return ("⚠️ Aucun sous-titre utilisable (texte vide après filtrage)", None)
x = torch.tensor(audio).float().unsqueeze(0).to(DEVICE)
ln = torch.tensor([x.shape[1]]).to(DEVICE)
with torch.no_grad():
logits = model(input_signal=x, input_signal_length=ln)[0]
# heuristique mapping frames -> seconds
time_per_frame = duration / max(1, logits.shape[1])
# build char list
try:
raw = model.tokenizer.vocab
vocab = list(raw.keys()) if isinstance(raw, dict) else list(raw)
except Exception:
vocab = None
cfg = CtcSegmentationParameters()
if vocab:
cfg.char_list = vocab
gt = prepare_text(cfg, words)[0]
try:
timing, _, _ = ctc_segmentation(cfg, logits.detach().cpu().numpy()[0], gt)
spans = [(timing[i] * time_per_frame, timing[i+1] * time_per_frame, words[i]) for i in range(len(words) - 1)]
subs = pack(spans, duration)
except AssertionError:
print("[WARN] Audio shorter than text -> fallback to VAD alignment")
subs = align_vad(text, audio, sr, duration)
except Exception as e:
print(f"[WARN] ctc_segmentation not available or failed ({e}) -> fallback VAD")
subs = align_vad(text, audio, sr, duration)
elif mode == "ctc_char":
# QuartzNet : pas de tokenizer BPE, on procède avec VAD (ou on peut essayer timestamps si model le permet)
# On essaie d'obtenir timestamps via model.transcribe() si disponible (mais souvent non)
try:
subs = align_vad(text, audio, sr, duration)
except Exception as e:
print(f"[WARN] QuartzNet alignment failed: {e}")
subs = align_vad(text, audio, sr, duration)
else: # ctc (BPE)
# Pour les modèles CTC-BPE, VAD reste une option raisonnable si segmentation manque
try:
subs = align_vad(text, audio, sr, duration)
except Exception as e:
print(f"[WARN] CTC alignment failed: {e}")
subs = align_vad(text, audio, sr, duration)
if not subs:
return ("⚠️ Aucun sous-titre utilisable (sub list vide)", None)
out_video = burn(video_path, subs)
return ("✅ Terminé avec succès", out_video)
except Exception as e:
traceback.print_exc()
return (f"❌ Erreur — {str(e)}", None)
# ----------------------------
# INTERFACE GRADIO (optionnel)
# ----------------------------
with gr.Blocks(title="RobotsMali - Sous-titrage") as demo:
gr.Markdown(" RobotsMali — Sous-titrage")
v = gr.Video(label="Vidéo à sous-titrer")
m = gr.Dropdown(list(MODELS.keys()), value="Soloba V1 (CTC)", label="Modèle ASR")
b = gr.Button("▶️ Générer")
s = gr.Markdown()
o = gr.Video(label="Vidéo sous-titrée")
b.click(pipeline, [v, m], [s, o])
# Pour exécuter l'interface :
# demo.launch(share=True, debug=False)
demo.launch(share=True, debug=False)