Daporte commited on
Commit
e25340b
·
verified ·
1 Parent(s): e7557d8

upload files from https://github.com/facebookresearch/speech-resynthesis

Browse files
Files changed (8) hide show
  1. models.py +38 -0
  2. modules/dist.py +108 -0
  3. modules/jukebox.py +178 -0
  4. modules/resnet.py +82 -0
  5. modules/vq.py +249 -0
  6. pipeline_utils.py +120 -0
  7. quantizer_config.py +167 -0
  8. utils.py +36 -0
models.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from https://github.com/jik876/hifi-gan
2
+
3
+ from transformers.modeling_utils import PreTrainedModel
4
+
5
+ from quantizer_config import QuantizerConfig
6
+ from modules.jukebox import Encoder, Decoder
7
+ from modules.vq import Bottleneck
8
+
9
+
10
+
11
+ class Quantizer(PreTrainedModel):
12
+ config_class = QuantizerConfig
13
+
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+
17
+ self.config = config
18
+ self.encoder = Encoder(**config.f0_encoder_params)
19
+ self.vq = Bottleneck(**config.f0_vq_params)
20
+ self.decoder = Decoder(**config.f0_decoder_params)
21
+
22
+ def forward(self, **kwargs):
23
+ f0_h = self.encoder(kwargs['features'])
24
+
25
+ zs, f0_h_q, f0_commit_losses, f0_metrics = self.vq(f0_h)
26
+
27
+ f0 = self.decoder(f0_h_q)
28
+
29
+ return {
30
+ 'f0': f0,
31
+ 'commit_losses': f0_commit_losses,
32
+ 'metrics': f0_metrics,
33
+ 'codes': zs,
34
+ 'hidden_states': f0_h_q
35
+ }
36
+
37
+
38
+
modules/dist.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/openai/jukebox
2
+
3
+ from enum import Enum
4
+
5
+ import torch.distributed as dist
6
+
7
+
8
+ class ReduceOp(Enum):
9
+ SUM = 0,
10
+ PRODUCT = 1,
11
+ MIN = 2,
12
+ MAX = 3
13
+
14
+ def ToDistOp(self):
15
+ return {
16
+ self.SUM: dist.ReduceOp.SUM,
17
+ self.PRODUCT: dist.ReduceOp.PRODUCT,
18
+ self.MIN: dist.ReduceOp.MIN,
19
+ self.MAX: dist.ReduceOp.MAX
20
+ }[self]
21
+
22
+
23
+ def is_available():
24
+ return dist.is_initialized()
25
+
26
+
27
+ def get_rank():
28
+ if is_available():
29
+ return _get_rank()
30
+ else:
31
+ return 0
32
+
33
+
34
+ def get_world_size():
35
+ if is_available():
36
+ return _get_world_size()
37
+ else:
38
+ return 1
39
+
40
+
41
+ def barrier():
42
+ if is_available():
43
+ return _barrier()
44
+ # else: do nothing
45
+
46
+
47
+ def all_gather(tensor_list, tensor):
48
+ if is_available():
49
+ return _all_gather(tensor_list, tensor)
50
+ else:
51
+ tensor_list[0] = tensor
52
+
53
+
54
+ def all_reduce(tensor, op=ReduceOp.SUM):
55
+ if is_available():
56
+ return _all_reduce(tensor, op)
57
+ # else: do nothing
58
+
59
+
60
+ def reduce(tensor, dst, op=ReduceOp.SUM):
61
+ if is_available():
62
+ return _reduce(tensor, dst, op)
63
+ # else: do nothing
64
+
65
+
66
+ def broadcast(tensor, src):
67
+ if is_available():
68
+ return _broadcast(tensor, src)
69
+ # else: do nothing
70
+
71
+
72
+ def init_process_group(backend, init_method):
73
+ if is_available():
74
+ return _init_process_group(backend, init_method)
75
+ # else: do nothing
76
+
77
+
78
+ def _get_rank():
79
+ return dist.get_rank()
80
+
81
+
82
+ def _barrier():
83
+ return dist.barrier()
84
+
85
+
86
+ def _get_world_size():
87
+ return dist.get_world_size()
88
+
89
+
90
+ def _all_gather(tensor_list, tensor):
91
+ return dist.all_gather(tensor_list, tensor)
92
+
93
+
94
+ def _all_reduce(tensor, op):
95
+ return dist.all_reduce(tensor, op.ToDistOp())
96
+
97
+
98
+ def _reduce(tensor, dst, op):
99
+ return dist.reduce(tensor, dst, op.ToDistOp())
100
+
101
+
102
+ def _broadcast(tensor, src):
103
+ return dist.broadcast(tensor, src)
104
+
105
+
106
+ def _init_process_group(backend, init_method):
107
+ return dist.init_process_group(backend, init_method)
108
+
modules/jukebox.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/openai/jukebox
2
+
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ from modules.resnet import Resnet1D
6
+
7
+
8
+ def assert_shape(x, exp_shape):
9
+ assert x.shape == exp_shape, f"Expected {exp_shape} got {x.shape}"
10
+
11
+
12
+ class EncoderConvBlock(nn.Module):
13
+ def __init__(self, input_emb_width, output_emb_width, down_t, stride_t, width, depth, m_conv,
14
+ dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False):
15
+ super().__init__()
16
+ blocks = []
17
+ if type(stride_t) is tuple or type(stride_t) is list:
18
+ start = True
19
+ for s_t, d_t in zip(stride_t, down_t):
20
+ if s_t % 2 == 0:
21
+ filter_t, pad_t = s_t * 2, s_t // 2
22
+ else:
23
+ filter_t, pad_t = s_t * 2 + 1, s_t // 2 + 1
24
+ if d_t > 0:
25
+ for i in range(d_t):
26
+ block = nn.Sequential(
27
+ nn.Conv1d(input_emb_width if i == 0 and start else width, width, filter_t, s_t, pad_t),
28
+ Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out, res_scale), )
29
+ blocks.append(block)
30
+ start = False
31
+ block = nn.Conv1d(width, output_emb_width, 3, 1, 1)
32
+ blocks.append(block)
33
+ else:
34
+ filter_t, pad_t = stride_t * 2, stride_t // 2
35
+ if down_t > 0:
36
+ for i in range(down_t):
37
+ block = nn.Sequential(
38
+ nn.Conv1d(input_emb_width if i == 0 else width, width, filter_t, stride_t, pad_t),
39
+ Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out, res_scale), )
40
+ blocks.append(block)
41
+ block = nn.Conv1d(width, output_emb_width, 3, 1, 1)
42
+ blocks.append(block)
43
+ self.model = nn.Sequential(*blocks)
44
+
45
+ def forward(self, x):
46
+ return self.model(x)
47
+
48
+
49
+ class DecoderConvBock(nn.Module):
50
+ def __init__(self, input_emb_width, output_emb_width, down_t, stride_t, width, depth, m_conv,
51
+ dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False,
52
+ reverse_decoder_dilation=False, checkpoint_res=False):
53
+ super().__init__()
54
+ blocks = []
55
+
56
+ if type(stride_t) is tuple or type(stride_t) is list:
57
+ block = nn.Conv1d(output_emb_width, width, 3, 1, 1)
58
+ blocks.append(block)
59
+ for k, (s_t, d_t) in enumerate(zip(stride_t, down_t)):
60
+ if d_t > 0:
61
+ if s_t % 2 == 0:
62
+ filter_t, pad_t = s_t * 2, s_t // 2
63
+ else:
64
+ filter_t, pad_t = s_t * 2 + 1, s_t // 2 + 1
65
+ end = k == len(stride_t) - 1
66
+ for i in range(d_t):
67
+ block = nn.Sequential(
68
+ Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out=zero_out,
69
+ res_scale=res_scale, reverse_dilation=reverse_decoder_dilation,
70
+ checkpoint_res=checkpoint_res),
71
+ nn.ConvTranspose1d(width, input_emb_width if i == (d_t - 1) and end else width, filter_t,
72
+ s_t, pad_t))
73
+ blocks.append(block)
74
+ else:
75
+ if down_t > 0:
76
+ filter_t, pad_t = stride_t * 2, stride_t // 2
77
+ block = nn.Conv1d(output_emb_width, width, 3, 1, 1)
78
+ blocks.append(block)
79
+ for i in range(down_t):
80
+ block = nn.Sequential(
81
+ Resnet1D(width, depth, m_conv, dilation_growth_rate, dilation_cycle, zero_out=zero_out,
82
+ res_scale=res_scale, reverse_dilation=reverse_decoder_dilation,
83
+ checkpoint_res=checkpoint_res),
84
+ nn.ConvTranspose1d(width, input_emb_width if i == (down_t - 1) else width, filter_t, stride_t,
85
+ pad_t))
86
+ blocks.append(block)
87
+ self.model = nn.Sequential(*blocks)
88
+
89
+ def forward(self, x):
90
+ return self.model(x)
91
+
92
+
93
+ class Encoder(nn.Module):
94
+ def __init__(self, input_emb_width, output_emb_width, levels, downs_t, strides_t, **block_kwargs):
95
+ super().__init__()
96
+ self.input_emb_width = input_emb_width
97
+ self.output_emb_width = output_emb_width
98
+ self.levels = levels
99
+ self.downs_t = downs_t
100
+ self.strides_t = strides_t
101
+
102
+ block_kwargs_copy = dict(**block_kwargs)
103
+ if 'reverse_decoder_dilation' in block_kwargs_copy:
104
+ del block_kwargs_copy['reverse_decoder_dilation']
105
+ level_block = lambda level, down_t, stride_t: EncoderConvBlock(
106
+ input_emb_width if level == 0 else output_emb_width, output_emb_width, down_t, stride_t,
107
+ **block_kwargs_copy)
108
+ self.level_blocks = nn.ModuleList()
109
+ iterator = zip(list(range(self.levels)), downs_t, strides_t)
110
+ for level, down_t, stride_t in iterator:
111
+ self.level_blocks.append(level_block(level, down_t, stride_t))
112
+
113
+ def forward(self, x):
114
+ N, T = x.shape[0], x.shape[-1]
115
+ emb = self.input_emb_width
116
+ assert_shape(x, (N, emb, T))
117
+ xs = []
118
+
119
+ # 64, 32, ...
120
+ iterator = zip(list(range(self.levels)), self.downs_t, self.strides_t)
121
+ for level, down_t, stride_t in iterator:
122
+ level_block = self.level_blocks[level]
123
+ x = level_block(x)
124
+ if type(stride_t) is tuple or type(stride_t) is list:
125
+ emb, T = self.output_emb_width, T // np.prod([s ** d for s, d in zip(stride_t, down_t)])
126
+ else:
127
+ emb, T = self.output_emb_width, T // (stride_t ** down_t)
128
+ assert_shape(x, (N, emb, T))
129
+ xs.append(x)
130
+
131
+ return xs
132
+
133
+
134
+ class Decoder(nn.Module):
135
+ def __init__(self, input_emb_width, output_emb_width, levels, downs_t, strides_t, **block_kwargs):
136
+ super().__init__()
137
+ self.input_emb_width = input_emb_width
138
+ self.output_emb_width = output_emb_width
139
+ self.levels = levels
140
+
141
+ self.downs_t = downs_t
142
+
143
+ self.strides_t = strides_t
144
+
145
+ level_block = lambda level, down_t, stride_t: DecoderConvBock(output_emb_width, output_emb_width, down_t,
146
+ stride_t, **block_kwargs)
147
+ self.level_blocks = nn.ModuleList()
148
+ iterator = zip(list(range(self.levels)), downs_t, strides_t)
149
+ for level, down_t, stride_t in iterator:
150
+ self.level_blocks.append(level_block(level, down_t, stride_t))
151
+
152
+ self.out = nn.Conv1d(output_emb_width, input_emb_width, 3, 1, 1)
153
+
154
+ def forward(self, xs, all_levels=True):
155
+ if all_levels:
156
+ assert len(xs) == self.levels
157
+ else:
158
+ assert len(xs) == 1
159
+ x = xs[-1]
160
+ N, T = x.shape[0], x.shape[-1]
161
+ emb = self.output_emb_width
162
+ assert_shape(x, (N, emb, T))
163
+
164
+ # 32, 64 ...
165
+ iterator = reversed(list(zip(list(range(self.levels)), self.downs_t, self.strides_t)))
166
+ for level, down_t, stride_t in iterator:
167
+ level_block = self.level_blocks[level]
168
+ x = level_block(x)
169
+ if type(stride_t) is tuple or type(stride_t) is list:
170
+ emb, T = self.output_emb_width, T * np.prod([s ** d for s, d in zip(stride_t, down_t)])
171
+ else:
172
+ emb, T = self.output_emb_width, T * (stride_t ** down_t)
173
+ assert_shape(x, (N, emb, T))
174
+ if level != 0 and all_levels:
175
+ x = x + xs[level - 1]
176
+
177
+ x = self.out(x)
178
+ return x
modules/resnet.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/openai/jukebox
2
+
3
+ import math
4
+ import torch.nn as nn
5
+
6
+ import modules.dist as dist
7
+
8
+
9
+ class ResConvBlock(nn.Module):
10
+ def __init__(self, n_in, n_state):
11
+ super().__init__()
12
+ self.model = nn.Sequential(
13
+ nn.ReLU(),
14
+ nn.Conv2d(n_in, n_state, 3, 1, 1),
15
+ nn.ReLU(),
16
+ nn.Conv2d(n_state, n_in, 1, 1, 0),
17
+ )
18
+
19
+ def forward(self, x):
20
+ return x + self.model(x)
21
+
22
+
23
+ class Resnet(nn.Module):
24
+ def __init__(self, n_in, n_depth, m_conv=1.0):
25
+ super().__init__()
26
+ self.model = nn.Sequential(*[ResConvBlock(n_in, int(m_conv * n_in)) for _ in range(n_depth)])
27
+
28
+ def forward(self, x):
29
+ return self.model(x)
30
+
31
+
32
+ class ResConv1DBlock(nn.Module):
33
+ def __init__(self, n_in, n_state, dilation=1, zero_out=False, res_scale=1.0):
34
+ super().__init__()
35
+ padding = dilation
36
+ self.model = nn.Sequential(
37
+ nn.ReLU(),
38
+ nn.Conv1d(n_in, n_state, 3, 1, padding, dilation),
39
+ nn.ReLU(),
40
+ nn.Conv1d(n_state, n_in, 1, 1, 0),
41
+ )
42
+ if zero_out:
43
+ out = self.model[-1]
44
+ nn.init.zeros_(out.weight)
45
+ nn.init.zeros_(out.bias)
46
+ self.res_scale = res_scale
47
+
48
+ def forward(self, x):
49
+ return x + self.res_scale * self.model(x)
50
+
51
+
52
+ class Resnet1D(nn.Module):
53
+ def __init__(self, n_in, n_depth, m_conv=1.0, dilation_growth_rate=1, dilation_cycle=None, zero_out=False,
54
+ res_scale=False, reverse_dilation=False, checkpoint_res=False):
55
+ super().__init__()
56
+
57
+ def _get_depth(depth):
58
+ if dilation_cycle is None:
59
+ return depth
60
+ else:
61
+ return depth % dilation_cycle
62
+
63
+ blocks = [ResConv1DBlock(n_in, int(m_conv * n_in),
64
+ dilation=dilation_growth_rate ** _get_depth(depth),
65
+ zero_out=zero_out,
66
+ res_scale=1.0 if not res_scale else 1.0 / math.sqrt(n_depth))
67
+ for depth in range(n_depth)]
68
+ if reverse_dilation:
69
+ blocks = blocks[::-1]
70
+ self.checkpoint_res = checkpoint_res
71
+ if self.checkpoint_res == 1:
72
+ if dist.get_rank() == 0:
73
+ print("Checkpointing convs")
74
+ self.blocks = nn.ModuleList(blocks)
75
+ else:
76
+ self.model = nn.Sequential(*blocks)
77
+
78
+ def forward(self, x):
79
+ if self.checkpoint_res == 1:
80
+ raise NotImplementedError("Checkpoint not implemented")
81
+ else:
82
+ return self.model(x)
modules/vq.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/openai/jukebox
2
+
3
+ import numpy as np
4
+ import torch as t
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import modules.dist as dist
9
+
10
+
11
+ class BottleneckBlock(nn.Module):
12
+ def __init__(self, k_bins, emb_width, mu):
13
+ super().__init__()
14
+ self.k_bins = k_bins
15
+ self.emb_width = emb_width
16
+ self.mu = mu
17
+ self.reset_k()
18
+ self.threshold = 1.0
19
+
20
+ def reset_k(self):
21
+ self.init = False
22
+ self.k_sum = None
23
+ self.k_elem = None
24
+ self.register_buffer('k', t.zeros(self.k_bins, self.emb_width).cuda())
25
+
26
+ def _tile(self, x):
27
+ d, ew = x.shape
28
+ if d < self.k_bins:
29
+ n_repeats = (self.k_bins + d - 1) // d
30
+ std = 0.01 / np.sqrt(ew)
31
+ x = x.repeat(n_repeats, 1)
32
+ x = x + t.randn_like(x) * std
33
+ return x
34
+
35
+ def init_k(self, x):
36
+ mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins
37
+ self.init = True
38
+ # init k_w using random vectors from x
39
+ y = self._tile(x)
40
+ _k_rand = y[t.randperm(y.shape[0])][:k_bins]
41
+ dist.broadcast(_k_rand, 0)
42
+ self.k = _k_rand
43
+ assert self.k.shape == (k_bins, emb_width)
44
+ self.k_sum = self.k
45
+ self.k_elem = t.ones(k_bins, device=self.k.device)
46
+
47
+ def restore_k(self, num_tokens=None, threshold=1.0):
48
+ mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins
49
+ self.init = True
50
+ assert self.k.shape == (k_bins, emb_width)
51
+ self.k_sum = self.k.clone()
52
+ self.k_elem = t.ones(k_bins, device=self.k.device)
53
+ if num_tokens is not None:
54
+ expected_usage = num_tokens / k_bins
55
+ self.k_elem.data.mul_(expected_usage)
56
+ self.k_sum.data.mul_(expected_usage)
57
+ self.threshold = threshold
58
+
59
+ def update_k(self, x, x_l):
60
+ mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins
61
+ with t.no_grad():
62
+ # Calculate new centres
63
+ x_l_onehot = t.zeros(k_bins, x.shape[0], device=x.device) # k_bins, N * L
64
+ x_l_onehot.scatter_(0, x_l.view(1, x.shape[0]), 1)
65
+
66
+ _k_sum = t.matmul(x_l_onehot, x) # k_bins, w
67
+ _k_elem = x_l_onehot.sum(dim=-1) # k_bins
68
+ y = self._tile(x)
69
+ _k_rand = y[t.randperm(y.shape[0])][:k_bins]
70
+
71
+ dist.broadcast(_k_rand, 0)
72
+ dist.all_reduce(_k_sum)
73
+ dist.all_reduce(_k_elem)
74
+
75
+ # Update centres
76
+ old_k = self.k
77
+ self.k_sum = mu * self.k_sum + (1. - mu) * _k_sum # w, k_bins
78
+ self.k_elem = mu * self.k_elem + (1. - mu) * _k_elem # k_bins
79
+ usage = (self.k_elem.view(k_bins, 1) >= self.threshold).float()
80
+ self.k = usage * (self.k_sum.view(k_bins, emb_width) / self.k_elem.view(k_bins, 1)) \
81
+ + (1 - usage) * _k_rand
82
+ _k_prob = _k_elem / t.sum(_k_elem) # x_l_onehot.mean(dim=-1) # prob of each bin
83
+ entropy = -t.sum(_k_prob * t.log(_k_prob + 1e-8)) # entropy ie how diverse
84
+ used_curr = (_k_elem >= self.threshold).sum()
85
+ usage = t.sum(usage)
86
+ dk = t.norm(self.k - old_k) / np.sqrt(np.prod(old_k.shape))
87
+ return dict(entropy=entropy,
88
+ used_curr=used_curr,
89
+ usage=usage,
90
+ dk=dk)
91
+
92
+ def preprocess(self, x):
93
+ # NCT -> NTC -> [NT, C]
94
+ x = x.permute(0, 2, 1).contiguous()
95
+ x = x.view(-1, x.shape[-1]) # x_en = (N * L, w), k_j = (w, k_bins)
96
+
97
+ if x.shape[-1] == self.emb_width:
98
+ prenorm = t.norm(x - t.mean(x)) / np.sqrt(np.prod(x.shape))
99
+ elif x.shape[-1] == 2 * self.emb_width:
100
+ x1, x2 = x[..., :self.emb_width], x[..., self.emb_width:]
101
+ prenorm = (t.norm(x1 - t.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (
102
+ t.norm(x2 - t.mean(x2)) / np.sqrt(np.prod(x2.shape)))
103
+
104
+ # Normalise
105
+ x = x1 + x2
106
+ else:
107
+ assert False, f"Expected {x.shape[-1]} to be (1 or 2) * {self.emb_width}"
108
+ return x, prenorm
109
+
110
+ def postprocess(self, x_l, x_d, x_shape):
111
+ # [NT, C] -> NTC -> NCT
112
+ N, T = x_shape
113
+ x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous()
114
+ x_l = x_l.view(N, T)
115
+ return x_l, x_d
116
+
117
+ def quantise(self, x):
118
+ # Calculate latent code x_l
119
+ k_w = self.k.t()
120
+ distance = t.sum(x ** 2, dim=-1, keepdim=True) - 2 * t.matmul(x, k_w) + t.sum(k_w ** 2, dim=0,
121
+ keepdim=True) # (N * L, b)
122
+ min_distance, x_l = t.min(distance, dim=-1)
123
+ fit = t.mean(min_distance)
124
+ return x_l, fit
125
+
126
+ def dequantise(self, x_l):
127
+ x = F.embedding(x_l, self.k)
128
+ return x
129
+
130
+ def encode(self, x):
131
+ N, width, T = x.shape
132
+
133
+ # Preprocess.
134
+ x, prenorm = self.preprocess(x)
135
+
136
+ # Quantise
137
+ x_l, fit = self.quantise(x)
138
+
139
+ # Postprocess.
140
+ x_l = x_l.view(N, T)
141
+ return x_l
142
+
143
+ def decode(self, x_l):
144
+ N, T = x_l.shape
145
+ width = self.emb_width
146
+
147
+ # Dequantise
148
+ x_d = self.dequantise(x_l)
149
+
150
+ # Postprocess
151
+ x_d = x_d.view(N, T, width).permute(0, 2, 1).contiguous()
152
+ return x_d
153
+
154
+ def forward(self, x, update_k=True):
155
+ N, width, T = x.shape
156
+
157
+ # Preprocess
158
+ x, prenorm = self.preprocess(x)
159
+
160
+ # Init k if not inited
161
+ if update_k and not self.init:
162
+ self.init_k(x)
163
+
164
+ # Quantise and dequantise through bottleneck
165
+ x_l, fit = self.quantise(x)
166
+ x_d = self.dequantise(x_l)
167
+
168
+ # Update embeddings
169
+ if update_k and self.training:
170
+ update_metrics = self.update_k(x, x_l)
171
+ else:
172
+ update_metrics = {}
173
+
174
+ # Loss
175
+ commit_loss = t.norm(x_d.detach() - x) ** 2 / np.prod(x.shape)
176
+
177
+ # Passthrough
178
+ x_d = x + (x_d - x).detach()
179
+
180
+ # Postprocess
181
+ x_l, x_d = self.postprocess(x_l, x_d, (N, T))
182
+ return x_l, x_d, commit_loss, dict(fit=fit,
183
+ pn=prenorm,
184
+ **update_metrics)
185
+
186
+
187
+ class Bottleneck(nn.Module):
188
+ def __init__(self, l_bins, emb_width, mu, levels):
189
+ super().__init__()
190
+ self.levels = levels
191
+ level_block = lambda level: BottleneckBlock(l_bins, emb_width, mu)
192
+ self.level_blocks = nn.ModuleList()
193
+ for level in range(self.levels):
194
+ self.level_blocks.append(level_block(level))
195
+
196
+ def encode(self, xs):
197
+ zs = [level_block.encode(x) for (level_block, x) in zip(self.level_blocks, xs)]
198
+ return zs
199
+
200
+ def decode(self, zs, start_level=0, end_level=None):
201
+ if end_level is None:
202
+ end_level = self.levels
203
+ xs_quantised = [level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], zs)]
204
+ return xs_quantised
205
+
206
+ def forward(self, xs):
207
+ zs, xs_quantised, commit_losses, metrics = [], [], [], []
208
+ for level in range(self.levels):
209
+ level_block = self.level_blocks[level]
210
+ x = xs[level]
211
+ z, x_quantised, commit_loss, metric = level_block(x, update_k=self.training)
212
+ zs.append(z)
213
+ if not self.training:
214
+ # Be extra paranoid and make sure the encoder weights can't
215
+ # change from straight-through estimator
216
+ x_quantised = x_quantised.detach()
217
+ xs_quantised.append(x_quantised)
218
+ commit_losses.append(commit_loss)
219
+ if self.training:
220
+ metrics.append(metric)
221
+ return zs, xs_quantised, commit_losses, metrics
222
+
223
+
224
+ class NoBottleneckBlock(nn.Module):
225
+ def restore_k(self):
226
+ pass
227
+
228
+
229
+ class NoBottleneck(nn.Module):
230
+ def __init__(self, levels):
231
+ super().__init__()
232
+ self.level_blocks = nn.ModuleList()
233
+ self.levels = levels
234
+ for level in range(levels):
235
+ self.level_blocks.append(NoBottleneckBlock())
236
+
237
+ def encode(self, xs):
238
+ return xs
239
+
240
+ def decode(self, zs, start_level=0, end_level=None):
241
+ if end_level is None:
242
+ end_level = self.levels
243
+ return zs
244
+
245
+ def forward(self, xs):
246
+ zero = t.zeros(()).cuda()
247
+ commit_losses = [zero for _ in range(self.levels)]
248
+ metrics = [dict(entropy=zero, usage=zero, used_curr=zero, pn=zero, dk=zero) for _ in range(self.levels)]
249
+ return xs, xs, commit_losses, metrics
pipeline_utils.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import matplotlib.pyplot as plt
3
+ from typing import List
4
+ import numpy as np
5
+ from dataclasses import dataclass
6
+
7
+ @dataclass
8
+ class SpeakerStats:
9
+ f0_mean: float
10
+ f0_std: float
11
+ intensity_mean: float
12
+ intensity_std: float
13
+
14
+ @classmethod
15
+ def from_features(cls, f0_values: List[np.ndarray], intensity_values: List[np.ndarray]):
16
+
17
+ f0_arrays = [np.array(f0) for f0 in f0_values]
18
+ intensity_arrays = [np.array(i) for i in intensity_values]
19
+
20
+ f0_concat = np.concatenate([f0[f0 != 0] for f0 in f0_arrays])
21
+ intensity_concat = np.concatenate(intensity_arrays)
22
+
23
+
24
+ return cls(
25
+ f0_mean=float(np.mean(f0_concat)),
26
+ f0_std=float(np.std(f0_concat)),
27
+ intensity_mean=float(np.mean(intensity_concat)),
28
+ intensity_std=float(np.std(intensity_concat))
29
+ )
30
+
31
+ def compute_speaker_stats(dataset, speaker_column='speaker_id'):
32
+ """
33
+ Calculate speaker statistics from a preprocessed dataset.
34
+
35
+ Args:
36
+ dataset: HuggingFace dataset containing f0 and intensity features
37
+ speaker_column: Name of the speaker ID column (default: 'speaker')
38
+
39
+ Returns:
40
+ Dict[str, SpeakerStats]: Dictionary mapping speaker IDs to their statistics
41
+ """
42
+ speaker_features = {}
43
+
44
+ # Group features by speaker
45
+ for item in dataset:
46
+ speaker_id = item[speaker_column]
47
+ if speaker_id not in speaker_features:
48
+ speaker_features[speaker_id] = {'f0': [], 'intensity': []}
49
+
50
+ speaker_features[speaker_id]['f0'].append(item['f0'])
51
+ speaker_features[speaker_id]['intensity'].append(item['intensity'])
52
+
53
+ # Calculate stats per speaker
54
+ speaker_stats = {
55
+ spk: SpeakerStats.from_features(
56
+ feats['f0'],
57
+ feats['intensity']
58
+ )
59
+ for spk, feats in speaker_features.items()
60
+ }
61
+
62
+ return speaker_stats
63
+
64
+ def plot_reconstruction(result, sample_idx):
65
+ # Get F0 data
66
+ input_f0 = result['input_features']['f0_orig']
67
+ output_f0 = np.array(result['f0_recon'])
68
+
69
+ length = len(input_f0)
70
+ truncated_length = (length // 16) * 16
71
+
72
+ input_f0 = np.array(input_f0[:truncated_length])
73
+
74
+ # Get intensity data
75
+ input_intensity = np.array(result['input_features']['intensity_orig'][:truncated_length])
76
+ output_intensity = np.array(result['intensity_recon'])
77
+
78
+ time = np.arange(len(input_f0))
79
+
80
+ # Create figure with two subplots
81
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))
82
+
83
+ # Plot F0
84
+ ax1.plot(time, input_f0, label='Original F0', alpha=0.7)
85
+ ax1.plot(time, output_f0, label='Reconstructed F0', alpha=0.7)
86
+
87
+ # Highlight large differences in F0 (>20% of original)
88
+ f0_diff_percent = np.abs(input_f0 - output_f0) / (input_f0 + 1e-8) * 100 # Add small epsilon to avoid division by zero
89
+ large_diff_mask = (f0_diff_percent > 20)
90
+ if np.any(large_diff_mask):
91
+ ax1.fill_between(time, input_f0, output_f0,
92
+ where=large_diff_mask,
93
+ color='red', alpha=0.3,
94
+ label='Diff > 20%')
95
+
96
+ ax1.set_title(f'F0 Reconstruction (Sample {sample_idx})')
97
+ ax1.set_ylabel('Frequency (Hz)')
98
+ ax1.legend()
99
+
100
+ # Plot Intensity
101
+ ax2.plot(time, input_intensity, label='Original Intensity', alpha=0.7)
102
+ ax2.plot(time, output_intensity, label='Reconstructed Intensity', alpha=0.7)
103
+
104
+ # Highlight large differences in intensity (>20% of original)
105
+ intensity_diff_percent = np.abs(input_intensity - output_intensity) / (np.abs(input_intensity) + 1e-8) * 100
106
+ intensity_large_diff = intensity_diff_percent > 20
107
+ if np.any(intensity_large_diff):
108
+ ax2.fill_between(time, input_intensity, output_intensity,
109
+ where=intensity_large_diff,
110
+ color='red', alpha=0.3,
111
+ label='Diff > 20%')
112
+
113
+ ax2.set_title('Intensity Reconstruction')
114
+ ax2.set_ylabel('Intensity (dB)')
115
+ ax2.set_xlabel('Time (frames)')
116
+ ax2.legend()
117
+
118
+
119
+ plt.tight_layout()
120
+ return fig
quantizer_config.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from transformers import PretrainedConfig
3
+ from typing import List, Optional
4
+
5
+ class QuantizerConfig(PretrainedConfig):
6
+ model_type = "prosody_quantizer"
7
+
8
+ def __init__(
9
+ self,
10
+ # VQ parameters
11
+ l_bins: int = 320,
12
+ emb_width: int = 64,
13
+ mu: float = 0.99,
14
+ levels: int = 1,
15
+
16
+ # Encoder parameters
17
+ encoder_input_emb_width: int = 3,
18
+ encoder_output_emb_width: int = 64,
19
+ encoder_levels: int = 1,
20
+ encoder_downs_t: List[int] = [4],
21
+ encoder_strides_t: List[int] = [2],
22
+ encoder_width: int = 32,
23
+ encoder_depth: int = 4,
24
+ encoder_m_conv: float = 1.0,
25
+ encoder_dilation_growth_rate: int = 3,
26
+
27
+ # Decoder parameters
28
+ decoder_input_emb_width: int = 3,
29
+ decoder_output_emb_width: int = 64,
30
+ decoder_levels: int = 1,
31
+ decoder_downs_t: List[int] = [4],
32
+ decoder_strides_t: List[int] = [2],
33
+ decoder_width: int = 32,
34
+ decoder_depth: int = 4,
35
+ decoder_m_conv: float = 1.0,
36
+ decoder_dilation_growth_rate: int = 3,
37
+
38
+ # Training parameters
39
+ lambda_commit: float = 0.02,
40
+ f0_normalize: bool = True,
41
+ intensity_normalize: bool = True,
42
+ multispkr: str = "single",
43
+ f0_feats: bool = False,
44
+ f0_median: bool = False,
45
+
46
+ # Optional training hyperparameters
47
+ learning_rate: float = 0.0002,
48
+ adam_b1: float = 0.8,
49
+ adam_b2: float = 0.99,
50
+ lr_decay: float = 0.999,
51
+ **kwargs
52
+ ):
53
+ super().__init__(**kwargs)
54
+
55
+ # VQ parameters
56
+ self.l_bins = l_bins
57
+ self.emb_width = emb_width
58
+ self.mu = mu
59
+ self.levels = levels
60
+
61
+ # Encoder parameters
62
+ self.encoder_input_emb_width = encoder_input_emb_width
63
+ self.encoder_output_emb_width = encoder_output_emb_width
64
+ self.encoder_levels = encoder_levels
65
+ self.encoder_downs_t = encoder_downs_t
66
+ self.encoder_strides_t = encoder_strides_t
67
+ self.encoder_width = encoder_width
68
+ self.encoder_depth = encoder_depth
69
+ self.encoder_m_conv = encoder_m_conv
70
+ self.encoder_dilation_growth_rate = encoder_dilation_growth_rate
71
+
72
+ # Decoder parameters
73
+ self.decoder_input_emb_width = decoder_input_emb_width
74
+ self.decoder_output_emb_width = decoder_output_emb_width
75
+ self.decoder_levels = decoder_levels
76
+ self.decoder_downs_t = decoder_downs_t
77
+ self.decoder_strides_t = decoder_strides_t
78
+ self.decoder_width = decoder_width
79
+ self.decoder_depth = decoder_depth
80
+ self.decoder_m_conv = decoder_m_conv
81
+ self.decoder_dilation_growth_rate = decoder_dilation_growth_rate
82
+
83
+ # Training parameters
84
+ self.lambda_commit = lambda_commit
85
+ self.f0_normalize = f0_normalize
86
+ self.intensity_normalize = intensity_normalize
87
+ self.multispkr = multispkr
88
+ self.f0_feats = f0_feats
89
+ self.f0_median = f0_median
90
+
91
+ # Training hyperparameters
92
+ self.learning_rate = learning_rate
93
+ self.adam_b1 = adam_b1
94
+ self.adam_b2 = adam_b2
95
+ self.lr_decay = lr_decay
96
+
97
+ @property
98
+ def f0_vq_params(self):
99
+ return {
100
+ "l_bins": self.l_bins,
101
+ "emb_width": self.emb_width,
102
+ "mu": self.mu,
103
+ "levels": self.levels
104
+ }
105
+
106
+ @property
107
+ def f0_encoder_params(self):
108
+ return {
109
+ "input_emb_width": self.encoder_input_emb_width,
110
+ "output_emb_width": self.encoder_output_emb_width,
111
+ "levels": self.encoder_levels,
112
+ "downs_t": self.encoder_downs_t,
113
+ "strides_t": self.encoder_strides_t,
114
+ "width": self.encoder_width,
115
+ "depth": self.encoder_depth,
116
+ "m_conv": self.encoder_m_conv,
117
+ "dilation_growth_rate": self.encoder_dilation_growth_rate
118
+ }
119
+
120
+ @property
121
+ def f0_decoder_params(self):
122
+ return {
123
+ "input_emb_width": self.decoder_input_emb_width,
124
+ "output_emb_width": self.decoder_output_emb_width,
125
+ "levels": self.decoder_levels,
126
+ "downs_t": self.decoder_downs_t,
127
+ "strides_t": self.decoder_strides_t,
128
+ "width": self.decoder_width,
129
+ "depth": self.decoder_depth,
130
+ "m_conv": self.decoder_m_conv,
131
+ "dilation_growth_rate": self.decoder_dilation_growth_rate
132
+ }
133
+
134
+ @classmethod
135
+ def from_yaml(cls, yaml_path: str):
136
+ """Load config from yaml file"""
137
+ import yaml
138
+ with open(yaml_path, 'r') as f:
139
+ config = yaml.safe_load(f)
140
+
141
+ # Convert yaml config to kwargs
142
+ kwargs = {
143
+ # VQ params
144
+ **{k: v for k, v in config['f0_vq_params'].items()},
145
+
146
+ # Encoder params
147
+ **{f"encoder_{k}": v for k, v in config['f0_encoder_params'].items()},
148
+
149
+ # Decoder params
150
+ **{f"decoder_{k}": v for k, v in config['f0_decoder_params'].items()},
151
+
152
+ # Training params
153
+ "lambda_commit": config.get('lambda_commit', 0.02),
154
+ "f0_normalize": config.get('f0_normalize', True),
155
+ "intensity_normalize": config.get('intensity_normalize', True),
156
+ "multispkr": config.get('multispkr', "single"),
157
+ "f0_feats": config.get('f0_feats', False),
158
+ "f0_median": config.get('f0_median', False),
159
+
160
+ # Training hyperparams
161
+ "learning_rate": config.get('learning_rate', 0.0002),
162
+ "adam_b1": config.get('adam_b1', 0.8),
163
+ "adam_b2": config.get('adam_b2', 0.99),
164
+ "lr_decay": config.get('lr_decay', 0.999),
165
+ }
166
+
167
+ return cls(**kwargs)
utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # # All rights reserved.
3
+ # #
4
+ # # This source code is licensed under the license found in the
5
+ # # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # # Adapted from https://github.com/jik876/hifi-gan
8
+ #
9
+ import os
10
+ import torch
11
+
12
+
13
+ def init_weights(m, mean=0.0, std=0.01):
14
+ classname = m.__class__.__name__
15
+ if classname.find("Conv") != -1:
16
+ m.weight.data.normal_(mean, std)
17
+
18
+
19
+ def get_padding(kernel_size, dilation=1):
20
+ return int((kernel_size*dilation - dilation)/2)
21
+
22
+
23
+ def load_checkpoint(filepath, device):
24
+ assert os.path.isfile(filepath)
25
+ print("Loading '{}'".format(filepath))
26
+ checkpoint_dict = torch.load(filepath, map_location=device)
27
+ print("Complete.")
28
+ return checkpoint_dict
29
+
30
+
31
+ class AttrDict(dict):
32
+ def __init__(self, *args, **kwargs):
33
+ super(AttrDict, self).__init__(*args, **kwargs)
34
+ self.__dict__ = self
35
+
36
+