Spaces:
Runtime error
Runtime error
Introduce uncertainty to word error with PER threshold
Browse files
wav2vecasr/MispronounciationDetector.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from pandas.core.construction import T
|
| 2 |
import torch
|
| 3 |
import jiwer
|
|
|
|
| 4 |
|
| 5 |
class MispronounciationDetector:
|
| 6 |
def __init__(self, l2_phoneme_recogniser, g2p, device):
|
|
@@ -8,18 +9,19 @@ class MispronounciationDetector:
|
|
| 8 |
self.g2p = g2p
|
| 9 |
self.device = device
|
| 10 |
|
| 11 |
-
def detect(self, audio, text):
|
| 12 |
l2_phones = self.phoneme_asr_model.get_l2_phoneme_sequence(audio)
|
|
|
|
| 13 |
native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
|
| 14 |
standardised_native_speaker_phones = self.phoneme_asr_model.standardise_g2p_phoneme_sequence(native_speaker_phones)
|
| 15 |
-
raw_info = self.get_mispronounciation_output(text, l2_phones, standardised_native_speaker_phones)
|
| 16 |
return raw_info
|
| 17 |
|
| 18 |
def get_native_speaker_phoneme_sequence(self, text):
|
| 19 |
phonemes = self.g2p(text)
|
| 20 |
return phonemes
|
| 21 |
|
| 22 |
-
def get_mispronounciation_output(self, text, pred_phones, org_label_phones):
|
| 23 |
"""
|
| 24 |
Aligns the predicted phones from the L2 speaker and the expected native speaker phone to get the errors
|
| 25 |
:param text: original words read by the user
|
|
@@ -101,7 +103,7 @@ class MispronounciationDetector:
|
|
| 101 |
# get mispronounced words based on if there are phoneme errors present in the phonemes of that word
|
| 102 |
aligned_word_error_output = ""
|
| 103 |
words = text.split(" ")
|
| 104 |
-
word_error_bool = self.get_mispronounced_words(error_bool)
|
| 105 |
wer = sum(word_error_bool) / len(words)
|
| 106 |
|
| 107 |
raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool}
|
|
@@ -109,16 +111,27 @@ class MispronounciationDetector:
|
|
| 109 |
return raw_info
|
| 110 |
|
| 111 |
|
| 112 |
-
def get_mispronounced_words(self, phoneme_error_bool):
|
| 113 |
# map mispronounced phones back to words that were mispronounce to get WER
|
| 114 |
word_error_bool = []
|
| 115 |
phoneme_error_bool.append("|")
|
| 116 |
word_phones = self.split_lst_by_delim(phoneme_error_bool, "|")
|
|
|
|
|
|
|
| 117 |
for phones in word_phones:
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
word_error_bool.append(True)
|
| 120 |
else:
|
| 121 |
word_error_bool.append(False)
|
|
|
|
| 122 |
return word_error_bool
|
| 123 |
|
| 124 |
|
|
|
|
| 1 |
from pandas.core.construction import T
|
| 2 |
import torch
|
| 3 |
import jiwer
|
| 4 |
+
import re
|
| 5 |
|
| 6 |
class MispronounciationDetector:
|
| 7 |
def __init__(self, l2_phoneme_recogniser, g2p, device):
|
|
|
|
| 9 |
self.g2p = g2p
|
| 10 |
self.device = device
|
| 11 |
|
| 12 |
+
def detect(self, audio, text, phoneme_error_threshold=0.25):
|
| 13 |
l2_phones = self.phoneme_asr_model.get_l2_phoneme_sequence(audio)
|
| 14 |
+
l2_phones = [re.sub(r'\d', "", phone_str) for phone_str in l2_phones] #g2p has no lexical stress
|
| 15 |
native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
|
| 16 |
standardised_native_speaker_phones = self.phoneme_asr_model.standardise_g2p_phoneme_sequence(native_speaker_phones)
|
| 17 |
+
raw_info = self.get_mispronounciation_output(text, l2_phones, standardised_native_speaker_phones, phoneme_error_threshold)
|
| 18 |
return raw_info
|
| 19 |
|
| 20 |
def get_native_speaker_phoneme_sequence(self, text):
|
| 21 |
phonemes = self.g2p(text)
|
| 22 |
return phonemes
|
| 23 |
|
| 24 |
+
def get_mispronounciation_output(self, text, pred_phones, org_label_phones, phoneme_error_threshold):
|
| 25 |
"""
|
| 26 |
Aligns the predicted phones from the L2 speaker and the expected native speaker phone to get the errors
|
| 27 |
:param text: original words read by the user
|
|
|
|
| 103 |
# get mispronounced words based on if there are phoneme errors present in the phonemes of that word
|
| 104 |
aligned_word_error_output = ""
|
| 105 |
words = text.split(" ")
|
| 106 |
+
word_error_bool = self.get_mispronounced_words(error_bool, phoneme_error_threshold)
|
| 107 |
wer = sum(word_error_bool) / len(words)
|
| 108 |
|
| 109 |
raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool}
|
|
|
|
| 111 |
return raw_info
|
| 112 |
|
| 113 |
|
| 114 |
+
def get_mispronounced_words(self, phoneme_error_bool, phoneme_error_threshold):
|
| 115 |
# map mispronounced phones back to words that were mispronounce to get WER
|
| 116 |
word_error_bool = []
|
| 117 |
phoneme_error_bool.append("|")
|
| 118 |
word_phones = self.split_lst_by_delim(phoneme_error_bool, "|")
|
| 119 |
+
|
| 120 |
+
# wrong only if percentage of phones that are wrong > phoneme error threshold
|
| 121 |
for phones in word_phones:
|
| 122 |
+
|
| 123 |
+
# get count of "s", "d", "a" in phones
|
| 124 |
+
error_count = 0
|
| 125 |
+
for phone in phones:
|
| 126 |
+
if phone == "s" or phone == "d" or phone == "a":
|
| 127 |
+
error_count += 1
|
| 128 |
+
|
| 129 |
+
# check if pass threshold
|
| 130 |
+
if error_count / len(phones) > phoneme_error_threshold:
|
| 131 |
word_error_bool.append(True)
|
| 132 |
else:
|
| 133 |
word_error_bool.append(False)
|
| 134 |
+
|
| 135 |
return word_error_bool
|
| 136 |
|
| 137 |
|