Translator / transcribe /strategy.py
Xin Zhang
[fix]: refactor.
7f191bc
raw
history blame
4.97 kB
from logging import getLogger
from difflib import SequenceMatcher
import collections
import config
import numpy as np
from itertools import chain
logger = getLogger("Stragegy")
class TripleTextBuffer:
def __init__(self, size=2):
self.history = collections.deque(maxlen=size)
def add_entry(self, text, index):
"""
text: 文本
index: 当前buffer的相对下标 数组索引
"""
self.history.append((text, index))
def get_final_index(self, similarity_threshold=0.7):
"""根据文本变化,返回可靠的标点的buffer的位置下标"""
if len(self.history) < 2:
return None
# 获取三次的文本
text1, _ = self.history[0]
text2, idx2 = self.history[1]
# text3, idx3 = self.history[2]
# 计算变化程度
sim_12 = self.text_similarity(text1, text2)
# print("比较: ", text1, text2," => ", sim_12)
# sim_23 = self.text_similarity(text2, text3)
if sim_12 >= similarity_threshold:
self.history.clear()
return idx2
return None
@staticmethod
def text_similarity(text1, text2):
return SequenceMatcher(None, text1, text2).ratio()
class SegmentManager:
def __init__(self) -> None:
self._commited_segments = [] # 确定后的段落
self._commited_short_sentences = [] # 确定后的序列
self._temp_string = "" # 存储当前临时的文本字符串,直到以句号结尾
def handle(self, string):
self._temp_string = string
return self
@property
def short_sentence(self) -> str:
return "".join(self._commited_short_sentences)
@property
def segment(self):
return self._commited_segments[-1] if len(self._commited_segments) > 0 else ""
def get_seg_id(self):
return len(self._commited_segments)
@property
def string(self):
return self._temp_string
def commit_short_sentence(self):
"""将临时字符串 提交到临时短句"""
self._commited_short_sentences.append(self._temp_string)
self._temp_string = ""
def commit_segment(self):
"""将短句 合并 到长句中"""
self._commited_segments.append(self.short_sentence)
self._commited_short_sentences = []
def commit(self, is_end_sentence=False):
"""
当需要切掉的音频部分的时候,将句子提交到短句队列中,并移除临时字符串
当完成一个整句的时候提交到段落中
"""
self.commit_short_sentence()
if is_end_sentence:
self.commit_segment()
def segement_merge(segments):
"""根据标点符号分整句"""
sequences = []
temp_seq = []
for seg in segments:
temp_seq.append(seg)
if any([mk in seg.text for mk in config.SENTENCE_END_MARKERS + config.PAUSE_END_MARKERS]):
sequences.append(temp_seq.copy())
temp_seq = []
if temp_seq:
sequences.append(temp_seq)
return sequences
def segments_split(segments, audio_buffer: np.ndarray, sample_rate=16000):
"""根据左边第一个标点符号来将序列拆分成 观察段 和 剩余部分"""
left_watch_sequences = []
left_watch_idx = 0
right_watch_sequences = []
is_end = False
if (len(audio_buffer) / sample_rate) < 12:
# 低于12s 使用短句符号比如逗号作为判断依据
markers = config.PAUSE_END_MARKERS + config.SENTENCE_END_MARKERS
is_end = False
for idx, seg in enumerate(segments):
left_watch_sequences.append(seg)
if seg.text and seg.text[-1] in markers:
seg_index = int(seg.t1 / 100 * sample_rate)
# rest_buffer_duration = (len(audio_buffer) - seg_index) / sample_rate
# is_end = any(i in seg.text for i in config.SENTENCE_END_MARKERS)
right_watch_sequences = segments[min(idx+1, len(segments)):]
# if rest_buffer_duration >= 1.5:
left_watch_idx = seg_index
break
return left_watch_idx, left_watch_sequences, right_watch_sequences, is_end
def sequences_split(segments, audio_buffer: np.ndarray, sample_rate=16000):
# 长句 保留最后两句即可
left_watch_sequences = []
right_watch_sequences = []
left_watch_idx = 0
is_end = False
sequences = segement_merge(segments)
if len(sequences) > 2:
logger.info(f"buffer clip via sequence, current length: {len(sequences)}")
is_end = True
left_watch_sequences = chain(*sequences[:-2])
right_watch_sequences = chain(*sequences[-2:])
last_sequence_segment = sequences[-3]
last_segment = last_sequence_segment[-1]
left_watch_idx = int(last_segment.t1 / 100 * sample_rate)
return left_watch_idx, left_watch_sequences, right_watch_sequences, is_end