daihui.zhang commited on
Commit
e046f39
·
1 Parent(s): 6f13b8c

add buffer clip via sequence strategy

Browse files
transcribe/strategy.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from logging import getLogger
4
+ from difflib import SequenceMatcher
5
+ import collections
6
+ import config
7
+ import numpy as np
8
+ from itertools import chain
9
+
10
+ logger = getLogger("Stragegy")
11
+
12
+ class TripleTextBuffer:
13
+ def __init__(self, size=2):
14
+ self.history = collections.deque(maxlen=size)
15
+
16
+ def add_entry(self, text, index):
17
+ """
18
+ text: 文本
19
+ index: 当前buffer的相对下标 数组索引
20
+ """
21
+ self.history.append((text, index))
22
+
23
+
24
+ def get_final_index(self, similarity_threshold=0.7):
25
+ """根据文本变化,返回可靠的标点的buffer的位置下标"""
26
+ if len(self.history) < 2:
27
+ return None
28
+
29
+ # 获取三次的文本
30
+ text1, _ = self.history[0]
31
+ text2, idx2 = self.history[1]
32
+ # text3, idx3 = self.history[2]
33
+
34
+ # 计算变化程度
35
+ sim_12 = self.text_similarity(text1, text2)
36
+ # print("比较: ", text1, text2," => ", sim_12)
37
+ # sim_23 = self.text_similarity(text2, text3)
38
+ if sim_12 >= similarity_threshold:
39
+ self.history.clear()
40
+ return idx2
41
+ return None
42
+
43
+ @staticmethod
44
+ def text_similarity(text1, text2):
45
+ return SequenceMatcher(None, text1, text2).ratio()
46
+
47
+
48
+
49
+ class SegmentManager:
50
+ def __init__(self) -> None:
51
+ self._commited_segments = [] # 确定后的段落
52
+ self._commited_short_sentences = [] # 确定后的序列
53
+ self._temp_string = "" # 存储当前临时的文本字符串,直到以句号结尾
54
+
55
+ def handle(self, string):
56
+ self._temp_string = string
57
+ return self
58
+
59
+ @property
60
+ def short_sentence(self) -> str:
61
+ return "".join(self._commited_short_sentences)
62
+
63
+ @property
64
+ def segment(self):
65
+ return self._commited_segments[-1] if len(self._commited_segments) > 0 else ""
66
+
67
+ def get_seg_id(self):
68
+ return len(self._commited_segments)
69
+
70
+ @property
71
+ def string(self):
72
+ return self._temp_string
73
+
74
+
75
+ def commit_short_sentence(self):
76
+ """将临时字符串 提交到临时短句"""
77
+ self._commited_short_sentences.append(self._temp_string)
78
+ self._temp_string = ""
79
+
80
+ def commit_segment(self):
81
+ """将短句 合并 到长句中"""
82
+ self._commited_segments.append(self.short_sentence)
83
+ self._commited_short_sentences = []
84
+
85
+ def commit(self, is_end_sentence=False):
86
+ """
87
+ 当需要切掉的音频部分的时候,将句子提交到短句队列中,并移除临时字符串
88
+ 当完成一个整句的时候提交到段落中
89
+ """
90
+ self.commit_short_sentence()
91
+ if is_end_sentence:
92
+ self.commit_segment()
93
+
94
+ def segement_merge(segments):
95
+ """根据标点符号分整句"""
96
+ sequences = []
97
+ temp_seq = []
98
+
99
+ for seg in segments:
100
+ temp_seq.append(seg)
101
+ if any([mk in seg.text for mk in config.SENTENCE_END_MARKERS]):
102
+ sequences.append(temp_seq.copy())
103
+ temp_seq = []
104
+ if temp_seq:
105
+ sequences.append(temp_seq)
106
+ return sequences
107
+
108
+ def segments_split(segments, audio_buffer: np.ndarray, sample_rate=16000):
109
+ """根据左边第一个标点符号来将序列拆分成 观察段 和 剩余部分"""
110
+ left_watch_sequences = []
111
+ left_watch_idx = 0
112
+ right_watch_sequences = []
113
+ is_end = False
114
+
115
+ if (len(audio_buffer) / sample_rate) < 12:
116
+ # 低于12s 使用短句符号比如逗号作为判断依据
117
+ markers = config.PAUSE_END_MARKERS
118
+ is_end = False
119
+
120
+ for idx, seg in enumerate(segments):
121
+ left_watch_sequences.append(seg)
122
+ if seg.text in markers:
123
+ seg_index = int(seg.t1 / 100 * sample_rate)
124
+ rest_buffer_duration = (len(audio_buffer) - seg_index) / sample_rate
125
+ # is_end = any(i in seg.text for i in config.SENTENCE_END_MARKERS)
126
+ right_watch_sequences = segments[min(idx+1, len(segments)):]
127
+ if rest_buffer_duration >= 1.5:
128
+ left_watch_idx = seg_index
129
+ break
130
+ return left_watch_idx, left_watch_sequences, right_watch_sequences, is_end
131
+
132
+
133
+ def sequences_split(segments, audio_buffer: np.ndarray, sample_rate=16000):
134
+ # 长句 保留最后两句即可
135
+ left_watch_sequences = []
136
+ right_watch_sequences = []
137
+ left_watch_idx = 0
138
+ is_end = False
139
+ sequences = segement_merge(segments)
140
+
141
+ if len(sequences) > 2:
142
+ logger.info(f"buffer clip via sequence, current length: {len(sequences)}")
143
+ is_end = True
144
+ left_watch_sequences = chain(*sequences[:-2])
145
+ right_watch_sequences = chain(*sequences[-2:])
146
+ last_sequence_segment = sequences[-3]
147
+ last_segment = last_sequence_segment[-1]
148
+ left_watch_idx = int(last_segment.t1 / 100 * sample_rate)
149
+ return left_watch_idx, left_watch_sequences, right_watch_sequences, is_end
150
+
151
+
152
+
transcribe/whisper_llm_serve.py CHANGED
@@ -1,119 +1,31 @@
1
 
2
 
3
  import soundfile
4
- from concurrent.futures import ProcessPoolExecutor as PPool
5
  import multiprocessing as mp
6
  import numpy as np
7
  from logging import getLogger
8
- from difflib import SequenceMatcher
9
- import collections
10
  import config
11
  import time
12
  import json
13
  import threading
14
- from functools import partial
15
  from .server import ServeClientBase
16
- from .translator import QwenTranslator
17
- from .vad import VoiceActivityDetector
18
- from pywhispercpp.model import Model
19
- from queue import Queue
20
  from scipy.io.wavfile import write
21
  from api_model import TransResult, Message
22
  from .utils import log_block
23
  from .translatepipes import TranslatePipes
 
24
 
25
  logger = getLogger("TranslatorApp")
26
 
27
  translate_pipes = TranslatePipes()
28
  translate_pipes.wait_ready()
29
-
30
  logger.info("Pipeline is ready.")
31
 
32
  def save_to_wave(filename, data:np.ndarray, sample_rate=16000):
33
  write(filename, sample_rate, data)
34
 
35
- class TripleTextBuffer:
36
- def __init__(self, size=2):
37
- self.history = collections.deque(maxlen=size)
38
-
39
- def add_entry(self, text, index):
40
- """
41
- text: 文本
42
- index: 当前buffer的相对下标 数组索引
43
- """
44
- self.history.append((text, index))
45
-
46
-
47
- def get_final_index(self, similarity_threshold=0.7):
48
- """根据文本变化,返回可靠的标点的buffer的位置下标"""
49
- if len(self.history) < 2:
50
- return None
51
-
52
- # 获取三次的文本
53
- text1, _ = self.history[0]
54
- text2, idx2 = self.history[1]
55
- # text3, idx3 = self.history[2]
56
-
57
- # 计算变化程度
58
- sim_12 = self.text_similarity(text1, text2)
59
- # print("比较: ", text1, text2," => ", sim_12)
60
- # sim_23 = self.text_similarity(text2, text3)
61
- if sim_12 >= similarity_threshold:
62
- self.history.clear()
63
- return idx2
64
- return None
65
-
66
- @staticmethod
67
- def text_similarity(text1, text2):
68
- return SequenceMatcher(None, text1, text2).ratio()
69
-
70
-
71
-
72
- class SegmentManager:
73
- def __init__(self) -> None:
74
- self._commited_segments = [] # 确定后的段落
75
- self._commited_short_sentences = [] # 确定后的序列
76
- self._temp_string = "" # 存储当前临时的文本字符串,直到以句号结尾
77
-
78
- def handle(self, string):
79
- self._temp_string = string
80
- return self
81
-
82
- @property
83
- def short_sentence(self) -> str:
84
- return "".join(self._commited_short_sentences)
85
-
86
- @property
87
- def segment(self):
88
- return self._commited_segments[-1] if len(self._commited_segments) > 0 else ""
89
-
90
- def get_seg_id(self):
91
- return len(self._commited_segments)
92
-
93
- @property
94
- def string(self):
95
- return self._temp_string
96
-
97
-
98
- def commit_short_sentence(self):
99
- """将临时字符串 提交到临时短句"""
100
- self._commited_short_sentences.append(self._temp_string)
101
- self._temp_string = ""
102
-
103
- def commit_segment(self):
104
- """将短句 合并 到长句中"""
105
- self._commited_segments.append(self.short_sentence)
106
- self._commited_short_sentences = []
107
-
108
- def commit(self, is_end_sentence=False):
109
- """
110
- 当需要切掉的音频部分的时候,将句子提交到短句队列中,并移除临时字符串
111
- 当完成一个整句的时候提交到段落中
112
- """
113
- self.commit_short_sentence()
114
- if is_end_sentence:
115
- self.commit_segment()
116
-
117
 
118
 
119
  class PyWhiperCppServe(ServeClientBase):
@@ -127,15 +39,11 @@ class PyWhiperCppServe(ServeClientBase):
127
  self._text_buffer = TripleTextBuffer()
128
  # 存储转录数据
129
  self._segment_manager = SegmentManager()
130
- self._ready_state = mp.Event()
131
 
132
  self.lock = threading.Lock()
133
  self.frames_np = None
134
  self.sample_rate = 16000
135
- # self._audio_queue = Queue()
136
- # 进程初始化后再开始收音频
137
-
138
- logger.info('Create a process to process audio.')
139
  self.send_ready_state()
140
 
141
  self.trans_thread = threading.Thread(target=self.speech_to_text)
@@ -143,7 +51,7 @@ class PyWhiperCppServe(ServeClientBase):
143
  self.trans_thread.start()
144
 
145
  def send_ready_state(self):
146
- # self._ready_state.wait()
147
  self.websocket.send(json.dumps({
148
  "uid": self.client_uid,
149
  "message": self.SERVER_READY,
@@ -193,40 +101,14 @@ class PyWhiperCppServe(ServeClientBase):
193
  ret = translate_pipes.translate(text, self.language, self.dst_lang)
194
  log_block("LLM translate time", f"{(time.perf_counter() - start_time):.3f}", "s")
195
  return ret.translate_content
196
-
197
- def _segments_split(self, segments, audio_buffer: np.ndarray):
198
- """根据左边第一个标点符号来将序列拆分成 观察段 和 剩余部分"""
199
- left_watch_sequences = []
200
- left_watch_idx = 0
201
- right_watch_sequences = []
202
- is_end = False
203
-
204
- if (len(audio_buffer) / self.sample_rate) < 10:
205
- # 低于10s 使用短句符号比如逗号作为判断依据
206
- markers = config.PAUSE_END_MARKERS
207
- is_end = False
208
- else:
209
- # 使用句号 长句结尾符号作为判断
210
- markers = config.SENTENCE_END_MARKERS
211
- is_end = True
212
 
213
- for idx, seg in enumerate(segments):
214
- left_watch_sequences.append(seg)
215
- if seg.text in markers:
216
- seg_index = int(seg.t1 / 100 * self.sample_rate)
217
- rest_buffer_duration = (len(audio_buffer) - seg_index) / self.sample_rate
218
- # is_end = any(i in seg.text for i in config.SENTENCE_END_MARKERS)
219
- right_watch_sequences = segments[min(idx+1, len(segments)):]
220
- if rest_buffer_duration >= 1.5:
221
- left_watch_idx = seg_index
222
- break
223
- return left_watch_idx, left_watch_sequences, right_watch_sequences, is_end
224
 
225
  def analysis_segments(self, segments, audio_buffer: np.ndarray):
226
  # 找到第一个标点符号作为锚点 左边为确认段,右边为观察段,
227
  # 当左边确认后,右边段才会进入观察
228
  # 当左边确认后,会从缓冲区中删除对应的buffer,减少下次输入的数据量
229
- left_watch_idx, left_watch_sequences, right_watch_sequences, is_end_sentence = self._segments_split(segments, audio_buffer)
230
  left_watch_string = "".join(i.text for i in left_watch_sequences)
231
  right_watch_string = "".join(i.text for i in right_watch_sequences)
232
 
@@ -236,6 +118,14 @@ class PyWhiperCppServe(ServeClientBase):
236
  audio_cut_index = self._text_buffer.get_final_index()
237
  if audio_cut_index:
238
  return audio_cut_index, left_watch_string, right_watch_string, is_end_sentence
 
 
 
 
 
 
 
 
239
  return None, left_watch_string, right_watch_string, is_end_sentence
240
 
241
  def speech_to_text(self):
@@ -253,15 +143,15 @@ class PyWhiperCppServe(ServeClientBase):
253
  # c+= 1
254
  # name = f"dev-{c}.wav"
255
  # save_to_wave(name, audio_buffer)
256
- try:
257
- logger.info(f"Audio buffer length: {len(audio_buffer) / self.sample_rate:.2f}s")
258
- segments = self.transcribe_audio(audio_buffer)
259
- for tran_result in self.handle_transcription_output(segments, audio_buffer):
260
- self.send_to_client(tran_result)
261
- except KeyboardInterrupt:
262
- break
263
- except Exception as e:
264
- logger.error(f"{e}")
265
 
266
  def handle_transcription_output(self, segments, audio_buffer):
267
  texts = "".join(i.text for i in segments)
 
1
 
2
 
3
  import soundfile
 
4
  import multiprocessing as mp
5
  import numpy as np
6
  from logging import getLogger
7
+
 
8
  import config
9
  import time
10
  import json
11
  import threading
 
12
  from .server import ServeClientBase
13
+
 
 
 
14
  from scipy.io.wavfile import write
15
  from api_model import TransResult, Message
16
  from .utils import log_block
17
  from .translatepipes import TranslatePipes
18
+ from .strategy import TripleTextBuffer, SegmentManager, segments_split, sequences_split
19
 
20
  logger = getLogger("TranslatorApp")
21
 
22
  translate_pipes = TranslatePipes()
23
  translate_pipes.wait_ready()
 
24
  logger.info("Pipeline is ready.")
25
 
26
  def save_to_wave(filename, data:np.ndarray, sample_rate=16000):
27
  write(filename, sample_rate, data)
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  class PyWhiperCppServe(ServeClientBase):
 
39
  self._text_buffer = TripleTextBuffer()
40
  # 存储转录数据
41
  self._segment_manager = SegmentManager()
 
42
 
43
  self.lock = threading.Lock()
44
  self.frames_np = None
45
  self.sample_rate = 16000
46
+
 
 
 
47
  self.send_ready_state()
48
 
49
  self.trans_thread = threading.Thread(target=self.speech_to_text)
 
51
  self.trans_thread.start()
52
 
53
  def send_ready_state(self):
54
+
55
  self.websocket.send(json.dumps({
56
  "uid": self.client_uid,
57
  "message": self.SERVER_READY,
 
101
  ret = translate_pipes.translate(text, self.language, self.dst_lang)
102
  log_block("LLM translate time", f"{(time.perf_counter() - start_time):.3f}", "s")
103
  return ret.translate_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+
 
 
 
 
 
 
 
 
 
 
106
 
107
  def analysis_segments(self, segments, audio_buffer: np.ndarray):
108
  # 找到第一个标点符号作为锚点 左边为确认段,右边为观察段,
109
  # 当左边确认后,右边段才会进入观察
110
  # 当左边确认后,会从缓冲区中删除对应的buffer,减少下次输入的数据量
111
+ left_watch_idx, left_watch_sequences, right_watch_sequences, is_end_sentence = segments_split(segments, audio_buffer)
112
  left_watch_string = "".join(i.text for i in left_watch_sequences)
113
  right_watch_string = "".join(i.text for i in right_watch_sequences)
114
 
 
118
  audio_cut_index = self._text_buffer.get_final_index()
119
  if audio_cut_index:
120
  return audio_cut_index, left_watch_string, right_watch_string, is_end_sentence
121
+
122
+ # 整句消除 后两句之前的内容
123
+ left_watch_idx, left_watch_sequences, right_watch_sequences, is_end_sentence = sequences_split(segments, audio_buffer)
124
+ left_watch_string = "".join(i.text for i in left_watch_sequences)
125
+ right_watch_string = "".join(i.text for i in right_watch_sequences)
126
+ if left_watch_idx != 0:
127
+ return left_watch_idx, left_watch_string, right_watch_string, is_end_sentence
128
+
129
  return None, left_watch_string, right_watch_string, is_end_sentence
130
 
131
  def speech_to_text(self):
 
143
  # c+= 1
144
  # name = f"dev-{c}.wav"
145
  # save_to_wave(name, audio_buffer)
146
+ # try:
147
+ logger.info(f"Audio buffer length: {len(audio_buffer) / self.sample_rate:.2f}s")
148
+ segments = self.transcribe_audio(audio_buffer)
149
+ for tran_result in self.handle_transcription_output(segments, audio_buffer):
150
+ self.send_to_client(tran_result)
151
+ # except KeyboardInterrupt:
152
+ # break
153
+ # except Exception as e:
154
+ # logger.error(f"{e}")
155
 
156
  def handle_transcription_output(self, segments, audio_buffer):
157
  texts = "".join(i.text for i in segments)