|
|
|
|
|
""" |
|
|
ROBOTSMALI — Sous-titrage Bambara |
|
|
""" |
|
|
|
|
|
import os |
|
|
import shlex |
|
|
import subprocess |
|
|
import tempfile |
|
|
import traceback |
|
|
import random |
|
|
import textwrap |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import soundfile as sf |
|
|
import librosa |
|
|
from huggingface_hub import snapshot_download |
|
|
from nemo.collections import asr as nemo_asr |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
random.seed(1234) |
|
|
np.random.seed(1234) |
|
|
torch.manual_seed(1234) |
|
|
|
|
|
MODELS = { |
|
|
"Soloni V1 (RNNT)": ("RobotsMali/soloni-114m-tdt-ctc-v1", "rnnt"), |
|
|
"Soloni V0 (RNNT)": ("RobotsMali/soloni-114m-tdt-ctc-v0", "rnnt"), |
|
|
"Soloba V1 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v1", "ctc"), |
|
|
"Soloba V0 (CTC)": ("RobotsMali/soloba-ctc-0.6b-v0", "ctc"), |
|
|
"QuartzNet V1 (CTC-char)": ("RobotsMali/stt-bm-quartznet15x5-v1", "ctc_char"), |
|
|
"QuartzNet V0 (CTC-char)": ("RobotsMali/stt-bm-quartznet15x5-v0", "ctc_char"), |
|
|
} |
|
|
|
|
|
_cache = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_cmd(cmd): |
|
|
"""Execute a shell command and raise on non-zero exit.""" |
|
|
print("RUN:", cmd) |
|
|
res = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) |
|
|
if res.returncode != 0: |
|
|
raise RuntimeError(f"Commande échouée [{cmd}]\nOutput:\n{res.stdout}") |
|
|
return res.stdout |
|
|
|
|
|
def ffprobe_duration(path): |
|
|
cmd = f'ffprobe -v error -select_streams v:0 -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {shlex.quote(path)}' |
|
|
out = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
|
|
if out.returncode != 0: |
|
|
print("ffprobe erreur:", out.stderr) |
|
|
return None |
|
|
try: |
|
|
return float(out.stdout.strip()) |
|
|
except: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model(name): |
|
|
"""Charge le modèle NeMo correct selon type (rnnt / ctc / ctc_char).""" |
|
|
if name in _cache: |
|
|
return _cache[name] |
|
|
|
|
|
repo, mode = MODELS[name] |
|
|
print(f"[LOAD] snapshot_download {repo} ...") |
|
|
folder = snapshot_download(repo, local_dir_use_symlinks=False) |
|
|
nemo_file = next((os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".nemo")), None) |
|
|
if not nemo_file: |
|
|
raise FileNotFoundError(f"Aucun .nemo trouvé pour {name} dans {folder}") |
|
|
|
|
|
print(f"[LOAD] .nemo trouvé: {nemo_file}; mode={mode}") |
|
|
|
|
|
|
|
|
if mode == "rnnt": |
|
|
model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel.restore_from(nemo_file) |
|
|
elif mode == "ctc_char": |
|
|
|
|
|
model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file) |
|
|
else: |
|
|
try: |
|
|
model = nemo_asr.models.EncDecCTCModelBPE.restore_from(nemo_file) |
|
|
except Exception as e: |
|
|
|
|
|
print(f"[WARN] EncDecCTCModelBPE failed ({e}), fallback EncDecCTCModel") |
|
|
model = nemo_asr.models.EncDecCTCModel.restore_from(nemo_file) |
|
|
|
|
|
model.to(DEVICE).eval() |
|
|
_cache[name] = model |
|
|
print(f"[OK] Modèle {name} chargé sur {DEVICE}") |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_audio(video_path, out_wav): |
|
|
"""Extract mono 16k WAV using ffmpeg.""" |
|
|
cmd = f'ffmpeg -hide_banner -loglevel error -y -i {shlex.quote(video_path)} -vn -ac 1 -ar 16000 -f wav {shlex.quote(out_wav)}' |
|
|
run_cmd(cmd) |
|
|
|
|
|
def clean_audio(wav_path, target_sr=16000): |
|
|
"""Load audio, ensure mono, resample to target_sr, normalize, write cleaned wav.""" |
|
|
audio, sr = sf.read(wav_path) |
|
|
if audio.ndim == 2: |
|
|
audio = audio.mean(axis=1) |
|
|
if sr != target_sr: |
|
|
audio = librosa.resample(audio.astype(float), orig_sr=sr, target_sr=target_sr) |
|
|
sr = target_sr |
|
|
max_val = np.max(np.abs(audio)) if audio.size > 0 else 0.0 |
|
|
if max_val > 1e-6: |
|
|
audio = audio / max_val * 0.9 |
|
|
clean_path = str(Path(wav_path).with_name(Path(wav_path).stem + "_clean.wav")) |
|
|
sf.write(clean_path, audio, sr) |
|
|
return clean_path, audio, sr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transcribe(model, wav_path): |
|
|
"""Robuste: essaie model.transcribe et nettoie la sortie.""" |
|
|
if not hasattr(model, "transcribe"): |
|
|
raise RuntimeError("Le modèle ne supporte pas model.transcribe()") |
|
|
out = model.transcribe([wav_path]) |
|
|
|
|
|
if isinstance(out, list): |
|
|
if len(out) == 0: |
|
|
return "" |
|
|
first = out[0] |
|
|
if isinstance(first, str): |
|
|
return first.strip() |
|
|
if hasattr(first, "text"): |
|
|
return first.text.strip() |
|
|
return str(first).strip() |
|
|
if hasattr(out, "text"): |
|
|
return out.text.strip() |
|
|
return str(out).strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def keep_bambara(words): |
|
|
res = [] |
|
|
for w in words: |
|
|
wl = w.lower() |
|
|
if any(c in wl for c in ["ɛ","ɔ","ŋ"]) or sum(1 for c in wl if c in "aeiou") >= 2: |
|
|
res.append(w) |
|
|
return res |
|
|
|
|
|
MAX_CHARS = 45; MIN_DUR = 0.3; MAX_DUR = 3.2; MAX_WORDS = 8 |
|
|
|
|
|
def wrap2(txt): |
|
|
parts = textwrap.wrap(txt, MAX_CHARS) |
|
|
if len(parts) <= 1: |
|
|
return txt |
|
|
mid = len(txt) // 2 |
|
|
left = txt.rfind(" ", 0, mid) |
|
|
right = txt.find(" ", mid) |
|
|
cut = left if (mid - left) <= ((right - mid) if right != -1 else 1e9) else right |
|
|
l1 = txt[:cut].strip(); l2 = txt[cut:].strip() |
|
|
return l1 + "\n" + l2 if l2 else l1 |
|
|
|
|
|
def pack(spans, total): |
|
|
tmp = [] |
|
|
for s, e, t in spans: |
|
|
s = max(0, min(s, total)); e = max(0, min(e, total)) |
|
|
if e <= s or not t.strip(): continue |
|
|
tmp.append((s, e, t.strip())) |
|
|
merged = [] |
|
|
for seg in tmp: |
|
|
if not merged: |
|
|
merged.append(seg); continue |
|
|
ps, pe, pt = merged[-1]; s, e, t = seg |
|
|
if (e - s) < MIN_DUR or (s - pe) < 0.1: |
|
|
merged[-1] = (ps, max(pe, e), (pt + " " + t).strip()) |
|
|
else: |
|
|
merged.append(seg) |
|
|
out = []; last_end = 0 |
|
|
for s, e, t in merged: |
|
|
dur = e - s; words = t.split() |
|
|
blocks = [" ".join(words[i:i+MAX_WORDS]) for i in range(0, len(words), MAX_WORDS)] |
|
|
step = dur / max(1, len(blocks)) |
|
|
base = s |
|
|
for b in blocks: |
|
|
st = base; en = min(base + step, e); base = en |
|
|
if en <= st: en = min(st + 0.05, total) |
|
|
txt = wrap2(b) |
|
|
if st < last_end: |
|
|
st = last_end + 1e-3; en = max(en, st + 0.05) |
|
|
out.append((st, en, txt)); last_end = en |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def align_vad(text, audio, sr, total_dur, top_db=28): |
|
|
words = keep_bambara(text.split()) |
|
|
total = total_dur |
|
|
if audio is None or len(audio) == 0 or not words: |
|
|
return pack([(0, total, " ".join(words[:MAX_WORDS]))], total) |
|
|
iv = librosa.effects.split(audio, top_db=top_db) |
|
|
if len(iv) == 0: |
|
|
return pack([(0, total, " ".join(words[:MAX_WORDS]))], total) |
|
|
spans = [] |
|
|
L = sum(e - s for s, e in iv) |
|
|
idx = 0 |
|
|
for s, e in iv: |
|
|
seg = e - s; segt = seg / sr |
|
|
k = max(1, int(round(len(words) * (seg / L)))) |
|
|
chunk = words[idx:idx+k]; idx += k |
|
|
if not chunk: continue |
|
|
lines = [chunk[i:i+MAX_WORDS] for i in range(0, len(chunk), MAX_WORDS)] |
|
|
step = max(MIN_DUR, min(MAX_DUR, segt / max(1, len(lines)))) |
|
|
base = s / sr |
|
|
for j, ln in enumerate(lines): |
|
|
st = base + j * step; en = base + (j + 1) * step |
|
|
spans.append((st, en, " ".join(ln))) |
|
|
return pack(spans, total) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def burn(video_path, subs, output_path=None): |
|
|
if output_path is None: |
|
|
output_path = "RobotsMali_Subtitled.mp4" |
|
|
tmp_fd, tmp_srt = tempfile.mkstemp(suffix=".srt") |
|
|
os.close(tmp_fd) |
|
|
def sec_to_srt(t): |
|
|
h = int(t // 3600); m = int((t % 3600) // 60); s = int(t % 60); ms = int((t - int(t)) * 1000) |
|
|
return f"{h:02}:{m:02}:{s:02},{ms:03}" |
|
|
with open(tmp_srt, "w", encoding="utf-8") as f: |
|
|
for i, (start, end, text) in enumerate(subs, 1): |
|
|
f.write(f"{i}\n{sec_to_srt(start)} --> {sec_to_srt(end)}\n{text}\n\n") |
|
|
|
|
|
|
|
|
vf = f"subtitles={shlex.quote(tmp_srt)}:force_style='Fontsize=22,PrimaryColour=&HFFFFFF&,OutlineColour=&H000000&'" |
|
|
cmd = f'ffmpeg -hide_banner -loglevel error -y -i {shlex.quote(video_path)} -vf {shlex.quote(vf)} -c:v libx264 -preset fast -crf 23 -c:a aac -b:a 192k {shlex.quote(output_path)}' |
|
|
try: |
|
|
run_cmd(cmd) |
|
|
finally: |
|
|
if os.path.exists(tmp_srt): |
|
|
os.remove(tmp_srt) |
|
|
return output_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pipeline(video_input, model_name): |
|
|
""" |
|
|
video_input : chemin ou dict Gradio (tmp_path) |
|
|
model_name : clé dans MODELS |
|
|
""" |
|
|
try: |
|
|
|
|
|
if isinstance(video_input, dict) and "tmp_path" in video_input: |
|
|
video_path = video_input["tmp_path"] |
|
|
else: |
|
|
video_path = video_input |
|
|
|
|
|
duration = ffprobe_duration(video_path) |
|
|
if duration is None: |
|
|
raise RuntimeError("Impossible d'obtenir la durée de la vidéo via ffprobe") |
|
|
|
|
|
|
|
|
tmp_fd, tmp_wav = tempfile.mkstemp(suffix=".wav") |
|
|
os.close(tmp_fd) |
|
|
|
|
|
|
|
|
extract_audio(video_path, tmp_wav) |
|
|
clean_wav, audio, sr = clean_audio(tmp_wav) |
|
|
|
|
|
|
|
|
model = load_model(model_name) |
|
|
text = transcribe(model, clean_wav) |
|
|
mode = MODELS[model_name][1] |
|
|
|
|
|
|
|
|
subs = None |
|
|
if mode == "rnnt": |
|
|
|
|
|
try: |
|
|
from ctc_segmentation import ctc_segmentation, CtcSegmentationParameters, prepare_text |
|
|
words = keep_bambara(text.split()) |
|
|
if not words: |
|
|
return ("⚠️ Aucun sous-titre utilisable (texte vide après filtrage)", None) |
|
|
x = torch.tensor(audio).float().unsqueeze(0).to(DEVICE) |
|
|
ln = torch.tensor([x.shape[1]]).to(DEVICE) |
|
|
with torch.no_grad(): |
|
|
logits = model(input_signal=x, input_signal_length=ln)[0] |
|
|
|
|
|
time_per_frame = duration / max(1, logits.shape[1]) |
|
|
|
|
|
try: |
|
|
raw = model.tokenizer.vocab |
|
|
vocab = list(raw.keys()) if isinstance(raw, dict) else list(raw) |
|
|
except Exception: |
|
|
vocab = None |
|
|
cfg = CtcSegmentationParameters() |
|
|
if vocab: |
|
|
cfg.char_list = vocab |
|
|
gt = prepare_text(cfg, words)[0] |
|
|
try: |
|
|
timing, _, _ = ctc_segmentation(cfg, logits.detach().cpu().numpy()[0], gt) |
|
|
spans = [(timing[i] * time_per_frame, timing[i+1] * time_per_frame, words[i]) for i in range(len(words) - 1)] |
|
|
subs = pack(spans, duration) |
|
|
except AssertionError: |
|
|
print("[WARN] Audio shorter than text -> fallback to VAD alignment") |
|
|
subs = align_vad(text, audio, sr, duration) |
|
|
except Exception as e: |
|
|
print(f"[WARN] ctc_segmentation not available or failed ({e}) -> fallback VAD") |
|
|
subs = align_vad(text, audio, sr, duration) |
|
|
|
|
|
elif mode == "ctc_char": |
|
|
|
|
|
|
|
|
try: |
|
|
subs = align_vad(text, audio, sr, duration) |
|
|
except Exception as e: |
|
|
print(f"[WARN] QuartzNet alignment failed: {e}") |
|
|
subs = align_vad(text, audio, sr, duration) |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
subs = align_vad(text, audio, sr, duration) |
|
|
except Exception as e: |
|
|
print(f"[WARN] CTC alignment failed: {e}") |
|
|
subs = align_vad(text, audio, sr, duration) |
|
|
|
|
|
if not subs: |
|
|
return ("⚠️ Aucun sous-titre utilisable (sub list vide)", None) |
|
|
|
|
|
out_video = burn(video_path, subs) |
|
|
return ("✅ Terminé avec succès", out_video) |
|
|
|
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
return (f"❌ Erreur — {str(e)}", None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="RobotsMali - Sous-titrage") as demo: |
|
|
gr.Markdown(" RobotsMali — Sous-titrage") |
|
|
v = gr.Video(label="Vidéo à sous-titrer") |
|
|
m = gr.Dropdown(list(MODELS.keys()), value="Soloba V1 (CTC)", label="Modèle ASR") |
|
|
b = gr.Button("▶️ Générer") |
|
|
s = gr.Markdown() |
|
|
o = gr.Video(label="Vidéo sous-titrée") |
|
|
b.click(pipeline, [v, m], [s, o]) |
|
|
|
|
|
|
|
|
|
|
|
demo.launch(share=True, debug=False) |
|
|
|
|
|
|