chatbot / app.py
vuvanhung's picture
Create app.py
67ca15b verified
import os
import io
import uuid
import time
import json
import logging
import tempfile
import threading
from flask import Flask, request, jsonify, send_file
from transformers import pipeline
from gtts import gTTS
from pydub import AudioSegment
# ================= CONFIG =================
TEMP_AUDIO_DIR = "/tmp/audio"
os.makedirs(TEMP_AUDIO_DIR, exist_ok=True)
STT_MODEL = "openai/whisper-tiny"
LLM_MODEL = "google/flan-t5-base"
MAX_AUDIO_SECONDS = 10
MAX_TEXT_LEN = 200
CLEANUP_INTERVAL = 300 # seconds
FILE_EXPIRE_TIME = 600 # seconds
# ================= LOG =================
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s"
)
logger = logging.getLogger(__name__)
# ================= APP =================
app = Flask(__name__)
app.config["TEMP_AUDIO_DIR"] = TEMP_AUDIO_DIR
# ================= LOAD MODELS =================
logger.info("Loading STT model...")
stt_pipeline = pipeline(
"automatic-speech-recognition",
model=STT_MODEL,
device="cpu"
)
logger.info("Loading LLM model...")
llm_pipeline = pipeline(
"text2text-generation",
model=LLM_MODEL,
device="cpu"
)
logger.info("Models loaded successfully")
# ================= UTILS =================
def generate_tts_audio(text: str) -> bytes:
"""
Generate WAV 16kHz mono audio from text
"""
try:
text = text.replace("\n", " ").strip()
if not text:
text = "I understand."
text = text[:MAX_TEXT_LEN]
logger.info(f"TTS: {text}")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_file:
mp3_path = wav_file.name.replace(".wav", ".mp3")
tts = gTTS(text=text, lang="en")
tts.save(mp3_path)
audio = AudioSegment.from_file(mp3_path)
audio = audio.set_frame_rate(16000).set_channels(1)
audio.export(wav_file.name, format="wav")
with open(wav_file.name, "rb") as f:
wav_data = f.read()
os.remove(mp3_path)
os.remove(wav_file.name)
return wav_data
except Exception as e:
logger.error(f"TTS error: {e}", exc_info=True)
return b""
def cleanup_temp_files():
while True:
try:
now = time.time()
for filename in os.listdir(TEMP_AUDIO_DIR):
path = os.path.join(TEMP_AUDIO_DIR, filename)
if os.path.isfile(path):
if now - os.path.getmtime(path) > FILE_EXPIRE_TIME:
os.remove(path)
except Exception as e:
logger.warning(f"Cleanup error: {e}")
time.sleep(CLEANUP_INTERVAL)
# ================= ROUTES =================
@app.route("/health", methods=["GET"])
def health():
return jsonify({
"status": "ok",
"stt": STT_MODEL,
"llm": LLM_MODEL
})
@app.route("/process_audio", methods=["POST"])
def process_audio():
try:
if "audio" not in request.files:
return jsonify({"error": "No audio file"}), 400
audio_file = request.files["audio"]
raw_audio = audio_file.read()
if len(raw_audio) < 1000:
return jsonify({"error": "Audio too short"}), 400
# ================= STT =================
logger.info("Running STT...")
stt_result = stt_pipeline(
raw_audio,
sampling_rate=16000
)
user_text = stt_result.get("text", "").strip()
logger.info(f"User said: {user_text}")
if not user_text:
user_text = "Hello"
# ================= LLM =================
logger.info("Running LLM...")
llm_result = llm_pipeline(
user_text,
max_new_tokens=64,
do_sample=False
)
answer = llm_result[0]["generated_text"]
logger.info(f"Answer: {answer}")
# ================= TTS =================
audio_response = generate_tts_audio(answer)
if not audio_response:
return jsonify({"error": "TTS failed"}), 500
file_id = str(uuid.uuid4())
filepath = os.path.join(TEMP_AUDIO_DIR, f"{file_id}.wav")
with open(filepath, "wb") as f:
f.write(audio_response)
return send_file(
filepath,
mimetype="audio/wav",
as_attachment=False,
download_name="response.wav"
)
except Exception as e:
logger.error(f"Processing error: {e}", exc_info=True)
return jsonify({"error": "Internal error"}), 500
# ================= STARTUP =================
if __name__ == "__main__":
threading.Thread(target=cleanup_temp_files, daemon=True).start()
app.run(
host="0.0.0.0",
port=7860,
threaded=True
)