daihui.zhang
commited on
Commit
·
5518c26
1
Parent(s):
80839d5
filter [] words
Browse files
transcribe/helpers/vadprocessor.py
CHANGED
|
@@ -113,6 +113,7 @@ class VADIteratorOnnx:
|
|
| 113 |
sampling_rate: int = 16000,
|
| 114 |
min_silence_duration_ms: int = 100,
|
| 115 |
max_speech_duration_s: float = float('inf'),
|
|
|
|
| 116 |
):
|
| 117 |
self.model = OnnxWrapper(VAD_MODEL_PATH, True)
|
| 118 |
self.threshold = threshold
|
|
@@ -123,7 +124,7 @@ class VADIteratorOnnx:
|
|
| 123 |
|
| 124 |
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
| 125 |
self.max_speech_samples = int(sampling_rate * max_speech_duration_s)
|
| 126 |
-
|
| 127 |
self.reset_states()
|
| 128 |
|
| 129 |
def reset_states(self):
|
|
@@ -158,7 +159,8 @@ class VADIteratorOnnx:
|
|
| 158 |
|
| 159 |
if (speech_prob >= self.threshold) and not self.triggered:
|
| 160 |
self.triggered = True
|
| 161 |
-
speech_start = max(0, self.current_sample - window_size_samples)
|
|
|
|
| 162 |
self.start = speech_start
|
| 163 |
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
|
| 164 |
|
|
@@ -174,7 +176,8 @@ class VADIteratorOnnx:
|
|
| 174 |
if self.current_sample - self.temp_end < self.min_silence_samples:
|
| 175 |
return None
|
| 176 |
else:
|
| 177 |
-
speech_end = self.temp_end - window_size_samples
|
|
|
|
| 178 |
self.temp_end = 0
|
| 179 |
self.triggered = False
|
| 180 |
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
|
|
|
|
| 113 |
sampling_rate: int = 16000,
|
| 114 |
min_silence_duration_ms: int = 100,
|
| 115 |
max_speech_duration_s: float = float('inf'),
|
| 116 |
+
speech_pad_ms: int = 30
|
| 117 |
):
|
| 118 |
self.model = OnnxWrapper(VAD_MODEL_PATH, True)
|
| 119 |
self.threshold = threshold
|
|
|
|
| 124 |
|
| 125 |
self.min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
| 126 |
self.max_speech_samples = int(sampling_rate * max_speech_duration_s)
|
| 127 |
+
self.speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
| 128 |
self.reset_states()
|
| 129 |
|
| 130 |
def reset_states(self):
|
|
|
|
| 159 |
|
| 160 |
if (speech_prob >= self.threshold) and not self.triggered:
|
| 161 |
self.triggered = True
|
| 162 |
+
# speech_start = max(0, self.current_sample - window_size_samples)
|
| 163 |
+
speech_start = max(0, self.current_sample - self.speech_pad_samples - window_size_samples)
|
| 164 |
self.start = speech_start
|
| 165 |
return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
|
| 166 |
|
|
|
|
| 176 |
if self.current_sample - self.temp_end < self.min_silence_samples:
|
| 177 |
return None
|
| 178 |
else:
|
| 179 |
+
# speech_end = self.temp_end - window_size_samples
|
| 180 |
+
speech_end = self.temp_end + self.speech_pad_samples - window_size_samples
|
| 181 |
self.temp_end = 0
|
| 182 |
self.triggered = False
|
| 183 |
return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
|
transcribe/helpers/whisper.py
CHANGED
|
@@ -52,7 +52,7 @@ class WhisperCPP:
|
|
| 52 |
initial_prompt=prompt,
|
| 53 |
language=language,
|
| 54 |
# token_timestamps=True,
|
| 55 |
-
|
| 56 |
# max_len=max_len
|
| 57 |
)
|
| 58 |
return output
|
|
|
|
| 52 |
initial_prompt=prompt,
|
| 53 |
language=language,
|
| 54 |
# token_timestamps=True,
|
| 55 |
+
split_on_word=True,
|
| 56 |
# max_len=max_len
|
| 57 |
)
|
| 58 |
return output
|
transcribe/utils.py
CHANGED
|
@@ -7,6 +7,51 @@ from scipy.io.wavfile import write
|
|
| 7 |
import config
|
| 8 |
import csv
|
| 9 |
import av
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
def log_block(key: str, value, unit=''):
|
| 11 |
if config.DEBUG:
|
| 12 |
return
|
|
|
|
| 7 |
import config
|
| 8 |
import csv
|
| 9 |
import av
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
+
# Compile regex patterns once outside the loop for better performance
|
| 13 |
+
p_pattern = re.compile(r"(\s*\[.*?\])")
|
| 14 |
+
p_start_pattern = re.compile(r"(\s*\[.*)")
|
| 15 |
+
p_end_pattern = re.compile(r"(\s*.*\])")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def filter_words(res_word):
|
| 19 |
+
"""
|
| 20 |
+
Filter words according to specific bracket patterns.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
res_word: Iterable of word objects with a 'text' attribute
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
List of filtered word objects
|
| 27 |
+
"""
|
| 28 |
+
asr_results = []
|
| 29 |
+
skip_word = False
|
| 30 |
+
|
| 31 |
+
for word in res_word:
|
| 32 |
+
# Skip words that completely match the pattern
|
| 33 |
+
if p_pattern.match(word.text):
|
| 34 |
+
continue
|
| 35 |
+
|
| 36 |
+
# Mark the start of a section to skip
|
| 37 |
+
if p_start_pattern.match(word.text):
|
| 38 |
+
skip_word = True
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
# Mark the end of a section to skip
|
| 42 |
+
if p_end_pattern.match(word.text) and skip_word:
|
| 43 |
+
skip_word = False
|
| 44 |
+
continue
|
| 45 |
+
|
| 46 |
+
# Skip words if we're in a skip section
|
| 47 |
+
if skip_word:
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
# Add the word to results if it passed all filters
|
| 51 |
+
asr_results.append(word)
|
| 52 |
+
|
| 53 |
+
return asr_results
|
| 54 |
+
|
| 55 |
def log_block(key: str, value, unit=''):
|
| 56 |
if config.DEBUG:
|
| 57 |
return
|
transcribe/whisper_llm_serve.py
CHANGED
|
@@ -11,7 +11,7 @@ import config
|
|
| 11 |
import collections
|
| 12 |
from api_model import TransResult, Message, DebugResult
|
| 13 |
|
| 14 |
-
from .utils import log_block, save_to_wave, TestDataWriter
|
| 15 |
from .translatepipes import TranslatePipes
|
| 16 |
from .strategy import (
|
| 17 |
TranscriptStabilityAnalyzer, TranscriptToken)
|
|
@@ -191,6 +191,7 @@ class WhisperTranscriptionService:
|
|
| 191 |
meta_item = self._transcribe_audio(audio_buffer)
|
| 192 |
segments = meta_item.segments
|
| 193 |
logger.debug(f"Segments: {segments}")
|
|
|
|
| 194 |
if len(segments):
|
| 195 |
seg_text = self.text_separator.join(seg.text for seg in segments)
|
| 196 |
if self._temp_string:
|
|
|
|
| 11 |
import collections
|
| 12 |
from api_model import TransResult, Message, DebugResult
|
| 13 |
|
| 14 |
+
from .utils import log_block, save_to_wave, TestDataWriter, filter_words
|
| 15 |
from .translatepipes import TranslatePipes
|
| 16 |
from .strategy import (
|
| 17 |
TranscriptStabilityAnalyzer, TranscriptToken)
|
|
|
|
| 191 |
meta_item = self._transcribe_audio(audio_buffer)
|
| 192 |
segments = meta_item.segments
|
| 193 |
logger.debug(f"Segments: {segments}")
|
| 194 |
+
segments = filter_words(segments)
|
| 195 |
if len(segments):
|
| 196 |
seg_text = self.text_separator.join(seg.text for seg in segments)
|
| 197 |
if self._temp_string:
|