david commited on
Commit
0c38083
·
1 Parent(s): dfb349e

update code for readability

Browse files
Files changed (4) hide show
  1. config.py +4 -0
  2. main.py +2 -2
  3. transcribe/strategy.py +246 -119
  4. transcribe/whisper_llm_serve.py +243 -195
config.py CHANGED
@@ -17,6 +17,10 @@ ASSERT_DIR = BASE_DIR / "assets"
17
  # 标点
18
  SENTENCE_END_MARKERS = ['.', '!', '?', '。', '!', '?', ';', ';', ':', ':']
19
  PAUSE_END_MARKERS = [',', ',', '、']
 
 
 
 
20
 
21
  sentence_end_chars = ''.join([re.escape(char) for char in SENTENCE_END_MARKERS])
22
  SENTENCE_END_PATTERN = re.compile(f'[{sentence_end_chars}]')
 
17
  # 标点
18
  SENTENCE_END_MARKERS = ['.', '!', '?', '。', '!', '?', ';', ';', ':', ':']
19
  PAUSE_END_MARKERS = [',', ',', '、']
20
+ # 合并所有标点
21
+ ALL_MARKERS = SENTENCE_END_MARKERS + PAUSE_END_MARKERS
22
+ # 构造正则表达式字符类
23
+ REGEX_MARKERS = re.compile(r'[' + re.escape(''.join(ALL_MARKERS)) + r']')
24
 
25
  sentence_end_chars = ''.join([re.escape(char) for char in SENTENCE_END_MARKERS])
26
  SENTENCE_END_PATTERN = re.compile(f'[{sentence_end_chars}]')
main.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
2
  from urllib.parse import urlparse, parse_qsl
3
- from transcribe.whisper_llm_serve import PyWhiperCppServe
4
  from uuid import uuid1
5
  from logging import getLogger
6
  import numpy as np
@@ -57,7 +57,7 @@ async def root():
57
  async def translate(websocket: WebSocket):
58
  query_parameters_dict = websocket.query_params
59
  from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to')
60
- client = PyWhiperCppServe(
61
  websocket,
62
  pipe,
63
  language="en",
 
1
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
2
  from urllib.parse import urlparse, parse_qsl
3
+ from transcribe.whisper_llm_serve import WhisperTranscriptionService
4
  from uuid import uuid1
5
  from logging import getLogger
6
  import numpy as np
 
57
  async def translate(websocket: WebSocket):
58
  query_parameters_dict = websocket.query_params
59
  from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to')
60
+ client = WhisperTranscriptionService(
61
  websocket,
62
  pipe,
63
  language="en",
transcribe/strategy.py CHANGED
@@ -1,153 +1,280 @@
1
-
2
-
3
- from logging import getLogger
4
- from difflib import SequenceMatcher
5
  import collections
6
- import config
7
- import numpy as np
8
  from itertools import chain
 
 
 
 
9
 
10
- logger = getLogger("Stragegy")
11
 
12
- class TripleTextBuffer:
13
- def __init__(self, size=2):
14
- self.history = collections.deque(maxlen=size)
15
 
16
- def add_entry(self, text, index):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  """
18
- text: 文本
19
- index: 当前buffer的相对下标 数组索引
 
 
 
20
  """
21
  self.history.append((text, index))
22
 
23
-
24
- def get_final_index(self, similarity_threshold=0.7):
25
- """根据文本变化,返回可靠的标点的buffer的位置下标"""
 
 
 
 
 
 
 
26
  if len(self.history) < 2:
27
  return None
28
 
29
- # 获取三次的文本
30
  text1, _ = self.history[0]
31
  text2, idx2 = self.history[1]
32
- # text3, idx3 = self.history[2]
33
 
34
- # 计算变化程度
35
- sim_12 = self.text_similarity(text1, text2)
36
- # print("比较: ", text1, text2," => ", sim_12)
37
- # sim_23 = self.text_similarity(text2, text3)
38
- if sim_12 >= similarity_threshold:
39
  self.history.clear()
40
  return idx2
41
  return None
42
 
43
  @staticmethod
44
- def text_similarity(text1, text2):
 
45
  return SequenceMatcher(None, text1, text2).ratio()
46
 
47
 
48
 
49
- class SegmentManager:
50
- def __init__(self) -> None:
51
- self._commited_segments = [] # 确定后的段落
52
- self._commited_short_sentences = [] # 确定后的序列
53
- self._temp_string = "" # 存储当前临时的文本字符串,直到以句号结尾
54
-
55
- def handle(self, string):
56
- self._temp_string = string
57
- return self
58
 
 
 
 
 
 
 
 
59
  @property
60
- def short_sentence(self) -> str:
61
- return "".join(self._commited_short_sentences)
62
-
 
63
  @property
64
- def segment(self):
65
- return self._commited_segments[-1] if len(self._commited_segments) > 0 else ""
66
-
67
- def get_seg_id(self):
68
- return len(self._commited_segments)
69
-
70
  @property
71
- def string(self):
72
- return self._temp_string
73
-
74
-
75
- def commit_short_sentence(self):
76
- """将临时字符串 提交到临时短句"""
77
- self._commited_short_sentences.append(self._temp_string)
78
- self._temp_string = ""
79
-
80
- def commit_segment(self):
81
- """将短句 合并 到长句中"""
82
- self._commited_segments.append(self.short_sentence)
83
- self._commited_short_sentences = []
84
-
85
- def commit(self, is_end_sentence=False):
 
 
 
 
 
 
86
  """
87
- 当需要切掉的音频部分的时候,将句子提交到短句队列中,并移除临时字符串
88
- 当完成一个整句的时候提交到段落中
 
 
89
  """
90
- self.commit_short_sentence()
91
- if is_end_sentence:
92
- self.commit_segment()
93
-
94
- def segement_merge(segments):
95
- """根据标点符号分整句"""
96
- sequences = []
97
- temp_seq = []
98
-
99
- for seg in segments:
100
- temp_seq.append(seg)
101
- if any([mk in seg.text for mk in config.SENTENCE_END_MARKERS]):
102
- sequences.append(temp_seq.copy())
103
- temp_seq = []
104
- if temp_seq:
105
- sequences.append(temp_seq)
106
- return sequences
107
-
108
- def segments_split(segments, audio_buffer: np.ndarray, sample_rate=16000):
109
- """根据左边第一个标点符号来将序列拆分成 观察段 剩余部分"""
110
- left_watch_sequences = []
111
- left_watch_idx = 0
112
- right_watch_sequences = []
113
- is_end = False
114
-
115
- if (len(audio_buffer) / sample_rate) < 12:
116
- # 低于12s 使用短句符号比如逗号作为判断依据
117
- markers = config.PAUSE_END_MARKERS + config.SENTENCE_END_MARKERS
118
- is_end = False
119
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  for idx, seg in enumerate(segments):
121
- # print('>>>>>>>>>>>>>>>> seg : ', seg)
122
- left_watch_sequences.append(seg)
123
  if seg.text and seg.text[-1] in markers:
124
- seg_index = int(seg.t1 / 100 * sample_rate)
125
- # rest_buffer_duration = (len(audio_buffer) - seg_index) / sample_rate
126
- is_end = config.SENTENCE_END_PATTERN.search(seg.text)
 
127
 
128
- right_watch_sequences = segments[min(idx+1, len(segments)):]
129
- # if rest_buffer_duration >= 1.5:
130
- left_watch_idx = seg_index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  break
132
- return left_watch_idx, left_watch_sequences, right_watch_sequences, is_end
133
-
134
-
135
- def sequences_split(segments, audio_buffer: np.ndarray, sample_rate=16000):
136
- # 长句 保留最后两句即可
137
- left_watch_sequences = []
138
- right_watch_sequences = []
139
- left_watch_idx = 0
140
- is_end = False
141
- sequences = segement_merge(segments)
142
-
143
- if len(sequences) > 2:
144
- logger.info(f"buffer clip via sequence, current length: {len(sequences)}")
145
- is_end = True
146
- left_watch_sequences = chain(*sequences[:-2])
147
- right_watch_sequences = chain(*sequences[-2:])
148
- last_sequence_segment = sequences[-3]
149
- last_segment = last_sequence_segment[-1]
150
- left_watch_idx = int(last_segment.t1 / 100 * sample_rate)
151
- return left_watch_idx, left_watch_sequences, right_watch_sequences, is_end
152
-
153
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
 
 
 
2
  import collections
3
+ import logging
4
+ from difflib import SequenceMatcher
5
  from itertools import chain
6
+ from dataclasses import dataclass
7
+ from typing import List, Tuple, Optional, Deque, Any, Iterator
8
+ from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS
9
+ import numpy as np
10
 
11
+ logger = logging.getLogger("TranscriptionStrategy")
12
 
 
 
 
13
 
14
+ @dataclass
15
+ class TranscriptSegment:
16
+ """表示一个转录片段,包含文本和时间信息"""
17
+ text: str
18
+ t0: float # 开始时间(百分之一秒)
19
+ t1: float # 结束时间(百分之一秒)
20
+
21
+
22
+ class TextStabilityBuffer:
23
+ """
24
+ 通过比较连续文本样本的相似度来确定转录文本的稳定性。
25
+ 当连续样本的相似度超过阈值时,认为文本已稳定。
26
+ """
27
+ def __init__(self, max_history: int = 2):
28
+ self.history: Deque[Tuple[str, int]] = collections.deque(maxlen=max_history)
29
+
30
+ def add_entry(self, text: str, index: int) -> None:
31
  """
32
+ 添加新的文本和索引到历史记录中
33
+
34
+ Args:
35
+ text: 文本内容
36
+ index: 当前buffer的相对下标
37
  """
38
  self.history.append((text, index))
39
 
40
+ def get_stable_index(self, similarity_threshold: float = 0.7) -> Optional[int]:
41
+ """
42
+ 根据文本相似度,判断文本是否稳定,返回稳定文本的索引
43
+
44
+ Args:
45
+ similarity_threshold: 相似度阈值,超过此值认为文本稳定
46
+
47
+ Returns:
48
+ 稳定文本的索引,如果没有找到稳定文本则返回None
49
+ """
50
  if len(self.history) < 2:
51
  return None
52
 
 
53
  text1, _ = self.history[0]
54
  text2, idx2 = self.history[1]
 
55
 
56
+ similarity = self._calculate_similarity(text1, text2)
57
+
58
+ if similarity >= similarity_threshold:
 
 
59
  self.history.clear()
60
  return idx2
61
  return None
62
 
63
  @staticmethod
64
+ def _calculate_similarity(text1: str, text2: str) -> float:
65
+ """计算两段文本的相似度"""
66
  return SequenceMatcher(None, text1, text2).ratio()
67
 
68
 
69
 
70
+ class TranscriptionManager:
71
+ """
72
+ 管理转录文本的分级结构:临时字符串 -> 短句 -> 完整段落
 
 
 
 
 
 
73
 
74
+ |-- 已确认文本 --|-- 观察窗口 --|-- 新输入 --|
75
+ """
76
+ def __init__(self):
77
+ self._committed_segments: List[str] = [] # 确认的完整段落
78
+ self._committed_sentences: List[str] = [] # 确认的短句
79
+ self._temp_string: str = "" # 临时字符串缓冲
80
+
81
  @property
82
+ def current_sentence(self) -> str:
83
+ """当前已确认的短句组合"""
84
+ return "".join(self._committed_sentences)
85
+
86
  @property
87
+ def latest_segment(self) -> str:
88
+ """最新确认的完整段落"""
89
+ return self._committed_segments[-1] if self._committed_segments else ""
90
+
 
 
91
  @property
92
+ def segment_count(self) -> int:
93
+ """已确认的段落数量"""
94
+ return len(self._committed_segments)
95
+
96
+ @property
97
+ def sentence_length(self) -> int:
98
+ """当前短句的总字符长度"""
99
+ return sum(len(s) for s in self._committed_sentences)
100
+
101
+ def update_temp(self, text: str) -> 'TranscriptionManager':
102
+ """更新临时字符串"""
103
+ self._temp_string = text
104
+ return self
105
+
106
+ def commit_sentence(self) -> None:
107
+ """将临时字符串提交到短句列表"""
108
+ if self._temp_string:
109
+ self._committed_sentences.append(self._temp_string)
110
+ self._temp_string = ""
111
+
112
+ def commit_segment(self, is_end_of_sentence: bool = False) -> None:
113
  """
114
+ 提交当前内容到适当的层级
115
+
116
+ Args:
117
+ is_end_of_sentence: 是否为完整句子的结束
118
  """
119
+ self.commit_sentence()
120
+ if is_end_of_sentence and self._committed_sentences:
121
+ self._committed_segments.append(self.current_sentence)
122
+ self._committed_sentences = []
123
+
124
+ def get_all_text(self) -> str:
125
+ """获取所有已提交的文本"""
126
+ all_segments = self._committed_segments.copy()
127
+ if self.current_sentence:
128
+ all_segments.append(self.current_sentence)
129
+ if self._temp_string:
130
+ all_segments.append(self._temp_string)
131
+ return "\n".join(all_segments)
132
+
133
+
134
+ class TranscriptionSplitter:
135
+ """负责根据语音和文本特征拆分转录片段"""
136
+
137
+ @staticmethod
138
+ def group_by_sentences(segments: List[TranscriptSegment]) -> List[List[TranscriptSegment]]:
139
+ """将片段按照完整句子分组"""
140
+ sequences = []
141
+ temp_seq = []
142
+
143
+ for seg in segments:
144
+ temp_seq.append(seg)
145
+ if any(marker in seg.text for marker in SENTENCE_END_MARKERS):
146
+ sequences.append(temp_seq.copy())
147
+ temp_seq = []
148
+
149
+ if temp_seq:
150
+ sequences.append(temp_seq)
151
+ return sequences
152
+
153
+ @staticmethod
154
+ def split_by_punctuation(
155
+ segments: List[TranscriptSegment],
156
+ audio_buffer: np.ndarray,
157
+ sample_rate: int = 16000
158
+ ) -> Tuple[int, List[TranscriptSegment], List[TranscriptSegment], bool]:
159
+ """
160
+ 根据标点符号将片段分为左侧(已确认)和右侧(待确认)
161
+
162
+ Returns:
163
+ (分割索引, 左侧片段, 右侧片段, 是否为句子结束)
164
+ """
165
+ left_segments = []
166
+ right_segments = []
167
+ split_index = 0
168
+ is_sentence_end = False
169
+
170
+ # 短音频使用所有标点符号作为分割依据
171
+ buffer_duration = len(audio_buffer) / sample_rate
172
+ markers = ALL_MARKERS if buffer_duration < 12 else SENTENCE_END_MARKERS
173
+
174
  for idx, seg in enumerate(segments):
175
+ left_segments.append(seg)
 
176
  if seg.text and seg.text[-1] in markers:
177
+ split_index = int(seg.t1 / 100 * sample_rate)
178
+ is_sentence_end = bool(SENTENCE_END_PATTERN.search(seg.text))
179
+ right_segments = segments[min(idx+1, len(segments)):]
180
+ break
181
 
182
+ return split_index, left_segments, right_segments, is_sentence_end
183
+
184
+ @staticmethod
185
+ def split_by_sequences(
186
+ segments: List[TranscriptSegment],
187
+ audio_buffer: np.ndarray,
188
+ sample_rate: int = 16000
189
+ ) -> Tuple[int, Iterator[TranscriptSegment], Iterator[TranscriptSegment], bool]:
190
+ """
191
+ 对于长文本,按照句子组保留最新的两句
192
+
193
+ Returns:
194
+ (分割索引, 左侧片段, 右侧片段, 是否为句子结束)
195
+ """
196
+ sequences = TranscriptionSplitter.group_by_sentences(segments)
197
+
198
+ if len(sequences) > 2:
199
+ logger.info(f"Buffer clip via sequence, current length: {len(sequences)}")
200
+ left_segments = chain(*sequences[:-2])
201
+ right_segments = chain(*sequences[-2:])
202
+
203
+ # 确定切分点
204
+ last_sequence = sequences[-3]
205
+ last_segment = last_sequence[-1]
206
+ split_index = int(last_segment.t1 / 100 * sample_rate)
207
+
208
+ return split_index, left_segments, right_segments, True
209
+
210
+ return 0, iter([]), iter(segments), False
211
+
212
+
213
+ class TranscriptionStabilizer:
214
+ """
215
+ 转录结果稳定器,负责确认和管理转录片段
216
+ """
217
+ def __init__(self, sample_rate: int = 16000):
218
+ self.manager = TranscriptionManager()
219
+ self.stability_buffer = TextStabilityBuffer(max_history=2)
220
+ self.sample_rate = sample_rate
221
+
222
+ def process_segments(self, segments: List[TranscriptSegment]) -> Tuple[Optional[int], bool]:
223
+ """
224
+ 处理转录片段,确认稳定的文本
225
+
226
+ Args:
227
+ segments: 转录片段列表
228
+
229
+ Returns:
230
+ (音频分割点索引, 是否达到足够长度需要换行)
231
+ """
232
+ # 查找第一个包含标点的片段作为分割点
233
+ split_index = None
234
+ stable_segments = []
235
+
236
+ for idx, seg in enumerate(segments):
237
+ stable_segments.append(seg)
238
+ if REGEX_MARKERS.search(seg.text):
239
+ split_index = int(seg.t1 / 100 * self.sample_rate)
240
+ stable_idx = min(idx + 1, len(segments))
241
  break
242
+
243
+ if split_index: # 找到标点,确认标点前的内容
244
+ stable_text = self._join_segment_text(segments[:stable_idx])
245
+ self.manager.update_temp(stable_text).commit_sentence()
246
+
247
+ # 更新剩余文本
248
+ remaining_text = self._join_segment_text(segments[stable_idx:])
249
+ self.manager.update_temp(remaining_text)
250
+ else:
251
+ # 没有找到标点,全部作为临时文本
252
+ self.manager.update_temp(self._join_segment_text(segments))
253
+
254
+ # 检查是否达到换行标准
255
+ should_linebreak = self.manager.sentence_length >= 20
256
+
257
+ return split_index, should_linebreak
258
+
259
+ def check_stability(self, text: str, index: int) -> Optional[int]:
260
+ """
261
+ 检查文本是否稳定
262
+
263
+ Args:
264
+ text: 当前文本
265
+ index: 当前索引
266
+
267
+ Returns:
268
+ 如果文本稳定,返回稳定的索引;否则返回None
269
+ """
270
+ self.stability_buffer.add_entry(text, index)
271
+ return self.stability_buffer.get_stable_index()
272
+
273
+ def commit_segment(self, is_end_of_sentence: bool) -> None:
274
+ """提交转录片段"""
275
+ self.manager.commit_segment(is_end_of_sentence)
276
+
277
+ @staticmethod
278
+ def _join_segment_text(segments: List[TranscriptSegment], separator: str = "") -> str:
279
+ """连接多个片段的文本"""
280
+ return separator.join(seg.text for seg in segments)
transcribe/whisper_llm_serve.py CHANGED
@@ -1,261 +1,309 @@
1
-
2
-
3
- import numpy as np
4
- from logging import getLogger
5
  import asyncio
6
- from .utils import save_to_wave
7
- import time
8
  import json
9
- import threading
10
- from .server import ServeClientBase
11
  import queue
12
- import collections
 
 
 
 
 
 
13
  from api_model import TransResult, Message
14
- from .utils import log_block
 
15
  from .translatepipes import TranslatePipes
16
- from .strategy import TripleTextBuffer, SegmentManager, segments_split, sequences_split
17
-
18
- logger = getLogger("TranslatorApp")
19
 
 
20
 
21
 
22
- class PyWhiperCppServe(ServeClientBase):
23
-
24
- def __init__(self, websocket, pipe:TranslatePipes,language=None, dst_lang=None, client_uid=None,):
 
 
 
25
  super().__init__(client_uid, websocket)
26
- self.language = language
27
- self.dst_lang = dst_lang # 目标翻译语言
28
- # 设置观察字符串 对比上下次的文字来判断字符串的输出是否固定
29
- self._text_buffer = TripleTextBuffer()
30
- # 存储转录数据
31
- self._segment_manager = SegmentManager()
32
- self._translate_pipes = pipe
33
- self.lock = threading.Lock()
 
 
34
  self.frames_np = None
 
35
  self._frame_queue = queue.Queue()
36
- self.sample_rate = 16000
37
-
 
 
 
38
  self.send_ready_state()
 
 
39
  self._translate_thread_stop = threading.Event()
40
- self._frame_to_queue_thread_stop = threading.Event()
41
- self.translate_thread = self.run_in_thread(self.speech_to_text)
42
- self.frame_to_queue_thread = self.run_in_thread(self.get_frame_from_queue)
43
 
44
- self.text_sep = ""
 
 
 
 
 
45
 
46
- def run_in_thread(self, func):
47
- t = threading.Thread(target=func)
48
- t.daemon = True
49
- t.start()
50
- return t
51
 
52
- def send_ready_state(self):
 
53
  self.websocket.send(json.dumps({
54
  "uid": self.client_uid,
55
  "message": self.SERVER_READY,
56
- "backend": "pywhispercpp"
57
  }))
58
 
59
- def set_lang(self, src_lang, dst_lang):
60
- self.language = src_lang
61
- self.dst_lang = dst_lang
62
- self.text_sep = "" if self.language == "zh" else " "
 
63
 
64
- def add_frames(self, frame_np):
 
65
  self._frame_queue.put(frame_np)
66
 
67
- def vad_merge(self):
68
- with self.lock:
69
- frame = self.frames_np.copy()
70
- item = self._translate_pipes.voice_detect(frame.tobytes())
71
- frame_np = np.frombuffer(item.audio, dtype=np.float32)
72
- self.frames_np = frame_np.copy()
73
-
74
-
75
- def get_frame_from_queue(self,):
76
- while not self._frame_to_queue_thread_stop.is_set():
77
  try:
78
  frame_np = self._frame_queue.get(timeout=0.1)
79
  with self.lock:
80
  if self.frames_np is None:
81
  self.frames_np = frame_np.copy()
82
  else:
83
- self.frames_np = np.append(self.frames_np,frame_np)
84
  except queue.Empty:
85
  pass
86
-
87
 
88
- def update_audio_buffer(self, last_offset):
 
89
  with self.lock:
90
- self.frames_np = self.frames_np[last_offset:]
 
 
 
91
 
92
- def transcribe_audio(self, audio_buffer):
93
- """
94
- Transcribe the audio chunk and send the results to the client.
 
 
95
 
96
- Args:
97
- audio_buffer (np.array): The audio chunk to transcribe.
98
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
 
 
100
  log_block("Audio buffer length", f"{audio_buffer.shape[0]/self.sample_rate:.2f}", "s")
101
  start_time = time.perf_counter()
102
 
103
- item = self._translate_pipes.transcrible(audio_buffer.tobytes(), self.language)
104
- segments = item.segments
105
- log_block("Whisper transcrible out", f"{''.join(seg.text for seg in segments)}", "")
106
- log_block("Whisper transcrible time", f"{(time.perf_counter() - start_time):.3f}", "s")
107
-
 
108
  return segments
109
-
110
- def translate_text(self, text):
111
- """
112
- translate the text to dst lang"""
113
- # return "sample english"
114
- log_block("LLM translate input", f"{text}")
 
115
  start_time = time.perf_counter()
116
- ret = self._translate_pipes.translate(text, self.language, self.dst_lang)
117
- translated_text = ret.translate_content
118
- log_block("LLM translate time", f"{(time.perf_counter() - start_time):.3f}", "s")
119
- log_block("LLM translate out", f"{translated_text}")
 
 
 
120
  return translated_text
121
-
122
 
123
-
124
- def analysis_segments(self, segments, audio_buffer: np.ndarray):
125
- # 找到第一个标点符号作为锚点 左边为确认段,右边为观察段,
126
- # 当左边确认后,右边段才会进入观察
127
- # 当左边确认后,会从缓冲区中删除对应的buffer,减少下次输入的数据量
128
- left_watch_idx, left_watch_sequences, right_watch_sequences, is_end_sentence = segments_split(segments, audio_buffer)
129
- left_watch_string = self.text_sep.join(i.text for i in left_watch_sequences)
130
- right_watch_string = self.text_sep.join(i.text for i in right_watch_sequences)
 
 
 
 
 
 
131
 
132
- if left_watch_idx != 0:
133
- # 将观察字符串临时存储
134
- self._text_buffer.add_entry(left_watch_string, left_watch_idx)
135
- audio_cut_index = self._text_buffer.get_final_index()
136
- if audio_cut_index:
137
- return audio_cut_index, left_watch_string, right_watch_string, is_end_sentence
138
 
139
- # 整句消除 后两句之前的内容
140
- left_watch_idx, left_watch_sequences, right_watch_sequences, is_end_sentence = sequences_split(segments, audio_buffer)
141
- left_watch_string = self.text_sep.join(i.text for i in left_watch_sequences)
142
- right_watch_string = self.text_sep.join(i.text for i in right_watch_sequences)
143
- if left_watch_idx != 0:
144
- return left_watch_idx, left_watch_string, right_watch_string, is_end_sentence
145
 
146
- return None, left_watch_string, right_watch_string, is_end_sentence
 
 
 
 
 
 
147
 
148
- def speech_to_text(self):
149
- c = 0
150
  while not self._translate_thread_stop.is_set():
151
  if self.exit:
152
- logger.info("Exiting speech to text thread")
153
  break
154
-
155
- if self.frames_np is None :
156
- time.sleep(0.02) # wait for any audio to arrive
157
- logger.info("waiting for client data...")
 
158
  continue
159
-
160
- audio_buffer = self.get_audio_chunk_for_processing()
 
161
  if audio_buffer is None:
162
- time.sleep(0.02) # wait for any audio to arrive
163
  continue
164
-
165
- # c+= 1
166
- # name = f"dev-{c}.wav"
167
- # save_to_wave(name, audio_buffer)
168
- # try:
169
- logger.info(f"Audio buffer length: {len(audio_buffer) / self.sample_rate:.2f}s")
170
- segments = self.transcribe_audio(audio_buffer)
171
- for tran_result in self.handle_transcription_output(segments, audio_buffer):
172
- self.send_to_client(tran_result)
173
- # except KeyboardInterrupt:
174
- # break
175
- # except Exception as e:
176
- # logger.error(f"{e}")
177
- # if (time_delay := (1 - audio_duration)) > 0:
178
- # time.sleep(time_delay)
179
-
180
- def handle_transcription_output(self, segments, audio_buffer):
181
- texts = self.text_sep.join(i.text for i in segments)
182
- if not len(texts):
 
 
 
183
  return
184
- self._segment_manager.handle(texts)
185
- # 分析句子
186
- last_cut_index, left_string, right_string, is_end_sentence = self.analysis_segments(segments, audio_buffer)
187
- # print(last_cut_index, left_string, right_string, is_end_sentence)
188
- if last_cut_index:
189
- self.update_audio_buffer(last_cut_index)
190
- # 句子或者短句的提交
191
- log_block("Whisper string lock ", f"{left_string}",)
192
- self._segment_manager.handle(left_string).commit(is_end_sentence)
193
- self._segment_manager.handle(right_string)
194
 
195
- if is_end_sentence and last_cut_index:
196
- message = self._segment_manager.segment
197
- seg_id = self._segment_manager.get_seg_id() - 1
198
- # logger.info(f"{seg_id}, {message}")
199
- yield TransResult(
200
- seg_id=seg_id,
201
- context=message,
202
- from_=self.language,
203
- to=self.dst_lang,
204
- tran_content=self.translate_text(message),
205
- partial=False
206
- )
207
- if self._segment_manager.string.strip():
208
- message = self._segment_manager.string.strip()
209
- # logger.info(f"{seg_id + 1}, {message}")
210
- yield TransResult(
211
- seg_id=seg_id+1,
212
- context=self._segment_manager.string,
213
- from_=self.language,
214
- to=self.dst_lang,
215
- tran_content=self.translate_text(message),
216
- )
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  else:
219
- seg_id = self._segment_manager.get_seg_id()
220
- message = self._segment_manager.short_sentence + self._segment_manager.string
221
- # logger.info(f"{seg_id}, {message}")
 
222
  yield TransResult(
223
- seg_id=seg_id,
224
- context=message,
225
- from_=self.language,
226
- to=self.dst_lang,
227
- tran_content=self.translate_text(message),
 
228
  )
229
-
230
- def send_to_client(self, data:TransResult):
 
231
  try:
232
- coro = self.websocket.send_text(
233
- Message(result=data, request_id=self.client_uid).model_dump_json(by_alias=True)
234
- )
235
  asyncio.run(coro)
236
- except RuntimeError as e:
237
  self.stop()
238
- return
239
  except Exception as e:
240
- logger.error(e)
241
-
242
-
243
-
244
- def get_audio_chunk_for_processing(self):
245
- self.vad_merge()
246
- silence_audio = np.zeros((self.sample_rate+1000,), dtype=np.float32)
247
- frames = self.frames_np.copy()
248
- # 添加对非常短音频的处理
249
- if len(frames) <= 100:
250
- # 对于极短的音频段(<=100帧),直接返回空音频
251
- self.update_audio_buffer(len(frames))
252
- return None
253
- elif len(frames) < self.sample_rate:
254
- silence_audio[-len(frames):] = frames
255
- return silence_audio.copy()
256
- return frames.copy()
257
-
258
 
259
- def stop(self):
 
260
  self._translate_thread_stop.set()
261
- self._frame_to_queue_thread_stop.set()
 
 
 
 
 
 
1
  import asyncio
 
 
2
  import json
 
 
3
  import queue
4
+ import threading
5
+ import time
6
+ from logging import getLogger
7
+ from typing import List, Optional, Iterator, Tuple, Any
8
+
9
+ import numpy as np
10
+
11
  from api_model import TransResult, Message
12
+ from .server import ServeClientBase
13
+ from .utils import log_block, save_to_wave
14
  from .translatepipes import TranslatePipes
15
+ from .strategy import TextStabilityBuffer, TranscriptionManager, TranscriptionSplitter, TranscriptSegment
 
 
16
 
17
+ logger = getLogger("TranscriptionService")
18
 
19
 
20
+ class WhisperTranscriptionService(ServeClientBase):
21
+ """
22
+ Whisper语音转录服务类,处理音频流转录和翻译
23
+ """
24
+
25
+ def __init__(self, websocket, pipe: TranslatePipes, language=None, dst_lang=None, client_uid=None):
26
  super().__init__(client_uid, websocket)
27
+ self.source_language = language # 源语言
28
+ self.target_language = dst_lang # 目标翻译语言
29
+
30
+ # 转录结果稳定性管理
31
+ self._text_stability_buffer = TextStabilityBuffer()
32
+ self._transcription_manager = TranscriptionManager()
33
+ self._translate_pipe = pipe
34
+
35
+ # 音频处理相关
36
+ self.sample_rate = 16000
37
  self.frames_np = None
38
+ self.lock = threading.Lock()
39
  self._frame_queue = queue.Queue()
40
+
41
+ # 文本分隔符,根据语言设置
42
+ self.text_separator = self._get_text_separator(language)
43
+
44
+ # 发送就绪状态
45
  self.send_ready_state()
46
+
47
+ # 启动处理线程
48
  self._translate_thread_stop = threading.Event()
49
+ self._frame_processing_thread_stop = threading.Event()
50
+ self.translate_thread = self._start_thread(self._transcription_processing_loop)
51
+ self.frame_processing_thread = self._start_thread(self._frame_processing_loop)
52
 
53
+ def _start_thread(self, target_function) -> threading.Thread:
54
+ """启动守护线程执行指定函数"""
55
+ thread = threading.Thread(target=target_function)
56
+ thread.daemon = True
57
+ thread.start()
58
+ return thread
59
 
60
+ def _get_text_separator(self, language: str) -> str:
61
+ """根据语言返回适当的文本分隔符"""
62
+ return "" if language == "zh" else " "
 
 
63
 
64
+ def send_ready_state(self) -> None:
65
+ """发送服务就绪状态消息"""
66
  self.websocket.send(json.dumps({
67
  "uid": self.client_uid,
68
  "message": self.SERVER_READY,
69
+ "backend": "whisper_transcription"
70
  }))
71
 
72
+ def set_language(self, source_lang: str, target_lang: str) -> None:
73
+ """设置源语言和目标语言"""
74
+ self.source_language = source_lang
75
+ self.target_language = target_lang
76
+ self.text_separator = self._get_text_separator(source_lang)
77
 
78
+ def add_audio_frames(self, frame_np: np.ndarray) -> None:
79
+ """添加音频帧到处理队列"""
80
  self._frame_queue.put(frame_np)
81
 
82
+ def _frame_processing_loop(self) -> None:
83
+ """从队列获取音频帧并合并到缓冲区"""
84
+ while not self._frame_processing_thread_stop.is_set():
 
 
 
 
 
 
 
85
  try:
86
  frame_np = self._frame_queue.get(timeout=0.1)
87
  with self.lock:
88
  if self.frames_np is None:
89
  self.frames_np = frame_np.copy()
90
  else:
91
+ self.frames_np = np.append(self.frames_np, frame_np)
92
  except queue.Empty:
93
  pass
 
94
 
95
+ def _apply_voice_activity_detection(self) -> None:
96
+ """应用语音活动检测来优化音频缓冲区"""
97
  with self.lock:
98
+ if self.frames_np is not None:
99
+ frame = self.frames_np.copy()
100
+ processed_audio = self._translate_pipe.voice_detect(frame.tobytes())
101
+ self.frames_np = np.frombuffer(processed_audio.audio, dtype=np.float32).copy()
102
 
103
+ def _update_audio_buffer(self, offset: int) -> None:
104
+ """从音频缓冲区中移除已处理的部分"""
105
+ with self.lock:
106
+ if self.frames_np is not None and offset > 0:
107
+ self.frames_np = self.frames_np[offset:]
108
 
109
+ def _get_audio_for_processing(self) -> Optional[np.ndarray]:
110
+ """准备用于处理的音频块"""
111
+ # 应用VAD处理
112
+ self._apply_voice_activity_detection()
113
+
114
+ # 没有音频帧
115
+ if self.frames_np is None:
116
+ return None
117
+
118
+ frames = self.frames_np.copy()
119
+
120
+ # 音频过短时的处理
121
+ if len(frames) <= 100:
122
+ # 极短音频段,清空并返回None
123
+ self._update_audio_buffer(len(frames))
124
+ return None
125
+ elif len(frames) < self.sample_rate:
126
+ # 不足一秒的音频,补充静音
127
+ silence_audio = np.zeros((self.sample_rate + 1000,), dtype=np.float32)
128
+ silence_audio[-len(frames):] = frames
129
+ return silence_audio.copy()
130
+
131
+ return frames.copy()
132
 
133
+ def _transcribe_audio(self, audio_buffer: np.ndarray) -> List[TranscriptSegment]:
134
+ """转录音频并返回转录片段"""
135
  log_block("Audio buffer length", f"{audio_buffer.shape[0]/self.sample_rate:.2f}", "s")
136
  start_time = time.perf_counter()
137
 
138
+ result = self._translate_pipe.transcrible(audio_buffer.tobytes(), self.source_language)
139
+ segments = result.segments
140
+
141
+ log_block("Whisper transcription output", f"{''.join(seg.text for seg in segments)}", "")
142
+ log_block("Whisper transcription time", f"{(time.perf_counter() - start_time):.3f}", "s")
143
+
144
  return segments
145
+
146
+ def _translate_text(self, text: str) -> str:
147
+ """将文本翻译为目标语言"""
148
+ if not text.strip():
149
+ return ""
150
+
151
+ log_block("Translation input", f"{text}")
152
  start_time = time.perf_counter()
153
+
154
+ result = self._translate_pipe.translate(text, self.source_language, self.target_language)
155
+ translated_text = result.translate_content
156
+
157
+ log_block("Translation time", f"{(time.perf_counter() - start_time):.3f}", "s")
158
+ log_block("Translation output", f"{translated_text}")
159
+
160
  return translated_text
 
161
 
162
+ def _analyze_segments(self, segments: List[TranscriptSegment], audio_buffer: np.ndarray) -> Tuple[Optional[int], str, str, bool]:
163
+ """
164
+ 分析转录片段,确定稳定部分和需要继续观察的部分
165
+
166
+ Returns:
167
+ (分割索引, 左侧稳定文本, 右侧观察文本, 是否为句子结束)
168
+ """
169
+ # 尝试基于标点符号进行分割
170
+ left_idx, left_segments, right_segments, is_end = TranscriptionSplitter.split_by_punctuation(
171
+ segments, audio_buffer, self.sample_rate
172
+ )
173
+
174
+ left_text = self.text_separator.join(seg.text for seg in left_segments)
175
+ right_text = self.text_separator.join(seg.text for seg in right_segments)
176
 
177
+ # 如果找到分割点,检查左侧文本稳定性
178
+ if left_idx != 0:
179
+ self._text_stability_buffer.add_entry(left_text, left_idx)
180
+ stable_idx = self._text_stability_buffer.get_stable_index()
181
+ if stable_idx:
182
+ return stable_idx, left_text, right_text, is_end
183
 
184
+ # 如果基于标点的方法未找到稳定点,尝试基于句子序列的方法
185
+ left_idx, left_segments, right_segments, is_end = TranscriptionSplitter.split_by_sequences(
186
+ segments, audio_buffer, self.sample_rate
187
+ )
 
 
188
 
189
+ if left_idx != 0:
190
+ left_text = self.text_separator.join(seg.text for seg in left_segments)
191
+ right_text = self.text_separator.join(seg.text for seg in right_segments)
192
+ return left_idx, left_text, right_text, is_end
193
+
194
+ # 如果都没有找到分割点
195
+ return None, left_text, right_text, is_end
196
 
197
+ def _transcription_processing_loop(self) -> None:
198
+ """主转录处理循环"""
199
  while not self._translate_thread_stop.is_set():
200
  if self.exit:
201
+ logger.info("Exiting transcription thread")
202
  break
203
+
204
+ # 等待音频数据
205
+ if self.frames_np is None:
206
+ time.sleep(0.02)
207
+ logger.info("Waiting for audio data...")
208
  continue
209
+
210
+ # 获取音频块进行处理
211
+ audio_buffer = self._get_audio_for_processing()
212
  if audio_buffer is None:
213
+ time.sleep(0.02)
214
  continue
215
+
216
+ try:
217
+ logger.info(f"Processing audio buffer: {len(audio_buffer)/self.sample_rate:.2f}s")
218
+ segments = self._transcribe_audio(audio_buffer)
219
+
220
+ # 处理转录结果并发送到客户端
221
+ for result in self._process_transcription_results(segments, audio_buffer):
222
+ self._send_result_to_client(result)
223
+
224
+ except Exception as e:
225
+ logger.error(f"Error processing audio: {e}")
226
+
227
+ def _process_transcription_results(self, segments: List[TranscriptSegment], audio_buffer: np.ndarray) -> Iterator[TransResult]:
228
+ """
229
+ 处理转录结果,生成翻译结果
230
+
231
+ Returns:
232
+ TransResult对象的迭代器
233
+ """
234
+ # 合并所有片段的文本
235
+ full_text = self.text_separator.join(seg.text for seg in segments)
236
+ if not full_text:
237
  return
238
+
239
+ # 更新转录管理器中的临时文本
240
+ self._transcription_manager.update_temp(full_text)
 
 
 
 
 
 
 
241
 
242
+ # 分析片段,确定稳定部分和需要继续观察的部分
243
+ cut_index, stable_text, remaining_text, is_sentence_end = self._analyze_segments(segments, audio_buffer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ # 如果找到稳定的分割点
246
+ if cut_index:
247
+ # 更新音频缓冲区,移除已处理部分
248
+ self._update_audio_buffer(cut_index)
249
+
250
+ # 提交稳定的文本
251
+ log_block("Stable transcription", f"{stable_text}")
252
+ self._transcription_manager.update_temp(stable_text).commit_segment(is_sentence_end)
253
+ self._transcription_manager.update_temp(remaining_text)
254
+
255
+ # 如果是句子结束,发送完整句子的翻译结果
256
+ if is_sentence_end:
257
+ segment_text = self._transcription_manager.latest_segment
258
+ segment_id = self._transcription_manager.segment_count - 1
259
+
260
+ # 生成已确认句子的翻译结果
261
+ yield TransResult(
262
+ seg_id=segment_id,
263
+ context=segment_text,
264
+ from_=self.source_language,
265
+ to=self.target_language,
266
+ tran_content=self._translate_text(segment_text),
267
+ partial=False
268
+ )
269
+
270
+ # 如果还有剩余部分,生成临时翻译结果
271
+ if self._transcription_manager.current_sentence.strip():
272
+ yield TransResult(
273
+ seg_id=segment_id + 1,
274
+ context=self._transcription_manager.current_sentence,
275
+ from_=self.source_language,
276
+ to=self.target_language,
277
+ tran_content=self._translate_text(self._transcription_manager.current_sentence.strip()),
278
+ partial=True
279
+ )
280
  else:
281
+ # 没有找到稳定点,发送当前所有内容的临时翻译结果
282
+ segment_id = self._transcription_manager.segment_count
283
+ current_text = self._transcription_manager.current_sentence + self._transcription_manager.update_temp(remaining_text)._temp_string
284
+
285
  yield TransResult(
286
+ seg_id=segment_id,
287
+ context=current_text,
288
+ from_=self.source_language,
289
+ to=self.target_language,
290
+ tran_content=self._translate_text(current_text),
291
+ partial=True
292
  )
293
+
294
+ def _send_result_to_client(self, result: TransResult) -> None:
295
+ """发送翻译结果到客户端"""
296
  try:
297
+ message = Message(result=result, request_id=self.client_uid).model_dump_json(by_alias=True)
298
+ coro = self.websocket.send_text(message)
 
299
  asyncio.run(coro)
300
+ except RuntimeError:
301
  self.stop()
 
302
  except Exception as e:
303
+ logger.error(f"Error sending result to client: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ def stop(self) -> None:
306
+ """停止所有处理线程并清理资源"""
307
  self._translate_thread_stop.set()
308
+ self._frame_processing_thread_stop.set()
309
+ logger.info(f"Stopping transcription service for client: {self.client_uid}")