|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import io |
|
|
import re |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
import numpy as np |
|
|
from pydub import AudioSegment |
|
|
import torch |
|
|
from gtts import gTTS |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
pipeline, |
|
|
) |
|
|
|
|
|
from .config import get_user_dir |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
QWEN_MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" |
|
|
|
|
|
|
|
|
CONTROL_PROMPTS = { |
|
|
"A1": "Use extremely short, simple sentences and very basic vocabulary.", |
|
|
"A2": "Use simple sentences and common everyday vocabulary.", |
|
|
"B1": "Use moderately complex sentences and conversational vocabulary.", |
|
|
"B2": "Use natural, fluent sentences with richer vocabulary.", |
|
|
"C1": "Use complex, advanced sentences with nuanced expressions.", |
|
|
"C2": "Use highly sophisticated, near-native language and style.", |
|
|
} |
|
|
|
|
|
|
|
|
GTTS_LANG = { |
|
|
"english": "en", |
|
|
"spanish": "es", |
|
|
"german": "de", |
|
|
"russian": "ru", |
|
|
"japanese": "ja", |
|
|
"chinese": "zh-cn", |
|
|
"korean": "ko", |
|
|
"french": "fr", |
|
|
"italian": "it", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_QWEN_TOKENIZER = None |
|
|
_QWEN_MODEL = None |
|
|
_WHISPER_PIPE = None |
|
|
|
|
|
|
|
|
def load_partner_lm(): |
|
|
"""Load Qwen conversational model once.""" |
|
|
global _QWEN_TOKENIZER, _QWEN_MODEL |
|
|
if _QWEN_MODEL is not None: |
|
|
return _QWEN_TOKENIZER, _QWEN_MODEL |
|
|
|
|
|
print("[conversation_core] loading:", QWEN_MODEL_NAME) |
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(QWEN_MODEL_NAME, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
QWEN_MODEL_NAME, |
|
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
_QWEN_TOKENIZER = tok |
|
|
_QWEN_MODEL = model |
|
|
return tok, model |
|
|
|
|
|
|
|
|
def load_whisper_pipe(): |
|
|
"""Load Whisper ASR pipeline once.""" |
|
|
global _WHISPER_PIPE |
|
|
if _WHISPER_PIPE is not None: |
|
|
return _WHISPER_PIPE |
|
|
|
|
|
print("[conversation_core] loading Whisper pipeline…") |
|
|
_WHISPER_PIPE = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model="openai/whisper-small", |
|
|
device="cpu", |
|
|
) |
|
|
return _WHISPER_PIPE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ConversationTurn: |
|
|
role: str |
|
|
text: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_assistant_reply(text: str) -> str: |
|
|
"""Remove meta junk, labels, identity statements.""" |
|
|
if not text: |
|
|
return "" |
|
|
|
|
|
|
|
|
text = re.sub(r"(?i)\bassistant\s*:\s*", "", text) |
|
|
text = re.sub(r"(?i)\buser\s*:\s*", "", text) |
|
|
|
|
|
|
|
|
text = re.sub(r"(?m)^\s*[-•*]\s+.*$", "", text) |
|
|
text = re.sub(r"(?m)^\s*\d+\.\s+.*$", "", text) |
|
|
|
|
|
|
|
|
identity_patterns = [ |
|
|
r"(?i)i am an ai.*", |
|
|
r"(?i)i am a large language model.*", |
|
|
r"(?i)i was created.*", |
|
|
r"(?i)my name is .*", |
|
|
] |
|
|
for p in identity_patterns: |
|
|
text = re.sub(p, "", text) |
|
|
|
|
|
text = re.sub(r"\s{2,}", " ", text) |
|
|
return text.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConversationManager: |
|
|
def __init__( |
|
|
self, |
|
|
target_language="german", |
|
|
native_language="english", |
|
|
cefr_level="B1", |
|
|
topic="general conversation", |
|
|
): |
|
|
self.target_language = target_language.lower() |
|
|
self.native_language = native_language.lower() |
|
|
self.cefr_level = cefr_level.upper() |
|
|
self.topic = topic |
|
|
self.history: List[ConversationTurn] = [] |
|
|
|
|
|
load_partner_lm() |
|
|
load_whisper_pipe() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_system_prompt(self): |
|
|
base = ( |
|
|
f"You are a friendly conversation partner speaking {self.target_language}. " |
|
|
f"Reply ONLY in {self.target_language}. " |
|
|
f"Adapt your language to CEFR level {self.cefr_level}. " |
|
|
f"{CONTROL_PROMPTS.get(self.cefr_level, '')} " |
|
|
f"Topic of conversation: {self.topic}. " |
|
|
"Give 1–3 short natural sentences and ALWAYS end with 1 follow-up question. " |
|
|
"Never mention AI, assistants, grammar explanations, or meta commentary." |
|
|
) |
|
|
return base |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_lm(self, user_text: str) -> str: |
|
|
tok, model = load_partner_lm() |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": self._build_system_prompt()}, |
|
|
{"role": "user", "content": user_text}, |
|
|
] |
|
|
|
|
|
prompt = tok.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
enc = tok(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
**enc, |
|
|
max_new_tokens=160, |
|
|
temperature=0.8, |
|
|
top_p=0.95, |
|
|
repetition_penalty=1.15, |
|
|
do_sample=True, |
|
|
pad_token_id=tok.eos_token_id, |
|
|
) |
|
|
|
|
|
raw = tok.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
cleaned = clean_assistant_reply(raw) |
|
|
return cleaned |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reply(self, user_text: str, input_lang="german"): |
|
|
self.history.append(ConversationTurn("user", user_text)) |
|
|
|
|
|
assistant_text = self._generate_lm(user_text) |
|
|
self.history.append(ConversationTurn("assistant", assistant_text)) |
|
|
|
|
|
explanation = self._generate_explanation(assistant_text) |
|
|
audio_bytes = self.text_to_speech(assistant_text) |
|
|
|
|
|
return { |
|
|
"reply_text": assistant_text, |
|
|
"explanation": explanation, |
|
|
"audio": audio_bytes, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_explanation(self, assistant_text: str) -> str: |
|
|
tok, model = load_partner_lm() |
|
|
|
|
|
prompt = ( |
|
|
f"Rewrite the meaning of this {self.target_language} sentence " |
|
|
f"in ONE short {self.native_language} sentence:\n{assistant_text}" |
|
|
) |
|
|
|
|
|
enc = tok(prompt, return_tensors="pt").to(model.device) |
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
**enc, |
|
|
max_new_tokens=40, |
|
|
temperature=0.6, |
|
|
top_p=0.9, |
|
|
pad_token_id=tok.eos_token_id, |
|
|
) |
|
|
|
|
|
decoded = tok.decode(out[0], skip_special_tokens=True) |
|
|
cleaned = decoded.replace(prompt, "").strip() |
|
|
|
|
|
|
|
|
parts = re.split(r"(?<=[.!?])\s+", cleaned) |
|
|
return parts[0].strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transcribe(self, audio_segment, spoken_lang=None): |
|
|
"""Transcribe using Transformers Whisper.""" |
|
|
pipe = load_whisper_pipe() |
|
|
|
|
|
audio = np.array(audio_segment.get_array_of_samples()).astype("float32") |
|
|
audio = audio / max(np.max(np.abs(audio)), 1e-6) |
|
|
|
|
|
result = pipe(audio) |
|
|
text = result.get("text", "").strip() |
|
|
|
|
|
return text, spoken_lang or "unknown", 1.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def text_to_speech(self, text: str) -> Optional[bytes]: |
|
|
if not text: |
|
|
return None |
|
|
try: |
|
|
lang = GTTS_LANG.get(self.target_language, "en") |
|
|
tts = gTTS(text=text, lang=lang) |
|
|
buf = io.BytesIO() |
|
|
tts.write_to_fp(buf) |
|
|
return buf.getvalue() |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|