-
Notifications
You must be signed in to change notification settings - Fork 102
/
model.py
443 lines (360 loc) · 15.9 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
from typing import Optional, Tuple
import torch as T
import torch.nn as nn
import torch.nn.functional as F
from ioblocks import GaussianMixtureIOLayer, FSQ
from transformer import Stack, ShapeRotator, Block as PerfBlock, GPTOutput, CACHE_FILL_VALUE, FFNN, Norm
from tokenizer import make_tokenizer
from utils import si_module, exists, isnt, tqdm0, print0, default, print0_colored
from utils import load_ckpt
@si_module
class LatentQuantizer(nn.Module):
class Config:
compressor_config: Optional[FSQ.Config] = None
dim: Optional[int] = None
ff_dim: Optional[int] = None
input_dim: int = None
from_pretrained: Optional[Tuple[str, str]] = None
def __init__(self, c: Config):
super().__init__()
if exists(c.from_pretrained):
checkpoint = load_ckpt(*c.from_pretrained)
else:
assert exists(c.compressor_config), f'hmm {c}'
self.compressor = c.compressor_config()
self.ffnn = FFNN(c.dim, c.ff_dim)
self.input = nn.Linear(c.input_dim, c.dim) if exists(c.input_dim) else nn.Identity()
if exists(c.from_pretrained):
self.load_state_dict(checkpoint)
@T.no_grad()
def forward(self, x, return_latent=False, known_latent=None):
"""
x: (B, S, D)
"""
if exists(known_latent):
return self.compressor.indices_to_codes(known_latent)
x = self.input(x)
x = self.ffnn(x)
x, tokens = self.compressor(x)
if return_latent:
return x, tokens
return x
@si_module
class TransformerVAE(nn.Module):
class Config:
io_config: Optional[GaussianMixtureIOLayer.Config] = None
stack_config: Optional[Stack.Config] = None
quantizer_config: Optional[LatentQuantizer.Config] = None
plex_layer: int = None
plex_roll: int = 1
split: bool = True
from_pretrained: Optional[Tuple[str, str]] = None
def __init__(self, c: Config):
super().__init__()
if exists(c.from_pretrained):
checkpoint = load_ckpt(*c.from_pretrained)
else:
assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}'
self.io = c.io_config()
self.stack = c.stack_config()
self.plex_layer = c.stack_config.layers//2
self.plex_roll = c.plex_roll
self.plex_dim = c.quantizer_config.dim
assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}'
self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim)
self.out_norm = Norm(c.stack_config.dim)
if c.split:
self.io2 = c.io_config()
self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim)
self.io2.fc_loc = None
self.io2.fc_scale = None
self.io2.fc_weight = None
kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
head_dim = c.stack_config.dim // c.stack_config.n_head
self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0)
cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim]
self.cache_shape = cache_shape
self.cache = [None] * self.cache_num_layers
if exists(c.from_pretrained):
result = self.load_state_dict(checkpoint, strict=False)
print0_colored(result, 'yellow')
self.quantizer = c.quantizer_config().eval()
self.quantizer.requires_grad = False
@T.no_grad()
def quantize(self, x):
if self.c.split:
x1, x2 = x.chunk(2, dim=-1)
with T.autocast(device_type='cuda', dtype=T.bfloat16):
quantized1 = self.quantizer(x1)
quantized2 = self.quantizer(x2)
return quantized1, quantized2
else:
with T.autocast(device_type='cuda', dtype=T.bfloat16):
return self.quantizer(x)
@T.no_grad()
def untokenize(self, token_data):
return self.quantizer(None, known_latent=token_data)
def init_cache(self, bsize, device, dtype, length:int=None):
cache_shape = self.cache_shape.copy()
cache_shape[1] = length or cache_shape[1]
self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
def deinit_cache(self):
self.cache = [None] * self.cache_num_layers
@T.no_grad()
def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None):
if self.c.split:
x1, x2 = data.chunk(2, dim=-1)
x = self.io.input(x1) + self.io2.input(x2)
else:
x = self.io.input(data)
cache_idx = 0
for l, layer in enumerate(self.stack.layers):
if l == self.plex_layer:
if self.c.split:
plex1, plex2 = self.quantize(data)
plex1 = T.roll(plex1, -self.c.plex_roll, dims=1)
plex2 = T.roll(plex2, -self.c.plex_roll, dims=1)
if exists(next_tokens):
plex1[:, -1:] = self.untokenize(next_tokens[0])
plex2[:, -1:] = self.untokenize(next_tokens[1])
x1 = x + self.plex_projection(plex1)
x2 = x + self.plex_projection2(plex2)
else:
plex = self.quantize(data)
plex = T.roll(plex, -self.c.plex_roll, dims=1)
if exists(next_tokens):
plex[:, -1:] = self.untokenize(next_tokens)
x = x + self.plex_projection(plex)
if l < self.plex_layer:
x = layer(x, kv=self.cache[l])
else:
if self.c.split:
x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx])
cache_idx += 1
x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx])
cache_idx += 1
else:
x = layer(x, kv=self.cache[l])
with T.autocast(device_type='cuda', dtype=T.bfloat16):
if self.c.split:
x1, x2 = self.out_norm(x1), self.out_norm(x2)
out1, out2 = self.io.output(x1), self.io.output(x2)
else:
x = self.out_norm(x)
out = self.io.output(x)
if isnt(temps):
if self.c.split:
return out1, out2
else:
return out
else:
if self.c.split:
next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :]
next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :]
next_data = T.cat([next_data1, next_data2], dim=-1)
return next_data
else:
next_data = self.io.temp_sample(out, temps)[:, -1:, :]
return next_data
@si_module
class HertzDevModel(nn.Module):
class Config:
dim: int
vocab_size: int
stack_config: Optional[Stack.Config] = None
latent_size: int = 32
split: bool = True
quantizer_config: Optional[LatentQuantizer.Config] = None
resynthesizer_config: Optional[TransformerVAE.Config] = None
from_pretrained: Optional[Tuple[str, str]] = None
def __init__(self, c: Config):
super().__init__()
if exists(c.from_pretrained):
checkpoint = load_ckpt(*c.from_pretrained)
else:
assert (exists(c.stack_config)), f'hmm {c}'
self.input = nn.Linear(c.latent_size, c.dim)
if self.c.split:
self.input2 = nn.Linear(c.latent_size, c.dim)
self.shape_rotator = ShapeRotator(c.stack_config.dim//c.stack_config.n_head, c.stack_config.seq_len, theta=c.stack_config.theta)
self.layers = nn.ModuleList([
PerfBlock(
dim=c.stack_config.dim,
layer_id=l,
n_head=c.stack_config.n_head,
kv_heads=c.stack_config.kv_heads,
ff_dim=c.stack_config.ff_dim,
eps=c.stack_config.eps,
shape_rotator=self.shape_rotator,
) for l in range(c.stack_config.layers)
])
self.output = GPTOutput(c.dim, c.vocab_size)
if self.c.split:
self.output2 = GPTOutput(c.dim, c.vocab_size)
self.cache = [None] * c.stack_config.layers
self.kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
self.head_dim = c.stack_config.dim // c.stack_config.n_head
if exists(c.from_pretrained):
result = self.load_state_dict(checkpoint, strict=False)
print0_colored(result, 'yellow')
self.resynthesizer = c.resynthesizer_config().eval()
self.resynthesizer.requires_grad = False
self.audio_tokenizer = make_tokenizer(device='cpu')
self.audio_cache = None
self.audio_latent_cache = None
self.use_audio_cache = False
@T.no_grad()
def tokenize(self, audio_data):
orig_audio_shape = audio_data.shape
if exists(self.audio_cache):
audio_data = T.cat([self.audio_cache, audio_data], dim=-1)
self.audio_cache = audio_data[..., -(6*16_000):]
elif self.use_audio_cache:
self.audio_cache = audio_data[..., -(6*16_000):]
if audio_data.shape[1] == 2:
enc_ch1 = self.audio_tokenizer.latent_from_data(audio_data[:, 0:1])
enc_ch2 = self.audio_tokenizer.latent_from_data(audio_data[:, 1:2])
return T.cat([enc_ch1, enc_ch2], dim=-1)[:, -(orig_audio_shape[-1]//2000):]
else:
return self.audio_tokenizer.latent_from_data(audio_data)[:, -(orig_audio_shape[-1]//2000):]
@T.no_grad()
def untokenize(self, token_data):
if exists(self.audio_latent_cache):
token_data = T.cat([self.audio_latent_cache, token_data], dim=1)
self.audio_latent_cache = token_data[:, -(6*8):]
elif self.use_audio_cache:
self.audio_latent_cache = token_data[:, -(6*8):]
if token_data.shape[-1] == 2*self.c.latent_size:
dec_ch1 = self.audio_tokenizer.data_from_latent(token_data[:, :self.c.latent_size])
dec_ch2 = self.audio_tokenizer.data_from_latent(token_data[:, self.c.latent_size:])
return T.cat([dec_ch1, dec_ch2], dim=1)[..., -(token_data.shape[1]*2000):]
else:
return self.audio_tokenizer.data_from_latent(token_data)[..., -(token_data.shape[1]*2000):]
def init_cache(self, bsize, device, dtype, length:int=None):
cache_shape = [self.c.stack_config.layers, length or self.c.stack_config.seq_len, 2, self.kv_heads, self.head_dim]
self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
self.resynthesizer.init_cache(bsize, device, dtype, length)
self.use_audio_cache = True
def deinit_cache(self):
self.cache = [None] * len(self.layers)
self.resynthesizer.deinit_cache()
self.audio_cache = None
self.audio_latent_cache = None
self.use_audio_cache = False
@T.no_grad()
def forward(self, data):
if self.c.split:
x1, x2 = data.chunk(2, dim=-1)
x = self.input(x1) + self.input2(x2)
else:
x = self.input(data)
for l, layer in enumerate(self.layers):
x = layer(x, kv=self.cache[l])
if self.c.split:
return self.output(x), self.output2(x)
else:
return self.output(x)
@T.no_grad()
def next_audio_from_audio(self, audio_data: T.Tensor, temps=(0.8, (0.5, 0.1))):
latents_in = self.tokenize(audio_data)
next_latents = self.next_latent(latents_in, temps)
next_model_latent = next_latents[..., self.c.latent_size:]
audio_decoded = self.untokenize(next_model_latent)[..., -2000:]
return audio_decoded
@T.no_grad()
def next_latent(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1))):
if self.c.split:
logits1, logits2 = self.forward(model_input)
next_logits1 = logits1[:, -1]
next_logits2 = logits2[:, -1]
next_token1 = F.softmax(next_logits1 / temps[0], dim=-1).multinomial(1)
next_token2 = F.softmax(next_logits2 / temps[0], dim=-1).multinomial(1)
next_input = self.resynthesizer(model_input, next_tokens=(next_token1, next_token2), temps=temps[1])
else:
logits = self.forward(model_input)
next_logits = logits[:, -1]
next_token = F.softmax(next_logits / temps[0], dim=-1).multinomial(1)
next_input = self.resynthesizer(model_input, next_tokens=next_token, temps=temps[1])
return next_input
@T.no_grad()
def completion(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True) -> T.Tensor:
"""
only accepts latent-space data.
"""
if use_cache:
self.init_cache(data.shape[0], data.device, T.bfloat16)
next_input = generated = data
target_len = min(data.shape[1] + default(gen_len, data.shape[1]), self.c.stack_config.seq_len)
for _ in tqdm0(range(data.shape[1], target_len)):
model_input = next_input if use_cache else generated
next_input = self.next_latent(model_input, temps)
generated = T.cat([generated, next_input], dim=1)
if use_cache:
self.deinit_cache()
return generated
def get_hertz_dev_config(is_split=True, use_pure_audio_ablation=False):
if is_split:
checkpoints = [('inference_care_50000', 'e4ff4fe5c7e9f066410d2a5673b7a935'), ('inference_scion_54000', 'cb8bc484423922747b277ebc2933af5d')]
elif not use_pure_audio_ablation:
checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_caraway_112000', 'fcb8368ef8ebf7712f3e31e6856da580')]
else:
checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_syrup_110000', '353c48f553f1706824c11f3bb6a049e9')]
quantizer_config=LatentQuantizer.Config(
from_pretrained=('inference_volcano_3', 'd42bf674022c5f84b051d5d7794f6169'),
compressor_config=FSQ.Config(
levels=[8,8,8,8,8],
dim=2048,
num_codebooks=1,
keep_num_codebooks_dim=None,
scale=None,
allowed_dtypes=['float32', 'float64', 'bfloat16'],
channel_first=False,
projection_has_bias=True,
return_indices=True,
force_quantization_f32=True,
use_rms=False
),
dim=2048,
ff_dim=8192,
input_dim=32
)
resynthesizer_config=TransformerVAE.Config(
io_config=GaussianMixtureIOLayer.Config(
latent_dim=32,
dim=4096,
num_components=8,
),
stack_config=Stack.Config(
layers=8,
dim=4096,
seq_len=8192,
n_head=16,
ff_dim=11008,
kv_heads=16,
eps=1e-5,
theta=10_000
),
quantizer_config=quantizer_config,
plex_layer=None,
plex_roll=1,
split=is_split,
from_pretrained=checkpoints[0],
)
return HertzDevModel.Config(
dim=4096,
vocab_size=32_768,
stack_config=Stack.Config(
layers=32,
dim=4096,
seq_len=2048,
n_head=32,
ff_dim=None,
kv_heads=None,
eps=1e-5,
theta=10_000,
),
quantizer_config=quantizer_config,
resynthesizer_config=resynthesizer_config,
split=is_split,
from_pretrained=checkpoints[1],
)