-
Notifications
You must be signed in to change notification settings - Fork 0
/
hifigan.py
457 lines (381 loc) · 14.1 KB
/
hifigan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
"""
HiFi-GAN model: generator and critics.
"""
import jax
import jax.numpy as jnp
import numpy as np
import pax
LRELU_SLOPE = 0.1
# Source: https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/spectral_norm.py
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
# Licensed under the Apache License, Version 2.0
def _l2_normalize(x, axis=None, eps=1e-12):
"""Normalizes along dimension `axis` using an L2 norm.
This specialized function exists for numerical stability reasons.
Args:
x: An input ndarray.
axis: Dimension along which to normalize, e.g. `1` to separately normalize
vectors in a batch. Passing `None` views `t` as a flattened vector when
calculating the norm (equivalent to Frobenius norm).
eps: Epsilon to avoid dividing by zero.
Returns:
An array of the same shape as 'x' L2-normalized along 'axis'.
"""
return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps)
class WeightNormConv(pax.Module):
"""Weight norm normalized convolution"""
def __init__(self, conv: pax.Conv1D):
super().__init__()
self.conv = conv
if isinstance(conv, pax.Conv1D):
self.g = jnp.ones((1, 1, conv.out_features))
elif isinstance(conv, pax.Conv1DTranspose):
self.g = jnp.ones((1, conv.out_features, 1))
parameters = pax.parameters_method("g")
def get_weight(self):
"""compute the normalized weight"""
assert self.g is not None, "Missing parameter `g`."
if isinstance(self.conv, pax.Conv1D):
weight = _l2_normalize(self.conv.weight, (0, 1))
elif isinstance(self.conv, pax.Conv1DTranspose):
weight = _l2_normalize(self.conv.weight, (0, 2))
assert weight.shape == self.conv.weight.shape
weight = self.g * weight
return weight
def __call__(self, x):
"""compute conv"""
if self.g is None:
return self.conv(x)
return self.conv.replace(weight=self.get_weight())(x)
def remove_weight_norm(self):
"""
remove g parameter for better performance
"""
conv = self.conv.replace(weight=self.get_weight())
return self.replace(conv=conv, g=None)
class SpectralNormConv(pax.Module):
"""Spectral norm normalized convolution"""
def __init__(
self,
conv: pax.Conv1D,
eps: float = 1e-4,
n_steps: int = 1,
):
super().__init__()
self.conv = conv
self.eps = eps
self.n_steps = n_steps
self.u0 = jax.random.normal(pax.next_rng_key(), (1, conv.out_features))
self.sigma = jnp.ones(())
def get_weight(self):
"""get normalized weight"""
weight = self.conv.weight
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
# Licensed under the Apache License, Version 2.0
value = jnp.reshape(weight, [-1, weight.shape[-1]])
if self.training:
u0 = self.u0
# Power iteration for the weight's singular value.
for _ in range(self.n_steps):
v0 = _l2_normalize(
jnp.matmul(u0, value.transpose([1, 0])), eps=self.eps
)
u0 = _l2_normalize(jnp.matmul(v0, value), eps=self.eps)
u0 = jax.lax.stop_gradient(u0)
v0 = jax.lax.stop_gradient(v0)
sigma = jnp.matmul(jnp.matmul(v0, value), jnp.transpose(u0))[0, 0]
self.u0 = u0
self.sigma = sigma
else:
sigma = self.sigma
value /= sigma
value_bar = value.reshape(weight.shape)
return value_bar
def __call__(self, x):
return self.conv.replace(weight=self.get_weight())(x)
def normalized_conv(
input,
output,
kernel_size,
stride,
dilation=1,
padding="SAME",
group=1,
spectral_norm=False,
):
"""return a 'normalized' conv"""
mod = pax.Conv1D(
input,
output,
kernel_size,
stride=stride,
rate=dilation,
padding=padding,
feature_group_count=group,
w_init=jax.nn.initializers.normal(0.01),
)
if spectral_norm:
return SpectralNormConv(mod)
return WeightNormConv(mod)
def conv_transpose(in_channel, out_channel, kernel_size, upsample_factor):
"""return a conv transpose"""
return WeightNormConv(
pax.Conv1DTranspose(
in_channel,
out_channel,
kernel_size,
upsample_factor,
padding="VALID",
w_init=jax.nn.initializers.normal(0.01),
)
)
class ResBlock1(pax.Module):
"""ResBlock1 module"""
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super().__init__()
self.convs1 = [
normalized_conv(channels, channels, kernel_size, 1, dilation[0], "VALID"),
normalized_conv(channels, channels, kernel_size, 1, dilation[1], "VALID"),
normalized_conv(channels, channels, kernel_size, 1, dilation[2], "VALID"),
]
self.convs2 = [
normalized_conv(channels, channels, kernel_size, 1, 1, "VALID"),
normalized_conv(channels, channels, kernel_size, 1, 1, "VALID"),
normalized_conv(channels, channels, kernel_size, 1, 1, "VALID"),
]
def __call__(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = jax.nn.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = jax.nn.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
p = (x.shape[1] - xt.shape[1]) // 2
x = xt + x[:, p:-p, :]
return x
class ResBlock2(pax.Module):
"""ResBlock2 module"""
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super().__init__()
self.convs = [
normalized_conv(channels, channels, kernel_size, 1, dilation[0], "VALID"),
normalized_conv(channels, channels, kernel_size, 1, dilation[1], "VALID"),
]
def __call__(self, x):
for c in self.convs:
xt = jax.nn.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
p = (x.shape[1] - xt.shape[1]) // 2
x = xt + x[:, p:-p, :]
return x
class Generator(pax.Module):
"""HiFi-GAN Generator"""
def __init__(
self,
mel_dim,
resblock_kernel_sizes,
upsample_rates,
upsample_kernel_sizes,
upsample_initial_channel,
resblock_kind,
resblock_dilation_sizes,
):
super().__init__()
self.mel_dim = mel_dim
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = normalized_conv(
mel_dim, upsample_initial_channel, 7, 1, 1, "VALID"
)
create_resblock = ResBlock1 if resblock_kind == "1" else ResBlock2
self.ups = []
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
in_channel = upsample_initial_channel // (2**i)
self.ups.append(
pax.Sequential(
conv_transpose(in_channel, in_channel // 2, k, u),
lambda x: x[:, u:-u, :],
)
)
self.resblocks = []
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2**i) // 2
for (k, d) in zip(resblock_kernel_sizes, resblock_dilation_sizes):
self.resblocks.append(create_resblock(ch, k, d))
self.conv_post = normalized_conv(ch, 1, 7, 1, 1, "VALID")
def compute_padding_values(self, num_frame=1024):
"""Compute the input padding and output padding of the network.
Usage:
>>> net = Generator(...)
>>> input_pad, output_pad = net.compute_padding_values()
>>> x = jnp.pad(x, [(0, 0), (input_pad, input_pad), (0, 0)], mode="reflect")
>>> y = net(x)
>>> y = y[:, output_pad:-output_pad]
"""
assert num_frame % 2 == 0
mel = np.empty((1, num_frame + 1, self.mel_dim))
fn = lambda g, c: g(c, remove_output_padding=False)
y1 = jax.eval_shape(fn, self.eval(), mel)
y2 = jax.eval_shape(fn, self.eval(), mel[:, :-1, :])
# each frame generates "hop" values
hop = y1.shape[1] - y2.shape[1]
# compute the minimum even number of frames
min_frame = num_frame - y2.shape[1] // hop // 2 * 2
y3 = jax.eval_shape(fn, self.eval(), mel[:, :min_frame, :])
# we need to remove the "output" of padding frames
remain = y3.shape[1]
assert remain % 2 == 0
return min_frame // 2, remain // 2
def __call__(self, x, remove_output_padding=True):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = jax.nn.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
residual = self.resblocks[i * self.num_kernels + j](x)
p = (xs.shape[1] - residual.shape[1]) // 2
xs = xs[:, p:-p, :] + residual
x = xs / self.num_kernels
# 0.01 is pytorch leaky value slope,
# this is not needed as jax uses the same value.
x = jax.nn.leaky_relu(x, 0.01)
x = self.conv_post(x)
x = jnp.tanh(x)
x = jnp.squeeze(x, axis=-1)
if remove_output_padding:
p = self.compute_padding_values()[-1]
x = x[:, p:-p]
return x
class CriticP(pax.Module):
"""HiFi-GAN CriticP"""
def __init__(self, period, kernel_size=5, stride=3):
super().__init__()
self.period = period
self.convs = [
normalized_conv(1, 32, kernel_size, stride),
normalized_conv(32, 128, kernel_size, stride),
normalized_conv(128, 512, kernel_size, stride),
normalized_conv(512, 1024, kernel_size, stride),
normalized_conv(1024, 1024, kernel_size, stride),
]
self.conv_post = normalized_conv(1024, 1, 3, 1, 1)
def __call__(self, x: jnp.ndarray):
fmap = []
b, t, c = x.shape
if t % self.period != 0:
n_pad = self.period - (t % self.period)
x = jnp.pad(x, [(0, 0), (0, n_pad), (0, 0)])
t = t + n_pad
x = jnp.reshape(x, (b, t // self.period, self.period, c))
x = jnp.swapaxes(x, 1, 2)
x = jnp.reshape(x, (b * self.period, t // self.period, c))
for conv in self.convs:
x = conv(x)
x = jax.nn.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = jnp.reshape(x, (b, -1))
return x, fmap
class MultiPeriodCritic(pax.Module):
"""Multi Period Critic"""
def __init__(self):
super().__init__()
self.critics = [CriticP(2), CriticP(3), CriticP(5), CriticP(7), CriticP(11)]
def __call__(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.critics:
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class CriticS(pax.Module):
"""Scale Critic"""
def __init__(self, spectral_norm=False):
super().__init__()
self.convs = [
normalized_conv(1, 128, 15, 1, spectral_norm=spectral_norm),
normalized_conv(128, 128, 41, 2, group=4, spectral_norm=spectral_norm),
normalized_conv(128, 256, 41, 2, group=16, spectral_norm=spectral_norm),
normalized_conv(256, 512, 41, 4, group=16, spectral_norm=spectral_norm),
normalized_conv(512, 1024, 41, 4, group=16, spectral_norm=spectral_norm),
normalized_conv(1024, 1024, 41, 1, group=16, spectral_norm=spectral_norm),
normalized_conv(1024, 1024, 5, 1, spectral_norm=spectral_norm),
]
self.conv_post = normalized_conv(1024, 1, 3, 1, spectral_norm=spectral_norm)
def __call__(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = jax.nn.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = jnp.reshape(x, (x.shape[0], -1))
return x, fmap
class MultiScaleCritic(pax.Module):
"""Multi Scale Critic"""
def __init__(self):
super().__init__()
self.critics = [
CriticS(spectral_norm=True),
CriticS(),
CriticS(),
]
self.meanpools = [
lambda x: x,
lambda x: pax.avg_pool(x, 4, 2, "SAME", -1),
lambda x: pax.avg_pool(x, 4, 2, "SAME", -1),
]
def __call__(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.critics):
y = self.meanpools[i](y)
y_hat = self.meanpools[i](y_hat)
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
def feature_loss(fmap_r, fmap_g):
"""feature loss"""
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += jnp.mean(jnp.abs(rl - gl))
return loss * 2
def critic_loss(disc_real_outputs, disc_generated_outputs):
"""critic loss"""
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = jnp.mean((1 - dr) ** 2)
g_loss = jnp.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss)
g_losses.append(g_loss)
return loss, r_losses, g_losses
def generator_loss(disc_outputs):
"""generator loss"""
loss = 0
gen_losses = []
for dg in disc_outputs:
l = jnp.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses