|
|
|
|
|
import collections |
|
|
import logging |
|
|
from difflib import SequenceMatcher |
|
|
from itertools import chain |
|
|
from dataclasses import dataclass, field |
|
|
from typing import List, Tuple, Optional, Deque, Any, Iterator,Literal |
|
|
from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE |
|
|
from enum import Enum |
|
|
import wordninja |
|
|
import config |
|
|
import re |
|
|
logger = logging.getLogger("TranscriptionStrategy") |
|
|
|
|
|
|
|
|
class SplitMode(Enum): |
|
|
PUNCTUATION = "punctuation" |
|
|
PAUSE = "pause" |
|
|
END = "end" |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TranscriptResult: |
|
|
seg_id: int = 0 |
|
|
cut_index: int = 0 |
|
|
is_end_sentence: bool = False |
|
|
context: str = "" |
|
|
|
|
|
def partial(self): |
|
|
return not self.is_end_sentence |
|
|
|
|
|
@dataclass |
|
|
class TranscriptToken: |
|
|
"""表示一个转录片段,包含文本和时间信息""" |
|
|
text: str |
|
|
t0: int |
|
|
t1: int |
|
|
|
|
|
def is_punctuation(self): |
|
|
"""检查文本是否包含标点符号""" |
|
|
return REGEX_MARKERS.search(self.text.strip()) is not None |
|
|
|
|
|
def is_end(self): |
|
|
"""检查文本是否为句子结束标记""" |
|
|
return SENTENCE_END_PATTERN.search(self.text.strip()) is not None |
|
|
|
|
|
def is_pause(self): |
|
|
"""检查文本是否为暂停标记""" |
|
|
return PAUSEE_END_PATTERN.search(self.text.strip()) is not None |
|
|
|
|
|
def buffer_index(self) -> int: |
|
|
return int(self.t1 / 100 * SAMPLE_RATE) |
|
|
|
|
|
@dataclass |
|
|
class TranscriptChunk: |
|
|
"""表示一组转录片段,支持分割和比较操作""" |
|
|
separator: str = "" |
|
|
items: list[TranscriptToken] = field(default_factory=list) |
|
|
|
|
|
@staticmethod |
|
|
def _calculate_similarity(text1: str, text2: str) -> float: |
|
|
"""计算两段文本的相似度""" |
|
|
return SequenceMatcher(None, text1, text2).ratio() |
|
|
|
|
|
def split_by(self, mode: SplitMode) -> list['TranscriptChunk']: |
|
|
"""根据文本中的标点符号分割片段列表""" |
|
|
if mode == SplitMode.PUNCTUATION: |
|
|
indexes = [i for i, seg in enumerate(self.items) if seg.is_punctuation()] |
|
|
elif mode == SplitMode.PAUSE: |
|
|
indexes = [i for i, seg in enumerate(self.items) if seg.is_pause()] |
|
|
elif mode == SplitMode.END: |
|
|
indexes = [i for i, seg in enumerate(self.items) if seg.is_end()] |
|
|
else: |
|
|
raise ValueError(f"Unsupported mode: {mode}") |
|
|
|
|
|
|
|
|
cut_points = [0] + sorted(i + 1 for i in indexes) + [len(self.items)] |
|
|
chunks = [ |
|
|
TranscriptChunk(items=self.items[start:end], separator=self.separator) |
|
|
for start, end in zip(cut_points, cut_points[1:]) |
|
|
] |
|
|
return [ |
|
|
ck |
|
|
for ck in chunks |
|
|
if not ck.only_punctuation() |
|
|
] |
|
|
|
|
|
|
|
|
def get_split_first_rest(self, mode: SplitMode): |
|
|
chunks = self.split_by(mode) |
|
|
fisrt_chunk = chunks[0] if chunks else self |
|
|
rest_chunks = chunks[1:] if chunks else None |
|
|
return fisrt_chunk, rest_chunks |
|
|
|
|
|
def puncation_numbers(self) -> int: |
|
|
"""计算片段中标点符号的数量""" |
|
|
return sum(1 for seg in self.items if seg.is_punctuation()) |
|
|
|
|
|
def length(self) -> int: |
|
|
"""返回片段列表的长度""" |
|
|
return len(self.items) |
|
|
|
|
|
def join(self) -> str: |
|
|
"""将片段连接为一个字符串""" |
|
|
return self.separator.join(seg.text for seg in self.items) |
|
|
|
|
|
def compare(self, chunk: Optional['TranscriptChunk'] = None) -> float: |
|
|
"""比较当前片段与另一个片段的相似度""" |
|
|
if not chunk: |
|
|
return 0 |
|
|
|
|
|
score = self._calculate_similarity(self.join(), chunk.join()) |
|
|
|
|
|
return score |
|
|
|
|
|
def only_punctuation(self)->bool: |
|
|
return all(seg.is_punctuation() for seg in self.items) |
|
|
|
|
|
def has_punctuation(self) -> bool: |
|
|
return any(seg.is_punctuation() for seg in self.items) |
|
|
|
|
|
def get_buffer_index(self) -> int: |
|
|
return self.items[-1].buffer_index() |
|
|
|
|
|
def is_end_sentence(self) ->bool: |
|
|
return self.items[-1].is_end() |
|
|
|
|
|
|
|
|
class TranscriptHistory: |
|
|
"""管理转录片段的历史记录""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
self.history = collections.deque(maxlen=2) |
|
|
|
|
|
def add(self, chunk: TranscriptChunk): |
|
|
"""添加新的片段到历史记录""" |
|
|
self.history.appendleft(chunk) |
|
|
|
|
|
def previous_chunk(self) -> Optional[TranscriptChunk]: |
|
|
"""获取上一个片段(如果存在)""" |
|
|
return self.history[1] if len(self.history) == 2 else None |
|
|
|
|
|
def lastest_chunk(self): |
|
|
"""获取最后一个片段""" |
|
|
return self.history[-1] |
|
|
|
|
|
def clear(self): |
|
|
self.history.clear() |
|
|
|
|
|
class TranscriptBuffer: |
|
|
""" |
|
|
管理转录文本的分级结构:临时字符串 -> 短句 -> 完整段落 |
|
|
|
|
|
|-- 已确认文本 --|-- 观察窗口 --|-- 新输入 --| |
|
|
|
|
|
管理 pending -> line -> paragraph 的缓冲逻辑 |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self, source_lang:str, separator:str): |
|
|
self._segments: List[str] = collections.deque(maxlen=2) |
|
|
self._sentences: List[str] = collections.deque() |
|
|
self._buffer: str = "" |
|
|
self._current_seg_id: int = 0 |
|
|
self.source_language = source_lang |
|
|
self._separator = separator |
|
|
|
|
|
def get_seg_id(self) -> int: |
|
|
return self._current_seg_id |
|
|
|
|
|
@property |
|
|
def current_sentences_length(self) -> int: |
|
|
count = 0 |
|
|
for item in self._sentences: |
|
|
if self._separator: |
|
|
count += len(item.split(self._separator)) |
|
|
else: |
|
|
count += len(item) |
|
|
return count |
|
|
|
|
|
def update_pending_text(self, text: str) -> None: |
|
|
"""更新临时缓冲字符串""" |
|
|
self._buffer = text |
|
|
|
|
|
def commit_line(self,) -> None: |
|
|
"""将缓冲字符串提交为短句""" |
|
|
if self._buffer: |
|
|
self._sentences.append(self._buffer) |
|
|
self._buffer = "" |
|
|
|
|
|
def commit_paragraph(self) -> None: |
|
|
""" |
|
|
提交当前短句为完整段落(如句子结束) |
|
|
|
|
|
Args: |
|
|
end_of_sentence: 是否为句子结尾(如检测到句号) |
|
|
""" |
|
|
|
|
|
count = 0 |
|
|
current_sentences = [] |
|
|
while len(self._sentences): |
|
|
item = self._sentences.popleft() |
|
|
current_sentences.append(item) |
|
|
if self._separator: |
|
|
count += len(item.split(self._separator)) |
|
|
else: |
|
|
count += len(item) |
|
|
if current_sentences: |
|
|
self._segments.append("".join(current_sentences)) |
|
|
logger.debug(f"=== count to paragraph ===") |
|
|
logger.debug(f"push: {current_sentences}") |
|
|
logger.debug(f"rest: {self._sentences}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rebuild(self, text): |
|
|
output = self.split_and_join( |
|
|
text.replace( |
|
|
self._separator, "")) |
|
|
|
|
|
logger.debug("==== rebuild string ====") |
|
|
logger.debug(text) |
|
|
logger.debug(output) |
|
|
|
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def split_and_join(text): |
|
|
tokens = [] |
|
|
word_buf = '' |
|
|
|
|
|
for char in text: |
|
|
if char in ALL_MARKERS: |
|
|
if word_buf: |
|
|
tokens.extend(wordninja.split(word_buf)) |
|
|
word_buf = '' |
|
|
tokens.append(char) |
|
|
else: |
|
|
word_buf += char |
|
|
if word_buf: |
|
|
tokens.extend(wordninja.split(word_buf)) |
|
|
|
|
|
output = '' |
|
|
for i, token in enumerate(tokens): |
|
|
if i == 0: |
|
|
output += token |
|
|
elif token in ALL_MARKERS: |
|
|
output += (token + " ") |
|
|
else: |
|
|
output += ' ' + token |
|
|
return output |
|
|
|
|
|
|
|
|
def update_and_commit(self, stable_strings: List[str], remaining_strings:List[str], is_end_sentence=False): |
|
|
if self.source_language == "en": |
|
|
stable_strings = [self.rebuild(i) for i in stable_strings] |
|
|
remaining_strings =[self.rebuild(i) for i in remaining_strings] |
|
|
remaining_string = "".join(remaining_strings) |
|
|
|
|
|
logger.debug(f"{self.__dict__}") |
|
|
if is_end_sentence: |
|
|
for stable_str in stable_strings: |
|
|
self.update_pending_text(stable_str) |
|
|
self.commit_line() |
|
|
|
|
|
current_text_len = len(self.current_not_commit_text.split(self._separator)) if self._separator else len(self.current_not_commit_text) |
|
|
|
|
|
self.update_pending_text(remaining_string) |
|
|
if current_text_len >= config.TEXT_THREHOLD: |
|
|
self.commit_paragraph() |
|
|
self._current_seg_id += 1 |
|
|
return True |
|
|
else: |
|
|
for stable_str in stable_strings: |
|
|
self.update_pending_text(stable_str) |
|
|
self.commit_line() |
|
|
self.update_pending_text(remaining_string) |
|
|
return False |
|
|
|
|
|
|
|
|
@property |
|
|
def un_commit_paragraph(self) -> str: |
|
|
"""当前短句组合""" |
|
|
return "".join([i for i in self._sentences]) |
|
|
|
|
|
@property |
|
|
def pending_text(self) -> str: |
|
|
"""当前缓冲内容""" |
|
|
return self._buffer |
|
|
|
|
|
@property |
|
|
def latest_paragraph(self) -> str: |
|
|
"""最新确认的段落""" |
|
|
return self._segments[-1] if self._segments else "" |
|
|
|
|
|
@property |
|
|
def current_not_commit_text(self) -> str: |
|
|
return self.un_commit_paragraph + self.pending_text |
|
|
|
|
|
|
|
|
|
|
|
class TranscriptStabilityAnalyzer: |
|
|
def __init__(self, source_lang, separator) -> None: |
|
|
self._transcript_buffer = TranscriptBuffer(source_lang=source_lang,separator=separator) |
|
|
self._transcript_history = TranscriptHistory() |
|
|
self._separator = separator |
|
|
logger.debug(f"Current separator: {self._separator}") |
|
|
|
|
|
def merge_chunks(self, chunks: List[TranscriptChunk])->str: |
|
|
if not chunks: |
|
|
return [""] |
|
|
output = list(r.join() for r in chunks if r) |
|
|
return output |
|
|
|
|
|
|
|
|
def analysis(self, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]: |
|
|
current = TranscriptChunk(items=current, separator=self._separator) |
|
|
self._transcript_history.add(current) |
|
|
|
|
|
prev = self._transcript_history.previous_chunk() |
|
|
self._transcript_buffer.update_pending_text(current.join()) |
|
|
if not prev: |
|
|
yield TranscriptResult( |
|
|
context=self._transcript_buffer.current_not_commit_text, |
|
|
seg_id=self._transcript_buffer.get_seg_id() |
|
|
) |
|
|
return |
|
|
|
|
|
|
|
|
if buffer_duration <= 4: |
|
|
yield from self._handle_short_buffer(current, prev) |
|
|
else: |
|
|
yield from self._handle_long_buffer(current) |
|
|
|
|
|
|
|
|
def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]: |
|
|
curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION) |
|
|
prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if curr_first and prev_first: |
|
|
|
|
|
core = curr_first.compare(prev_first) |
|
|
has_punctuation = curr_first.has_punctuation() |
|
|
if core >= 0.8 and has_punctuation: |
|
|
yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence()) |
|
|
return |
|
|
|
|
|
yield TranscriptResult( |
|
|
seg_id=self._transcript_buffer.get_seg_id(), |
|
|
context=self._transcript_buffer.current_not_commit_text |
|
|
) |
|
|
|
|
|
|
|
|
def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]: |
|
|
chunks = curr.split_by(SplitMode.PUNCTUATION) |
|
|
if len(chunks) > 1: |
|
|
stable, remaining = chunks[:-1], chunks[-1:] |
|
|
|
|
|
|
|
|
yield from self._yield_commit_results( |
|
|
stable, remaining, is_end_sentence=True |
|
|
) |
|
|
else: |
|
|
yield TranscriptResult( |
|
|
seg_id=self._transcript_buffer.get_seg_id(), |
|
|
context=self._transcript_buffer.current_not_commit_text |
|
|
) |
|
|
|
|
|
|
|
|
def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]: |
|
|
stable_str_list = [stable_chunk.join()] if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk) |
|
|
remaining_str_list = self.merge_chunks(remaining_chunks) |
|
|
frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index() |
|
|
|
|
|
prev_seg_id = self._transcript_buffer.get_seg_id() |
|
|
commit_paragraph = self._transcript_buffer.update_and_commit(stable_str_list, remaining_str_list, is_end_sentence) |
|
|
logger.debug(f"current buffer: {self._transcript_buffer.__dict__}") |
|
|
|
|
|
if commit_paragraph: |
|
|
|
|
|
yield TranscriptResult( |
|
|
seg_id=prev_seg_id, |
|
|
cut_index=frame_cut_index, |
|
|
context=self._transcript_buffer.latest_paragraph, |
|
|
is_end_sentence=True |
|
|
) |
|
|
if (context := self._transcript_buffer.current_not_commit_text.strip()): |
|
|
yield TranscriptResult( |
|
|
seg_id=self._transcript_buffer.get_seg_id(), |
|
|
context=context, |
|
|
) |
|
|
else: |
|
|
yield TranscriptResult( |
|
|
seg_id=self._transcript_buffer.get_seg_id(), |
|
|
cut_index=frame_cut_index, |
|
|
context=self._transcript_buffer.current_not_commit_text, |
|
|
) |
|
|
|
|
|
|