File size: 14,539 Bytes
c6b44fd
e046f39
0c38083
 
e046f39
813ffab
 
 
 
c6b44fd
ce0e589
c6b44fd
0c38083
e046f39
 
813ffab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c38083
813ffab
0c38083
813ffab
fdeedee
 
0c38083
813ffab
 
b9ad81c
fdeedee
813ffab
 
b9ad81c
fdeedee
813ffab
 
b9ad81c
e046f39
813ffab
f887108
e046f39
813ffab
 
 
 
 
e046f39
 
0c38083
 
e046f39
9f6a51c
813ffab
 
 
 
 
 
 
 
 
 
e046f39
813ffab
 
0a036e5
813ffab
 
 
0a036e5
 
 
 
 
e046f39
fdeedee
813ffab
 
 
99b58ae
813ffab
fdeedee
813ffab
 
99b58ae
813ffab
 
 
 
 
 
 
 
fdeedee
813ffab
 
 
 
fdeedee
99b58ae
c4470f1
99b58ae
fdeedee
0a036e5
 
fdeedee
813ffab
 
fdeedee
813ffab
 
fdeedee
813ffab
 
fdeedee
813ffab
 
 
 
 
 
fdeedee
813ffab
 
 
 
 
 
 
 
 
 
 
fdeedee
99b58ae
 
813ffab
 
 
 
 
 
 
 
 
 
 
c6b44fd
99b58ae
5c84c3c
813ffab
99b58ae
c6b44fd
7dc6a6f
813ffab
e1e0093
 
fdeedee
359ffc6
 
 
 
 
 
 
 
 
fdeedee
813ffab
 
 
 
02e7bde
813ffab
 
 
 
 
e1e0093
e046f39
813ffab
fdeedee
0c38083
813ffab
0c38083
fdeedee
5c84c3c
 
ce0e589
5c84c3c
 
359ffc6
 
 
 
5c84c3c
 
 
 
 
 
 
 
813ffab
c6b44fd
 
 
 
fdeedee
02e7bde
 
 
5c84c3c
c6b44fd
fdeedee
c6b44fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27321a0
02e7bde
c6b44fd
02e7bde
 
 
c6b44fd
99b58ae
813ffab
02e7bde
 
 
fdeedee
ce0e589
7dc6a6f
b6e4de3
ce0e589
99b58ae
 
 
813ffab
02e7bde
 
 
b6e4de3
99b58ae
813ffab
fdeedee
27321a0
99b58ae
813ffab
5c84c3c
27321a0
813ffab
 
 
 
27321a0
 
813ffab
 
 
 
27321a0
813ffab
99b58ae
fdeedee
813ffab
 
 
c6b44fd
 
813ffab
7dc6a6f
 
813ffab
 
9df3704
 
730ea7e
 
fdeedee
 
7dc6a6f
 
813ffab
 
 
fa46942
99b58ae
fa46942
 
 
 
813ffab
 
99b58ae
5c84c3c
813ffab
0c38083
813ffab
 
 
 
 
 
99b58ae
 
 
 
 
 
 
fdeedee
99b58ae
eabbf72
 
99b58ae
 
fdeedee
99b58ae
 
 
 
813ffab
 
 
 
c54125b
 
cd7fb92
 
813ffab
fa46942
813ffab
 
 
 
 
 
 
 
 
02e7bde
 
cd7fb92
fdeedee
813ffab
02e7bde
813ffab
 
b6e4de3
cd7fb92
813ffab
e1e0093
813ffab
 
 
 
b2b3b92
 
 
 
 
99b58ae
813ffab
e1e0093
813ffab
99b58ae
813ffab
fdeedee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406

import collections
import logging
from difflib import SequenceMatcher
from itertools import chain
from dataclasses import dataclass, field
from typing import List, Tuple, Optional, Deque, Any, Iterator,Literal
from config import SENTENCE_END_MARKERS, ALL_MARKERS,SENTENCE_END_PATTERN,REGEX_MARKERS, PAUSEE_END_PATTERN,SAMPLE_RATE
from enum import Enum
import wordninja
import config
import re
logger = logging.getLogger("TranscriptionStrategy")


class SplitMode(Enum):
    PUNCTUATION = "punctuation"
    PAUSE = "pause"
    END = "end"



@dataclass
class TranscriptResult:
    seg_id: int = 0
    cut_index: int = 0
    is_end_sentence: bool = False
    context: str = ""

    def partial(self):
        return not self.is_end_sentence

@dataclass
class TranscriptToken:
    """表示一个转录片段,包含文本和时间信息"""
    text: str  # 转录的文本内容
    t0: int  # 开始时间(百分之一秒)
    t1: int  # 结束时间(百分之一秒)

    def is_punctuation(self):
        """检查文本是否包含标点符号"""
        return REGEX_MARKERS.search(self.text.strip()) is not  None

    def is_end(self):
        """检查文本是否为句子结束标记"""
        return SENTENCE_END_PATTERN.search(self.text.strip())  is not  None

    def is_pause(self):
        """检查文本是否为暂停标记"""
        return PAUSEE_END_PATTERN.search(self.text.strip()) is not  None

    def buffer_index(self) -> int:
        return int(self.t1 / 100 * SAMPLE_RATE)

@dataclass
class TranscriptChunk:
    """表示一组转录片段,支持分割和比较操作"""
    separator: str = ""  # 用于连接片段的分隔符
    items: list[TranscriptToken] = field(default_factory=list)  # 转录片段列表

    @staticmethod
    def _calculate_similarity(text1: str, text2: str) -> float:
        """计算两段文本的相似度"""
        return SequenceMatcher(None, text1, text2).ratio()

    def split_by(self, mode: SplitMode) -> list['TranscriptChunk']:
        """根据文本中的标点符号分割片段列表"""
        if mode == SplitMode.PUNCTUATION:
            indexes = [i for i, seg in enumerate(self.items) if seg.is_punctuation()]
        elif mode == SplitMode.PAUSE:
            indexes = [i for i, seg in enumerate(self.items) if seg.is_pause()]
        elif mode == SplitMode.END:
            indexes = [i for i, seg in enumerate(self.items) if seg.is_end()]
        else:
            raise ValueError(f"Unsupported mode: {mode}")

        # 每个切分点向后移一个索引,表示“分隔符归前段”
        cut_points = [0] + sorted(i + 1 for i in indexes) + [len(self.items)]
        chunks =  [
            TranscriptChunk(items=self.items[start:end], separator=self.separator)
            for start, end in zip(cut_points, cut_points[1:])
        ]
        return [
            ck
            for ck in chunks
            if not ck.only_punctuation()
        ]


    def get_split_first_rest(self,  mode: SplitMode):
        chunks = self.split_by(mode)
        fisrt_chunk = chunks[0] if chunks else self
        rest_chunks = chunks[1:] if chunks else None
        return fisrt_chunk, rest_chunks

    def puncation_numbers(self) -> int:
        """计算片段中标点符号的数量"""
        return sum(1 for seg in self.items if seg.is_punctuation())

    def length(self) -> int:
        """返回片段列表的长度"""
        return len(self.items)

    def join(self) -> str:
        """将片段连接为一个字符串"""
        return self.separator.join(seg.text for seg in self.items)

    def compare(self, chunk: Optional['TranscriptChunk'] = None) -> float:
        """比较当前片段与另一个片段的相似度"""
        if not chunk:
            return 0

        score =  self._calculate_similarity(self.join(), chunk.join())
        # logger.debug(f"Compare: {self.join()} vs {chunk.join()} : {score}")
        return score

    def only_punctuation(self)->bool:
        return all(seg.is_punctuation() for seg in self.items)

    def has_punctuation(self) -> bool:
        return any(seg.is_punctuation() for seg in self.items)

    def get_buffer_index(self) -> int:
        return self.items[-1].buffer_index()

    def is_end_sentence(self) ->bool:
        return self.items[-1].is_end()


class TranscriptHistory:
    """管理转录片段的历史记录"""

    def __init__(self) -> None:
        self.history = collections.deque(maxlen=2)  # 存储最近的两个片段

    def add(self, chunk: TranscriptChunk):
        """添加新的片段到历史记录"""
        self.history.appendleft(chunk)

    def previous_chunk(self) -> Optional[TranscriptChunk]:
        """获取上一个片段(如果存在)"""
        return self.history[1] if len(self.history) == 2 else None

    def lastest_chunk(self):
        """获取最后一个片段"""
        return self.history[-1]

    def clear(self):
        self.history.clear()

class TranscriptBuffer:
    """
    管理转录文本的分级结构:临时字符串 -> 短句 -> 完整段落

    |-- 已确认文本 --|-- 观察窗口 --|-- 新输入 --|

    管理 pending -> line -> paragraph 的缓冲逻辑

    """

    def __init__(self, source_lang:str, separator:str):
        self._segments: List[str] = collections.deque(maxlen=2)     # 确认的完整段落
        self._sentences: List[str] = collections.deque()   # 当前段落中的短句
        self._buffer: str = ""             # 当前缓冲中的文本
        self._current_seg_id: int = 0
        self.source_language = source_lang
        self._separator = separator

    def get_seg_id(self) -> int:
        return self._current_seg_id

    @property
    def current_sentences_length(self) -> int:
        count = 0
        for item in self._sentences:
            if self._separator:
                count += len(item.split(self._separator))
            else:
                count += len(item)
        return count

    def update_pending_text(self, text: str) -> None:
        """更新临时缓冲字符串"""
        self._buffer = text

    def commit_line(self,) -> None:
        """将缓冲字符串提交为短句"""
        if self._buffer:
            self._sentences.append(self._buffer)
            self._buffer = ""

    def commit_paragraph(self) -> None:
        """
        提交当前短句为完整段落(如句子结束)

        Args:
            end_of_sentence: 是否为句子结尾(如检测到句号)
        """

        count = 0
        current_sentences = []
        while len(self._sentences): # and count < 20:
            item = self._sentences.popleft()
            current_sentences.append(item)
            if self._separator:
                count += len(item.split(self._separator))
            else:
                count += len(item)
        if current_sentences:
            self._segments.append("".join(current_sentences))
        logger.debug(f"=== count to paragraph ===")
        logger.debug(f"push: {current_sentences}")
        logger.debug(f"rest: {self._sentences}")
        # if self._sentences:
        #     self._segments.append("".join(self._sentences))
        #     self._sentences.clear()

    def rebuild(self, text):
        output = self.split_and_join(
                    text.replace(
                        self._separator, ""))

        logger.debug("==== rebuild string ====")
        logger.debug(text)
        logger.debug(output)

        return output

    @staticmethod
    def split_and_join(text):
        tokens = []
        word_buf = ''

        for char in text:
            if char in ALL_MARKERS:
                if word_buf:
                    tokens.extend(wordninja.split(word_buf))
                    word_buf = ''
                tokens.append(char)
            else:
                word_buf += char
        if word_buf:
            tokens.extend(wordninja.split(word_buf))

        output = ''
        for i, token in enumerate(tokens):
            if i == 0:
                output += token
            elif token in ALL_MARKERS:
                output += (token + " ")
            else:
                output += ' ' + token
        return output


    def update_and_commit(self, stable_strings: List[str], remaining_strings:List[str], is_end_sentence=False):
        if self.source_language == "en":
            stable_strings = [self.rebuild(i) for i in stable_strings]
            remaining_strings =[self.rebuild(i) for i in remaining_strings]
        remaining_string = "".join(remaining_strings)

        logger.debug(f"{self.__dict__}")
        if is_end_sentence:
            for stable_str in stable_strings:
                self.update_pending_text(stable_str)
                self.commit_line()

            current_text_len = len(self.current_not_commit_text.split(self._separator)) if self._separator else len(self.current_not_commit_text)
            # current_text_len = len(self.current_not_commit_text.split(self._separator))
            self.update_pending_text(remaining_string)
            if current_text_len >= config.TEXT_THREHOLD:
                self.commit_paragraph()
                self._current_seg_id += 1
                return True
        else:
            for stable_str in stable_strings:
                self.update_pending_text(stable_str)
                self.commit_line()
            self.update_pending_text(remaining_string)
        return False


    @property
    def un_commit_paragraph(self) -> str:
        """当前短句组合"""
        return "".join([i for i in self._sentences])

    @property
    def pending_text(self) -> str:
        """当前缓冲内容"""
        return self._buffer

    @property
    def latest_paragraph(self) -> str:
        """最新确认的段落"""
        return self._segments[-1] if self._segments else ""

    @property
    def current_not_commit_text(self) -> str:
        return self.un_commit_paragraph + self.pending_text



class TranscriptStabilityAnalyzer:
    def __init__(self, source_lang, separator) -> None:
        self._transcript_buffer = TranscriptBuffer(source_lang=source_lang,separator=separator)
        self._transcript_history = TranscriptHistory()
        self._separator = separator
        logger.debug(f"Current separator: {self._separator}")

    def merge_chunks(self, chunks: List[TranscriptChunk])->str:
        if not chunks:
            return [""]
        output =  list(r.join() for r in chunks if r)
        return output


    def analysis(self, current: TranscriptChunk, buffer_duration: float) -> Iterator[TranscriptResult]:
        current = TranscriptChunk(items=current, separator=self._separator)
        self._transcript_history.add(current)

        prev = self._transcript_history.previous_chunk()
        self._transcript_buffer.update_pending_text(current.join())
        if not prev: # 如果没有历史记录 那么就说明是新的语句 直接输出就行
            yield TranscriptResult(
                context=self._transcript_buffer.current_not_commit_text,
                seg_id=self._transcript_buffer.get_seg_id()
            )
            return

        # yield from self._handle_short_buffer(current, prev)
        if buffer_duration <= 4:
            yield from self._handle_short_buffer(current, prev)
        else:
            yield from self._handle_long_buffer(current)


    def _handle_short_buffer(self, curr: TranscriptChunk, prev: TranscriptChunk) -> Iterator[TranscriptResult]:
        curr_first, curr_rest = curr.get_split_first_rest(SplitMode.PUNCTUATION)
        prev_first, _ = prev.get_split_first_rest(SplitMode.PUNCTUATION)

        # logger.debug("==== Current cut item ====")
        # logger.debug(f"{curr.join()} ")
        # logger.debug(f"{prev.join()}")
        # logger.debug("==========================")

        if curr_first and prev_first:

            core = curr_first.compare(prev_first)
            has_punctuation = curr_first.has_punctuation()
            if core >= 0.8 and has_punctuation:
                yield from self._yield_commit_results(curr_first, curr_rest, curr_first.is_end_sentence())
                return

        yield TranscriptResult(
            seg_id=self._transcript_buffer.get_seg_id(),
            context=self._transcript_buffer.current_not_commit_text
        )


    def _handle_long_buffer(self, curr: TranscriptChunk) -> Iterator[TranscriptResult]:
        chunks = curr.split_by(SplitMode.PUNCTUATION)
        if len(chunks) > 1:
            stable, remaining = chunks[:-1], chunks[-1:]
            # stable_str = self.merge_chunks(stable)
            # remaining_str = self.merge_chunks(remaining)
            yield from self._yield_commit_results(
                stable, remaining, is_end_sentence=True  # 暂时硬编码为True
            )
        else:
            yield TranscriptResult(
                seg_id=self._transcript_buffer.get_seg_id(),
                context=self._transcript_buffer.current_not_commit_text
            )


    def _yield_commit_results(self, stable_chunk, remaining_chunks, is_end_sentence: bool) -> Iterator[TranscriptResult]:
        stable_str_list = [stable_chunk.join()] if hasattr(stable_chunk, "join") else self.merge_chunks(stable_chunk)
        remaining_str_list = self.merge_chunks(remaining_chunks)
        frame_cut_index = stable_chunk[-1].get_buffer_index() if isinstance(stable_chunk, list) else stable_chunk.get_buffer_index()

        prev_seg_id = self._transcript_buffer.get_seg_id()
        commit_paragraph = self._transcript_buffer.update_and_commit(stable_str_list, remaining_str_list, is_end_sentence)
        logger.debug(f"current buffer: {self._transcript_buffer.__dict__}")

        if commit_paragraph:
            # 表示生成了一个新段落 换行
            yield TranscriptResult(
                seg_id=prev_seg_id,
                cut_index=frame_cut_index,
                context=self._transcript_buffer.latest_paragraph,
                is_end_sentence=True
            )
            if (context := self._transcript_buffer.current_not_commit_text.strip()):
                yield TranscriptResult(
                    seg_id=self._transcript_buffer.get_seg_id(),
                    context=context,
                )
        else:
            yield TranscriptResult(
                seg_id=self._transcript_buffer.get_seg_id(),
                cut_index=frame_cut_index,
                context=self._transcript_buffer.current_not_commit_text,
            )