forked from layumi/Person_reID_baseline_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepare_static.py
115 lines (90 loc) · 3.67 KB
/
prepare_static.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# -*- coding: utf-8 -*-
from __future__ import print_function, division
import argparse
import torch
from torchvision import datasets, transforms
import time
import os
version = torch.__version__
######################################################################
# Options
# --------
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--data_dir',default='/home/zzd/Market/pytorch',type=str, help='training dir path')
parser.add_argument('--train_all', action='store_true', help='use all training data' )
parser.add_argument('--color_jitter', action='store_true', help='use color jitter in training' )
parser.add_argument('--batchsize', default=128, type=int, help='batchsize')
opt = parser.parse_args()
data_dir = opt.data_dir
######################################################################
# Load Data
# ---------
#
transform_train_list = [
#transforms.RandomResizedCrop(size=128, scale=(0.75,1.0), ratio=(0.75,1.3333), interpolation=3), #Image.BICUBIC)
transforms.Resize((288,144), interpolation=3),
#transforms.RandomCrop((256,128)),
#transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
transform_val_list = [
transforms.Resize(size=(256,128),interpolation=3), #Image.BICUBIC
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
print(transform_train_list)
data_transforms = {
'train': transforms.Compose( transform_train_list ),
'val': transforms.Compose(transform_val_list),
}
train_all = ''
if opt.train_all:
train_all = '_all'
image_datasets = {}
image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train' + train_all),
data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),
data_transforms['val'])
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
shuffle=True, num_workers=16)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
use_gpu = torch.cuda.is_available()
######################################################################
# prepare_dataset
# ------------------
#
# Now, let's write a general function to train a model. Here, we will
# illustrate:
#
# - Scheduling the learning rate
# - Saving the best model
#
# In the following, parameter ``scheduler`` is an LR scheduler object from
# ``torch.optim.lr_scheduler``.
def prepare_model():
since = time.time()
num_epochs = 1
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train']:
mean = torch.zeros(3)
std = torch.zeros(3)
# Iterate over data.
for data in dataloaders[phase]:
# get the inputs
inputs, labels = data
now_batch_size,c,h,w = inputs.shape
mean += torch.sum(torch.mean(torch.mean(inputs,dim=3),dim=2),dim=0)
std += torch.sum(torch.std(inputs.view(now_batch_size,c,h*w),dim=2),dim=0)
print(mean/dataset_sizes['train'])
print(std/dataset_sizes['train'])
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
return
prepare_model()