import os
import re
import tempfile
import torch
import gradio as gr
from faster_whisper import BatchedInferencePipeline, WhisperModel
from pydub import AudioSegment, effects
from pyannote.audio import Pipeline as DiarizationPipeline
import opencc
import spaces  # zeroGPU support
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
from termcolor import cprint
import time
import torchaudio
from pyannote.audio.pipelines.utils.hook import ProgressHook
# —————— Model Lists ——————
WHISPER_MODELS = [
    "SoybeanMilk/faster-whisper-Breeze-ASR-25",
    "asadfgglie/faster-whisper-large-v3-zh-TW",
    "deepdml/faster-whisper-large-v3-turbo-ct2",
    "guillaumekln/faster-whisper-tiny",
    "Systran/faster-whisper-large-v3",
    "XA9/Belle-faster-whisper-large-v3-zh-punct",
    "guillaumekln/faster-whisper-medium",
    "guillaumekln/faster-whisper-small",
    "guillaumekln/faster-whisper-base",
    "Luigi/whisper-small-zh_tw-ct2",
]
SENSEVOICE_MODELS = [
    "FunAudioLLM/SenseVoiceSmall",
    "funasr/paraformer-zh",
]
# —————— Language Options ——————
WHISPER_LANGUAGES = [
    "zh", "af","am","ar","as","az","ba","be","bg","bn","bo",
    "br","bs","ca","cs","cy","da","de","el","en","es","et",
    "eu","fa","fi","fo","fr","gl","gu","ha","haw","he","hi",
    "hr","ht","hu","hy","id","is","it","ja","jw","ka","kk",
    "km","kn","ko","la","lb","ln","lo","lt","lv","mg","mi",
    "mk","ml","mn","mr","ms","mt","my","ne","nl","nn","no",
    "oc","pa","pl","ps","pt","ro","ru","sa","sd","si","sk",
    "sl","sn","so","sq","sr","su","sv","sw","ta","te","tg",
    "th","tk","tl","tr","tt","uk","ur","uz","vi","yi","yo",
    "yue", "auto", 
]
SENSEVOICE_LANGUAGES = ["zh", "yue", "en", "ja", "ko", "auto", "nospeech"]
# —————— Caches ——————
whisper_pipes = {}
sense_models = {}
dar_pipe = None
converter = opencc.OpenCC('s2t')
# —————— Diarization Formatter ——————
def format_diarization_html(snippets):
    palette = ["#e74c3c", "#3498db", "#27ae60", "#e67e22", "#9b59b6", "#16a085", "#f1c40f"]
    speaker_colors = {}
    html_lines = []
    last_spk = None
    for s in snippets:
        if s.startswith("[") and "]" in s:
            spk, txt = s[1:].split("]", 1)
            spk, txt = spk.strip(), txt.strip()
        else:
            spk, txt = "", s.strip()
        # hide empty lines
        if not txt:
            continue
        # assign color if new speaker
        if spk not in speaker_colors:
            speaker_colors[spk] = palette[len(speaker_colors) % len(palette)]
        color = speaker_colors[spk]
        # simplify tag for same speaker
        if spk == last_spk:
            display = txt
        else:
            display = f"{spk}: {txt}"
        last_spk = spk
        html_lines.append(
            f"
{display}
"
        )
    return "" + "".join(html_lines) + "
"
# —————— Helpers ——————
# —————— Faster-Whisper Cache & Factory ——————
_fwhisper_models: dict[tuple[str, str], WhisperModel] = {}
def get_fwhisper_model(model_id: str, device: str) -> WhisperModel:
    """
    Lazily load and cache WhisperModel(model_id) on 'cpu' or 'cuda:0'.
    Uses float16 on GPU and int8 on CPU for speed.
    """
    key = (model_id, device)
    if key not in _fwhisper_models:
        compute_type = "float16" if device.startswith("cuda") else "int8"
        model = WhisperModel(
            model_id,
            device=device,
            compute_type=compute_type,
        )
        _fwhisper_models[key] = BatchedInferencePipeline(model=model)
    return _fwhisper_models[key]
def get_sense_model(model_id: str, device_str: str):
    key = (model_id, device_str)
    if key not in sense_models:
        sense_models[key] = AutoModel(
            model=model_id,
            vad_model="fsmn-vad",
            vad_kwargs={"max_single_segment_time": 300000},
            device=device_str,
            ban_emo_unk=False,
            hub="hf",
        )
    return sense_models[key]
def get_diarization_pipe():
    global dar_pipe
    if dar_pipe is None:
        token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
        try:
            dar_pipe = DiarizationPipeline.from_pretrained(
                "pyannote/speaker-diarization-3.1",
                use_auth_token=token or True
            )
        except Exception as e:
            print(f"Failed to load pyannote/speaker-diarization-3.1: {e}\nFalling back to pyannote/speaker-diarization@2.1.")
            dar_pipe = DiarizationPipeline.from_pretrained(
                "pyannote/speaker-diarization@2.1",
                use_auth_token=token or True
            )
    return dar_pipe
# —————— Whisper Transcription ——————
def _transcribe_fwhisper_stream_common(
    model_id,
    language,
    audio_path,
    whisper_multilingual_en,
    enable_punct,
    backend,
    device,
    banner_text,
    banner_color
):
    """
    Core generator for streaming transcription with accumulation using Faster-Whisper.
    Handles both CPU and CUDA backends; merges consecutive turns by the same speaker;
    strips injected trailing punctuation; and appends a Chinese period to new speaker turns if missing.
    Args:
        model_id: Whisper model identifier
        language: language code or "auto"
        audio_path: path to audio file
        whisper_multilingual_en: allow English in multilingual mode
        enable_punct: whether to append a Chinese period on new speaker turns when missing
        backend: "cpu" or "cuda"
        device: torch.device for model and diarizer
        banner_text: label for cprint (e.g. "CPU" or "CUDA")
        banner_color: color for cprint
    Yields:
        ("", format_diarization_html(snippets))
    """
    import re
    # Pattern to detect trailing punctuation
    end_punct_pattern = r'[。!?…~~\.\!?]+$'
    # Initialize whisper pipe
    pipe = get_fwhisper_model(model_id, backend)
    cprint(f'Whisper (faster-whisper) using {banner_text} [stream]', banner_color)
    # Load diarizer and audio
    diarizer = get_diarization_pipe()
    waveform, sample_rate = torchaudio.load(audio_path)
    if device.type == 'cuda':
        waveform = waveform.to(device)
    diarizer.to(device)
    # Run diarization
    with ProgressHook() as hook:
        diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}, hook=hook)
    snippets = []
    for turn, _, speaker in diary.itertracks(yield_label=True):
        # Extract audio segment
        start_ms = int(turn.start * 1000)
        end_ms = int(turn.end * 1000)
        segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
        # Transcribe with faster-whisper
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            segment = effects.normalize(segment)
            segment.export(tmp.name, format="wav")
            segments, _ = pipe.transcribe(
                tmp.name,
                beam_size=3,
                best_of=3,
                language=None if language == "auto" else language,
                vad_filter=True,
                batch_size=16,
                multilingual=whisper_multilingual_en,
            )
        os.unlink(tmp.name)
        # Convert and clean text
        raw_text = "".join(s.text for s in segments).strip()
        text = converter.convert(raw_text)
        if text:
            tag = f"[{speaker}]"
            if enable_punct and not re.search(end_punct_pattern, text):
                text = f'{text}。'
            else:
                text = f'{text} '
            if snippets and snippets[-1].startswith(tag):
                # Same speaker: merge 
                prev_text = snippets[-1].split('] ', 1)[1]
                snippets[-1] = f"{tag} {prev_text}{text}"
            else:
                # New speaker: 
                snippets.append(f"{tag} {text}")
            # Yield accumulated HTML
            yield "", format_diarization_html(snippets)
    return
def _transcribe_fwhisper_cpu_stream(
    model_id,
    language,
    audio_path,
    whisper_multilingual_en,
    enable_punct
):
    """
    CPU wrapper for Faster-Whisper streaming transcription.
    """
    yield from _transcribe_fwhisper_stream_common(
        model_id,
        language,
        audio_path,
        whisper_multilingual_en,
        enable_punct,
        backend="cpu",
        device=torch.device('cpu'),
        banner_text="CPU",
        banner_color="red",
    )
@spaces.GPU
def _transcribe_fwhisper_gpu_stream(
    model_id,
    language,
    audio_path,
    whisper_multilingual_en,
    enable_punct
):
    """
    CUDA wrapper for Faster-Whisper streaming transcription.
    """
    yield from _transcribe_fwhisper_stream_common(
        model_id,
        language,
        audio_path,
        whisper_multilingual_en,
        enable_punct,
        backend="cuda",
        device=torch.device('cuda'),
        banner_text="CUDA",
        banner_color="green",
    )
def transcribe_fwhisper_stream(model_id, language, audio_path, device_sel, whisper_multilingual_en, enable_punct):
    """Dispatch to CPU or GPU streaming generators, preserving two-value yields."""
    if device_sel == "GPU" and torch.cuda.is_available():
        yield from _transcribe_fwhisper_gpu_stream(model_id, language, audio_path, whisper_multilingual_en, enable_punct)
    else:
        yield from _transcribe_fwhisper_cpu_stream(model_id, language, audio_path, whisper_multilingual_en, enable_punct)
# —————— SenseVoice Transcription ——————
def _transcribe_sense_stream_common(
    model_id: str,
    language: str,
    audio_path: str,
    enable_punct: bool,
    backend: str,
    device: torch.device,
    banner_text: str,
    banner_color: str
):
    """
    Core generator for SenseVoiceSmall streaming transcription.
    Handles CPU and CUDA; merges consecutive turns by the same speaker;
    strips injected trailing punctuation; appends a Chinese period to new speaker turns if missing.
    Args:
        model_id: model identifier for SenseVoiceSmall
        language: language code
        audio_path: path to audio file
        enable_punct: whether to keep ITN punctuation and append periods
        backend: device spec for get_sense_model ("cpu" or "cuda:0")
        device: torch.device for waveform & diarizer
        banner_text: label for console banner
        banner_color: color for console banner
    Yields:
        ("", format_diarization_html(snippets))
    """
    import re
    # Pattern to detect trailing punctuation
    end_punct_pattern = r'[。!?…~~\.\!?]+$'
    # Load model
    model = get_sense_model(model_id, backend)
    cprint(f'SenseVoiceSmall using {banner_text} [stream]', banner_color)
    # Prepare diarizer and audio
    diarizer = get_diarization_pipe()
    diarizer.to(device)
    waveform, sample_rate = torchaudio.load(audio_path)
    if device.type == 'cuda':
        waveform = waveform.to(device)
    # Run diarization
    with ProgressHook() as hook:
        diary = diarizer({"waveform": waveform, "sample_rate": sample_rate}, hook=hook)
    snippets = []
    cache = {}
    for turn, _, speaker in diary.itertracks(yield_label=True):
        start_ms = int(turn.start * 1000)
        end_ms = int(turn.end * 1000)
        segment = AudioSegment.from_file(audio_path)[start_ms:end_ms]
        # Export and transcribe segment
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            segment.export(tmp.name, format="wav")
            try:
                segs = model.generate(
                    input=tmp.name,
                    cache=cache,
                    language=language,
                    use_itn=enable_punct,
                    batch_size_s=300
                )
            except Exception as e:
                cprint(f'Error: {e}', 'red')
                segs = None
        os.unlink(tmp.name)
        # Post-process text
        if segs:
            txt = rich_transcription_postprocess(segs[0]['text'])
            # Remove all punctuation if disabled
            if not enable_punct:
                txt = re.sub(r"[^\w\s]", "", txt)
            if txt:
                txt = converter.convert(txt)
                tag = f"[{speaker}]"
                if enable_punct and not re.search(end_punct_pattern, txt):
                    txt = f'{txt}。'
                else:
                    txt = f'{txt} '
                if snippets and snippets[-1].startswith(tag):
                    # Same speaker: merge with previous
                    prev_text = snippets[-1].split('] ', 1)[1]
                    snippets[-1] = f"{tag} {prev_text}{txt}"
                else:
                    # New speaker
                    snippets.append(f"{tag} {txt}")
        # Yield accumulated HTML
        yield "", format_diarization_html(snippets)
    return
def _transcribe_sense_cpu_stream(
    model_id: str,
    language: str,
    audio_path: str,
    enable_punct: bool
):
    """
    CPU wrapper for SenseVoiceSmall streaming transcription.
    """
    yield from _transcribe_sense_stream_common(
        model_id=model_id,
        language=language,
        audio_path=audio_path,
        enable_punct=enable_punct,
        backend="cpu",
        device=torch.device('cpu'),
        banner_text="CPU",
        banner_color="red"
    )
@spaces.GPU(duration=120)
def _transcribe_sense_gpu_stream(
    model_id: str,
    language: str,
    audio_path: str,
    enable_punct: bool
):
    """
    CUDA wrapper for SenseVoiceSmall streaming transcription.
    """
    yield from _transcribe_sense_stream_common(
        model_id=model_id,
        language=language,
        audio_path=audio_path,
        enable_punct=enable_punct,
        backend="cuda:0",
        device=torch.device('cuda'),
        banner_text="CUDA",
        banner_color="green"
    )
def transcribe_sense_steam(model_id: str,
                     language: str,
                     audio_path: str,
                     enable_punct: bool,
                     device_sel: str):
    if device_sel == "GPU" and torch.cuda.is_available():
        yield from _transcribe_sense_gpu_stream(model_id, language, audio_path, enable_punct)
    else:
        yield from _transcribe_sense_cpu_stream(model_id, language, audio_path, enable_punct)
# —————— Gradio UI ——————
DEMO_CSS = """
.diar {
    padding: 0.5rem;
    color: #f1f1f1;
    font-family: monospace;
    font-size: 0.9rem;
}
"""
Demo = gr.Blocks(css=DEMO_CSS)
with Demo:
    gr.Markdown("## Faster-Whisper vs. SenseVoice")
    audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input")
    examples = gr.Examples(
        examples=[["interview.mp3"], ["news.mp3"], ["meeting.mp3"]],
        inputs=[audio_input],
        label="Example Audio Files"
    )
    # ────────────────────────────────────────────────────────────────
    # 1) CONTROL PANELS (still side-by-side)
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Faster-Whisper ASR")
            whisper_dd      = gr.Dropdown(choices=WHISPER_MODELS, value=WHISPER_MODELS[0], label="Whisper Model")
            whisper_lang    = gr.Dropdown(choices=WHISPER_LANGUAGES, value="auto",      label="Whisper Language")
            device_radio    = gr.Radio(choices=["GPU","CPU"], value="GPU", label="Device")
            whisper_punct_chk = gr.Checkbox(label="Enable Punctuation", value=True)
            whisper_multilingual_en = gr.Checkbox(label="Multilingual", value=False)
            btn_w           = gr.Button("Transcribe with Faster-Whisper")
        with gr.Column():
            gr.Markdown("### FunASR SenseVoice ASR")
            sense_dd         = gr.Dropdown(choices=SENSEVOICE_MODELS, value=SENSEVOICE_MODELS[0], label="SenseVoice Model")
            sense_lang       = gr.Dropdown(choices=SENSEVOICE_LANGUAGES, value="auto", label="SenseVoice Language")
            device_radio_s   = gr.Radio(choices=["GPU","CPU"], value="GPU",     label="Device")
            sense_punct_chk        = gr.Checkbox(label="Enable Punctuation", value=True)
            btn_s            = gr.Button("Transcribe with SenseVoice")
    # ────────────────────────────────────────────────────────────────
    # 2) SHARED TRANSCRIPT ROW (aligned side-by-side)
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Faster-Whisper Output")
            out_w   = gr.Textbox(label="Raw Transcript", visible=False)
            out_w_d = gr.HTML(label="Diarized Transcript", elem_classes=["diar"])
        with gr.Column():
            gr.Markdown("### SenseVoice Output")
            out_s   = gr.Textbox(label="Raw Transcript", visible=False)
            out_s_d = gr.HTML(label="Diarized Transcript", elem_classes=["diar"])
    # ────────────────────────────────────────────────────────────────
    # 3) WIRING UP TOGGLES & BUTTONS
    # wire the callbacks into those shared boxes
    btn_w.click(
        fn=transcribe_fwhisper_stream,
        inputs=[whisper_dd, whisper_lang, audio_input, device_radio, whisper_multilingual_en, whisper_punct_chk],
        outputs=[out_w, out_w_d]
    )
    btn_s.click(
        fn=transcribe_sense_steam,
        inputs=[sense_dd, sense_lang, audio_input, sense_punct_chk, device_radio_s],
        outputs=[out_s, out_s_d]
    )
if __name__ == "__main__":
    Demo.launch()