-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
82 lines (66 loc) · 2.72 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import time
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
from torchvision.utils import save_image
from model import StackedAutoEncoder
if not os.path.exists('./imgs'):
os.mkdir('./imgs')
def to_img(x):
x = x.view(x.size(0), 3, 32, 32)
return x
num_epochs = 1000
batch_size = 128
img_transform = transforms.Compose([
#transforms.RandomRotation(360),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0, hue=0),
transforms.ToTensor(),
])
dataset = CIFAR10('../data/cifar10/', transform=img_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)
model = StackedAutoEncoder().cuda()
for epoch in range(num_epochs):
if epoch % 10 == 0:
# Test the quality of our features with a randomly initialzed linear classifier.
classifier = nn.Linear(512 * 16, 10).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
model.train()
total_time = time.time()
correct = 0
for i, data in enumerate(dataloader):
img, target = data
target = Variable(target).cuda()
img = Variable(img).cuda()
features = model(img).detach()
prediction = classifier(features.view(features.size(0), -1))
loss = criterion(prediction, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred = prediction.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
total_time = time.time() - total_time
model.eval()
img, _ = data
img = Variable(img).cuda()
features, x_reconstructed = model(img)
reconstruction_loss = torch.mean((x_reconstructed.data - img.data)**2)
if epoch % 10 == 0:
print("Saving epoch {}".format(epoch))
orig = to_img(img.cpu().data)
save_image(orig, './imgs/orig_{}.png'.format(epoch))
pic = to_img(x_reconstructed.cpu().data)
save_image(pic, './imgs/reconstruction_{}.png'.format(epoch))
print("Epoch {} complete\tTime: {:.4f}s\t\tLoss: {:.4f}".format(epoch, total_time, reconstruction_loss))
print("Feature Statistics\tMean: {:.4f}\t\tMax: {:.4f}\t\tSparsity: {:.4f}%".format(
torch.mean(features.data), torch.max(features.data), torch.sum(features.data == 0.0)*100 / features.data.numel())
)
print("Linear classifier performance: {}/{} = {:.2f}%".format(correct, len(dataloader)*batch_size, 100*float(correct) / (len(dataloader)*batch_size)))
print("="*80)
torch.save(model.state_dict(), './CDAE.pth')