liumaolin commited on
Commit
3a0633a
·
1 Parent(s): b67c020

Integrate FunASR.

Browse files
transcribe/helpers/funasr.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import uuid
3
+ from logging import getLogger
4
+
5
+ import numpy as np
6
+ from funasr import AutoModel
7
+ import soundfile as sf
8
+
9
+ import config
10
+
11
+ logger = getLogger(__name__)
12
+
13
+
14
+ class FunASR:
15
+ def __init__(self, source_lange: str = 'en', warmup=True) -> None:
16
+ self.source_lange = source_lange
17
+
18
+ self.model = AutoModel(
19
+ model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc"
20
+ )
21
+ if warmup:
22
+ self.warmup()
23
+
24
+ def warmup(self, warmup_steps=1):
25
+ warmup_soundfile = f"{config.ASSERT_DIR}/jfk.flac"
26
+ for _ in range(warmup_steps):
27
+ self.model.generate(input=warmup_soundfile)
28
+
29
+ def transcribe(self, audio_buffer: bytes, language):
30
+ audio_frames = np.frombuffer(audio_buffer, dtype=np.float32)
31
+ sf.write(f'{config.ASSERT_DIR}/{time.time()}.wav', audio_frames, samplerate=16000)
32
+ try:
33
+ output = self.model.generate(input=audio_frames)
34
+ return output
35
+ except Exception as e:
36
+ logger.error(e)
37
+ return []
transcribe/pipelines/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
 
 
2
  from .pipe_translate import TranslatePipe, Translate7BPipe
3
- from .pipe_whisper import WhisperPipe, WhisperChinese
4
  from .pipe_vad import VadPipe
5
- from .base import MetaItem
 
 
1
 
2
+ from .base import MetaItem
3
  from .pipe_translate import TranslatePipe, Translate7BPipe
 
4
  from .pipe_vad import VadPipe
5
+ from .pipe_whisper import WhisperPipe, WhisperChinese
6
+ from .pipe_funasr import FunASRPipe
transcribe/pipelines/pipe_funasr.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unicodedata
2
+
3
+ from .base import MetaItem, BasePipe, Segment
4
+ from ..helpers.funasr import FunASR
5
+
6
+
7
+ class FunASRPipe(BasePipe):
8
+ funasr = None
9
+
10
+ @classmethod
11
+ def init(cls):
12
+ if cls.funasr is None:
13
+ cls.funasr = FunASR()
14
+
15
+ def process(self, in_data: MetaItem) -> MetaItem:
16
+ audio_data = in_data.audio
17
+ source_language = in_data.source_language
18
+ result = self.funasr.transcribe(audio_data, source_language)
19
+
20
+ # 处理 FunASR 的输出结果
21
+ if result and isinstance(result, list) and 'text' in result[0]:
22
+ # FunASR 输出格式为包含文本和时间戳的字典列表
23
+ segments = []
24
+ texts = []
25
+
26
+ for item in result:
27
+ text = item.get('text', '')
28
+ start = item.get('start', 0)
29
+ end = item.get('end', 0)
30
+ segments.append(Segment(t0=start, t1=end, text=self.filter_chinese_printable(text)))
31
+ texts.append(text)
32
+
33
+ in_data.segments = segments
34
+ in_data.transcribe_content = "".join(texts)
35
+ else:
36
+ # 如果 FunASR 返回的是单个文本字符串或其他格式
37
+ if isinstance(result, str):
38
+ in_data.transcribe_content = result
39
+ in_data.segments = [Segment(t0=0, t1=0, text=self.filter_chinese_printable(result))]
40
+ elif result and hasattr(result[0], 'text'):
41
+ # 如果是对象列表
42
+ segments = []
43
+ texts = []
44
+ for item in result:
45
+ text = item.text
46
+ start = getattr(item, 'start', 0) or getattr(item, 't0', 0)
47
+ end = getattr(item, 'end', 0) or getattr(item, 't1', 0)
48
+ segments.append(Segment(t0=start, t1=end, text=self.filter_chinese_printable(text)))
49
+ texts.append(text)
50
+
51
+ in_data.segments = segments
52
+ in_data.transcribe_content = "".join(texts)
53
+ else:
54
+ in_data.transcribe_content = ""
55
+ in_data.segments = []
56
+
57
+ in_data.audio = b""
58
+ return in_data
59
+
60
+ def filter_chinese_printable(self, s):
61
+ printable = []
62
+ bytearray_chars = s.encode('utf-8')
63
+ for char in bytearray_chars.decode('utf-8', errors='replace'):
64
+ if unicodedata.category(char) != 'Cc': # 不可打印字符的分类为 'Cc'
65
+ printable.append(char)
66
+ return ''.join(printable).strip()
67
+
68
+
69
+ class FunASRChinese(FunASRPipe):
70
+ @classmethod
71
+ def init(cls):
72
+ if cls.funasr is None:
73
+ cls.funasr = FunASR(source_lange='zh')
transcribe/pipelines/pipe_whisper.py CHANGED
@@ -1,19 +1,17 @@
1
-
2
  import unicodedata
 
3
  from .base import MetaItem, BasePipe, Segment
4
  from ..helpers.whisper import WhisperCPP
5
 
 
6
  class WhisperPipe(BasePipe):
7
  whisper = None
8
 
9
-
10
-
11
  @classmethod
12
  def init(cls):
13
  if cls.whisper is None:
14
  # cls.zh_whisper = WhisperCPP(source_lange='zh')
15
  cls.whisper = WhisperCPP()
16
-
17
 
18
  def process(self, in_data: MetaItem) -> MetaItem:
19
  audio_data = in_data.audio
@@ -32,7 +30,6 @@ class WhisperPipe(BasePipe):
32
  if unicodedata.category(char) != 'Cc': # 不可打印字符的分类为 'Cc'
33
  printable.append(char)
34
  return ''.join(printable).strip()
35
-
36
 
37
 
38
  class WhisperChinese(WhisperPipe):
 
 
1
  import unicodedata
2
+
3
  from .base import MetaItem, BasePipe, Segment
4
  from ..helpers.whisper import WhisperCPP
5
 
6
+
7
  class WhisperPipe(BasePipe):
8
  whisper = None
9
 
 
 
10
  @classmethod
11
  def init(cls):
12
  if cls.whisper is None:
13
  # cls.zh_whisper = WhisperCPP(source_lange='zh')
14
  cls.whisper = WhisperCPP()
 
15
 
16
  def process(self, in_data: MetaItem) -> MetaItem:
17
  audio_data = in_data.audio
 
30
  if unicodedata.category(char) != 'Cc': # 不可打印字符的分类为 'Cc'
31
  printable.append(char)
32
  return ''.join(printable).strip()
 
33
 
34
 
35
  class WhisperChinese(WhisperPipe):
transcribe/translatepipes.py CHANGED
@@ -1,19 +1,17 @@
1
- from transcribe.pipelines import WhisperPipe, TranslatePipe, MetaItem, WhisperChinese, Translate7BPipe
2
- import multiprocessing as mp
3
- import config
4
 
5
 
6
  class TranslatePipes:
7
  def __init__(self) -> None:
8
-
9
- # self.whisper_input_q = mp.Queue()
10
  # self.translate_input_q = mp.Queue()
11
  # self.result_queue = mp.Queue()
12
-
13
  # whisper 转录
14
  self._whisper_pipe_en = self._launch_process(WhisperPipe())
15
  self._whisper_pipe_zh = self._launch_process(WhisperChinese())
16
-
 
17
  # llm 翻译
18
  # self._translate_pipe = self._launch_process(TranslatePipe())
19
 
@@ -23,7 +21,7 @@ class TranslatePipes:
23
 
24
  # def reset(self):
25
  # self._vad_pipe.reset()
26
-
27
  def _launch_process(self, process_obj):
28
  process_obj.daemon = True
29
  process_obj.start()
@@ -31,56 +29,56 @@ class TranslatePipes:
31
 
32
  def wait_ready(self):
33
  self._whisper_pipe_zh.wait()
 
34
  self._whisper_pipe_en.wait()
35
  # self._translate_pipe.wait()
36
  # self._vad_pipe.wait()
37
  self._translate_7b_pipe.wait()
38
-
39
  def translate(self, text, src_lang, dst_lang) -> MetaItem:
40
  item = MetaItem(
41
  transcribe_content=text,
42
- source_language=src_lang,
43
  destination_language=dst_lang)
44
  self._translate_pipe.input_queue.put(item)
45
  return self._translate_pipe.output_queue.get()
46
-
47
 
48
  def translate_large(self, text, src_lang, dst_lang) -> MetaItem:
49
  item = MetaItem(
50
  transcribe_content=text,
51
- source_language=src_lang,
52
  destination_language=dst_lang)
53
  self._translate_7b_pipe.input_queue.put(item)
54
  return self._translate_7b_pipe.output_queue.get()
55
-
56
- def get_whisper_model(self, lang:str='en'):
57
  if lang == 'zh':
58
  return self._whisper_pipe_zh
59
  return self._whisper_pipe_en
60
-
61
 
62
- def transcrible(self, audio_buffer:bytes, src_lang: str) -> MetaItem:
63
- whisper_model = self.get_whisper_model(src_lang)
 
 
 
 
 
64
  item = MetaItem(audio=audio_buffer, source_language=src_lang)
65
- whisper_model.input_queue.put(item)
66
- return whisper_model.output_queue.get()
67
-
68
- def voice_detect(self, audio_buffer:bytes) -> MetaItem:
69
  item = MetaItem(source_audio=audio_buffer)
70
  self._vad_pipe.input_queue.put(item)
71
  return self._vad_pipe.output_queue.get()
72
 
73
-
74
 
75
  if __name__ == "__main__":
76
  import soundfile
 
77
  tp = TranslatePipes()
78
  # result = tp.translate("你好,今天天气怎么样?", src_lang="zh", dst_lang="en")
79
  mel, _, = soundfile.read("assets/jfk.flac")
80
  # result = tp.transcrible(mel, 'en')
81
  result = tp.voice_detect(mel)
82
  print(result)
83
-
84
-
85
-
86
-
 
1
+ from transcribe.pipelines import WhisperPipe, MetaItem, WhisperChinese, Translate7BPipe, FunASRPipe
 
 
2
 
3
 
4
  class TranslatePipes:
5
  def __init__(self) -> None:
6
+ # self.whisper_input_q = mp.Queue()
 
7
  # self.translate_input_q = mp.Queue()
8
  # self.result_queue = mp.Queue()
9
+
10
  # whisper 转录
11
  self._whisper_pipe_en = self._launch_process(WhisperPipe())
12
  self._whisper_pipe_zh = self._launch_process(WhisperChinese())
13
+ self._funasr_pipe = self._launch_process(FunASRPipe())
14
+
15
  # llm 翻译
16
  # self._translate_pipe = self._launch_process(TranslatePipe())
17
 
 
21
 
22
  # def reset(self):
23
  # self._vad_pipe.reset()
24
+
25
  def _launch_process(self, process_obj):
26
  process_obj.daemon = True
27
  process_obj.start()
 
29
 
30
  def wait_ready(self):
31
  self._whisper_pipe_zh.wait()
32
+ self._funasr_pipe.wait()
33
  self._whisper_pipe_en.wait()
34
  # self._translate_pipe.wait()
35
  # self._vad_pipe.wait()
36
  self._translate_7b_pipe.wait()
37
+
38
  def translate(self, text, src_lang, dst_lang) -> MetaItem:
39
  item = MetaItem(
40
  transcribe_content=text,
41
+ source_language=src_lang,
42
  destination_language=dst_lang)
43
  self._translate_pipe.input_queue.put(item)
44
  return self._translate_pipe.output_queue.get()
 
45
 
46
  def translate_large(self, text, src_lang, dst_lang) -> MetaItem:
47
  item = MetaItem(
48
  transcribe_content=text,
49
+ source_language=src_lang,
50
  destination_language=dst_lang)
51
  self._translate_7b_pipe.input_queue.put(item)
52
  return self._translate_7b_pipe.output_queue.get()
53
+
54
+ def get_whisper_model(self, lang: str = 'en'):
55
  if lang == 'zh':
56
  return self._whisper_pipe_zh
57
  return self._whisper_pipe_en
 
58
 
59
+ def get_transcription_model(self, lang: str = 'en'):
60
+ if lang == 'zh':
61
+ return self._funasr_pipe
62
+ return self._whisper_pipe_en
63
+
64
+ def transcrible(self, audio_buffer: bytes, src_lang: str) -> MetaItem:
65
+ transcription_model = self.get_transcription_model(src_lang)
66
  item = MetaItem(audio=audio_buffer, source_language=src_lang)
67
+ transcription_model.input_queue.put(item)
68
+ return transcription_model.output_queue.get()
69
+
70
+ def voice_detect(self, audio_buffer: bytes) -> MetaItem:
71
  item = MetaItem(source_audio=audio_buffer)
72
  self._vad_pipe.input_queue.put(item)
73
  return self._vad_pipe.output_queue.get()
74
 
 
75
 
76
  if __name__ == "__main__":
77
  import soundfile
78
+
79
  tp = TranslatePipes()
80
  # result = tp.translate("你好,今天天气怎么样?", src_lang="zh", dst_lang="en")
81
  mel, _, = soundfile.read("assets/jfk.flac")
82
  # result = tp.transcrible(mel, 'en')
83
  result = tp.voice_detect(mel)
84
  print(result)