daihui.zhang commited on
Commit
abe0fe2
·
1 Parent(s): d15b373
run_client.py CHANGED
@@ -4,6 +4,7 @@ client = TranscriptionClient(
4
  "localhost",
5
  9090,
6
  lang="zh",
 
7
  save_output_recording=False, # Only used for microphone input, False by Default
8
  output_recording_filename="./output_recording.wav", # Only used for microphone input
9
  max_clients=4,
 
4
  "localhost",
5
  9090,
6
  lang="zh",
7
+ dst_lang="en",
8
  save_output_recording=False, # Only used for microphone input, False by Default
9
  output_recording_filename="./output_recording.wav", # Only used for microphone input
10
  max_clients=4,
transcribe/client.py CHANGED
@@ -29,6 +29,7 @@ class Client:
29
  log_transcription=True,
30
  max_clients=4,
31
  max_connection_time=600,
 
32
  ):
33
  """
34
  Initializes a Client instance for audio recording and streaming to a server.
@@ -56,12 +57,12 @@ class Client:
56
  self.log_transcription = log_transcription
57
  self.max_clients = max_clients
58
  self.max_connection_time = max_connection_time
59
-
60
 
61
  self.audio_bytes = None
62
 
63
  if host is not None and port is not None:
64
- socket_url = f"ws://{host}:{port}"
65
  self.client_socket = websocket.WebSocketApp(
66
  socket_url,
67
  on_open=lambda ws: self.on_open(ws),
@@ -657,10 +658,11 @@ class TranscriptionClient(TranscriptionTeeClient):
657
  max_clients=4,
658
  max_connection_time=600,
659
  mute_audio_playback=False,
 
660
  ):
661
  self.client = Client(
662
  host, port, lang, log_transcription=log_transcription, max_clients=max_clients,
663
- max_connection_time=max_connection_time
664
  )
665
 
666
  if save_output_recording and not output_recording_filename.endswith(".wav"):
@@ -671,5 +673,5 @@ class TranscriptionClient(TranscriptionTeeClient):
671
  [self.client],
672
  save_output_recording=save_output_recording,
673
  output_recording_filename=output_recording_filename,
674
- mute_audio_playback=mute_audio_playback
675
  )
 
29
  log_transcription=True,
30
  max_clients=4,
31
  max_connection_time=600,
32
+ dst_lang='zh',
33
  ):
34
  """
35
  Initializes a Client instance for audio recording and streaming to a server.
 
57
  self.log_transcription = log_transcription
58
  self.max_clients = max_clients
59
  self.max_connection_time = max_connection_time
60
+ self.dst_lang = dst_lang
61
 
62
  self.audio_bytes = None
63
 
64
  if host is not None and port is not None:
65
+ socket_url = f"ws://{host}:{port}?from={self.language}&to={self.dst_lang}"
66
  self.client_socket = websocket.WebSocketApp(
67
  socket_url,
68
  on_open=lambda ws: self.on_open(ws),
 
658
  max_clients=4,
659
  max_connection_time=600,
660
  mute_audio_playback=False,
661
+ dst_lang='en',
662
  ):
663
  self.client = Client(
664
  host, port, lang, log_transcription=log_transcription, max_clients=max_clients,
665
+ max_connection_time=max_connection_time, dst_lang=dst_lang
666
  )
667
 
668
  if save_output_recording and not output_recording_filename.endswith(".wav"):
 
673
  [self.client],
674
  save_output_recording=save_output_recording,
675
  output_recording_filename=output_recording_filename,
676
+ mute_audio_playback=mute_audio_playback,
677
  )
transcribe/transcription.py CHANGED
@@ -10,6 +10,7 @@ import numpy as np
10
  from .server import ServeClientBase
11
  from .whisper_llm_serve import PyWhiperCppServe
12
  from .vad import VoiceActivityDetector
 
13
  from websockets.exceptions import ConnectionClosed
14
  from websockets.sync.server import serve
15
 
@@ -226,6 +227,11 @@ class TranscriptionServer:
226
 
227
  client.add_frames(frame_np)
228
  return True
 
 
 
 
 
229
 
230
  def recv_audio(self,
231
  websocket,
@@ -234,6 +240,12 @@ class TranscriptionServer:
234
  self.backend = backend
235
  if not self.handle_new_connection(websocket):
236
  return
 
 
 
 
 
 
237
 
238
  try:
239
  while not self.client_manager.is_client_timeout(websocket):
 
10
  from .server import ServeClientBase
11
  from .whisper_llm_serve import PyWhiperCppServe
12
  from .vad import VoiceActivityDetector
13
+ from urllib.parse import urlparse, parse_qsl
14
  from websockets.exceptions import ConnectionClosed
15
  from websockets.sync.server import serve
16
 
 
227
 
228
  client.add_frames(frame_np)
229
  return True
230
+
231
+ def set_lang(self, websocket, src_lang, dst_lang):
232
+ client = self.client_manager.get_client(websocket)
233
+ if isinstance(client, PyWhiperCppServe):
234
+ client.set_lang(src_lang, dst_lang)
235
 
236
  def recv_audio(self,
237
  websocket,
 
240
  self.backend = backend
241
  if not self.handle_new_connection(websocket):
242
  return
243
+ query_parameters_dict = dict(parse_qsl(urlparse(websocket.request.path).query))
244
+ from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to')
245
+
246
+ if from_lang and to_lang:
247
+ self.set_lang(websocket, from_lang, to_lang)
248
+ logging.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
249
 
250
  try:
251
  while not self.client_manager.is_client_timeout(websocket):
transcribe/translator.py CHANGED
@@ -28,13 +28,12 @@ class QwenTranslator:
28
  message = self.to_message(prompt, src_lang, dst_lang)
29
  start_time = time.monotonic()
30
  output = self.llm.create_chat_completion(messages=message, temperature=0.9)
31
- logger.info(f"LLM translate cose: {time.monotonic() - start_time:.2f}s.")
32
  return output['choices'][0]['message']['content']
33
 
34
- def __call__(self, prompt, max_tokens=512,*args, **kwargs):
35
  return self.llm(
36
  prompt,
37
  *args,
38
- max_tokens=max_tokens,
39
  **kwargs
40
  )
 
28
  message = self.to_message(prompt, src_lang, dst_lang)
29
  start_time = time.monotonic()
30
  output = self.llm.create_chat_completion(messages=message, temperature=0.9)
31
+ logger.info(f"LLM inference time: {time.monotonic() - start_time:.2f}s.")
32
  return output['choices'][0]['message']['content']
33
 
34
+ def __call__(self, prompt,*args, **kwargs):
35
  return self.llm(
36
  prompt,
37
  *args,
 
38
  **kwargs
39
  )
transcribe/whisper_llm_serve.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import soundfile
4
  from concurrent.futures import ProcessPoolExecutor as Pool
 
5
  import numpy as np
6
  from logging import getLogger
7
  from difflib import SequenceMatcher
@@ -10,7 +11,7 @@ import config
10
  import time
11
  import json
12
  import threading
13
-
14
  from .server import ServeClientBase
15
  from .translator import QwenTranslator
16
  from pywhispercpp.model import Model
@@ -109,7 +110,7 @@ class PywhisperInference:
109
  llm_model = None
110
 
111
  @classmethod
112
- def initializer(cls, warmup=True):
113
  models_dir = config.MODEL_DIR.as_posix()
114
  cls.whisper_model = Model(
115
  model=config.WHISPER_MODEL,
@@ -123,6 +124,7 @@ class PywhisperInference:
123
 
124
  # init llamacpp
125
  cls.llm_model = QwenTranslator(config.LLM_MODEL_PATH, config.LLM_SYS_PROMPT)
 
126
 
127
  @classmethod
128
  def warmup(cls, warmup_steps=1):
@@ -170,20 +172,30 @@ class PyWhiperCppServe(ServeClientBase):
170
  self.frames_np = None
171
  self.sample_rate = 16000
172
 
 
173
  self._pool = Pool(
174
- max_workers=1, initializer=PywhisperInference.initializer)
175
 
176
  logger.info('Create a process to process audio.')
177
  self.trans_thread = threading.Thread(target=self.speech_to_text)
178
  self.trans_thread.daemon = True
179
  self.trans_thread.start()
 
180
 
 
 
 
 
181
  self.websocket.send(json.dumps({
182
  "uid": self.client_uid,
183
  "message": self.SERVER_READY,
184
  "backend": "pywhispercpp"
185
  }))
186
 
 
 
 
 
187
  def add_frames(self, frame_np):
188
  with self.lock:
189
  if self.frames_np is None:
@@ -206,7 +218,6 @@ class PyWhiperCppServe(ServeClientBase):
206
  transcribe_fut = self._pool.submit(
207
  PywhisperInference.inference, audio_buffer.tobytes(), self.language)
208
  segments = transcribe_fut.result()
209
-
210
  return segments
211
 
212
  def translate_text(self, text):
@@ -215,8 +226,6 @@ class PyWhiperCppServe(ServeClientBase):
215
  translate_fut = self._pool.submit(
216
  PywhisperInference.translate, text, self.language, self.dst_lang)
217
  return translate_fut.result()
218
-
219
-
220
 
221
  def _segments_split(self, segments, audio_buffer: np.ndarray):
222
  """根据左边第一个标点符号来将序列拆分成 观察段 和 剩余部分"""
@@ -278,15 +287,14 @@ class PyWhiperCppServe(ServeClientBase):
278
  # logger.info(f"[pywhispercpp:] Processing audio with duration: {len(audio_buffer)}")
279
  # segments = self.transcribe_audio(audio_buffer)
280
  try:
281
- logger.info(f"[pywhispercpp:] Processing audio with duration: {len(audio_buffer)}")
282
  segments = self.transcribe_audio(audio_buffer)
 
 
283
  except KeyboardInterrupt:
284
  break
285
  except Exception as e:
286
- logger.error(f"[ERROR]: {e}")
287
- else:
288
- for item in self.handle_transcription_output(segments, audio_buffer):
289
- print(item)
290
 
291
 
292
 
@@ -306,7 +314,8 @@ class PyWhiperCppServe(ServeClientBase):
306
  message = self._segment_manager.segment
307
  seg_id = self._segment_manager.get_seg_id() - 1
308
  yield (seg_id, message, self.translate_text(message))
309
- yield (seg_id + 1, self._segment_manager.string, self.translate_text(self._segment_manager.string))
 
310
 
311
  else:
312
  seg_id = self._segment_manager.get_seg_id()
@@ -326,7 +335,7 @@ class PyWhiperCppServe(ServeClientBase):
326
  json.dumps(content)
327
  )
328
  except Exception as e:
329
- logger.error(f"[ERROR]: Sending data to client: {e}")
330
 
331
  def get_audio_chunk_for_processing(self):
332
  if self.frames_np.shape[0] >= self.sample_rate * 1:
 
2
 
3
  import soundfile
4
  from concurrent.futures import ProcessPoolExecutor as Pool
5
+ import multiprocessing as mp
6
  import numpy as np
7
  from logging import getLogger
8
  from difflib import SequenceMatcher
 
11
  import time
12
  import json
13
  import threading
14
+ from functools import partial
15
  from .server import ServeClientBase
16
  from .translator import QwenTranslator
17
  from pywhispercpp.model import Model
 
110
  llm_model = None
111
 
112
  @classmethod
113
+ def initializer(cls, event:mp.Event, warmup=True):
114
  models_dir = config.MODEL_DIR.as_posix()
115
  cls.whisper_model = Model(
116
  model=config.WHISPER_MODEL,
 
124
 
125
  # init llamacpp
126
  cls.llm_model = QwenTranslator(config.LLM_MODEL_PATH, config.LLM_SYS_PROMPT)
127
+ event.set()
128
 
129
  @classmethod
130
  def warmup(cls, warmup_steps=1):
 
172
  self.frames_np = None
173
  self.sample_rate = 16000
174
 
175
+ self._ready_state = mp.Event()
176
  self._pool = Pool(
177
+ max_workers=1, initializer=partial(PywhisperInference.initializer, event=self._ready_state))
178
 
179
  logger.info('Create a process to process audio.')
180
  self.trans_thread = threading.Thread(target=self.speech_to_text)
181
  self.trans_thread.daemon = True
182
  self.trans_thread.start()
183
+ self.send_ready_state()
184
 
185
+ def send_ready_state(self):
186
+ while not self._ready_state:
187
+ time.sleep(0.1)
188
+
189
  self.websocket.send(json.dumps({
190
  "uid": self.client_uid,
191
  "message": self.SERVER_READY,
192
  "backend": "pywhispercpp"
193
  }))
194
 
195
+ def set_lang(self, src_lang, dst_lang):
196
+ self.language = src_lang
197
+ self.dst_lang = dst_lang
198
+
199
  def add_frames(self, frame_np):
200
  with self.lock:
201
  if self.frames_np is None:
 
218
  transcribe_fut = self._pool.submit(
219
  PywhisperInference.inference, audio_buffer.tobytes(), self.language)
220
  segments = transcribe_fut.result()
 
221
  return segments
222
 
223
  def translate_text(self, text):
 
226
  translate_fut = self._pool.submit(
227
  PywhisperInference.translate, text, self.language, self.dst_lang)
228
  return translate_fut.result()
 
 
229
 
230
  def _segments_split(self, segments, audio_buffer: np.ndarray):
231
  """根据左边第一个标点符号来将序列拆分成 观察段 和 剩余部分"""
 
287
  # logger.info(f"[pywhispercpp:] Processing audio with duration: {len(audio_buffer)}")
288
  # segments = self.transcribe_audio(audio_buffer)
289
  try:
290
+ logger.info(f"Processing audio with duration: {len(audio_buffer) / self.sample_rate:.2f}s")
291
  segments = self.transcribe_audio(audio_buffer)
292
+ for item in self.handle_transcription_output(segments, audio_buffer):
293
+ print(item)
294
  except KeyboardInterrupt:
295
  break
296
  except Exception as e:
297
+ logger.error(f"{e}")
 
 
 
298
 
299
 
300
 
 
314
  message = self._segment_manager.segment
315
  seg_id = self._segment_manager.get_seg_id() - 1
316
  yield (seg_id, message, self.translate_text(message))
317
+ if self._segment_manager.string.strip():
318
+ yield (seg_id + 1, self._segment_manager.string, self.translate_text(self._segment_manager.string))
319
 
320
  else:
321
  seg_id = self._segment_manager.get_seg_id()
 
335
  json.dumps(content)
336
  )
337
  except Exception as e:
338
+ logger.error(f"Sending data to client: {e}")
339
 
340
  def get_audio_chunk_for_processing(self):
341
  if self.frames_np.shape[0] >= self.sample_rate * 1: