Translator / transcribe /strategy.py
daihui.zhang
vad parameters v1 test
c4470f1
raw
history blame
14.5 kB
import collections
import logging
from difflib import SequenceMatcher
from itertools import chain
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Deque, Any, Iterator,Literal
from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE
from enum import Enum
import wordninja
import config
import re
logger = logging.getLogger("TranscriptionStrategy")
class SplitMode(Enum):
PUNCTUATION = "punctuation"
PAUSE = "pause"
END = "end"
@dataclass
class TranscriptResult:
seg_id: int = 0
cut_index: int = 0
is_end_sentence: bool = False
context: str = ""
def partial(self):
return not self.is_end_sentence
@dataclass
class TranscriptToken:
"""表示一个转录片段,包含文本和时间信息"""
text: str # 转录的文本内容
t0: int # 开始时间(百分之一秒)
t1: int # 结束时间(百分之一秒)
def is_punctuation(self):
"""检查文本是否包含标点符号"""
return REGEX_MARKERS.search(self.text.strip()) is not None
def is_end(self):
"""检查文本是否为句子结束标记"""
return SENTENCE_END_PATTERN.search(self.text.strip()) is not None
def is_pause(self):
"""检查文本是否为暂停标记"""
return PAUSEE_END_PATTERN.search(self.text.strip()) is not None
def buffer_index(self) -> int:
return int(self.t1 / 100 * SAMPLE_RATE)
@dataclass
class TranscriptChunk:
"""表示一组转录片段,支持分割和比较操作"""
separator: str = "" # 用于连接片段的分隔符
items: list[TranscriptToken] = field(default_factory=list) # 转录片段列表
@staticmethod
def _calculate_similarity(text1: str, text2: str) -> float:
"""计算两段文本的相似度"""
return SequenceMatcher(None, text1, text2).ratio()
def split_by(self, mode: SplitMode) -> list['TranscriptChunk']:
"""根据文本中的标点符号分割片段列表"""
if mode == SplitMode.PUNCTUATION:
indexes = [i for i, seg in enumerate(self.items) if seg.is_punctuation()]
elif mode == SplitMode.PAUSE:
indexes = [i for i, seg in enumerate(self.items) if seg.is_pause()]
elif mode == SplitMode.END:
indexes = [i for i, seg in enumerate(self.items) if seg.is_end()]
else:
raise ValueError(f"Unsupported mode: {mode}")
# 每个切分点向后移一个索引,表示“分隔符归前段”
cut_points = [0] + sorted(i + 1 for i in indexes) + [len(self.items)]
chunks = [
TranscriptChunk(items=self.items[start:end], separator=self.separator)
for start, end in zip(cut_points, cut_points[1:])
]
return [
ck
for ck in chunks
if not ck.only_punctuation()
]
def get_split_first_rest(self, mode: SplitMode):
chunks = self.split_by(mode)
fisrt_chunk = chunks[0] if chunks else self
rest_chunks = chunks[1:] if chunks else None
return fisrt_chunk, rest_chunks
def puncation_numbers(self) -> int:
"""计算片段中标点符号的数量"""
return sum(1 for seg in self.items if seg.is_punctuation())
def length(self) -> int:
"""返回片段列表的长度"""
return len(self.items)
def join(self) -> str:
"""将片段连接为一个字符串"""
return self.separator.join(seg.text for seg in self.items)
def compare(self, chunk: Optional['TranscriptChunk'] = None) -> float:
"""比较当前片段与另一个片段的相似度"""
if not chunk:
return 0
score = self._calculate_similarity(self.join(), chunk.join())
# logger.debug(f"Compare: {self.join()} vs {chunk.join()} : {score}")
return score
def only_punctuation(self)->bool:
return all(seg.is_punctuation() for seg in self.items)
def has_punctuation(self) -> bool:
return any(seg.is_punctuation() for seg in self.items)
def get_buffer_index(self) -> int:
return self.items[-1].buffer_index()
def is_end_sentence(self) ->bool:
return self.items[-1].is_end()
class TranscriptHistory:
"""管理转录片段的历史记录"""
def __init__(self) -> None:
self.history = collections.deque(maxlen=2) # 存储最近的两个片段
def add(self, chunk: TranscriptChunk):
"""添加新的片段到历史记录"""
self.history.appendleft(chunk)
def previous_chunk(self) -> Optional[TranscriptChunk]:
"""获取上一个片段(如果存在)"""
return self.history[1] if len(self.history) == 2 else None
def lastest_chunk(self):
"""获取最后一个片段"""
return self.history[-1]
def clear(self):
self.history.clear()
class TranscriptBuffer:
"""
管理转录文本的分级结构:临时字符串 -> 短句 -> 完整段落
|-- 已确认文本 --|-- 观察窗口 --|-- 新输入 --|
管理 pending -> line -> paragraph 的缓冲逻辑
"""
def __init__(self, source_lang:str, separator:str):
self._segments: List[str] = collections.deque(maxlen=2) # 确认的完整段落
self._sentences: List[str] = collections.deque() # 当前段落中的短句
self._buffer: str = "" # 当前缓冲中的文本
self._current_seg_id: int = 0
self.source_language = source_lang
self._separator = separator
def get_seg_id(self) -> int:
return self._current_seg_id
@property
def current_sentences_length(self) -> int:
count = 0
for item in self._sentences:
if self._separator:
count += len(item.split(self._separator))
else:
count += len(item)
return count
def update_pending_text(self, text: str) -> None:
"""更新临时缓冲字符串"""
self._buffer = text
def commit_line(self,) -> None:
"""将缓冲字符串提交为短句"""
if self._buffer:
self._sentences.append(self._buffer)
self._buffer = ""
def commit_paragraph(self) -> None:
"""
提交当前短句为完整段落(如句子结束)
Args:
end_of_sentence: 是否为句子结尾(如检测到句号)
"""
count = 0
current_sentences = []
while len(self._sentences): # and count < 20:
item = self._sentences.popleft()
current_sentences.append(item)
if self._separator:
count += len(item.split(self._separator))
else:
count += len(item)
if current_sentences:
self._segments.append("".join(current_sentences))
logger.debug(f"=== count to paragraph ===")
logger.debug(f"push: {current_sentences}")
logger.debug(f"rest: {self._sentences}")
# if self._sentences:
# self._segments.append("".join(self._sentences))
# self._sentences.clear()
def rebuild(self, text):
output = self.split_and_join(
text.replace(
self._separator, ""))
logger.debug("==== rebuild string ====")
logger.debug(text)
logger.debug(output)
return output
@staticmethod
def split_and_join(text):
tokens = []
word_buf = ''
for char in text:
if char in ALL_MARKERS:
if word_buf:
tokens.extend(wordninja.split(word_buf))
word_buf = ''
tokens.append(char)
else:
word_buf += char
if word_buf:
tokens.extend(wordninja.split(word_buf))
output = ''
for i, token in enumerate(tokens):
if i == 0:
output += token
elif token in ALL_MARKERS:
output += (token + " ")
else:
output += ' ' + token
return output
def update_and_commit(self, stable_strings: List[str], remaining_strings:List[str], is_end_sentence=False):
if self.source_language == "en":
stable_strings = [self.rebuild(i) for i in stable_strings]
remaining_strings =[self.rebuild(i) for i in remaining_strings]
remaining_string = "".join(remaining_strings)
logger.debug(f"{self.__dict__}")
if is_end_sentence:
for stable_str in stable_strings:
self.update_pending_text(stable_str)
self.commit_line()
current_text_len = len(self.current_not_commit_text.split(self._separator)) if self._separator else len(self.current_not_commit_text)
# current_text_len = len(self.current_not_commit_text.split(self._separator))
self.update_pending_text(remaining_string)
if current_text_len >= config.TEXT_THREHOLD:
self.commit_paragraph()
self._current_seg_id += 1
return True
else:
for stable_str in stable_strings:
self.update_pending_text(stable_str)
self.commit_line()
self.update_pending_text(remaining_string)
return False
@property
def un_commit_paragraph(self) -> str:
"""当前短句组合"""
return "".join([i for i in self._sentences])
@property
def pending_text(self) -> str:
"""当前缓冲内容"""
return self._buffer
@property
def latest_paragraph(self) -> str:
"""最新确认的段落"""
return self._segments[-1] if self._segments else ""
@property
def current_not_commit_text(self) -> str:
return self.un_commit_paragraph + self.pending_text
class TranscriptStabilityAnalyzer:
def __init__(self, source_lang, separator) -> None:
self._transcript_buffer = TranscriptBuffer(source_lang=source_lang,separator=separator)
self._transcript_history = TranscriptHistory()
self._separator = separator
logger.debug(f"Current separator: {self._separator}")
def merge_chunks(self, chunks: List[TranscriptChunk])->str:
if not chunks:
return [""]
output = list(r.join() for r in chunks if r)
return output
def analysis(self, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]:
current = TranscriptChunk(items=current, separator=self._separator)
self._transcript_history.add(current)
prev = self._transcript_history.previous_chunk()
self._transcript_buffer.update_pending_text(current.join())
if not prev: # 如果没有历史记录 那么就说明是新的语句 直接输出就行
yield TranscriptResult(
context=self._transcript_buffer.current_not_commit_text,
seg_id=self._transcript_buffer.get_seg_id()
)
return
# yield from self._handle_short_buffer(current, prev)
if buffer_duration <= 4:
yield from self._handle_short_buffer(current, prev)
else:
yield from self._handle_long_buffer(current)
def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]:
curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION)
prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION)
# logger.debug("==== Current cut item ====")
# logger.debug(f"{curr.join()} ")
# logger.debug(f"{prev.join()}")
# logger.debug("==========================")
if curr_first and prev_first:
core = curr_first.compare(prev_first)
has_punctuation = curr_first.has_punctuation()
if core >= 0.8 and has_punctuation:
yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence())
return
yield TranscriptResult(
seg_id=self._transcript_buffer.get_seg_id(),
context=self._transcript_buffer.current_not_commit_text
)
def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]:
chunks = curr.split_by(SplitMode.PUNCTUATION)
if len(chunks) > 1:
stable, remaining = chunks[:-1], chunks[-1:]
# stable_str = self.merge_chunks(stable)
# remaining_str = self.merge_chunks(remaining)
yield from self._yield_commit_results(
stable, remaining, is_end_sentence=True # 暂时硬编码为True
)
else:
yield TranscriptResult(
seg_id=self._transcript_buffer.get_seg_id(),
context=self._transcript_buffer.current_not_commit_text
)
def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]:
stable_str_list = [stable_chunk.join()] if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk)
remaining_str_list = self.merge_chunks(remaining_chunks)
frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index()
prev_seg_id = self._transcript_buffer.get_seg_id()
commit_paragraph = self._transcript_buffer.update_and_commit(stable_str_list, remaining_str_list, is_end_sentence)
logger.debug(f"current buffer: {self._transcript_buffer.__dict__}")
if commit_paragraph:
# 表示生成了一个新段落 换行
yield TranscriptResult(
seg_id=prev_seg_id,
cut_index=frame_cut_index,
context=self._transcript_buffer.latest_paragraph,
is_end_sentence=True
)
if (context := self._transcript_buffer.current_not_commit_text.strip()):
yield TranscriptResult(
seg_id=self._transcript_buffer.get_seg_id(),
context=context,
)
else:
yield TranscriptResult(
seg_id=self._transcript_buffer.get_seg_id(),
cut_index=frame_cut_index,
context=self._transcript_buffer.current_not_commit_text,
)