daihui.zhang
commited on
Commit
·
abe0fe2
1
Parent(s):
d15b373
add llm
Browse files- run_client.py +1 -0
- transcribe/client.py +6 -4
- transcribe/transcription.py +12 -0
- transcribe/translator.py +2 -3
- transcribe/whisper_llm_serve.py +22 -13
run_client.py
CHANGED
|
@@ -4,6 +4,7 @@ client = TranscriptionClient(
|
|
| 4 |
"localhost",
|
| 5 |
9090,
|
| 6 |
lang="zh",
|
|
|
|
| 7 |
save_output_recording=False, # Only used for microphone input, False by Default
|
| 8 |
output_recording_filename="./output_recording.wav", # Only used for microphone input
|
| 9 |
max_clients=4,
|
|
|
|
| 4 |
"localhost",
|
| 5 |
9090,
|
| 6 |
lang="zh",
|
| 7 |
+
dst_lang="en",
|
| 8 |
save_output_recording=False, # Only used for microphone input, False by Default
|
| 9 |
output_recording_filename="./output_recording.wav", # Only used for microphone input
|
| 10 |
max_clients=4,
|
transcribe/client.py
CHANGED
|
@@ -29,6 +29,7 @@ class Client:
|
|
| 29 |
log_transcription=True,
|
| 30 |
max_clients=4,
|
| 31 |
max_connection_time=600,
|
|
|
|
| 32 |
):
|
| 33 |
"""
|
| 34 |
Initializes a Client instance for audio recording and streaming to a server.
|
|
@@ -56,12 +57,12 @@ class Client:
|
|
| 56 |
self.log_transcription = log_transcription
|
| 57 |
self.max_clients = max_clients
|
| 58 |
self.max_connection_time = max_connection_time
|
| 59 |
-
|
| 60 |
|
| 61 |
self.audio_bytes = None
|
| 62 |
|
| 63 |
if host is not None and port is not None:
|
| 64 |
-
socket_url = f"ws://{host}:{port}"
|
| 65 |
self.client_socket = websocket.WebSocketApp(
|
| 66 |
socket_url,
|
| 67 |
on_open=lambda ws: self.on_open(ws),
|
|
@@ -657,10 +658,11 @@ class TranscriptionClient(TranscriptionTeeClient):
|
|
| 657 |
max_clients=4,
|
| 658 |
max_connection_time=600,
|
| 659 |
mute_audio_playback=False,
|
|
|
|
| 660 |
):
|
| 661 |
self.client = Client(
|
| 662 |
host, port, lang, log_transcription=log_transcription, max_clients=max_clients,
|
| 663 |
-
max_connection_time=max_connection_time
|
| 664 |
)
|
| 665 |
|
| 666 |
if save_output_recording and not output_recording_filename.endswith(".wav"):
|
|
@@ -671,5 +673,5 @@ class TranscriptionClient(TranscriptionTeeClient):
|
|
| 671 |
[self.client],
|
| 672 |
save_output_recording=save_output_recording,
|
| 673 |
output_recording_filename=output_recording_filename,
|
| 674 |
-
mute_audio_playback=mute_audio_playback
|
| 675 |
)
|
|
|
|
| 29 |
log_transcription=True,
|
| 30 |
max_clients=4,
|
| 31 |
max_connection_time=600,
|
| 32 |
+
dst_lang='zh',
|
| 33 |
):
|
| 34 |
"""
|
| 35 |
Initializes a Client instance for audio recording and streaming to a server.
|
|
|
|
| 57 |
self.log_transcription = log_transcription
|
| 58 |
self.max_clients = max_clients
|
| 59 |
self.max_connection_time = max_connection_time
|
| 60 |
+
self.dst_lang = dst_lang
|
| 61 |
|
| 62 |
self.audio_bytes = None
|
| 63 |
|
| 64 |
if host is not None and port is not None:
|
| 65 |
+
socket_url = f"ws://{host}:{port}?from={self.language}&to={self.dst_lang}"
|
| 66 |
self.client_socket = websocket.WebSocketApp(
|
| 67 |
socket_url,
|
| 68 |
on_open=lambda ws: self.on_open(ws),
|
|
|
|
| 658 |
max_clients=4,
|
| 659 |
max_connection_time=600,
|
| 660 |
mute_audio_playback=False,
|
| 661 |
+
dst_lang='en',
|
| 662 |
):
|
| 663 |
self.client = Client(
|
| 664 |
host, port, lang, log_transcription=log_transcription, max_clients=max_clients,
|
| 665 |
+
max_connection_time=max_connection_time, dst_lang=dst_lang
|
| 666 |
)
|
| 667 |
|
| 668 |
if save_output_recording and not output_recording_filename.endswith(".wav"):
|
|
|
|
| 673 |
[self.client],
|
| 674 |
save_output_recording=save_output_recording,
|
| 675 |
output_recording_filename=output_recording_filename,
|
| 676 |
+
mute_audio_playback=mute_audio_playback,
|
| 677 |
)
|
transcribe/transcription.py
CHANGED
|
@@ -10,6 +10,7 @@ import numpy as np
|
|
| 10 |
from .server import ServeClientBase
|
| 11 |
from .whisper_llm_serve import PyWhiperCppServe
|
| 12 |
from .vad import VoiceActivityDetector
|
|
|
|
| 13 |
from websockets.exceptions import ConnectionClosed
|
| 14 |
from websockets.sync.server import serve
|
| 15 |
|
|
@@ -226,6 +227,11 @@ class TranscriptionServer:
|
|
| 226 |
|
| 227 |
client.add_frames(frame_np)
|
| 228 |
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
def recv_audio(self,
|
| 231 |
websocket,
|
|
@@ -234,6 +240,12 @@ class TranscriptionServer:
|
|
| 234 |
self.backend = backend
|
| 235 |
if not self.handle_new_connection(websocket):
|
| 236 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
try:
|
| 239 |
while not self.client_manager.is_client_timeout(websocket):
|
|
|
|
| 10 |
from .server import ServeClientBase
|
| 11 |
from .whisper_llm_serve import PyWhiperCppServe
|
| 12 |
from .vad import VoiceActivityDetector
|
| 13 |
+
from urllib.parse import urlparse, parse_qsl
|
| 14 |
from websockets.exceptions import ConnectionClosed
|
| 15 |
from websockets.sync.server import serve
|
| 16 |
|
|
|
|
| 227 |
|
| 228 |
client.add_frames(frame_np)
|
| 229 |
return True
|
| 230 |
+
|
| 231 |
+
def set_lang(self, websocket, src_lang, dst_lang):
|
| 232 |
+
client = self.client_manager.get_client(websocket)
|
| 233 |
+
if isinstance(client, PyWhiperCppServe):
|
| 234 |
+
client.set_lang(src_lang, dst_lang)
|
| 235 |
|
| 236 |
def recv_audio(self,
|
| 237 |
websocket,
|
|
|
|
| 240 |
self.backend = backend
|
| 241 |
if not self.handle_new_connection(websocket):
|
| 242 |
return
|
| 243 |
+
query_parameters_dict = dict(parse_qsl(urlparse(websocket.request.path).query))
|
| 244 |
+
from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to')
|
| 245 |
+
|
| 246 |
+
if from_lang and to_lang:
|
| 247 |
+
self.set_lang(websocket, from_lang, to_lang)
|
| 248 |
+
logging.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
|
| 249 |
|
| 250 |
try:
|
| 251 |
while not self.client_manager.is_client_timeout(websocket):
|
transcribe/translator.py
CHANGED
|
@@ -28,13 +28,12 @@ class QwenTranslator:
|
|
| 28 |
message = self.to_message(prompt, src_lang, dst_lang)
|
| 29 |
start_time = time.monotonic()
|
| 30 |
output = self.llm.create_chat_completion(messages=message, temperature=0.9)
|
| 31 |
-
logger.info(f"LLM
|
| 32 |
return output['choices'][0]['message']['content']
|
| 33 |
|
| 34 |
-
def __call__(self, prompt
|
| 35 |
return self.llm(
|
| 36 |
prompt,
|
| 37 |
*args,
|
| 38 |
-
max_tokens=max_tokens,
|
| 39 |
**kwargs
|
| 40 |
)
|
|
|
|
| 28 |
message = self.to_message(prompt, src_lang, dst_lang)
|
| 29 |
start_time = time.monotonic()
|
| 30 |
output = self.llm.create_chat_completion(messages=message, temperature=0.9)
|
| 31 |
+
logger.info(f"LLM inference time: {time.monotonic() - start_time:.2f}s.")
|
| 32 |
return output['choices'][0]['message']['content']
|
| 33 |
|
| 34 |
+
def __call__(self, prompt,*args, **kwargs):
|
| 35 |
return self.llm(
|
| 36 |
prompt,
|
| 37 |
*args,
|
|
|
|
| 38 |
**kwargs
|
| 39 |
)
|
transcribe/whisper_llm_serve.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
|
| 3 |
import soundfile
|
| 4 |
from concurrent.futures import ProcessPoolExecutor as Pool
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
from logging import getLogger
|
| 7 |
from difflib import SequenceMatcher
|
|
@@ -10,7 +11,7 @@ import config
|
|
| 10 |
import time
|
| 11 |
import json
|
| 12 |
import threading
|
| 13 |
-
|
| 14 |
from .server import ServeClientBase
|
| 15 |
from .translator import QwenTranslator
|
| 16 |
from pywhispercpp.model import Model
|
|
@@ -109,7 +110,7 @@ class PywhisperInference:
|
|
| 109 |
llm_model = None
|
| 110 |
|
| 111 |
@classmethod
|
| 112 |
-
def initializer(cls, warmup=True):
|
| 113 |
models_dir = config.MODEL_DIR.as_posix()
|
| 114 |
cls.whisper_model = Model(
|
| 115 |
model=config.WHISPER_MODEL,
|
|
@@ -123,6 +124,7 @@ class PywhisperInference:
|
|
| 123 |
|
| 124 |
# init llamacpp
|
| 125 |
cls.llm_model = QwenTranslator(config.LLM_MODEL_PATH, config.LLM_SYS_PROMPT)
|
|
|
|
| 126 |
|
| 127 |
@classmethod
|
| 128 |
def warmup(cls, warmup_steps=1):
|
|
@@ -170,20 +172,30 @@ class PyWhiperCppServe(ServeClientBase):
|
|
| 170 |
self.frames_np = None
|
| 171 |
self.sample_rate = 16000
|
| 172 |
|
|
|
|
| 173 |
self._pool = Pool(
|
| 174 |
-
max_workers=1, initializer=PywhisperInference.initializer)
|
| 175 |
|
| 176 |
logger.info('Create a process to process audio.')
|
| 177 |
self.trans_thread = threading.Thread(target=self.speech_to_text)
|
| 178 |
self.trans_thread.daemon = True
|
| 179 |
self.trans_thread.start()
|
|
|
|
| 180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
self.websocket.send(json.dumps({
|
| 182 |
"uid": self.client_uid,
|
| 183 |
"message": self.SERVER_READY,
|
| 184 |
"backend": "pywhispercpp"
|
| 185 |
}))
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
def add_frames(self, frame_np):
|
| 188 |
with self.lock:
|
| 189 |
if self.frames_np is None:
|
|
@@ -206,7 +218,6 @@ class PyWhiperCppServe(ServeClientBase):
|
|
| 206 |
transcribe_fut = self._pool.submit(
|
| 207 |
PywhisperInference.inference, audio_buffer.tobytes(), self.language)
|
| 208 |
segments = transcribe_fut.result()
|
| 209 |
-
|
| 210 |
return segments
|
| 211 |
|
| 212 |
def translate_text(self, text):
|
|
@@ -215,8 +226,6 @@ class PyWhiperCppServe(ServeClientBase):
|
|
| 215 |
translate_fut = self._pool.submit(
|
| 216 |
PywhisperInference.translate, text, self.language, self.dst_lang)
|
| 217 |
return translate_fut.result()
|
| 218 |
-
|
| 219 |
-
|
| 220 |
|
| 221 |
def _segments_split(self, segments, audio_buffer: np.ndarray):
|
| 222 |
"""根据左边第一个标点符号来将序列拆分成 观察段 和 剩余部分"""
|
|
@@ -278,15 +287,14 @@ class PyWhiperCppServe(ServeClientBase):
|
|
| 278 |
# logger.info(f"[pywhispercpp:] Processing audio with duration: {len(audio_buffer)}")
|
| 279 |
# segments = self.transcribe_audio(audio_buffer)
|
| 280 |
try:
|
| 281 |
-
logger.info(f"
|
| 282 |
segments = self.transcribe_audio(audio_buffer)
|
|
|
|
|
|
|
| 283 |
except KeyboardInterrupt:
|
| 284 |
break
|
| 285 |
except Exception as e:
|
| 286 |
-
logger.error(f"
|
| 287 |
-
else:
|
| 288 |
-
for item in self.handle_transcription_output(segments, audio_buffer):
|
| 289 |
-
print(item)
|
| 290 |
|
| 291 |
|
| 292 |
|
|
@@ -306,7 +314,8 @@ class PyWhiperCppServe(ServeClientBase):
|
|
| 306 |
message = self._segment_manager.segment
|
| 307 |
seg_id = self._segment_manager.get_seg_id() - 1
|
| 308 |
yield (seg_id, message, self.translate_text(message))
|
| 309 |
-
|
|
|
|
| 310 |
|
| 311 |
else:
|
| 312 |
seg_id = self._segment_manager.get_seg_id()
|
|
@@ -326,7 +335,7 @@ class PyWhiperCppServe(ServeClientBase):
|
|
| 326 |
json.dumps(content)
|
| 327 |
)
|
| 328 |
except Exception as e:
|
| 329 |
-
logger.error(f"
|
| 330 |
|
| 331 |
def get_audio_chunk_for_processing(self):
|
| 332 |
if self.frames_np.shape[0] >= self.sample_rate * 1:
|
|
|
|
| 2 |
|
| 3 |
import soundfile
|
| 4 |
from concurrent.futures import ProcessPoolExecutor as Pool
|
| 5 |
+
import multiprocessing as mp
|
| 6 |
import numpy as np
|
| 7 |
from logging import getLogger
|
| 8 |
from difflib import SequenceMatcher
|
|
|
|
| 11 |
import time
|
| 12 |
import json
|
| 13 |
import threading
|
| 14 |
+
from functools import partial
|
| 15 |
from .server import ServeClientBase
|
| 16 |
from .translator import QwenTranslator
|
| 17 |
from pywhispercpp.model import Model
|
|
|
|
| 110 |
llm_model = None
|
| 111 |
|
| 112 |
@classmethod
|
| 113 |
+
def initializer(cls, event:mp.Event, warmup=True):
|
| 114 |
models_dir = config.MODEL_DIR.as_posix()
|
| 115 |
cls.whisper_model = Model(
|
| 116 |
model=config.WHISPER_MODEL,
|
|
|
|
| 124 |
|
| 125 |
# init llamacpp
|
| 126 |
cls.llm_model = QwenTranslator(config.LLM_MODEL_PATH, config.LLM_SYS_PROMPT)
|
| 127 |
+
event.set()
|
| 128 |
|
| 129 |
@classmethod
|
| 130 |
def warmup(cls, warmup_steps=1):
|
|
|
|
| 172 |
self.frames_np = None
|
| 173 |
self.sample_rate = 16000
|
| 174 |
|
| 175 |
+
self._ready_state = mp.Event()
|
| 176 |
self._pool = Pool(
|
| 177 |
+
max_workers=1, initializer=partial(PywhisperInference.initializer, event=self._ready_state))
|
| 178 |
|
| 179 |
logger.info('Create a process to process audio.')
|
| 180 |
self.trans_thread = threading.Thread(target=self.speech_to_text)
|
| 181 |
self.trans_thread.daemon = True
|
| 182 |
self.trans_thread.start()
|
| 183 |
+
self.send_ready_state()
|
| 184 |
|
| 185 |
+
def send_ready_state(self):
|
| 186 |
+
while not self._ready_state:
|
| 187 |
+
time.sleep(0.1)
|
| 188 |
+
|
| 189 |
self.websocket.send(json.dumps({
|
| 190 |
"uid": self.client_uid,
|
| 191 |
"message": self.SERVER_READY,
|
| 192 |
"backend": "pywhispercpp"
|
| 193 |
}))
|
| 194 |
|
| 195 |
+
def set_lang(self, src_lang, dst_lang):
|
| 196 |
+
self.language = src_lang
|
| 197 |
+
self.dst_lang = dst_lang
|
| 198 |
+
|
| 199 |
def add_frames(self, frame_np):
|
| 200 |
with self.lock:
|
| 201 |
if self.frames_np is None:
|
|
|
|
| 218 |
transcribe_fut = self._pool.submit(
|
| 219 |
PywhisperInference.inference, audio_buffer.tobytes(), self.language)
|
| 220 |
segments = transcribe_fut.result()
|
|
|
|
| 221 |
return segments
|
| 222 |
|
| 223 |
def translate_text(self, text):
|
|
|
|
| 226 |
translate_fut = self._pool.submit(
|
| 227 |
PywhisperInference.translate, text, self.language, self.dst_lang)
|
| 228 |
return translate_fut.result()
|
|
|
|
|
|
|
| 229 |
|
| 230 |
def _segments_split(self, segments, audio_buffer: np.ndarray):
|
| 231 |
"""根据左边第一个标点符号来将序列拆分成 观察段 和 剩余部分"""
|
|
|
|
| 287 |
# logger.info(f"[pywhispercpp:] Processing audio with duration: {len(audio_buffer)}")
|
| 288 |
# segments = self.transcribe_audio(audio_buffer)
|
| 289 |
try:
|
| 290 |
+
logger.info(f"Processing audio with duration: {len(audio_buffer) / self.sample_rate:.2f}s")
|
| 291 |
segments = self.transcribe_audio(audio_buffer)
|
| 292 |
+
for item in self.handle_transcription_output(segments, audio_buffer):
|
| 293 |
+
print(item)
|
| 294 |
except KeyboardInterrupt:
|
| 295 |
break
|
| 296 |
except Exception as e:
|
| 297 |
+
logger.error(f"{e}")
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
|
| 300 |
|
|
|
|
| 314 |
message = self._segment_manager.segment
|
| 315 |
seg_id = self._segment_manager.get_seg_id() - 1
|
| 316 |
yield (seg_id, message, self.translate_text(message))
|
| 317 |
+
if self._segment_manager.string.strip():
|
| 318 |
+
yield (seg_id + 1, self._segment_manager.string, self.translate_text(self._segment_manager.string))
|
| 319 |
|
| 320 |
else:
|
| 321 |
seg_id = self._segment_manager.get_seg_id()
|
|
|
|
| 335 |
json.dumps(content)
|
| 336 |
)
|
| 337 |
except Exception as e:
|
| 338 |
+
logger.error(f"Sending data to client: {e}")
|
| 339 |
|
| 340 |
def get_audio_chunk_for_processing(self):
|
| 341 |
if self.frames_np.shape[0] >= self.sample_rate * 1:
|