FarmerlineML's picture
Update app.py
ad17ccb verified
# Swahili Text‑to‑Speech Gradio App – MP3 Output (mobile‑friendly)
# -----------------------------------------------------------------
# ‑ Generates clear Kiswahili speech and serves it **as an MP3 file** so that
# iOS/Android browsers play it reliably. Uses a fine‑tuned VITS checkpoint.
#
#‑ Dependencies (add to requirements.txt):
# torch, transformers, gradio, scipy, pydub
import os
import tempfile
import torch
import numpy as np
import gradio as gr
import scipy.io.wavfile as wavfile
from pydub import AudioSegment
from transformers import VitsModel, AutoTokenizer
MODEL_NAME = "FarmerlineML/swahili-tts-2025" # tokenizer
MODEL_CHECKPOINT = "FarmerlineML/Swahili-tts-2025_part4" # acoustic model
device = "cuda" if torch.cuda.is_available() else "cpu"
# ---------- Load model ------------------------------------------------------
model = VitsModel.from_pretrained(MODEL_CHECKPOINT).to(device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Apply clear‑speech inference parameters once (no UI toggle)
model.noise_scale = 0.7
model.noise_scale_duration = 0.667
model.speaking_rate = 0.75 # must be >0 to avoid ZeroDivisionError
# ---------- Helper ----------------------------------------------------------
def _wav_to_mp3(wave_np: np.ndarray, sr: int) -> str:
"""Convert int16 numpy waveform to an MP3 temp file, return its path."""
# Ensure int16 for pydub
if wave_np.dtype != np.int16:
# waveform from VITS is float32 in range [-1, 1]; scale and cast
wave_np = (wave_np * 32767).astype(np.int16)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tf:
wavfile.write(tf.name, sr, wave_np)
wav_path = tf.name
# Convert to mp3 via pydub (requires ffmpeg ‑ available in Spaces base img)
mp3_path = wav_path.replace(".wav", ".mp3")
AudioSegment.from_wav(wav_path).export(mp3_path, format="mp3", bitrate="64k")
os.remove(wav_path) # cleanup temp WAV
return mp3_path
# ---------- TTS endpoint ----------------------------------------------------
def tts_generate(text: str):
if not text:
return None
inputs = tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
wave = model(**inputs).waveform[0].cpu().numpy()
return _wav_to_mp3(wave, model.config.sampling_rate)
# ---------- UI --------------------------------------------------------------
examples = [
["zao kusaidia kuondoa umaskini na kujenga kampeni za mwamko wa virusi vya ukimwi amezitembelea"],
["Kidole hiki ni tofauti na vidole vingine kwa sababu mwelekeo wake ni wa pekee."],
["hivyo imekuwa msingi wa teknolojia yote ya umeme hasa nyaya za kila aina"],
["kumekuwa na majadiliano mengi juu ya usahihi wa ripoti hizi za madeni"],
["na kusaga ulipoanzia baada ya kumaliza masomo ndugu ruge mutahaba ndipo sasa mwishoni mwa"],
["Soko la Kariakoo huwa na watu wengi siku za Jumamosi."],
["Tafadhali hakikisha umefunga mlango kabla ya kuondoka."],
["Watoto walicheza mpira uwanjani hadi jua lilipotua."],
]
demo = gr.Interface(
fn=tts_generate,
inputs=gr.Textbox(lines=3, placeholder="Enter Swahili text here", label="Enter Swahili text here"),
outputs=gr.Audio(type="filepath", label="Audio", autoplay=True),
title="Swahili Text‑to‑Speech",
description=(
"Enter Swahili text and click **Submit** to play the audio"
),
examples=examples,
cache_examples=True,
)
if __name__ == "__main__":
demo.launch()