Spaces:
Running
Running
| import torchaudio | |
| import torch | |
| from transformers import ( | |
| WhisperProcessor, | |
| AutoProcessor, | |
| AutoModelForSpeechSeq2Seq, | |
| AutoModelForCTC, | |
| Wav2Vec2Processor, | |
| Wav2Vec2ForCTC | |
| ) | |
| import numpy as np | |
| import util | |
| # Load processor and model | |
| models_info = { | |
| "OpenAI-Whisper": { | |
| "processor": WhisperProcessor.from_pretrained("openai/whisper-small", language="uzbek", task="transcribe"), | |
| "model": AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small"), | |
| "ctc_model": False, | |
| "arabic_script": False | |
| }, | |
| "Meta-MMS": { | |
| "processor": AutoProcessor.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic'), | |
| "model": AutoModelForCTC.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic', ignore_mismatched_sizes=True), | |
| "ctc_model": True, | |
| "arabic_script": True | |
| }, | |
| "Ixxan-FineTuned-Whisper": { | |
| "processor": AutoProcessor.from_pretrained("ixxan/whisper-small-uyghur-common-voice"), | |
| "model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-uyghur-common-voice"), | |
| "ctc_model": False, | |
| "arabic_script": False | |
| }, | |
| "Ixxan-FineTuned-MMS": { | |
| "processor": Wav2Vec2Processor.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'), | |
| "model": Wav2Vec2ForCTC.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'), | |
| "ctc_model": True, | |
| "arabic_script": False | |
| }, | |
| } | |
| # def transcribe(audio_data, model_id) -> str: | |
| # if model_id == "Compare All Models": | |
| # return transcribe_all_models(audio_data) | |
| # else: | |
| # return transcribe_with_model(audio_data, model_id) | |
| # def transcribe_all_models(audio_data) -> dict: | |
| # transcriptions = {} | |
| # for model_id in models_info.keys(): | |
| # transcriptions[model_id] = transcribe_with_model(audio_data, model_id) | |
| # return transcriptions | |
| def transcribe(audio_data, model_id) -> str: | |
| # Load user audio | |
| if isinstance(audio_data, tuple): | |
| # microphone | |
| sampling_rate, audio_input = audio_data | |
| audio_input = (audio_input / 32768.0).astype(np.float32) | |
| elif isinstance(audio_data, str): | |
| # file upload | |
| audio_input, sampling_rate = torchaudio.load(audio_data) | |
| else: | |
| return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data)), None | |
| # # Check audio duration | |
| # duration = audio_input.shape[1] / sampling_rate | |
| # if duration > 10: | |
| # return f"<<ERROR: Audio duration ({duration:.2f}s) exceeds 10 seconds. Please upload a shorter audio clip for faster processing.>>", None | |
| model = models_info[model_id]["model"] | |
| processor = models_info[model_id]["processor"] | |
| target_sr = processor.feature_extractor.sampling_rate | |
| ctc_model = models_info[model_id]["ctc_model"] | |
| # Resample if needed | |
| if sampling_rate != target_sr: | |
| resampler = torchaudio.transforms.Resample(sampling_rate, target_sr) | |
| audio_input = resampler(audio_input) | |
| sampling_rate = target_sr | |
| # Preprocess the audio input | |
| inputs = processor(audio_input.squeeze(), sampling_rate=sampling_rate, return_tensors="pt") | |
| # Move model to GPU if available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| inputs = {key: val.to(device) for key, val in inputs.items()} | |
| # Generate transcription | |
| with torch.no_grad(): | |
| if ctc_model: | |
| logits = model(**inputs).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.batch_decode(predicted_ids)[0] | |
| else: | |
| generated_ids = model.generate(inputs["input_features"], max_length=225) | |
| transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| if models_info[model_id]["arabic_script"]: | |
| transcription_arabic = transcription | |
| transcription_latin = util.ug_arab_to_latn(transcription) | |
| else: # Latin script output | |
| transcription_arabic = util.ug_latn_to_arab(transcription) | |
| transcription_latin = transcription | |
| print(model_id, transcription_arabic, transcription_latin) | |
| return transcription_arabic, transcription_latin |