-
Notifications
You must be signed in to change notification settings - Fork 0
/
my_mnist.py
74 lines (64 loc) · 2.74 KB
/
my_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import torch.nn.functional as F
from preprocessing import *
import torch.optim as optim
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
import os
if __name__ == "__main__":
cuda = False
device = torch.device("cuda" if torch.cuda.is_available() and cuda else "cpu")
net = Net().to(device)
if os.path.exists("net"):
checkpoint = torch.load("net")
net.load_state_dict(checkpoint)
else:
data = training_set_image()
target = training_set_label()
data, target = data.to(device), target.to(device)
# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5)
batch_size = 29999
for epoch in range(1, 5):
net.train()
match = 0
for batch_idx in range(0,int(len(data)/batch_size)):
optimizer.zero_grad() # zero the gradient buffers
output = net(data[batch_idx * batch_size:batch_idx * batch_size+batch_size])
loss = F.nll_loss(output, target[batch_idx * batch_size:batch_idx * batch_size+batch_size])
loss.backward()
optimizer.step() # Does the update
# print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
# epoch, batch_idx * batch_size, len(data),
# 100. * batch_idx*batch_size / len(data), loss.item()))
# 计算训练集的准确度
_, predicted = torch.max(output.data, 1)
match += (predicted==target[batch_idx * batch_size:batch_idx * batch_size+batch_size]).sum().item()
print("Train Epoch: {} accuracy:{} ".format(epoch,match/len(data)))
torch.save(net.state_dict(),"net")
test = test_set_image()
test_label = test_set_label()
total = test.size(0)
test, test_label = test.to(device), test_label.to(device)
match = 0
batch_size = 400
for batch_idx in range(0,int(len(test)/batch_size)):
result = net(test[batch_idx * batch_size:batch_idx * batch_size+batch_size])
_, predicted = torch.max(result.data, 1)
match += (predicted==test_label[batch_idx * batch_size:batch_idx * batch_size+batch_size]).sum().item()
print("accuracy: ",match/total)