File size: 14,539 Bytes
c6b44fd e046f39 0c38083 e046f39 813ffab c6b44fd ce0e589 c6b44fd 0c38083 e046f39 813ffab 0c38083 813ffab 0c38083 813ffab fdeedee 0c38083 813ffab b9ad81c fdeedee 813ffab b9ad81c fdeedee 813ffab b9ad81c e046f39 813ffab f887108 e046f39 813ffab e046f39 0c38083 e046f39 9f6a51c 813ffab e046f39 813ffab 0a036e5 813ffab 0a036e5 e046f39 fdeedee 813ffab 99b58ae 813ffab fdeedee 813ffab 99b58ae 813ffab fdeedee 813ffab fdeedee 99b58ae c4470f1 99b58ae fdeedee 0a036e5 fdeedee 813ffab fdeedee 813ffab fdeedee 813ffab fdeedee 813ffab fdeedee 813ffab fdeedee 99b58ae 813ffab c6b44fd 99b58ae 5c84c3c 813ffab 99b58ae c6b44fd 7dc6a6f 813ffab e1e0093 fdeedee 359ffc6 fdeedee 813ffab 02e7bde 813ffab e1e0093 e046f39 813ffab fdeedee 0c38083 813ffab 0c38083 fdeedee 5c84c3c ce0e589 5c84c3c 359ffc6 5c84c3c 813ffab c6b44fd fdeedee 02e7bde 5c84c3c c6b44fd fdeedee c6b44fd 27321a0 02e7bde c6b44fd 02e7bde c6b44fd 99b58ae 813ffab 02e7bde fdeedee ce0e589 7dc6a6f b6e4de3 ce0e589 99b58ae 813ffab 02e7bde b6e4de3 99b58ae 813ffab fdeedee 27321a0 99b58ae 813ffab 5c84c3c 27321a0 813ffab 27321a0 813ffab 27321a0 813ffab 99b58ae fdeedee 813ffab c6b44fd 813ffab 7dc6a6f 813ffab 9df3704 730ea7e fdeedee 7dc6a6f 813ffab fa46942 99b58ae fa46942 813ffab 99b58ae 5c84c3c 813ffab 0c38083 813ffab 99b58ae fdeedee 99b58ae eabbf72 99b58ae fdeedee 99b58ae 813ffab c54125b cd7fb92 813ffab fa46942 813ffab 02e7bde cd7fb92 fdeedee 813ffab 02e7bde 813ffab b6e4de3 cd7fb92 813ffab e1e0093 813ffab b2b3b92 99b58ae 813ffab e1e0093 813ffab 99b58ae 813ffab fdeedee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 |
import collections
import logging
from difflib import SequenceMatcher
from itertools import chain
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Deque, Any, Iterator,Literal
from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE
from enum import Enum
import wordninja
import config
import re
logger = logging.getLogger("TranscriptionStrategy")
class SplitMode(Enum):
PUNCTUATION = "punctuation"
PAUSE = "pause"
END = "end"
@dataclass
class TranscriptResult:
seg_id: int = 0
cut_index: int = 0
is_end_sentence: bool = False
context: str = ""
def partial(self):
return not self.is_end_sentence
@dataclass
class TranscriptToken:
"""表示一个转录片段,包含文本和时间信息"""
text: str # 转录的文本内容
t0: int # 开始时间(百分之一秒)
t1: int # 结束时间(百分之一秒)
def is_punctuation(self):
"""检查文本是否包含标点符号"""
return REGEX_MARKERS.search(self.text.strip()) is not None
def is_end(self):
"""检查文本是否为句子结束标记"""
return SENTENCE_END_PATTERN.search(self.text.strip()) is not None
def is_pause(self):
"""检查文本是否为暂停标记"""
return PAUSEE_END_PATTERN.search(self.text.strip()) is not None
def buffer_index(self) -> int:
return int(self.t1 / 100 * SAMPLE_RATE)
@dataclass
class TranscriptChunk:
"""表示一组转录片段,支持分割和比较操作"""
separator: str = "" # 用于连接片段的分隔符
items: list[TranscriptToken] = field(default_factory=list) # 转录片段列表
@staticmethod
def _calculate_similarity(text1: str, text2: str) -> float:
"""计算两段文本的相似度"""
return SequenceMatcher(None, text1, text2).ratio()
def split_by(self, mode: SplitMode) -> list['TranscriptChunk']:
"""根据文本中的标点符号分割片段列表"""
if mode == SplitMode.PUNCTUATION:
indexes = [i for i, seg in enumerate(self.items) if seg.is_punctuation()]
elif mode == SplitMode.PAUSE:
indexes = [i for i, seg in enumerate(self.items) if seg.is_pause()]
elif mode == SplitMode.END:
indexes = [i for i, seg in enumerate(self.items) if seg.is_end()]
else:
raise ValueError(f"Unsupported mode: {mode}")
# 每个切分点向后移一个索引,表示“分隔符归前段”
cut_points = [0] + sorted(i + 1 for i in indexes) + [len(self.items)]
chunks = [
TranscriptChunk(items=self.items[start:end], separator=self.separator)
for start, end in zip(cut_points, cut_points[1:])
]
return [
ck
for ck in chunks
if not ck.only_punctuation()
]
def get_split_first_rest(self, mode: SplitMode):
chunks = self.split_by(mode)
fisrt_chunk = chunks[0] if chunks else self
rest_chunks = chunks[1:] if chunks else None
return fisrt_chunk, rest_chunks
def puncation_numbers(self) -> int:
"""计算片段中标点符号的数量"""
return sum(1 for seg in self.items if seg.is_punctuation())
def length(self) -> int:
"""返回片段列表的长度"""
return len(self.items)
def join(self) -> str:
"""将片段连接为一个字符串"""
return self.separator.join(seg.text for seg in self.items)
def compare(self, chunk: Optional['TranscriptChunk'] = None) -> float:
"""比较当前片段与另一个片段的相似度"""
if not chunk:
return 0
score = self._calculate_similarity(self.join(), chunk.join())
# logger.debug(f"Compare: {self.join()} vs {chunk.join()} : {score}")
return score
def only_punctuation(self)->bool:
return all(seg.is_punctuation() for seg in self.items)
def has_punctuation(self) -> bool:
return any(seg.is_punctuation() for seg in self.items)
def get_buffer_index(self) -> int:
return self.items[-1].buffer_index()
def is_end_sentence(self) ->bool:
return self.items[-1].is_end()
class TranscriptHistory:
"""管理转录片段的历史记录"""
def __init__(self) -> None:
self.history = collections.deque(maxlen=2) # 存储最近的两个片段
def add(self, chunk: TranscriptChunk):
"""添加新的片段到历史记录"""
self.history.appendleft(chunk)
def previous_chunk(self) -> Optional[TranscriptChunk]:
"""获取上一个片段(如果存在)"""
return self.history[1] if len(self.history) == 2 else None
def lastest_chunk(self):
"""获取最后一个片段"""
return self.history[-1]
def clear(self):
self.history.clear()
class TranscriptBuffer:
"""
管理转录文本的分级结构:临时字符串 -> 短句 -> 完整段落
|-- 已确认文本 --|-- 观察窗口 --|-- 新输入 --|
管理 pending -> line -> paragraph 的缓冲逻辑
"""
def __init__(self, source_lang:str, separator:str):
self._segments: List[str] = collections.deque(maxlen=2) # 确认的完整段落
self._sentences: List[str] = collections.deque() # 当前段落中的短句
self._buffer: str = "" # 当前缓冲中的文本
self._current_seg_id: int = 0
self.source_language = source_lang
self._separator = separator
def get_seg_id(self) -> int:
return self._current_seg_id
@property
def current_sentences_length(self) -> int:
count = 0
for item in self._sentences:
if self._separator:
count += len(item.split(self._separator))
else:
count += len(item)
return count
def update_pending_text(self, text: str) -> None:
"""更新临时缓冲字符串"""
self._buffer = text
def commit_line(self,) -> None:
"""将缓冲字符串提交为短句"""
if self._buffer:
self._sentences.append(self._buffer)
self._buffer = ""
def commit_paragraph(self) -> None:
"""
提交当前短句为完整段落(如句子结束)
Args:
end_of_sentence: 是否为句子结尾(如检测到句号)
"""
count = 0
current_sentences = []
while len(self._sentences): # and count < 20:
item = self._sentences.popleft()
current_sentences.append(item)
if self._separator:
count += len(item.split(self._separator))
else:
count += len(item)
if current_sentences:
self._segments.append("".join(current_sentences))
logger.debug(f"=== count to paragraph ===")
logger.debug(f"push: {current_sentences}")
logger.debug(f"rest: {self._sentences}")
# if self._sentences:
# self._segments.append("".join(self._sentences))
# self._sentences.clear()
def rebuild(self, text):
output = self.split_and_join(
text.replace(
self._separator, ""))
logger.debug("==== rebuild string ====")
logger.debug(text)
logger.debug(output)
return output
@staticmethod
def split_and_join(text):
tokens = []
word_buf = ''
for char in text:
if char in ALL_MARKERS:
if word_buf:
tokens.extend(wordninja.split(word_buf))
word_buf = ''
tokens.append(char)
else:
word_buf += char
if word_buf:
tokens.extend(wordninja.split(word_buf))
output = ''
for i, token in enumerate(tokens):
if i == 0:
output += token
elif token in ALL_MARKERS:
output += (token + " ")
else:
output += ' ' + token
return output
def update_and_commit(self, stable_strings: List[str], remaining_strings:List[str], is_end_sentence=False):
if self.source_language == "en":
stable_strings = [self.rebuild(i) for i in stable_strings]
remaining_strings =[self.rebuild(i) for i in remaining_strings]
remaining_string = "".join(remaining_strings)
logger.debug(f"{self.__dict__}")
if is_end_sentence:
for stable_str in stable_strings:
self.update_pending_text(stable_str)
self.commit_line()
current_text_len = len(self.current_not_commit_text.split(self._separator)) if self._separator else len(self.current_not_commit_text)
# current_text_len = len(self.current_not_commit_text.split(self._separator))
self.update_pending_text(remaining_string)
if current_text_len >= config.TEXT_THREHOLD:
self.commit_paragraph()
self._current_seg_id += 1
return True
else:
for stable_str in stable_strings:
self.update_pending_text(stable_str)
self.commit_line()
self.update_pending_text(remaining_string)
return False
@property
def un_commit_paragraph(self) -> str:
"""当前短句组合"""
return "".join([i for i in self._sentences])
@property
def pending_text(self) -> str:
"""当前缓冲内容"""
return self._buffer
@property
def latest_paragraph(self) -> str:
"""最新确认的段落"""
return self._segments[-1] if self._segments else ""
@property
def current_not_commit_text(self) -> str:
return self.un_commit_paragraph + self.pending_text
class TranscriptStabilityAnalyzer:
def __init__(self, source_lang, separator) -> None:
self._transcript_buffer = TranscriptBuffer(source_lang=source_lang,separator=separator)
self._transcript_history = TranscriptHistory()
self._separator = separator
logger.debug(f"Current separator: {self._separator}")
def merge_chunks(self, chunks: List[TranscriptChunk])->str:
if not chunks:
return [""]
output = list(r.join() for r in chunks if r)
return output
def analysis(self, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]:
current = TranscriptChunk(items=current, separator=self._separator)
self._transcript_history.add(current)
prev = self._transcript_history.previous_chunk()
self._transcript_buffer.update_pending_text(current.join())
if not prev: # 如果没有历史记录 那么就说明是新的语句 直接输出就行
yield TranscriptResult(
context=self._transcript_buffer.current_not_commit_text,
seg_id=self._transcript_buffer.get_seg_id()
)
return
# yield from self._handle_short_buffer(current, prev)
if buffer_duration <= 4:
yield from self._handle_short_buffer(current, prev)
else:
yield from self._handle_long_buffer(current)
def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]:
curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION)
prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION)
# logger.debug("==== Current cut item ====")
# logger.debug(f"{curr.join()} ")
# logger.debug(f"{prev.join()}")
# logger.debug("==========================")
if curr_first and prev_first:
core = curr_first.compare(prev_first)
has_punctuation = curr_first.has_punctuation()
if core >= 0.8 and has_punctuation:
yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence())
return
yield TranscriptResult(
seg_id=self._transcript_buffer.get_seg_id(),
context=self._transcript_buffer.current_not_commit_text
)
def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]:
chunks = curr.split_by(SplitMode.PUNCTUATION)
if len(chunks) > 1:
stable, remaining = chunks[:-1], chunks[-1:]
# stable_str = self.merge_chunks(stable)
# remaining_str = self.merge_chunks(remaining)
yield from self._yield_commit_results(
stable, remaining, is_end_sentence=True # 暂时硬编码为True
)
else:
yield TranscriptResult(
seg_id=self._transcript_buffer.get_seg_id(),
context=self._transcript_buffer.current_not_commit_text
)
def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]:
stable_str_list = [stable_chunk.join()] if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk)
remaining_str_list = self.merge_chunks(remaining_chunks)
frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index()
prev_seg_id = self._transcript_buffer.get_seg_id()
commit_paragraph = self._transcript_buffer.update_and_commit(stable_str_list, remaining_str_list, is_end_sentence)
logger.debug(f"current buffer: {self._transcript_buffer.__dict__}")
if commit_paragraph:
# 表示生成了一个新段落 换行
yield TranscriptResult(
seg_id=prev_seg_id,
cut_index=frame_cut_index,
context=self._transcript_buffer.latest_paragraph,
is_end_sentence=True
)
if (context := self._transcript_buffer.current_not_commit_text.strip()):
yield TranscriptResult(
seg_id=self._transcript_buffer.get_seg_id(),
context=context,
)
else:
yield TranscriptResult(
seg_id=self._transcript_buffer.get_seg_id(),
cut_index=frame_cut_index,
context=self._transcript_buffer.current_not_commit_text,
)
|