|
|
import os |
|
|
import io |
|
|
import uuid |
|
|
import time |
|
|
import json |
|
|
import logging |
|
|
import tempfile |
|
|
import threading |
|
|
|
|
|
from flask import Flask, request, jsonify, send_file |
|
|
from transformers import pipeline |
|
|
from gtts import gTTS |
|
|
from pydub import AudioSegment |
|
|
|
|
|
|
|
|
|
|
|
TEMP_AUDIO_DIR = "/tmp/audio" |
|
|
os.makedirs(TEMP_AUDIO_DIR, exist_ok=True) |
|
|
|
|
|
STT_MODEL = "openai/whisper-tiny" |
|
|
LLM_MODEL = "google/flan-t5-base" |
|
|
|
|
|
MAX_AUDIO_SECONDS = 10 |
|
|
MAX_TEXT_LEN = 200 |
|
|
|
|
|
CLEANUP_INTERVAL = 300 |
|
|
FILE_EXPIRE_TIME = 600 |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s | %(levelname)s | %(message)s" |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
app.config["TEMP_AUDIO_DIR"] = TEMP_AUDIO_DIR |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Loading STT model...") |
|
|
stt_pipeline = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model=STT_MODEL, |
|
|
device="cpu" |
|
|
) |
|
|
|
|
|
logger.info("Loading LLM model...") |
|
|
llm_pipeline = pipeline( |
|
|
"text2text-generation", |
|
|
model=LLM_MODEL, |
|
|
device="cpu" |
|
|
) |
|
|
|
|
|
logger.info("Models loaded successfully") |
|
|
|
|
|
|
|
|
|
|
|
def generate_tts_audio(text: str) -> bytes: |
|
|
""" |
|
|
Generate WAV 16kHz mono audio from text |
|
|
""" |
|
|
try: |
|
|
text = text.replace("\n", " ").strip() |
|
|
if not text: |
|
|
text = "I understand." |
|
|
|
|
|
text = text[:MAX_TEXT_LEN] |
|
|
logger.info(f"TTS: {text}") |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_file: |
|
|
mp3_path = wav_file.name.replace(".wav", ".mp3") |
|
|
|
|
|
tts = gTTS(text=text, lang="en") |
|
|
tts.save(mp3_path) |
|
|
|
|
|
audio = AudioSegment.from_file(mp3_path) |
|
|
audio = audio.set_frame_rate(16000).set_channels(1) |
|
|
audio.export(wav_file.name, format="wav") |
|
|
|
|
|
with open(wav_file.name, "rb") as f: |
|
|
wav_data = f.read() |
|
|
|
|
|
os.remove(mp3_path) |
|
|
os.remove(wav_file.name) |
|
|
|
|
|
return wav_data |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"TTS error: {e}", exc_info=True) |
|
|
return b"" |
|
|
|
|
|
|
|
|
def cleanup_temp_files(): |
|
|
while True: |
|
|
try: |
|
|
now = time.time() |
|
|
for filename in os.listdir(TEMP_AUDIO_DIR): |
|
|
path = os.path.join(TEMP_AUDIO_DIR, filename) |
|
|
if os.path.isfile(path): |
|
|
if now - os.path.getmtime(path) > FILE_EXPIRE_TIME: |
|
|
os.remove(path) |
|
|
except Exception as e: |
|
|
logger.warning(f"Cleanup error: {e}") |
|
|
|
|
|
time.sleep(CLEANUP_INTERVAL) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route("/health", methods=["GET"]) |
|
|
def health(): |
|
|
return jsonify({ |
|
|
"status": "ok", |
|
|
"stt": STT_MODEL, |
|
|
"llm": LLM_MODEL |
|
|
}) |
|
|
|
|
|
|
|
|
@app.route("/process_audio", methods=["POST"]) |
|
|
def process_audio(): |
|
|
try: |
|
|
if "audio" not in request.files: |
|
|
return jsonify({"error": "No audio file"}), 400 |
|
|
|
|
|
audio_file = request.files["audio"] |
|
|
raw_audio = audio_file.read() |
|
|
|
|
|
if len(raw_audio) < 1000: |
|
|
return jsonify({"error": "Audio too short"}), 400 |
|
|
|
|
|
|
|
|
logger.info("Running STT...") |
|
|
stt_result = stt_pipeline( |
|
|
raw_audio, |
|
|
sampling_rate=16000 |
|
|
) |
|
|
|
|
|
user_text = stt_result.get("text", "").strip() |
|
|
logger.info(f"User said: {user_text}") |
|
|
|
|
|
if not user_text: |
|
|
user_text = "Hello" |
|
|
|
|
|
|
|
|
logger.info("Running LLM...") |
|
|
llm_result = llm_pipeline( |
|
|
user_text, |
|
|
max_new_tokens=64, |
|
|
do_sample=False |
|
|
) |
|
|
|
|
|
answer = llm_result[0]["generated_text"] |
|
|
logger.info(f"Answer: {answer}") |
|
|
|
|
|
|
|
|
audio_response = generate_tts_audio(answer) |
|
|
|
|
|
if not audio_response: |
|
|
return jsonify({"error": "TTS failed"}), 500 |
|
|
|
|
|
file_id = str(uuid.uuid4()) |
|
|
filepath = os.path.join(TEMP_AUDIO_DIR, f"{file_id}.wav") |
|
|
|
|
|
with open(filepath, "wb") as f: |
|
|
f.write(audio_response) |
|
|
|
|
|
return send_file( |
|
|
filepath, |
|
|
mimetype="audio/wav", |
|
|
as_attachment=False, |
|
|
download_name="response.wav" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Processing error: {e}", exc_info=True) |
|
|
return jsonify({"error": "Internal error"}), 500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
threading.Thread(target=cleanup_temp_files, daemon=True).start() |
|
|
|
|
|
app.run( |
|
|
host="0.0.0.0", |
|
|
port=7860, |
|
|
threaded=True |
|
|
) |
|
|
|