Daporte commited on
Commit
4d494d9
·
verified ·
1 Parent(s): e92e60e

Create pipeline_utils

Browse files
Files changed (1) hide show
  1. pipeline_utils +120 -0
pipeline_utils ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import matplotlib.pyplot as plt
3
+ from typing import List
4
+ import numpy as np
5
+ from dataclasses import dataclass
6
+
7
+ @dataclass
8
+ class SpeakerStats:
9
+ f0_mean: float
10
+ f0_std: float
11
+ intensity_mean: float
12
+ intensity_std: float
13
+
14
+ @classmethod
15
+ def from_features(cls, f0_values: List[np.ndarray], intensity_values: List[np.ndarray]):
16
+
17
+ f0_arrays = [np.array(f0) for f0 in f0_values]
18
+ intensity_arrays = [np.array(i) for i in intensity_values]
19
+
20
+ f0_concat = np.concatenate([f0[f0 != 0] for f0 in f0_arrays])
21
+ intensity_concat = np.concatenate(intensity_arrays)
22
+
23
+
24
+ return cls(
25
+ f0_mean=float(np.mean(f0_concat)),
26
+ f0_std=float(np.std(f0_concat)),
27
+ intensity_mean=float(np.mean(intensity_concat)),
28
+ intensity_std=float(np.std(intensity_concat))
29
+ )
30
+
31
+ def compute_speaker_stats(dataset, speaker_column='speaker_id'):
32
+ """
33
+ Calculate speaker statistics from a preprocessed dataset.
34
+
35
+ Args:
36
+ dataset: HuggingFace dataset containing f0 and intensity features
37
+ speaker_column: Name of the speaker ID column (default: 'speaker')
38
+
39
+ Returns:
40
+ Dict[str, SpeakerStats]: Dictionary mapping speaker IDs to their statistics
41
+ """
42
+ speaker_features = {}
43
+
44
+ # Group features by speaker
45
+ for item in dataset:
46
+ speaker_id = item[speaker_column]
47
+ if speaker_id not in speaker_features:
48
+ speaker_features[speaker_id] = {'f0': [], 'intensity': []}
49
+
50
+ speaker_features[speaker_id]['f0'].append(item['f0'])
51
+ speaker_features[speaker_id]['intensity'].append(item['intensity'])
52
+
53
+ # Calculate stats per speaker
54
+ speaker_stats = {
55
+ spk: SpeakerStats.from_features(
56
+ feats['f0'],
57
+ feats['intensity']
58
+ )
59
+ for spk, feats in speaker_features.items()
60
+ }
61
+
62
+ return speaker_stats
63
+
64
+ def plot_reconstruction(result, sample_idx):
65
+ # Get F0 data
66
+ input_f0 = result['input_features']['f0_orig']
67
+ output_f0 = np.array(result['f0_recon'])
68
+
69
+ length = len(input_f0)
70
+ truncated_length = (length // 16) * 16
71
+
72
+ input_f0 = np.array(input_f0[:truncated_length])
73
+
74
+ # Get intensity data
75
+ input_intensity = np.array(result['input_features']['intensity_orig'][:truncated_length])
76
+ output_intensity = np.array(result['intensity_recon'])
77
+
78
+ time = np.arange(len(input_f0))
79
+
80
+ # Create figure with two subplots
81
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))
82
+
83
+ # Plot F0
84
+ ax1.plot(time, input_f0, label='Original F0', alpha=0.7)
85
+ ax1.plot(time, output_f0, label='Reconstructed F0', alpha=0.7)
86
+
87
+ # Highlight large differences in F0 (>20% of original)
88
+ f0_diff_percent = np.abs(input_f0 - output_f0) / (input_f0 + 1e-8) * 100 # Add small epsilon to avoid division by zero
89
+ large_diff_mask = (f0_diff_percent > 20)
90
+ if np.any(large_diff_mask):
91
+ ax1.fill_between(time, input_f0, output_f0,
92
+ where=large_diff_mask,
93
+ color='red', alpha=0.3,
94
+ label='Diff > 20%')
95
+
96
+ ax1.set_title(f'F0 Reconstruction (Sample {sample_idx})')
97
+ ax1.set_ylabel('Frequency (Hz)')
98
+ ax1.legend()
99
+
100
+ # Plot Intensity
101
+ ax2.plot(time, input_intensity, label='Original Intensity', alpha=0.7)
102
+ ax2.plot(time, output_intensity, label='Reconstructed Intensity', alpha=0.7)
103
+
104
+ # Highlight large differences in intensity (>20% of original)
105
+ intensity_diff_percent = np.abs(input_intensity - output_intensity) / (np.abs(input_intensity) + 1e-8) * 100
106
+ intensity_large_diff = intensity_diff_percent > 20
107
+ if np.any(intensity_large_diff):
108
+ ax2.fill_between(time, input_intensity, output_intensity,
109
+ where=intensity_large_diff,
110
+ color='red', alpha=0.3,
111
+ label='Diff > 20%')
112
+
113
+ ax2.set_title('Intensity Reconstruction')
114
+ ax2.set_ylabel('Intensity (dB)')
115
+ ax2.set_xlabel('Time (frames)')
116
+ ax2.legend()
117
+
118
+
119
+ plt.tight_layout()
120
+ return fig