daihui.zhang commited on
Commit
99b58ae
·
1 Parent(s): b6e4de3

fix error in get seg id

Browse files
transcribe/strategy.py CHANGED
@@ -8,6 +8,7 @@ from typing import List, Tuple, Optional, Deque, Any, Iterator,Literal
8
  from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE
9
  import numpy as np
10
  from enum import Enum
 
11
  logger = logging.getLogger("TranscriptionStrategy")
12
 
13
 
@@ -83,12 +84,12 @@ class TranscriptChunk:
83
  def get_split_first_rest(self, mode: SplitMode):
84
  chunks = self.split_by(mode)
85
  fisrt_chunk = chunks[0] if chunks else self
86
- rest_chunks = chunks[1:] if chunks else []
87
  return fisrt_chunk, rest_chunks
88
 
89
  def puncation_numbers(self) -> int:
90
  """计算片段中标点符号的数量"""
91
- return sum(1 for seg in self.items if REGEX_MARKERS.search(seg.text))
92
 
93
  def length(self) -> int:
94
  """返回片段列表的长度"""
@@ -102,15 +103,15 @@ class TranscriptChunk:
102
  """比较当前片段与另一个片段的相似度"""
103
  if not chunk:
104
  return 0
105
- return self._calculate_similarity(self.join(), chunk.join())
 
 
 
106
 
107
  def has_punctuation(self) -> bool:
108
  return any(seg.is_punctuation() for seg in self.items)
109
 
110
  def get_buffer_index(self) -> int:
111
- logger.debug("==== Current cut item ====")
112
- logger.debug(f"{self.items[-1]}")
113
- logger.debug("==========================")
114
  return self.items[-1].buffer_index()
115
 
116
  def is_end_sentence(self) ->bool:
@@ -134,7 +135,9 @@ class TranscriptHistory:
134
  def lastest_chunk(self):
135
  """获取最后一个片段"""
136
  return self.history[-1]
137
-
 
 
138
 
139
  class TranscriptBuffer:
140
  """
@@ -147,10 +150,10 @@ class TranscriptBuffer:
147
  """
148
 
149
  def __init__(self):
150
- self._segments: List[str] = [] # 确认的完整段落
151
  self._sentences: List[str] = [] # 当前段落中的短句
152
  self._buffer: str = "" # 当前缓冲中的文本
153
- self._current_seg_id: int =0
154
 
155
  def get_seg_id(self) -> int:
156
  return self._current_seg_id
@@ -176,26 +179,31 @@ class TranscriptBuffer:
176
  end_of_sentence: 是否为句子结尾(如检测到句号)
177
  """
178
  if self._sentences:
179
- self._segments.append("".join(self._sentences))
180
- self._current_seg_id += 1
181
  self._sentences.clear()
182
 
183
 
184
  def update_and_commit(self, stable_string: str, remaining_string:str, is_end_sentence=False):
185
- self.update_pending_text(stable_string)
 
186
  if is_end_sentence:
187
- self.commit_paragraph(end_of_sentence=True)
 
 
188
  self.update_pending_text(remaining_string)
189
- # if len() >=20
190
- return True
 
 
191
  else:
 
192
  self.commit_line()
193
  self.update_pending_text(remaining_string)
194
- return False
195
 
196
 
197
  @property
198
- def paragraph(self) -> str:
199
  """当前短句组合"""
200
  return "".join(self._sentences)
201
 
@@ -211,7 +219,7 @@ class TranscriptBuffer:
211
 
212
  @property
213
  def current_not_commit_text(self) -> str:
214
- return self.paragraph + self.pending_text
215
 
216
 
217
 
@@ -230,15 +238,14 @@ class TranscriptStabilityAnalyzer:
230
 
231
  prev = self._transcript_history.previous_chunk()
232
  self._transcript_buffer.update_pending_text(current.join())
233
- if not prev:
234
  yield TranscriptResult(
235
  context=self._transcript_buffer.current_not_commit_text,
236
  seg_id=self._transcript_buffer.get_seg_id()
237
  )
238
  return
239
-
240
-
241
 
 
242
  if buffer_duration <= 12:
243
  yield from self._handle_short_buffer(current, prev)
244
  else:
@@ -248,16 +255,24 @@ class TranscriptStabilityAnalyzer:
248
  def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]:
249
  curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION)
250
  prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION)
251
- core = curr_first.compare(prev_first)
252
- has_punctuation = curr_first.has_punctuation()
253
- logger.debug(f"Compare with rev score:{core},is end :{curr_first.is_end_sentence()}, has_punctuation: {has_punctuation}, current_first: {curr_first.join()},")
254
- if core >= 0.8:
255
- yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence())
256
- else:
257
- yield TranscriptResult(
258
- seg_id=self._transcript_buffer.get_seg_id(),
259
- context=self._transcript_buffer.current_not_commit_text
260
- )
 
 
 
 
 
 
 
 
261
 
262
 
263
  def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]:
@@ -279,10 +294,8 @@ class TranscriptStabilityAnalyzer:
279
  def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]:
280
  stable_str = stable_chunk.join() if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk)
281
  remaining_str = self.merge_chunks(remaining_chunks)
282
-
283
  frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index()
284
- logger.debug(f"Current cut index: {frame_cut_index}, Stable string: {stable_str}, Remaining_str:{remaining_str}")
285
-
286
  prev_seg_id = self._transcript_buffer.get_seg_id()
287
  commit_paragraph = self._transcript_buffer.update_and_commit(stable_str, remaining_str, is_end_sentence)
288
  logger.debug(f"current buffer: {self._transcript_buffer.__dict__}")
@@ -295,10 +308,15 @@ class TranscriptStabilityAnalyzer:
295
  context=self._transcript_buffer.latest_paragraph,
296
  is_end_sentence=True
297
  )
298
- # 如果还有挂起的文本
299
- if (current_not_commit_text := self._transcript_buffer.current_not_commit_text.strip()):
 
 
 
 
300
  yield TranscriptResult(
301
  seg_id=self._transcript_buffer.get_seg_id(),
302
  cut_index=frame_cut_index,
303
- context=current_not_commit_text
304
  )
 
 
8
  from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE
9
  import numpy as np
10
  from enum import Enum
11
+ from itertools import chain
12
  logger = logging.getLogger("TranscriptionStrategy")
13
 
14
 
 
84
  def get_split_first_rest(self, mode: SplitMode):
85
  chunks = self.split_by(mode)
86
  fisrt_chunk = chunks[0] if chunks else self
87
+ rest_chunks = chunks[1:] if chunks else None
88
  return fisrt_chunk, rest_chunks
89
 
90
  def puncation_numbers(self) -> int:
91
  """计算片段中标点符号的数量"""
92
+ return sum(1 for seg in self.items if seg.is_punctuation())
93
 
94
  def length(self) -> int:
95
  """返回片段列表的长度"""
 
103
  """比较当前片段与另一个片段的相似度"""
104
  if not chunk:
105
  return 0
106
+
107
+ score = self._calculate_similarity(self.join(), chunk.join())
108
+ logger.debug(f"Compare: {self.join()} vs {chunk.join()} : {score}")
109
+ return score
110
 
111
  def has_punctuation(self) -> bool:
112
  return any(seg.is_punctuation() for seg in self.items)
113
 
114
  def get_buffer_index(self) -> int:
 
 
 
115
  return self.items[-1].buffer_index()
116
 
117
  def is_end_sentence(self) ->bool:
 
135
  def lastest_chunk(self):
136
  """获取最后一个片段"""
137
  return self.history[-1]
138
+
139
+ def clear(self):
140
+ self.history.clear()
141
 
142
  class TranscriptBuffer:
143
  """
 
150
  """
151
 
152
  def __init__(self):
153
+ self._segments: List[str] = collections.deque(maxlen=2) # 确认的完整段落
154
  self._sentences: List[str] = [] # 当前段落中的短句
155
  self._buffer: str = "" # 当前缓冲中的文本
156
+ self._current_seg_id: int = 0
157
 
158
  def get_seg_id(self) -> int:
159
  return self._current_seg_id
 
179
  end_of_sentence: 是否为句子结尾(如检测到句号)
180
  """
181
  if self._sentences:
182
+ self._segments.appendleft("".join(self._sentences))
 
183
  self._sentences.clear()
184
 
185
 
186
  def update_and_commit(self, stable_string: str, remaining_string:str, is_end_sentence=False):
187
+
188
+ logger.debug(f"{self.__dict__}")
189
  if is_end_sentence:
190
+ self.update_pending_text(stable_string)
191
+ self.commit_line()
192
+ current_text_len = len(self.current_not_commit_text)
193
  self.update_pending_text(remaining_string)
194
+ if current_text_len >=20:
195
+ self.commit_paragraph()
196
+ self._current_seg_id += 1
197
+ return True
198
  else:
199
+ self.update_pending_text(stable_string)
200
  self.commit_line()
201
  self.update_pending_text(remaining_string)
202
+ return False
203
 
204
 
205
  @property
206
+ def un_commit_paragraph(self) -> str:
207
  """当前短句组合"""
208
  return "".join(self._sentences)
209
 
 
219
 
220
  @property
221
  def current_not_commit_text(self) -> str:
222
+ return self.un_commit_paragraph + self.pending_text
223
 
224
 
225
 
 
238
 
239
  prev = self._transcript_history.previous_chunk()
240
  self._transcript_buffer.update_pending_text(current.join())
241
+ if not prev: # 如果没有历史记录 那么就说明是新的语句 直接输出就行
242
  yield TranscriptResult(
243
  context=self._transcript_buffer.current_not_commit_text,
244
  seg_id=self._transcript_buffer.get_seg_id()
245
  )
246
  return
 
 
247
 
248
+ # yield from self._handle_short_buffer(current, prev)
249
  if buffer_duration <= 12:
250
  yield from self._handle_short_buffer(current, prev)
251
  else:
 
255
  def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]:
256
  curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION)
257
  prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION)
258
+
259
+ # logger.debug("==== Current cut item ====")
260
+ # logger.debug(f"{curr.join()} ")
261
+ # logger.debug(f"{prev.join()}")
262
+ # logger.debug("==========================")
263
+
264
+ if curr_first and prev_first:
265
+
266
+ core = curr_first.compare(prev_first)
267
+ # has_punctuation = curr_first.has_punctuation()
268
+ if core >= 0.8:
269
+ yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence())
270
+ return
271
+
272
+ yield TranscriptResult(
273
+ seg_id=self._transcript_buffer.get_seg_id(),
274
+ context=self._transcript_buffer.current_not_commit_text
275
+ )
276
 
277
 
278
  def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]:
 
294
  def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]:
295
  stable_str = stable_chunk.join() if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk)
296
  remaining_str = self.merge_chunks(remaining_chunks)
 
297
  frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index()
298
+
 
299
  prev_seg_id = self._transcript_buffer.get_seg_id()
300
  commit_paragraph = self._transcript_buffer.update_and_commit(stable_str, remaining_str, is_end_sentence)
301
  logger.debug(f"current buffer: {self._transcript_buffer.__dict__}")
 
308
  context=self._transcript_buffer.latest_paragraph,
309
  is_end_sentence=True
310
  )
311
+ yield TranscriptResult(
312
+ seg_id=self._transcript_buffer.get_seg_id(),
313
+ # cut_index=frame_cut_index,
314
+ context=self._transcript_buffer.pending_text,
315
+ )
316
+ else:
317
  yield TranscriptResult(
318
  seg_id=self._transcript_buffer.get_seg_id(),
319
  cut_index=frame_cut_index,
320
+ context=self._transcript_buffer.current_not_commit_text,
321
  )
322
+
transcribe/utils.py CHANGED
@@ -8,6 +8,7 @@ from scipy.io.wavfile import write
8
 
9
  import av
10
  def log_block(key: str, value, unit=''):
 
11
  """格式化输出日志内容"""
12
  key_fmt = f"[ {key.ljust(25)}]" # 左对齐填充
13
  val_fmt = f"{value} {unit}".strip()
 
8
 
9
  import av
10
  def log_block(key: str, value, unit=''):
11
+ return
12
  """格式化输出日志内容"""
13
  key_fmt = f"[ {key.ljust(25)}]" # 左对齐填充
14
  val_fmt = f"{value} {unit}".strip()
transcribe/whisper_llm_serve.py CHANGED
@@ -193,7 +193,6 @@ class WhisperTranscriptionService(ServeClientBase):
193
 
194
  # 处理转录结果并发送到客户端
195
  for result in self._process_transcription_results(segments, audio_buffer):
196
- print(result)
197
  self._send_result_to_client(result)
198
 
199
  # except Exception as e:
@@ -217,7 +216,7 @@ class WhisperTranscriptionService(ServeClientBase):
217
  self._update_audio_buffer(cut_index)
218
 
219
  translated_context = self._translate_text(ana_result.context)
220
- log_block("Translated context:", translated_context)
221
  yield TransResult(
222
  seg_id=ana_result.seg_id,
223
  context=ana_result.context,
 
193
 
194
  # 处理转录结果并发送到客户端
195
  for result in self._process_transcription_results(segments, audio_buffer):
 
196
  self._send_result_to_client(result)
197
 
198
  # except Exception as e:
 
216
  self._update_audio_buffer(cut_index)
217
 
218
  translated_context = self._translate_text(ana_result.context)
219
+
220
  yield TransResult(
221
  seg_id=ana_result.seg_id,
222
  context=ana_result.context,