import gradio as gr import torch import torch.nn.functional as F import librosa import numpy as np from transformers import Wav2Vec2Processor import json # Import your model architecture import torch.nn as nn from transformers import Wav2Vec2Model, Wav2Vec2Config import torch import torch.nn as nn import torch.nn.functional as F from transformers import Wav2Vec2Model class Wav2Vec2MispronunciationDetectorV3(nn.Module): def __init__(self, vocab_size, pad_token_id, blank_token_id, config): super().__init__() # Pretrained wav2vec2 encoder self.encoder = Wav2Vec2Model.from_pretrained(config.pretrained_model) self.encoder.feature_extractor._freeze_parameters() # freeze conv layers # Freeze all encoder parameters (optional) if config.freeze_encoder: for param in self.encoder.parameters(): param.requires_grad = False # Phoneme embedding + phoneme BiLSTM encoder self.phoneme_embedding = nn.Embedding(vocab_size, config.hidden_size) self.phoneme_encoder = nn.LSTM( input_size=config.hidden_size, hidden_size=config.hidden_size // 2, num_layers=1, bidirectional=True, batch_first=True ) # Cross-attention layer self.cross_attention = nn.MultiheadAttention( embed_dim=config.hidden_size, num_heads=8, batch_first=True ) # Layer norms self.ln_audio = nn.LayerNorm(config.hidden_size) self.ln_cross = nn.LayerNorm(config.hidden_size) # Projection layer + dropout self.projector = nn.Linear(config.hidden_size, config.hidden_size) nn.init.xavier_uniform_(self.projector.weight) nn.init.zeros_(self.projector.bias) self.dropout = nn.Dropout(p=0.1) # BiLSTM self.bilstm = nn.LSTM( input_size=config.hidden_size, hidden_size=config.lstm_hidden_size, num_layers=1, bidirectional=True, batch_first=True ) # Classifier self.classifier = nn.Linear(config.lstm_hidden_size * 2, vocab_size) nn.init.xavier_uniform_(self.classifier.weight) nn.init.zeros_(self.classifier.bias) # CTC loss self.ctc_loss = nn.CTCLoss(blank=blank_token_id, zero_infinity=True) self.pad_token_id = pad_token_id def _init_weights(self): """Initialize weights for better training stability""" nn.init.xavier_uniform_(self.fusion.weight) nn.init.zeros_(self.fusion.bias) nn.init.xavier_uniform_(self.classifier.weight) nn.init.zeros_(self.classifier.bias) def forward(self, input_values, phoneme_ref_ids, attention_mask=None): """ Args: input_values: Audio features [batch, time] phoneme_ref_ids: Target phoneme IDs [batch, seq_len] attention_mask: Audio attention mask [batch, time] Returns: logits: [batch, time, vocab_size] attention_weights: cross-attention weights """ # 1. Extract audio features audio_features = self.encoder( input_values, attention_mask=attention_mask ).last_hidden_state # [batch, time, hidden] audio_features = self.ln_audio(audio_features) # 2. Embed and contextualize phonemes phoneme_embed = self.phoneme_embedding(phoneme_ref_ids) phoneme_embed, _ = self.phoneme_encoder(phoneme_embed) # [batch, seq_len, hidden] # 3. Cross-attention: audio attends to phonemes attended_features, attention_weights = self.cross_attention( query=audio_features, key=phoneme_embed, value=phoneme_embed ) # 4. Residual + layer norm x = self.ln_cross(audio_features + attended_features) # 5. Project + dropout x = self.dropout(F.relu(self.projector(x))) # 6. BiLSTM x, _ = self.bilstm(x) # 7. Classifier logits = self.classifier(x) output = { 'logits': logits, 'attention_weights': attention_weights } return output class TransformerBlock(nn.Module): def __init__(self, hidden_size, num_heads, ff_dim, dropout=0.1): super().__init__() self.attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True) self.ln1 = nn.LayerNorm(hidden_size) self.ff = nn.Sequential( nn.Linear(hidden_size, ff_dim), nn.ReLU(), nn.Linear(ff_dim, hidden_size) ) self.ln2 = nn.LayerNorm(hidden_size) self.dropout = nn.Dropout(dropout) def forward(self, x, context=None): if context is not None: attn_out, attn_weights = self.attn(query=x, key=context, value=context) else: attn_out, attn_weights = self.attn(query=x, key=x, value=x) x = self.ln1(x + self.dropout(attn_out)) ff_out = self.ff(x) x = self.ln2(x + self.dropout(ff_out)) return x, attn_weights class Wav2Vec2MispronunciationDetector(nn.Module): def __init__(self, vocab_size, pad_token_id, blank_token_id, pretrained_model, hidden_size=1024, num_transformer_layers=2): super().__init__() self.encoder = Wav2Vec2Model.from_pretrained(pretrained_model) self.encoder.feature_extractor._freeze_parameters() self.phoneme_embedding = nn.Embedding(vocab_size, hidden_size) self.phoneme_encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=hidden_size, nhead=8, dim_feedforward=hidden_size * 4, batch_first=True, dropout=0.1 ), num_layers=num_transformer_layers ) self.fusion_blocks = nn.ModuleList([ TransformerBlock(hidden_size=hidden_size, num_heads=8, ff_dim=hidden_size*4) for _ in range(num_transformer_layers) ]) self.projector = nn.Linear(hidden_size, hidden_size) nn.init.xavier_uniform_(self.projector.weight) nn.init.zeros_(self.projector.bias) self.dropout = nn.Dropout(0.1) self.classifier = nn.Linear(hidden_size, vocab_size) nn.init.xavier_uniform_(self.classifier.weight) nn.init.zeros_(self.classifier.bias) self.ctc_loss = nn.CTCLoss(blank=blank_token_id, zero_infinity=True) self.pad_token_id = pad_token_id def forward(self, input_values, phoneme_ref_ids, attention_mask=None): audio_features = self.encoder(input_values, attention_mask=attention_mask).last_hidden_state audio_features = F.layer_norm(audio_features, audio_features.shape[-1:]) phoneme_embed = self.phoneme_embedding(phoneme_ref_ids) phoneme_embed = self.phoneme_encoder(phoneme_embed) attn_weights_list = [] x = audio_features for block in self.fusion_blocks: x, attn_weights = block(x, context=phoneme_embed) attn_weights_list.append(attn_weights) x = self.dropout(F.relu(self.projector(x))) logits = self.classifier(x) return { 'logits': logits, 'attention_weights': attn_weights_list } # Global variables device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = None processor = None # Sample sentences with reference phonemes SAMPLE_SENTENCES = { "الحَمدُ لِلَّهِ رَبِّ العالَمينَ": { "text": "الحَمدُ لِلَّهِ رَبِّ العالَمينَ", "phonemes": "< a l H a m d u l i ll a h i r a bb i l E aa l a m ii n a" }, "وَلَمّا سَكَتَ عَن موسَى الغَضَبُ أَخَذَ الأَلواح": { "text": "وَلَمّا سَكَتَ عَن موسَى الغَضَبُ أَخَذَ الأَلواح", "phonemes": "w a l a mm aa s a k a t a E a n m uu s aa l g A D A b u < a x A * a l < a l w aa H" }, "لَنْ أَكْتُبَ رِسَالَةً بِالْإِنْجِلِيزِيَّةِ": { "text": "لَنْ أَكْتُبَ رِسَالَةً بِالْإِنْجِلِيزِيَّة", "phonemes": "l a n < a k t u b a r i s aa l a b i l < i n j i l ii z ii y a" } } def load_model(): """Load the model and processor from HuggingFace Hub.""" global model, processor try: print("Loading config...") config = Wav2Vec2Config.from_pretrained("Haitam03/Wav2Vec2-mispronunciation-detector") print(config) print("Loading processor...") processor = Wav2Vec2Processor.from_pretrained("Haitam03/Wav2Vec2-mispronunciation-detector") print("Loading model...") vocab_size = len(processor.tokenizer) # model = Wav2Vec2MispronunciationDetector( # vocab_size=vocab_size, # pad_token_id=processor.tokenizer.pad_token_id, # blank_token_id=processor.tokenizer.pad_token_id, # pretrained_model="elgeish/wav2vec2-large-xlsr-53-arabic", # hidden_size=1024, # num_transformer_layers=2 # ) model = Wav2Vec2MispronunciationDetectorV3( vocab_size=len(processor.tokenizer), pad_token_id=processor.tokenizer.pad_token_id, blank_token_id=processor.tokenizer.pad_token_id, config=config ) # Load weights model_weights = torch.hub.load_state_dict_from_url( "https://huggingface.co/Haitam03/Wav2Vec2-mispronunciation-detector/resolve/main/best_mode_full_data.pt", map_location=device ) model.load_state_dict(model_weights) model = model.to(device) model.eval() print("✓ Model loaded successfully!") return True except Exception as e: print(f"Error loading model: {e}") return False def compare_phonemes(predicted, reference): """Compare predicted and reference phonemes and highlight differences.""" pred_phones = predicted.strip().split() ref_phones = reference.strip().split() # Align and compare results = [] errors = 0 max_len = max(len(pred_phones), len(ref_phones)) for i in range(max_len): ref = ref_phones[i] if i < len(ref_phones) else "___" pred = pred_phones[i] if i < len(pred_phones) else "___" if ref == pred: results.append(f"✓ {ref}") else: results.append(f"✗ {ref} → {pred}") errors += 1 return results, errors def predict_pronunciation(audio, reference_text, reference_phonemes): """Predict pronunciation from audio.""" if audio is None: return "❌ **Error:** Please record or upload audio first!" if not reference_phonemes or not reference_phonemes.strip(): return "❌ **Error:** Please provide the reference phonemes!" try: # Load audio sampling_rate = 16000 # Gradio 'numpy' type gives a tuple (sr, data) sr, audio_data = audio if sr != sampling_rate: # Resample if needed audio_data = librosa.resample(audio_data.astype(np.float32), orig_sr=sr, target_sr=sampling_rate) # Ensure audio is 1D if audio_data.ndim > 1: audio_data = audio_data.mean(axis=1) # Process audio input_values = processor( audio_data, sampling_rate=sampling_rate, return_tensors="pt" ).input_values.to(device) # Process reference phonemes with processor.as_target_processor(): ref_phoneme_ids = processor(reference_phonemes).input_ids ref_phoneme_ids = torch.tensor([ref_phoneme_ids]).to(device) # Predict with torch.no_grad(): output = model(input_values, ref_phoneme_ids) logits = output['logits'] pred_ids = torch.argmax(logits, dim=-1) # Use batch_decode with group_tokens=False to better handle CTC blanks predicted_phonemes = processor.batch_decode(pred_ids, group_tokens=True)[0] # Clean up potential artifacts predicted_phonemes = predicted_phonemes.replace(processor.tokenizer.pad_token, "").strip() # Compare results comparison, num_errors = compare_phonemes(predicted_phonemes, reference_phonemes) # Calculate accuracy ref_phone_list = reference_phonemes.strip().split() total_phonemes = len(ref_phone_list) accuracy = ((total_phonemes - num_errors) / total_phonemes * 100) if total_phonemes > 0 else 0 # Format results result_text = f"## 📊 Results\n\n" result_text += f"**Accuracy:** {accuracy:.1f}%\n\n" result_text += f"**Errors:** {num_errors} / {total_phonemes}\n\n" result_text += f"---\n\n" result_text += f"### Reference Sentence:\n{reference_text}\n\n" result_text += f"### Reference Phonemes:\n`{reference_phonemes}`\n\n" result_text += f"### Your Pronunciation:\n`{predicted_phonemes}`\n\n" result_text += f"---\n\n" result_text += f"### Detailed Comparison:\n\n" for item in comparison: result_text += f"- {item}\n" return result_text except Exception as e: return f"❌ **An unexpected error occurred:** {str(e)}" def load_sample_sentence(choice): """Load sample sentence and phonemes.""" if choice in SAMPLE_SENTENCES: sample = SAMPLE_SENTENCES[choice] return sample["text"], sample["phonemes"] return "", "" # Load model on startup print("Initializing model...") load_model() # Create Gradio interface with gr.Blocks(title="Arabic Pronunciation Checker", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🎤 Arabic Pronunciation Checker Record yourself reading an Arabic sentence and get instant feedback on your pronunciation! ## How to use: 1. Choose a sample sentence or enter your own 2. Provide the reference phonemes (IPA format) 3. Record yourself reading the sentence 4. Click "Check Pronunciation" to see your results """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📝 Input") sample_dropdown = gr.Dropdown( choices=list(SAMPLE_SENTENCES.keys()), label="Sample Sentences (Optional)", info="Select a pre-loaded sentence", value=list(SAMPLE_SENTENCES.keys())[0] # set default dropdown value ) reference_text = gr.Textbox( label="Reference Sentence (Arabic)", placeholder="Enter the sentence in Arabic...", lines=2, value=SAMPLE_SENTENCES[list(SAMPLE_SENTENCES.keys())[0]]["text"] # default sentence ) reference_phonemes = gr.Textbox( label="Reference Phonemes", placeholder="Enter phonemes separated by spaces (e.g., a s s a l aa m u)", lines=3, info="Use IPA phonetic symbols", value=SAMPLE_SENTENCES[list(SAMPLE_SENTENCES.keys())[0]]["phonemes"] # default phonemes ) audio_input = gr.Audio( sources=["microphone"], type="numpy", label="Record Your Pronunciation" ) check_btn = gr.Button("🎯 Check Pronunciation", variant="primary", size="lg") with gr.Column(scale=1): gr.Markdown("### 📊 Results") output_text = gr.Markdown(label="Results") gr.Markdown( """ --- ### 💡 Tips: - Speak clearly and at a normal pace - Make sure you're in a quiet environment - The model works best with clear audio recordings ### ℹ️ About: This tool uses a Wav2Vec2-based model fine-tuned for Arabic pronunciation detection. Model: [Haitam03/Wav2Vec2-mispronunciation-detector](https://huggingface.co/Haitam03/Wav2Vec2-mispronunciation-detector) """ ) # Event handlers sample_dropdown.change( fn=load_sample_sentence, inputs=[sample_dropdown], outputs=[reference_text, reference_phonemes] ) check_btn.click( fn=predict_pronunciation, inputs=[audio_input, reference_text, reference_phonemes], outputs=[output_text] ) # Launch if __name__ == "__main__": demo.launch()