import matplotlib.pyplot as plt from typing import List import numpy as np from dataclasses import dataclass @dataclass class SpeakerStats: f0_mean: float f0_std: float intensity_mean: float intensity_std: float @classmethod def from_features(cls, f0_values: List[np.ndarray], intensity_values: List[np.ndarray]): f0_arrays = [np.array(f0) for f0 in f0_values] intensity_arrays = [np.array(i) for i in intensity_values] f0_concat = np.concatenate([f0[f0 != 0] for f0 in f0_arrays]) intensity_concat = np.concatenate(intensity_arrays) return cls( f0_mean=float(np.mean(f0_concat)), f0_std=float(np.std(f0_concat)), intensity_mean=float(np.mean(intensity_concat)), intensity_std=float(np.std(intensity_concat)) ) def compute_speaker_stats(dataset, speaker_column='speaker_id'): """ Calculate speaker statistics from a preprocessed dataset. Args: dataset: HuggingFace dataset containing f0 and intensity features speaker_column: Name of the speaker ID column (default: 'speaker') Returns: Dict[str, SpeakerStats]: Dictionary mapping speaker IDs to their statistics """ speaker_features = {} # Group features by speaker for item in dataset: speaker_id = item[speaker_column] if speaker_id not in speaker_features: speaker_features[speaker_id] = {'f0': [], 'intensity': []} speaker_features[speaker_id]['f0'].append(item['f0']) speaker_features[speaker_id]['intensity'].append(item['intensity']) # Calculate stats per speaker speaker_stats = { spk: SpeakerStats.from_features( feats['f0'], feats['intensity'] ) for spk, feats in speaker_features.items() } return speaker_stats def plot_reconstruction(result, sample_idx): # Get F0 data input_f0 = result['input_features']['f0_orig'] output_f0 = np.array(result['f0_recon']) length = len(input_f0) truncated_length = (length // 16) * 16 input_f0 = np.array(input_f0[:truncated_length]) # Get intensity data input_intensity = np.array(result['input_features']['intensity_orig'][:truncated_length]) output_intensity = np.array(result['intensity_recon']) time = np.arange(len(input_f0)) # Create figure with two subplots fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10)) # Plot F0 ax1.plot(time, input_f0, label='Original F0', alpha=0.7) ax1.set_ylim(bottom=0) ax1.plot(time, output_f0, label='Reconstructed F0', alpha=0.7) # Highlight large differences in F0 (>20% of original) f0_diff_percent = np.abs(input_f0 - output_f0) / (input_f0 + 1e-8) * 100 # Add small epsilon to avoid division by zero large_diff_mask = (f0_diff_percent > 20) if np.any(large_diff_mask): ax1.fill_between(time, input_f0, output_f0, where=large_diff_mask, color='red', alpha=0.3, label='Diff > 20%') ax1.set_title(f'F0 Reconstruction (Sample {sample_idx})') ax1.set_ylabel('Frequency (Hz)') ax1.legend() # Plot Intensity ax2.plot(time, input_intensity, label='Original Intensity', alpha=0.7) ax2.plot(time, output_intensity, label='Reconstructed Intensity', alpha=0.7) intensity_diff = np.abs(input_intensity - output_intensity) intensity_large_diff = intensity_diff > 10 if np.any(intensity_large_diff): ax2.fill_between(time, input_intensity, output_intensity, where=intensity_large_diff, color='red', alpha=0.3, label='Diff > 10 dB') ax2.set_title('Intensity Reconstruction') ax2.set_ylabel('Intensity (dB)') ax2.set_xlabel('Time (frames)') ax2.legend() plt.tight_layout() return fig