Xin Zhang commited on
Commit
793447d
·
1 Parent(s): 484b9cf

[fix]: test dynamic vad.

Browse files
transcribe/helpers/vad_dynamic.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from queue import Queue, Empty
3
+ from time import time
4
+ from config import VAD_MODEL_PATH
5
+ # from silero_vad import load_silero_vad
6
+ import numpy as np
7
+ import onnxruntime
8
+
9
+ class OnnxWrapper():
10
+
11
+ def __init__(self, path, force_onnx_cpu=False):
12
+ opts = onnxruntime.SessionOptions()
13
+ opts.inter_op_num_threads = 1
14
+ opts.intra_op_num_threads = 1
15
+
16
+ if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
17
+ self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
18
+ else:
19
+ self.session = onnxruntime.InferenceSession(path, sess_options=opts)
20
+
21
+ self.reset_states()
22
+ self.sample_rates = [16000]
23
+
24
+ def _validate_input(self, x: np.ndarray, sr: int):
25
+ if x.ndim == 1:
26
+ x = x[None]
27
+ if x.ndim > 2:
28
+ raise ValueError(f"Too many dimensions for input audio chunk {x.ndim}")
29
+
30
+ if sr != 16000 and (sr % 16000 == 0):
31
+ step = sr // 16000
32
+ x = x[:, ::step]
33
+ sr = 16000
34
+
35
+ if sr not in self.sample_rates:
36
+ raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
37
+ if sr / x.shape[1] > 31.25:
38
+ raise ValueError("Input audio chunk is too short")
39
+
40
+ return x, sr
41
+
42
+ def reset_states(self, batch_size=1):
43
+ self._state = np.zeros((2, batch_size, 128)).astype(np.float32)
44
+ self._context = np.zeros(0)
45
+ self._last_sr = 0
46
+ self._last_batch_size = 0
47
+
48
+ def __call__(self, x, sr: int):
49
+
50
+ x, sr = self._validate_input(x, sr)
51
+ num_samples = 512 if sr == 16000 else 256
52
+
53
+ if x.shape[-1] != num_samples:
54
+ raise ValueError(
55
+ f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
56
+
57
+ batch_size = x.shape[0]
58
+ context_size = 64 if sr == 16000 else 32
59
+
60
+ if not self._last_batch_size:
61
+ self.reset_states(batch_size)
62
+ if (self._last_sr) and (self._last_sr != sr):
63
+ self.reset_states(batch_size)
64
+ if (self._last_batch_size) and (self._last_batch_size != batch_size):
65
+ self.reset_states(batch_size)
66
+
67
+ if not len(self._context):
68
+ self._context = np.zeros((batch_size, context_size)).astype(np.float32)
69
+
70
+ x = np.concatenate([self._context, x], axis=1)
71
+ if sr in [8000, 16000]:
72
+ ort_inputs = {'input': x, 'state': self._state, 'sr': np.array(sr, dtype='int64')}
73
+ ort_outs = self.session.run(None, ort_inputs)
74
+ out, state = ort_outs
75
+ self._state = state
76
+ else:
77
+ raise ValueError()
78
+
79
+ self._context = x[..., -context_size:]
80
+ self._last_sr = sr
81
+ self._last_batch_size = batch_size
82
+
83
+ # out = torch.from_numpy(out)
84
+ return out
85
+
86
+ def audio_forward(self, audio: np.ndarray, sr: int):
87
+ outs = []
88
+ x, sr = self._validate_input(audio, sr)
89
+ self.reset_states()
90
+ num_samples = 512 if sr == 16000 else 256
91
+
92
+ if x.shape[1] % num_samples:
93
+ pad_num = num_samples - (x.shape[1] % num_samples)
94
+ x = np.pad(x, ((0, 0), (0, pad_num)), 'constant', constant_values=(0.0, 0.0))
95
+
96
+ for i in range(0, x.shape[1], num_samples):
97
+ wavs_batch = x[:, i:i + num_samples]
98
+ out_chunk = self.__call__(wavs_batch, sr)
99
+ outs.append(out_chunk)
100
+
101
+ stacked = np.concatenate(outs, axis=1)
102
+ return stacked
103
+
104
+
105
+ class VADIteratorOnnx:
106
+ def __init__(self,
107
+ threshold: float = 0.5,
108
+ sampling_rate: int = 16000,
109
+ min_silence_duration_ms: int = 100,
110
+ max_speech_duration_s: float = float('inf'),
111
+ long_speech_threshold_s: float = 6.0, # 新增:长语音阈值(秒)
112
+ adjusted_min_silence_factor: float = 0.5 # 新增:调整后的静音时长因子
113
+ ):
114
+ self.model = OnnxWrapper(VAD_MODEL_PATH, True)
115
+ self.threshold = threshold
116
+ self.sampling_rate = sampling_rate
117
+
118
+ if sampling_rate not in [8000, 16000]:
119
+ raise ValueError('VADIterator does not support sampling rates other than [8000, 16000]')
120
+
121
+ self._original_min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 # 存储原始值
122
+ self.min_silence_samples = self._original_min_silence_samples # 当前使用的值
123
+ self.adjusted_min_silence_samples = self._original_min_silence_samples * adjusted_min_silence_factor # 计算调整后的值
124
+ self.long_speech_threshold_samples = sampling_rate * long_speech_threshold_s # 长语音阈值(样本数)
125
+
126
+ self.max_speech_samples = int(sampling_rate * max_speech_duration_s)
127
+ # self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
128
+ self.reset_states()
129
+
130
+ def reset_states(self):
131
+
132
+ self.model.reset_states()
133
+ self.triggered = False
134
+ self.temp_end = 0
135
+ self.current_sample = 0
136
+ self.start = 0
137
+ self.speech_start_sample = 0 # 新增:记录连续语音开始的样本点
138
+ self.min_silence_samples = self._original_min_silence_samples # 重置为原始值
139
+
140
+ def __call__(self, x: np.ndarray, return_seconds=False):
141
+ """
142
+ x: np.ndarray
143
+ audio chunk (see examples in repo)
144
+
145
+ return_seconds: bool (default - False)
146
+ whether return timestamps in seconds (default - samples)
147
+ """
148
+
149
+ window_size_samples = 512 if self.sampling_rate == 16000 else 256
150
+ x = x[:window_size_samples]
151
+ if len(x) < window_size_samples:
152
+ x = np.pad(x, ((0, 0), (0, window_size_samples - len(x))), 'constant', constant_values=0.0)
153
+
154
+ self.current_sample += window_size_samples
155
+
156
+ speech_prob = self.model(x, self.sampling_rate)[0,0]
157
+ # print(f"{self.current_sample/self.sampling_rate:.2f}: {speech_prob}")
158
+
159
+ # --- 动态调整逻辑 ---
160
+ current_min_silence_samples_to_use = self._original_min_silence_samples
161
+ if self.triggered and self.speech_start_sample > 0:
162
+ current_speech_duration_samples = self.current_sample - self.speech_start_sample
163
+ if current_speech_duration_samples > self.long_speech_threshold_samples:
164
+ # 如果连续语音超过阈值,使用调整后的(更短的)静音时长
165
+ current_min_silence_samples_to_use = self.adjusted_min_silence_samples
166
+ # --- 结束动态调整逻辑 ---
167
+
168
+
169
+ if (speech_prob >= self.threshold) and self.temp_end:
170
+ # 从临时静音恢复到语音,清除临时结束点
171
+ self.temp_end = 0
172
+
173
+ if (speech_prob >= self.threshold) and not self.triggered:
174
+ # 检测到语音开始
175
+ self.triggered = True
176
+ speech_start = max(0, self.current_sample - window_size_samples)
177
+ self.start = speech_start
178
+ self.speech_start_sample = self.start # 记录语音开始点
179
+ # self.min_silence_samples = self._original_min_silence_samples # 在 reset_states 中重置
180
+ return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
181
+
182
+ if (speech_prob >= self.threshold) and self.current_sample - self.start >= self.max_speech_samples:
183
+ # 达到最大语音长度,强制结束(如果设置了)
184
+ if self.temp_end:
185
+ self.temp_end = 0
186
+ speech_end = self.current_sample # 使用当前样本点作为结束点
187
+ self.triggered = False # 结束当前段
188
+ self.speech_start_sample = 0 # 重置连续语音开始点
189
+ # self.min_silence_samples = self._original_min_silence_samples # 在 reset_states 中重置
190
+ # 返回结束事件,并重置 start 以便可以立即开始新的段
191
+ end_val = int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)
192
+ self.start = speech_end # 将 start 设置为当前结束点,为下一段做准备?或者在 VadV2 中处理? VadV2 会重置 start/end
193
+ return {'end': end_val}
194
+
195
+
196
+ if (speech_prob < self.threshold - 0.15) and self.triggered:
197
+ # 检测到可能的静音
198
+ if not self.temp_end:
199
+ self.temp_end = self.current_sample # 记录可能的结束点
200
+ # 使用当前计算出的(可能调整过的)静音时长阈值进行判断
201
+ if self.current_sample - self.temp_end < current_min_silence_samples_to_use:
202
+ # 静音时间不够长,忽略
203
+ return None
204
+ else:
205
+ # 静音时间足够长,确认语音结束
206
+ speech_end = self.temp_end - window_size_samples # 结束点是临时结束点减去窗口大小
207
+ self.temp_end = 0
208
+ self.triggered = False
209
+ self.speech_start_sample = 0 # 重置连续语音开始点
210
+ # self.min_silence_samples = self._original_min_silence_samples # 在 reset_states 中重置
211
+ return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
212
+
213
+ return None
214
+
215
+
216
+
217
+ class VadV2:
218
+ def __init__(self,
219
+ threshold: float = 0.5,
220
+ sampling_rate: int = 16000,
221
+ min_silence_duration_ms: int = 100,
222
+ speech_pad_ms: int = 30,
223
+ max_speech_duration_s: float = float('inf'),
224
+ long_speech_threshold_s: float = 10.0, # 提高默认值,减少动态调整频率
225
+ adjusted_min_silence_factor: float = 0.6 # 提高默认值,使调整不那么激进
226
+ ):
227
+ self.vad_iterator = VADIteratorOnnx(threshold, sampling_rate, min_silence_duration_ms, max_speech_duration_s,
228
+ long_speech_threshold_s, adjusted_min_silence_factor)
229
+ self.speech_pad_samples = int(sampling_rate * speech_pad_ms / 1000)
230
+ self.sampling_rate = sampling_rate
231
+ self.audio_buffer = np.array([], dtype=np.float32)
232
+ self.start = 0
233
+ self.end = 0
234
+ self.offset = 0
235
+ # 检查 speech_pad_ms 是否小于 min_silence_duration_ms 是一个好的实践,但非强制
236
+ # assert speech_pad_ms <= min_silence_duration_ms, "speech_pad_ms should be less than min_silence_duration_ms"
237
+ self.max_speech_samples = int(sampling_rate * max_speech_duration_s)
238
+
239
+ self.silence_chunk_size = 0
240
+ # 基于窗口大小计算静音阈值(例如,大约2秒的静音)
241
+ self.silence_chunk_threshold = int(2.0 / (512 / self.sampling_rate))
242
+
243
+ def reset(self):
244
+ self.audio_buffer = np.array([], dtype=np.float32)
245
+ self.start = 0
246
+ self.end = 0
247
+ self.offset = 0
248
+ self.vad_iterator.reset_states()
249
+ self.silence_chunk_size = 0 # 重置静音计数器
250
+
251
+ def __call__(self, x: np.ndarray = None):
252
+ if x is None:
253
+ # 处理缓冲区中剩余的音频
254
+ # 检查条件:VAD 正在触发状态,或者 VAD 未触发但已检测到 start 且缓冲区有内容
255
+ if self.vad_iterator.triggered or (self.start > self.offset and len(self.audio_buffer) > 0):
256
+ start_global = max(self.offset, self.start - self.speech_pad_samples)
257
+ # 结束点是缓冲区的绝对末尾
258
+ end_global = self.offset + len(self.audio_buffer)
259
+
260
+ # 确保 start < end
261
+ if start_global < end_global:
262
+ start_ts = round(start_global / self.sampling_rate, 1)
263
+ end_ts = round(end_global / self.sampling_rate, 1)
264
+
265
+ # 提取数据,从计算出的 buffer 内索引开始到 buffer 末尾
266
+ buffer_start_index = max(0, start_global - self.offset)
267
+ audio_data = self.audio_buffer[buffer_start_index:]
268
+
269
+ if len(audio_data) > 0:
270
+ result = {
271
+ "start": start_ts,
272
+ "end": end_ts,
273
+ "audio": audio_data,
274
+ }
275
+ else:
276
+ result = None
277
+ else:
278
+ result = None # start >= end, 无效片段
279
+ else:
280
+ result = None # 无需处理的剩余音频
281
+ self.reset() # 处理完剩余部分后重置状态
282
+ return result
283
+
284
+ # 将新音频块添加到缓冲区
285
+ self.audio_buffer = np.append(self.audio_buffer, deepcopy(x))
286
+
287
+ # 使用 VAD 迭代器处理新块
288
+ vad_result = self.vad_iterator(x)
289
+ if vad_result is not None:
290
+ self.silence_chunk_size = 0 # VAD 有活动,重置静音计数
291
+ if 'start' in vad_result:
292
+ # 仅当尚未开始一个新片段时更新 start
293
+ # (self.start <= self.offset 意味着上一个片段已结束或从未开始)
294
+ if self.start <= self.offset:
295
+ self.start = vad_result['start'] + self.offset
296
+ if 'end' in vad_result:
297
+ # 仅当已检测到 start 时更新 end
298
+ if self.start > self.offset:
299
+ self.end = vad_result['end'] + self.offset
300
+ else:
301
+ # 仅在 VAD 未触发且未检测到语音开始时增加静音计数
302
+ if not self.vad_iterator.triggered and self.start <= self.offset:
303
+ self.silence_chunk_size += 1
304
+
305
+ # --- 缓冲区管理 ---
306
+ # 1. 清理前导静音 (如果从未检测到语音开始)
307
+ if self.start <= self.offset and not self.vad_iterator.triggered and len(self.audio_buffer) > self.speech_pad_samples:
308
+ # 仅当 VAD 内部状态也确认无语音时清理
309
+ if self.vad_iterator.speech_start_sample == 0:
310
+ clearable_length = len(self.audio_buffer) - self.speech_pad_samples
311
+ self.offset += clearable_length
312
+ self.audio_buffer = self.audio_buffer[clearable_length:]
313
+ self.silence_chunk_size = 0 # 清理后重置计数
314
+
315
+ # 2. 因长时间静音清理缓冲区 (如果从未检测到语音开始)
316
+ if self.start <= self.offset and not self.vad_iterator.triggered and self.silence_chunk_size >= self.silence_chunk_threshold:
317
+ clearable_length = len(self.audio_buffer) # 清理到当前位置的所有内容
318
+ if clearable_length > 0:
319
+ self.offset += clearable_length
320
+ self.audio_buffer = np.array([], dtype=np.float32) # 清空缓冲区
321
+ self.silence_chunk_size = 0 # 重置计数
322
+ # --- 结束缓冲区管理 ---
323
+
324
+ # --- 片段提取 ---
325
+ segment_to_return = None
326
+ if self.end > self.start:
327
+ # 检测到完整语音段 [start, end]
328
+ start_global = max(self.offset, self.start - self.speech_pad_samples)
329
+ end_global = self.end + self.speech_pad_samples
330
+
331
+ # 实际能提取的结束点不能超过当前缓冲区的末尾
332
+ effective_end_global = min(end_global, self.offset + len(self.audio_buffer))
333
+
334
+ # 确保 start_global < effective_end_global
335
+ if start_global < effective_end_global:
336
+ start_ts = round(start_global / self.sampling_rate, 1)
337
+ # 时间戳使用理论上的 end_global
338
+ end_ts = round(end_global / self.sampling_rate, 1)
339
+
340
+ # 计算在当前 audio_buffer 中的索引
341
+ buffer_start_index = max(0, start_global - self.offset)
342
+ buffer_end_index = effective_end_global - self.offset
343
+
344
+ if buffer_start_index < buffer_end_index: # 确保索引有效
345
+ audio_data = self.audio_buffer[buffer_start_index : buffer_end_index]
346
+
347
+ # --- 更新缓冲区和 Offset ---
348
+ # 保留从提取片段之后的数据
349
+ keep_from_index = buffer_end_index
350
+
351
+ if keep_from_index < len(self.audio_buffer):
352
+ self.audio_buffer = self.audio_buffer[keep_from_index:]
353
+ # *** 关键修复 ***: 新的 offset 是保留下来的缓冲区的起始全局位置
354
+ self.offset = effective_end_global
355
+ else:
356
+ # 提取的片段到达或超过了缓冲区的末尾
357
+ self.audio_buffer = np.array([], dtype=np.float32)
358
+ self.offset = effective_end_global # Offset 更新到缓冲区结束的位置
359
+
360
+ # 重置 start 和 end 以寻找下一个片段
361
+ # 新的查找应该从新的 offset 开始
362
+ self.start = self.offset
363
+ self.end = self.offset
364
+
365
+ segment_to_return = {
366
+ "start": start_ts,
367
+ "end": end_ts,
368
+ "audio": audio_data,
369
+ }
370
+ else:
371
+ # 索引无效,可能由快速的 start/end 事件或 padding 引起
372
+ # 谨慎重置状态,避免丢失同步
373
+ self.start = self.offset
374
+ self.end = self.offset
375
+ else:
376
+ # start_global >= effective_end_global,无效,重置状态
377
+ self.start = self.offset
378
+ self.end = self.offset
379
+
380
+ return segment_to_return
381
+
382
+
383
+ class VadProcessor:
384
+ def __init__(
385
+ self,
386
+ prob_threshold=0.5,
387
+ silence_s=0.2,
388
+ cache_s=0.15, # 这个参数现在由 VadV2 内部的 speech_pad_ms 控制
389
+ sr=16000,
390
+ long_speech_threshold_s: float = 6.0, # 新增:默认长语音阈值
391
+ adjusted_min_silence_factor: float = 0.5 # 新增:默认调整因子
392
+ ):
393
+ self.prob_threshold = prob_threshold
394
+ # self.cache_s = cache_s # 不再直接使用 cache_s,改用 speech_pad_ms
395
+ self.sr = sr
396
+ self.silence_s = silence_s # 用于 min_silence_duration_ms
397
+ self.speech_pad_s = cache_s # 将 cache_s 理解为 speech_pad_ms
398
+
399
+ # 传递所有参数给 VadV2
400
+ self.vad = VadV2(
401
+ threshold=self.prob_threshold,
402
+ sampling_rate=self.sr,
403
+ min_silence_duration_ms=int(self.silence_s * 1000),
404
+ speech_pad_ms=int(self.speech_pad_s * 1000),
405
+ max_speech_duration_s=15, # 保持原来的最大时长限制
406
+ long_speech_threshold_s=long_speech_threshold_s, # 传递新参数
407
+ adjusted_min_silence_factor=adjusted_min_silence_factor # 传递新参数
408
+ )
409
+
410
+
411
+ def process_audio(self, audio_buffer: np.ndarray):
412
+ audio = np.array([], np.float32)
413
+ chunk_size = 512 # VAD 模型期望的块大小
414
+ for i in range(0, len(audio_buffer), chunk_size):
415
+ chunk = audio_buffer[i:i+chunk_size]
416
+ # 如果是最后一块且长度不足,VADIteratorOnnx 内部会处理 padding
417
+ ret = self.vad(chunk)
418
+ if ret:
419
+ audio = np.append(audio, ret['audio'])
420
+
421
+ # 处理结束后,调用 vad(None) 来获取缓冲区中剩余的音频
422
+ final_ret = self.vad(None)
423
+ if final_ret:
424
+ audio = np.append(audio, final_ret['audio'])
425
+
426
+ return audio
427
+
428
+ # 可能需要一个 reset 方法来重置 VAD 状态,以备复用 VadProcessor 实例
429
+ def reset(self):
430
+ self.vad.reset()
transcribe/whisper_llm_serve.py CHANGED
@@ -16,6 +16,7 @@ from .translatepipes import TranslatePipes
16
  from .strategy import (
17
  TranscriptStabilityAnalyzer, TranscriptToken)
18
  from transcribe.helpers.vadprocessor import VadProcessor
 
19
  from transcribe.pipelines import MetaItem
20
 
21
  logger = getLogger("TranscriptionService")
@@ -59,9 +60,9 @@ class WhisperTranscriptionService:
59
  self.translate_thread = self._start_thread(self._transcription_processing_loop)
60
  self.frame_processing_thread = self._start_thread(self._frame_processing_loop)
61
  if language == "zh":
62
- self._vad = VadProcessor(prob_threshold=0.8, silence_s=0.2, cache_s=0.15)
63
  else:
64
- self._vad = VadProcessor(prob_threshold=0.7, silence_s=0.2, cache_s=0.15)
65
  self.row_number = 0
66
  # for test
67
  self._transcrible_time_cost = 0.
 
16
  from .strategy import (
17
  TranscriptStabilityAnalyzer, TranscriptToken)
18
  from transcribe.helpers.vadprocessor import VadProcessor
19
+ # from transcribe.helpers.vad_dynamic import VadProcessor
20
  from transcribe.pipelines import MetaItem
21
 
22
  logger = getLogger("TranscriptionService")
 
60
  self.translate_thread = self._start_thread(self._transcription_processing_loop)
61
  self.frame_processing_thread = self._start_thread(self._frame_processing_loop)
62
  if language == "zh":
63
+ self._vad = VadProcessor(prob_threshold=0.5, silence_s=0.15, cache_s=0.12)
64
  else:
65
+ self._vad = VadProcessor(prob_threshold=0.5, silence_s=0.2, cache_s=0.15)
66
  self.row_number = 0
67
  # for test
68
  self._transcrible_time_cost = 0.