-
Notifications
You must be signed in to change notification settings - Fork 0
/
eureka_train.py
124 lines (100 loc) · 4.73 KB
/
eureka_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
eureka_train.py
Zhiang Chen, Dec 25 2019
train mask rcnn with eureka data
"""
import transforms as T
from engine import train_one_epoch, evaluate
import utils
import torch
from data import Dataset
from model import get_model_instance_segmentation
def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
if __name__ == '__main__':
# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device('cuda:0')
# our dataset has three classes only - background, non-damaged, and damaged
# background, nd, d0, d1, d2, d3
num_classes = 6 # 3 or 6
# use our dataset and defined transformations
dataset = Dataset("./datasets/Eureka/images/", "./datasets/Eureka/labels/", get_transform(train=True), readsave=False, include_name=False)
dataset_test = Dataset("./datasets/Eureka/images_test/", "./datasets/Eureka/labels/", get_transform(train=False), savePickle=False,readsave=False, include_name=False)
# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices)
indices_test = torch.randperm(len(dataset_test)).tolist()
dataset_test = torch.utils.data.Subset(dataset_test, indices_test)
# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=4, shuffle=True, num_workers=4,
collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, shuffle=False, num_workers=4,
collate_fn=utils.collate_fn)
# get the model using our helper function
mask_rcnn = get_model_instance_segmentation(num_classes, None, None, False)
read_param = False
if read_param:
mask_rcnn.load_state_dict(torch.load("trained_param_eureka_aug_bin/epoch_0009.param"))
# move model to the right device
mask_rcnn.to(device)
# construct an optimizer
params = [p for p in mask_rcnn.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.01,
momentum=0.9, weight_decay=0.00001)
# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=10,
gamma=0.1)
init_epoch = 0
num_epochs = 32
#save_param = "trained_param_eureka_mult/epoch_{:04d}.param".format(0)
#torch.save(mask_rcnn.state_dict(), save_param)
save_param = "trained_param_eureka_mult/epoch_{:04d}.param".format(25)
mask_rcnn.load_state_dict(torch.load(save_param))
evaluate(mask_rcnn, data_loader_test, device=device)
"""
for epoch in range(init_epoch, init_epoch + num_epochs):
save_param = "trained_param_eureka_mult/epoch_{:04d}.param".format(epoch)
#torch.save(mask_rcnn.state_dict(), save_param)
#save_param = "trained_param_eureka_aug_bin/epoch_{:04d}.param".format(epoch)
#torch.save(mask_rcnn.state_dict(), save_param)
# train for one epoch, printing every 10 iterations
train_one_epoch(mask_rcnn, optimizer, data_loader, device, epoch, print_freq=500)
# update the learning rate
lr_scheduler.step()
# evaluate on the test dataset
evaluate(mask_rcnn, data_loader_test, device=device)
torch.save(mask_rcnn.state_dict(), save_param)
"""
"""
num_classes = 3
# use our dataset and defined transformations
dataset = Dataset("./datasets/Eureka/images/", "./datasets/Eureka/labels/", get_transform(train=True), readsave=False)
dataset_test = Dataset("./datasets/Eureka/images_test/", "./datasets/Eureka/labels_test/", get_transform(train=False), readsave=False)
# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices)
indices_test = torch.randperm(len(dataset_test)).tolist()
dataset_test = torch.utils.data.Subset(dataset_test, indices_test)
# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=2, shuffle=True, num_workers=4,
collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, shuffle=False, num_workers=4,
collate_fn=utils.collate_fn)
# get the model using our helper function
mask_rcnn = get_model_instance_segmentation(num_classes)
read_param = True
if read_param:
mask_rcnn.load_state_dict(torch.load("trained_param/epoch_0099.param"))
# move model to the right device
mask_rcnn.to(device)
"""