Translator / transcribe /whisper_llm_serve.py
daihui.zhang
add vad pipeline
fca9809
raw
history blame
9.59 kB
import soundfile
import multiprocessing as mp
import numpy as np
from logging import getLogger
from .utils import save_to_wave
import time
import json
import threading
from .server import ServeClientBase
import queue
import collections
from api_model import TransResult, Message
from .utils import log_block
from .translatepipes import TranslatePipes
from .strategy import TripleTextBuffer, SegmentManager, segments_split, sequences_split
logger = getLogger("TranslatorApp")
translate_pipes = TranslatePipes()
translate_pipes.wait_ready()
logger.info("Pipeline is ready.")
class PyWhiperCppServe(ServeClientBase):
def __init__(self, websocket, language=None, dst_lang=None, client_uid=None,):
super().__init__(client_uid, websocket)
self.language = language
self.dst_lang = dst_lang # 目标翻译语言
# 设置观察字符串 对比上下次的文字来判断字符串的输出是否固定
self._text_buffer = TripleTextBuffer()
# 存储转录数据
self._segment_manager = SegmentManager()
self.lock = threading.Lock()
self.frames_np = None
self._frame_queue = queue.Queue()
self.sample_rate = 16000
self.send_ready_state()
self.run_in_thread(self.speech_to_text)
self.run_in_thread(self.get_frame_from_queue)
self.text_sep = "" if self.language == "zh" else " "
def run_in_thread(self, func):
t = threading.Thread(target=func)
t.daemon = True
t.start()
def send_ready_state(self):
self.websocket.send(json.dumps({
"uid": self.client_uid,
"message": self.SERVER_READY,
"backend": "pywhispercpp"
}))
def set_lang(self, src_lang, dst_lang):
self.language = src_lang
self.dst_lang = dst_lang
def add_frames(self, frame_np):
self._frame_queue.put(frame_np)
def vad_merge(self):
with self.lock:
frame = self.frames_np.copy()
item = translate_pipes.voice_detect(frame.tobytes())
if item.audio != b'':
frame_np = np.frombuffer(item.audio, dtype=np.float32)
self.frames_np = frame_np.copy()
def get_frame_from_queue(self,):
while True:
try:
frame_np = self._frame_queue.get(timeout=0.1)
with self.lock:
if self.frames_np is None:
self.frames_np = frame_np.copy()
else:
self.frames_np = np.append(self.frames_np,frame_np)
except queue.Empty:
pass
def update_audio_buffer(self, last_offset):
with self.lock:
self.frames_np = self.frames_np[last_offset:]
def transcribe_audio(self, audio_buffer):
"""
Transcribe the audio chunk and send the results to the client.
Args:
audio_buffer (np.array): The audio chunk to transcribe.
"""
log_block("Audio buffer length", f"{audio_buffer.shape[0]/self.sample_rate:.2f}", "s")
start_time = time.perf_counter()
item = translate_pipes.transcrible(audio_buffer.tobytes(), self.language)
segments = item.segments
log_block("Whisper transcrible time", f"{(time.perf_counter() - start_time):.3f}", "s")
return segments
def translate_text(self, text):
"""
translate the text to dst lang"""
# return "sample english"
log_block("LLM translate input", f"{text}")
start_time = time.perf_counter()
ret = translate_pipes.translate(text, self.language, self.dst_lang)
translated_text = ret.translate_content
log_block("LLM translate time", f"{(time.perf_counter() - start_time):.3f}", "s")
log_block("LLM translate out", f"{translated_text}")
return translated_text
def analysis_segments(self, segments, audio_buffer: np.ndarray):
# 找到第一个标点符号作为锚点 左边为确认段,右边为观察段,
# 当左边确认后,右边段才会进入观察
# 当左边确认后,会从缓冲区中删除对应的buffer,减少下次输入的数据量
left_watch_idx, left_watch_sequences, right_watch_sequences, is_end_sentence = segments_split(segments, audio_buffer)
left_watch_string = self.text_sep.join(i.text for i in left_watch_sequences)
right_watch_string = self.text_sep.join(i.text for i in right_watch_sequences)
if left_watch_idx != 0:
# 将观察字符串临时存储
self._text_buffer.add_entry(left_watch_string, left_watch_idx)
audio_cut_index = self._text_buffer.get_final_index()
if audio_cut_index:
return audio_cut_index, left_watch_string, right_watch_string, is_end_sentence
# 整句消除 后两句之前的内容
left_watch_idx, left_watch_sequences, right_watch_sequences, is_end_sentence = sequences_split(segments, audio_buffer)
left_watch_string = self.text_sep.join(i.text for i in left_watch_sequences)
right_watch_string = self.text_sep.join(i.text for i in right_watch_sequences)
if left_watch_idx != 0:
return left_watch_idx, left_watch_string, right_watch_string, is_end_sentence
return None, left_watch_string, right_watch_string, is_end_sentence
def speech_to_text(self):
# c = 0
while True:
if self.exit:
logger.info("Exiting speech to text thread")
break
if self.frames_np is None:
time.sleep(0.02) # wait for any audio to arrive
continue
audio_buffer = self.get_audio_chunk_for_processing()
if audio_buffer.shape[0] < self.sample_rate * 2:
time.sleep(0.02)
continue
# c+= 1
# name = f"dev-{c}.wav"
# save_to_wave(name, audio_buffer)
# try:
logger.info(f"Audio buffer length: {len(audio_buffer) / self.sample_rate:.2f}s")
segments = self.transcribe_audio(audio_buffer)
for tran_result in self.handle_transcription_output(segments, audio_buffer):
self.send_to_client(tran_result)
# except KeyboardInterrupt:
# break
# except Exception as e:
# logger.error(f"{e}")
def handle_transcription_output(self, segments, audio_buffer):
texts = self.text_sep.join(i.text for i in segments)
if not len(texts):
return
self._segment_manager.handle(texts)
# 分析句子
last_cut_index, left_string, right_string, is_end_sentence = self.analysis_segments(segments, audio_buffer)
# print(last_cut_index, left_string, right_string, is_end_sentence)
if last_cut_index:
self.update_audio_buffer(last_cut_index)
# 句子或者短句的提交
self._segment_manager.handle(left_string).commit(is_end_sentence)
self._segment_manager.handle(right_string)
if is_end_sentence and last_cut_index:
message = self._segment_manager.segment
seg_id = self._segment_manager.get_seg_id() - 1
# logger.info(f"{seg_id}, {message}")
yield TransResult(
seg_id=seg_id,
context=message,
from_=self.language,
to=self.dst_lang,
tran_content=self.translate_text(message),
partial=False
)
if self._segment_manager.string.strip():
message = self._segment_manager.string.strip()
# logger.info(f"{seg_id + 1}, {message}")
yield TransResult(
seg_id=seg_id+1,
context=self._segment_manager.string,
from_=self.language,
to=self.dst_lang,
tran_content=self.translate_text(message),
)
else:
seg_id = self._segment_manager.get_seg_id()
message = self._segment_manager.short_sentence + self._segment_manager.string
# logger.info(f"{seg_id}, {message}")
yield TransResult(
seg_id=seg_id,
context=message,
from_=self.language,
to=self.dst_lang,
tran_content=self.translate_text(message),
)
def send_to_client(self, data:TransResult):
try:
self.websocket.send(
Message(result=data, request_id=self.client_uid).model_dump_json(by_alias=True)
)
except Exception as e:
logger.error(f"Sending data to client: {e}")
def get_audio_chunk_for_processing(self):
if self.frames_np.shape[0] >= self.sample_rate * 1:
return self.frames_np.copy()
self.vad_merge()
# 计算需要填充的样本数
padding_length = self.sample_rate * 1 - len(self.frames_np)
# 创建静音填充(零值)
silence = np.zeros(padding_length + int(0.01 * self.sample_rate), dtype=np.float32)
# 拼接原始音频和静音填充
padded_audio = np.concatenate([silence, self.frames_np])
return padded_audio.copy()
def cleanup(self):
return super().cleanup()#