forked from rosinality/vq-vae-2-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_vqvae.py
executable file
·110 lines (79 loc) · 2.96 KB
/
train_vqvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from tqdm import tqdm
from vqvae import VQVAE
from scheduler import CycleScheduler
def train(epoch, loader, model, optimizer, scheduler, device):
loader = tqdm(loader)
criterion = nn.MSELoss()
latent_loss_weight = 0.25
sample_size = 25
mse_sum = 0
mse_n = 0
for i, (img, label) in enumerate(loader):
model.zero_grad()
img = img.to(device)
out, latent_loss = model(img)
recon_loss = criterion(out, img)
latent_loss = latent_loss.mean()
loss = recon_loss + latent_loss_weight * latent_loss
loss.backward()
if scheduler is not None:
scheduler.step()
optimizer.step()
mse_sum += recon_loss.item() * img.shape[0]
mse_n += img.shape[0]
lr = optimizer.param_groups[0]['lr']
loader.set_description(
(
f'epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; '
f'latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; '
f'lr: {lr:.5f}'
)
)
if i % 100 == 0:
model.eval()
sample = img[:sample_size]
with torch.no_grad():
out, _ = model(sample)
utils.save_image(
torch.cat([sample, out], 0),
f'sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png',
nrow=sample_size,
normalize=True,
range=(-1, 1),
)
model.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--size', type=int, default=256)
parser.add_argument('--epoch', type=int, default=560)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--sched', type=str)
parser.add_argument('path', type=str)
args = parser.parse_args()
print(args)
device = 'cuda'
transform = transforms.Compose(
[
transforms.Resize(args.size),
transforms.CenterCrop(args.size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
dataset = datasets.ImageFolder(args.path, transform=transform)
loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)
model = VQVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = None
if args.sched == 'cycle':
scheduler = CycleScheduler(
optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None
)
for i in range(args.epoch):
train(i, loader, model, optimizer, scheduler, device)
torch.save(model.state_dict(), f'checkpoint/vqvae_{str(i + 1).zfill(3)}.pt')