|
|
|
|
|
|
|
|
from logging import getLogger |
|
|
from difflib import SequenceMatcher |
|
|
import collections |
|
|
import config |
|
|
import numpy as np |
|
|
from itertools import chain |
|
|
|
|
|
logger = getLogger("Stragegy") |
|
|
|
|
|
class TripleTextBuffer: |
|
|
def __init__(self, size=2): |
|
|
self.history = collections.deque(maxlen=size) |
|
|
|
|
|
def add_entry(self, text, index): |
|
|
""" |
|
|
text: 文本 |
|
|
index: 当前buffer的相对下标 数组索引 |
|
|
""" |
|
|
self.history.append((text, index)) |
|
|
|
|
|
|
|
|
def get_final_index(self, similarity_threshold=0.7): |
|
|
"""根据文本变化,返回可靠的标点的buffer的位置下标""" |
|
|
if len(self.history) < 2: |
|
|
return None |
|
|
|
|
|
|
|
|
text1, _ = self.history[0] |
|
|
text2, idx2 = self.history[1] |
|
|
|
|
|
|
|
|
|
|
|
sim_12 = self.text_similarity(text1, text2) |
|
|
|
|
|
|
|
|
if sim_12 >= similarity_threshold: |
|
|
self.history.clear() |
|
|
return idx2 |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def text_similarity(text1, text2): |
|
|
return SequenceMatcher(None, text1, text2).ratio() |
|
|
|
|
|
|
|
|
|
|
|
class SegmentManager: |
|
|
def __init__(self) -> None: |
|
|
self._commited_segments = [] |
|
|
self._commited_short_sentences = [] |
|
|
self._temp_string = "" |
|
|
|
|
|
def handle(self, string): |
|
|
self._temp_string = string |
|
|
return self |
|
|
|
|
|
@property |
|
|
def short_sentence(self) -> str: |
|
|
return "".join(self._commited_short_sentences) |
|
|
|
|
|
@property |
|
|
def segment(self): |
|
|
return self._commited_segments[-1] if len(self._commited_segments) > 0 else "" |
|
|
|
|
|
def get_seg_id(self): |
|
|
return len(self._commited_segments) |
|
|
|
|
|
@property |
|
|
def string(self): |
|
|
return self._temp_string |
|
|
|
|
|
|
|
|
def commit_short_sentence(self): |
|
|
"""将临时字符串 提交到临时短句""" |
|
|
self._commited_short_sentences.append(self._temp_string) |
|
|
self._temp_string = "" |
|
|
|
|
|
def commit_segment(self): |
|
|
"""将短句 合并 到长句中""" |
|
|
self._commited_segments.append(self.short_sentence) |
|
|
self._commited_short_sentences = [] |
|
|
|
|
|
def commit(self, is_end_sentence=False): |
|
|
""" |
|
|
当需要切掉的音频部分的时候,将句子提交到短句队列中,并移除临时字符串 |
|
|
当完成一个整句的时候提交到段落中 |
|
|
""" |
|
|
self.commit_short_sentence() |
|
|
if is_end_sentence: |
|
|
self.commit_segment() |
|
|
|
|
|
def segement_merge(segments): |
|
|
"""根据标点符号分整句""" |
|
|
sequences = [] |
|
|
temp_seq = [] |
|
|
|
|
|
for seg in segments: |
|
|
temp_seq.append(seg) |
|
|
if any([mk in seg.text for mk in config.SENTENCE_END_MARKERS + config.PAUSE_END_MARKERS]): |
|
|
sequences.append(temp_seq.copy()) |
|
|
temp_seq = [] |
|
|
if temp_seq: |
|
|
sequences.append(temp_seq) |
|
|
return sequences |
|
|
|
|
|
def segments_split(segments, audio_buffer: np.ndarray, sample_rate=16000): |
|
|
"""根据左边第一个标点符号来将序列拆分成 观察段 和 剩余部分""" |
|
|
left_watch_sequences = [] |
|
|
left_watch_idx = 0 |
|
|
right_watch_sequences = [] |
|
|
is_end = False |
|
|
|
|
|
if (len(audio_buffer) / sample_rate) < 12: |
|
|
|
|
|
markers = config.PAUSE_END_MARKERS + config.SENTENCE_END_MARKERS |
|
|
is_end = False |
|
|
|
|
|
for idx, seg in enumerate(segments): |
|
|
left_watch_sequences.append(seg) |
|
|
if seg.text and seg.text[-1] in markers: |
|
|
seg_index = int(seg.t1 / 100 * sample_rate) |
|
|
|
|
|
|
|
|
right_watch_sequences = segments[min(idx+1, len(segments)):] |
|
|
|
|
|
left_watch_idx = seg_index |
|
|
break |
|
|
return left_watch_idx, left_watch_sequences, right_watch_sequences, is_end |
|
|
|
|
|
|
|
|
def sequences_split(segments, audio_buffer: np.ndarray, sample_rate=16000): |
|
|
|
|
|
left_watch_sequences = [] |
|
|
right_watch_sequences = [] |
|
|
left_watch_idx = 0 |
|
|
is_end = False |
|
|
sequences = segement_merge(segments) |
|
|
|
|
|
if len(sequences) > 2: |
|
|
logger.info(f"buffer clip via sequence, current length: {len(sequences)}") |
|
|
is_end = True |
|
|
left_watch_sequences = chain(*sequences[:-2]) |
|
|
right_watch_sequences = chain(*sequences[-2:]) |
|
|
last_sequence_segment = sequences[-3] |
|
|
last_segment = last_sequence_segment[-1] |
|
|
left_watch_idx = int(last_segment.t1 / 100 * sample_rate) |
|
|
return left_watch_idx, left_watch_sequences, right_watch_sequences, is_end |
|
|
|
|
|
|
|
|
|