-
Notifications
You must be signed in to change notification settings - Fork 32
/
trainer.py
156 lines (114 loc) · 4.52 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.autograd as autograd
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import numpy as np
class log_gaussian:
def __call__(self, x, mu, var):
logli = -0.5*(var.mul(2*np.pi)+1e-6).log() - \
(x-mu).pow(2).div(var.mul(2.0)+1e-6)
return logli.sum(1).mean().mul(-1)
class Trainer:
def __init__(self, G, FE, D, Q):
self.G = G
self.FE = FE
self.D = D
self.Q = Q
self.batch_size = 100
def _noise_sample(self, dis_c, con_c, noise, bs):
idx = np.random.randint(10, size=bs)
c = np.zeros((bs, 10))
c[range(bs),idx] = 1.0
dis_c.data.copy_(torch.Tensor(c))
con_c.data.uniform_(-1.0, 1.0)
noise.data.uniform_(-1.0, 1.0)
z = torch.cat([noise, dis_c, con_c], 1).view(-1, 74, 1, 1)
return z, idx
def train(self):
real_x = torch.FloatTensor(self.batch_size, 1, 28, 28).cuda()
label = torch.FloatTensor(self.batch_size, 1).cuda()
dis_c = torch.FloatTensor(self.batch_size, 10).cuda()
con_c = torch.FloatTensor(self.batch_size, 2).cuda()
noise = torch.FloatTensor(self.batch_size, 62).cuda()
real_x = Variable(real_x)
label = Variable(label, requires_grad=False)
dis_c = Variable(dis_c)
con_c = Variable(con_c)
noise = Variable(noise)
criterionD = nn.BCELoss().cuda()
criterionQ_dis = nn.CrossEntropyLoss().cuda()
criterionQ_con = log_gaussian()
optimD = optim.Adam([{'params':self.FE.parameters()}, {'params':self.D.parameters()}], lr=0.0002, betas=(0.5, 0.99))
optimG = optim.Adam([{'params':self.G.parameters()}, {'params':self.Q.parameters()}], lr=0.001, betas=(0.5, 0.99))
dataset = dset.MNIST('./dataset', transform=transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=1)
# fixed random variables
c = np.linspace(-1, 1, 10).reshape(1, -1)
c = np.repeat(c, 10, 0).reshape(-1, 1)
c1 = np.hstack([c, np.zeros_like(c)])
c2 = np.hstack([np.zeros_like(c), c])
idx = np.arange(10).repeat(10)
one_hot = np.zeros((100, 10))
one_hot[range(100), idx] = 1
fix_noise = torch.Tensor(100, 62).uniform_(-1, 1)
for epoch in range(100):
for num_iters, batch_data in enumerate(dataloader, 0):
# real part
optimD.zero_grad()
x, _ = batch_data
bs = x.size(0)
real_x.data.resize_(x.size())
label.data.resize_(bs, 1)
dis_c.data.resize_(bs, 10)
con_c.data.resize_(bs, 2)
noise.data.resize_(bs, 62)
real_x.data.copy_(x)
fe_out1 = self.FE(real_x)
probs_real = self.D(fe_out1)
label.data.fill_(1)
loss_real = criterionD(probs_real, label)
loss_real.backward()
# fake part
z, idx = self._noise_sample(dis_c, con_c, noise, bs)
fake_x = self.G(z)
fe_out2 = self.FE(fake_x.detach())
probs_fake = self.D(fe_out2)
label.data.fill_(0)
loss_fake = criterionD(probs_fake, label)
loss_fake.backward()
D_loss = loss_real + loss_fake
optimD.step()
# G and Q part
optimG.zero_grad()
fe_out = self.FE(fake_x)
probs_fake = self.D(fe_out)
label.data.fill_(1.0)
reconstruct_loss = criterionD(probs_fake, label)
q_logits, q_mu, q_var = self.Q(fe_out)
class_ = torch.LongTensor(idx).cuda()
target = Variable(class_)
dis_loss = criterionQ_dis(q_logits, target)
con_loss = criterionQ_con(con_c, q_mu, q_var)*0.1
G_loss = reconstruct_loss + dis_loss + con_loss
G_loss.backward()
optimG.step()
if num_iters % 100 == 0:
print('Epoch/Iter:{0}/{1}, Dloss: {2}, Gloss: {3}'.format(
epoch, num_iters, D_loss.data.cpu().numpy(),
G_loss.data.cpu().numpy())
)
noise.data.copy_(fix_noise)
dis_c.data.copy_(torch.Tensor(one_hot))
con_c.data.copy_(torch.from_numpy(c1))
z = torch.cat([noise, dis_c, con_c], 1).view(-1, 74, 1, 1)
x_save = self.G(z)
save_image(x_save.data, './tmp/c1.png', nrow=10)
con_c.data.copy_(torch.from_numpy(c2))
z = torch.cat([noise, dis_c, con_c], 1).view(-1, 74, 1, 1)
x_save = self.G(z)
save_image(x_save.data, './tmp/c2.png', nrow=10)