-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_fsvae.py
372 lines (296 loc) · 15.7 KB
/
main_fsvae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import os
import os.path
import random
import numpy as np
import logging
import argparse
import pycuda.driver as cuda
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
import global_v as glv
from network_parser import parse
from datasets import load_dataset_snn
from utils import aboutCudaDevices
from utils import AverageMeter
from utils import CountMulAddSNN
import svae_models.fsvae as fsvae
from svae_models.snn_layers import LIFSpike
import metrics.inception_score as inception_score
import metrics.clean_fid as clean_fid
import metrics.autoencoder_fid as autoencoder_fid
max_accuracy = 0
min_loss = 1000
def add_hook(net):
count_mul_add = CountMulAddSNN()
hook_handles = []
for m in net.modules():
if isinstance(m, torch.nn.Conv3d) or isinstance(m, torch.nn.Linear) or isinstance(m,
torch.nn.ConvTranspose3d) or isinstance(
m, LIFSpike):
handle = m.register_forward_hook(count_mul_add)
hook_handles.append(handle)
return count_mul_add, hook_handles
def write_weight_hist(net, index):
for n, m in net.named_parameters():
root, name = os.path.splitext(n)
writer.add_histogram(root + '/' + name, m, index)
def train(network, trainloader, opti, epoch):
n_steps = glv.network_config['n_steps']
max_epoch = glv.network_config['epochs']
loss_meter = AverageMeter()
recons_meter = AverageMeter()
dist_meter = AverageMeter()
mean_q_z = 0
mean_p_z = 0
mean_sampled_z = 0
network = network.train()
for batch_idx, (real_img, labels) in enumerate(trainloader):
opti.zero_grad()
real_img = real_img.to(init_device, non_blocking=True)
labels = labels.to(init_device, non_blocking=True)
# direct spike input
spike_input = real_img.unsqueeze(-1).repeat(1, 1, 1, 1, n_steps) # (N, C, H, W, T)
x_recon, q_z, p_z, sampled_z = network(spike_input,
scheduled=network_config['scheduled']) # sampled_z(B, C, 1, 1, T)
# print("real_img: ", real_img.shape, real_img.max(), real_img.min(), real_img.mean())
# print("x_recon: ", x_recon.shape, x_recon.max(), x_recon.min(), x_recon.mean())
if network_config['loss_func'] == 'mmd':
losses = network.loss_function_mmd(real_img, x_recon, q_z, p_z)
elif network_config['loss_func'] == 'kld':
losses = network.loss_function_kld(real_img, x_recon, q_z, p_z)
else:
raise ValueError('unrecognized loss function')
losses['loss'].backward()
opti.step()
loss_meter.update(losses['loss'].detach().cpu().item())
recons_meter.update(losses['Reconstruction_Loss'].detach().cpu().item())
dist_meter.update(losses['Distance_Loss'].detach().cpu().item())
mean_q_z = (q_z.mean(0).detach().cpu() + batch_idx * mean_q_z) / (batch_idx + 1) # (C,k,T)
mean_p_z = (p_z.mean(0).detach().cpu() + batch_idx * mean_p_z) / (batch_idx + 1) # (C,k,T)
mean_sampled_z = (sampled_z.mean(0).detach().cpu() + batch_idx * mean_sampled_z) / (batch_idx + 1) # (C,T)
print(
f'Train[{epoch}/{max_epoch}] [{batch_idx}/{len(trainloader)}] Loss: {loss_meter.avg}, RECONS: {recons_meter.avg}, DISTANCE: {dist_meter.avg}')
if batch_idx == len(trainloader) - 1:
os.makedirs(f'{args.project_save_path}/checkpoint/{dataset_name}/{args.name}/imgs/train/', exist_ok=True)
torchvision.utils.save_image((real_img + 1) / 2,
f'{args.project_save_path}/checkpoint/{dataset_name}/{args.name}/imgs/train/epoch{epoch}_input.png')
torchvision.utils.save_image((x_recon + 1) / 2,
f'{args.project_save_path}/checkpoint/{dataset_name}/{args.name}/imgs/train/epoch{epoch}_recons.png')
writer.add_images('Train/input_img', (real_img + 1) / 2, epoch)
writer.add_images('Train/recons_img', (x_recon + 1) / 2, epoch)
# break
logging.info(f"Train [{epoch}] Loss: {loss_meter.avg} ReconsLoss: {recons_meter.avg} DISTANCE: {dist_meter.avg}")
writer.add_scalar('Train/loss', loss_meter.avg, epoch)
writer.add_scalar('Train/recons_loss', recons_meter.avg, epoch)
writer.add_scalar('Train/distance', dist_meter.avg, epoch)
writer.add_scalar('Train/mean_q', mean_q_z.mean().item(), epoch)
writer.add_scalar('Train/mean_p', mean_p_z.mean().item(), epoch)
writer.add_image('Train/mean_sampled_z', mean_sampled_z.unsqueeze(0), epoch)
writer.add_histogram(f'Train/mean_sampled_z_distribution', mean_sampled_z.sum(-1), epoch)
mean_q_z = mean_q_z.permute(1, 0, 2) # (k,C,T)
mean_p_z = mean_p_z.permute(1, 0, 2) # (k,C,T)
writer.add_image(f'Train/mean_q_z', mean_q_z.mean(0).unsqueeze(0))
writer.add_image(f'Train/mean_p_z', mean_p_z.mean(0).unsqueeze(0))
return loss_meter.avg
def test(network, testloader, epoch):
n_steps = glv.network_config['n_steps']
max_epoch = glv.network_config['epochs']
loss_meter = AverageMeter()
recons_meter = AverageMeter()
dist_meter = AverageMeter()
mean_q_z = 0
mean_p_z = 0
mean_sampled_z = 0
count_mul_add, hook_handles = add_hook(net)
network = network.eval()
with torch.no_grad():
for batch_idx, (real_img, labels) in enumerate(testloader):
real_img = real_img.to(init_device, non_blocking=True)
labels = labels.to(init_device, non_blocking=True)
# direct spike input
spike_input = real_img.unsqueeze(-1).repeat(1, 1, 1, 1, n_steps) # (N,C,H,W,T)
x_recon, q_z, p_z, sampled_z = network(spike_input, scheduled=network_config['scheduled'])
if network_config['loss_func'] == 'mmd':
losses = network.loss_function_mmd(real_img, x_recon, q_z, p_z)
elif network_config['loss_func'] == 'kld':
losses = network.loss_function_kld(real_img, x_recon, q_z, p_z)
else:
raise ValueError('unrecognized loss function')
mean_q_z = (q_z.mean(0).detach().cpu() + batch_idx * mean_q_z) / (batch_idx + 1) # (C,k,T)
mean_p_z = (p_z.mean(0).detach().cpu() + batch_idx * mean_p_z) / (batch_idx + 1) # (C,k,T)
mean_sampled_z = (sampled_z.mean(0).detach().cpu() + batch_idx * mean_sampled_z) / (batch_idx + 1) # (C,T)
loss_meter.update(losses['loss'].detach().cpu().item())
recons_meter.update(losses['Reconstruction_Loss'].detach().cpu().item())
dist_meter.update(losses['Distance_Loss'].detach().cpu().item())
print(
f'Test[{epoch}/{max_epoch}] [{batch_idx}/{len(testloader)}] Loss: {loss_meter.avg}, RECONS: {recons_meter.avg}, DISTANCE: {dist_meter.avg}')
if batch_idx == len(testloader) - 1:
os.makedirs(f'{args.project_save_path}/checkpoint/{dataset_name}/{args.name}/imgs/test/', exist_ok=True)
torchvision.utils.save_image((real_img + 1) / 2,
f'{args.project_save_path}/checkpoint/{dataset_name}/{args.name}/imgs/test/epoch{epoch}_input.png')
torchvision.utils.save_image((x_recon + 1) / 2,
f'{args.project_save_path}/checkpoint/{dataset_name}/{args.name}/imgs/test/epoch{epoch}_recons.png')
writer.add_images('Test/input_img', (real_img + 1) / 2, epoch)
writer.add_images('Test/recons_img', (x_recon + 1) / 2, epoch)
# break
logging.info(f"Test [{epoch}] Loss: {loss_meter.avg} ReconsLoss: {recons_meter.avg} DISTANCE: {dist_meter.avg}")
writer.add_scalar('Test/loss', loss_meter.avg, epoch)
writer.add_scalar('Test/recons_loss', recons_meter.avg, epoch)
writer.add_scalar('Test/distance', dist_meter.avg, epoch)
writer.add_scalar('Test/mean_q', mean_q_z.mean().item(), epoch)
writer.add_scalar('Test/mean_p', mean_p_z.mean().item(), epoch)
writer.add_scalar('Test/mul', count_mul_add.mul_sum.item() / len(testloader), epoch)
writer.add_scalar('Test/add', count_mul_add.add_sum.item() / len(testloader), epoch)
for handle in hook_handles:
handle.remove()
writer.add_image('Test/mean_sampled_z', mean_sampled_z.unsqueeze(0), epoch)
writer.add_histogram('Test/mean_sampled_z_distribution', mean_sampled_z.sum(-1), epoch)
mean_q_z = mean_q_z.permute(1, 0, 2) # # (k,C,T)
mean_p_z = mean_p_z.permute(1, 0, 2) # # (k,C,T)
writer.add_image(f'Test/mean_q_z', mean_q_z.mean(0).unsqueeze(0))
writer.add_image(f'Test/mean_p_z', mean_p_z.mean(0).unsqueeze(0))
return loss_meter.avg
def sample(network, epoch, batch_size=128):
network = network.eval()
with torch.no_grad():
sampled_x, sampled_z = network.sample(batch_size)
writer.add_images('Sample/sample_img', (sampled_x + 1) / 2, epoch)
writer.add_image('Sample/mean_sampled_z', sampled_z.mean(0).unsqueeze(0), epoch)
writer.add_histogram('Sample/mean_sampled_z_distribution', sampled_z.mean(0).sum(-1), epoch)
os.makedirs(f'{args.project_save_path}/checkpoint/{dataset_name}/{args.name}/imgs/sample/', exist_ok=True)
torchvision.utils.save_image((sampled_x + 1) / 2, f'{args.project_save_path}/checkpoint/{dataset_name}/{args.name}/imgs/sample/epoch{epoch}_sample.png')
def calc_inception_score(network, epoch, batch_size=256):
network = network.eval()
with torch.no_grad():
if (epoch % 5 == 0) or epoch == glv.network_config['epochs'] - 1:
batch_times = 10
else:
batch_times = 4
inception_mean, inception_std = inception_score.get_inception_score(network, device=init_device,
batch_size=batch_size,
batch_times=batch_times)
writer.add_scalar('Sample/inception_score_mean', inception_mean, epoch)
writer.add_scalar('Sample/inception_score_std', inception_std, epoch)
def calc_clean_fid(network, epoch):
network = network.eval()
with torch.no_grad():
num_gen = 5000
fid_score = clean_fid.get_clean_fid_score(network, glv.network_config['dataset'], init_device, num_gen)
writer.add_scalar('Sample/FID', fid_score, epoch)
def calc_autoencoder_frechet_distance(network, epoch):
network = network.eval()
if glv.network_config['dataset'] == "MNIST":
dataset = 'mnist'
elif glv.network_config['dataset'] == "FashionMNIST":
dataset = 'fashion'
elif glv.network_config['dataset'] == "CelebA":
dataset = 'celeba'
elif glv.network_config['dataset'] == "CIFAR10":
dataset = 'cifar10'
else:
raise ValueError()
with torch.no_grad():
fid_score = autoencoder_fid.get_autoencoder_frechet_distance(network, dataset, init_device, 5000)
writer.add_scalar('Sample/AutoencoderDist', fid_score, epoch)
def seed_all(seed=42):
"""
set random seed.
"""
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if __name__ == '__main__':
seed_all()
parser = argparse.ArgumentParser()
parser.add_argument('-name', default='tmp', type=str)
parser.add_argument('-config', action='store', dest='config', help='The path of config file')
parser.add_argument('-checkpoint', action='store', dest='checkpoint',
help='The path of checkpoint, if use checkpoint')
parser.add_argument('-device', type=int)
parser.add_argument('-project_save_path', default='/data/zhan/FullySpikingVAE-master/', type=str)
try:
args = parser.parse_args()
except:
parser.print_help()
exit(0)
if args.config is None:
raise Exception('Unrecognized config file.')
if args.device is None:
init_device = torch.device("cuda:0")
else:
init_device = torch.device(f"cuda:{args.device}")
logging.info("start parsing settings")
params = parse(args.config)
network_config = params['Network']
logging.info("finish parsing settings")
logging.info(network_config)
print(network_config)
glv.init(network_config, [args.device])
dataset_name = glv.network_config['dataset']
data_path = glv.network_config['data_path']
args.name = f'vanilla_{dataset_name}_1'
os.makedirs(f'{args.project_save_path}/checkpoint/{dataset_name}/{args.name}', exist_ok=True)
writer = SummaryWriter(log_dir=f'{args.project_save_path}/checkpoint/{dataset_name}/{args.name}/tb')
logging.basicConfig(filename=f'{args.project_save_path}/checkpoint/{dataset_name}/{args.name}.log', level=logging.INFO)
# Check whether a GPU is available
if torch.cuda.is_available():
cuda.init()
c_device = aboutCudaDevices()
print(c_device.info())
print("selected device: ", args.device)
else:
raise Exception("only support gpu")
logging.info("dataset loading...")
if dataset_name == "MNIST":
data_path = os.path.expanduser(data_path)
train_loader, test_loader = load_dataset_snn.load_mnist(data_path)
elif dataset_name == "FashionMNIST":
data_path = os.path.expanduser(data_path)
train_loader, test_loader = load_dataset_snn.load_fashionmnist(data_path)
elif dataset_name == "CIFAR10":
data_path = os.path.expanduser(data_path)
train_loader, test_loader = load_dataset_snn.load_cifar10(data_path)
elif dataset_name == "CelebA":
data_path = os.path.expanduser(data_path)
train_loader, test_loader = load_dataset_snn.load_celebA(data_path)
else:
raise Exception('Unrecognized dataset name.')
logging.info("dataset loaded")
if network_config['model'] == 'FSVAE':
net = fsvae.FSVAE()
elif network_config['model'] == 'FSVAE_large':
net = fsvae.FSVAELarge()
else:
raise Exception('not defined model')
net = net.to(init_device)
if args.checkpoint is not None:
checkpoint_path = args.checkpoint
checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint)
optimizer = torch.optim.AdamW(net.parameters(),
lr=glv.network_config['lr'],
betas=(0.9, 0.999),
weight_decay=0.001)
best_loss = 1e8
for e in range(glv.network_config['epochs']):
write_weight_hist(net, e)
if network_config['scheduled']:
net.update_p(e, glv.network_config['epochs'])
logging.info("update p")
train_loss = train(net, train_loader, optimizer, e)
test_loss = test(net, test_loader, e)
torch.save(net.state_dict(), f'{args.project_save_path}/checkpoint/{dataset_name}/{args.name}/checkpoint.pth')
if test_loss < best_loss:
best_loss = test_loss
torch.save(net.state_dict(), f'{args.project_save_path}/checkpoint/{dataset_name}/{args.name}/best.pth')
sample(net, e, batch_size=128)
# calc_inception_score(net, e, batch_size=glv.network_config['sample_batch_size'])
# calc_autoencoder_frechet_distance(net, e)
# calc_clean_fid(net, e)
writer.close()