forked from pytorch/ignite
-
Notifications
You must be signed in to change notification settings - Fork 0
/
baseline_resnet50.py
105 lines (79 loc) · 2.71 KB
/
baseline_resnet50.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Basic training configuration
import os
from functools import partial
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
import torch.distributed as dist
from torchvision.models.resnet import resnet50
import albumentations as A
from albumentations.pytorch import ToTensorV2 as ToTensor
from dataflow.dataloaders import get_train_val_loaders
from dataflow.transforms import denormalize
# ##############################
# Global configs
# ##############################
seed = 19
device = "cuda"
debug = False
# config to measure time passed to prepare batches and report measured time before the training
benchmark_dataflow = True
benchmark_dataflow_num_iters = 100
fp16_opt_level = "O2"
val_interval = 2
train_crop_size = 224
val_crop_size = 320
batch_size = 64 # batch size per local rank
num_workers = 10 # num_workers per local rank
# ##############################
# Setup Dataflow
# ##############################
assert "DATASET_PATH" in os.environ
data_path = os.environ["DATASET_PATH"]
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
train_transforms = A.Compose(
[
A.RandomResizedCrop(train_crop_size, train_crop_size, scale=(0.08, 1.0)),
A.HorizontalFlip(),
A.CoarseDropout(max_height=32, max_width=32),
A.HueSaturationValue(),
A.Normalize(mean=mean, std=std),
ToTensor(),
]
)
val_transforms = A.Compose(
[
# https://github.com/facebookresearch/FixRes/blob/b27575208a7c48a3a6e0fa9efb57baa4021d1305/imnet_resnet50_scratch/transforms.py#L76
A.Resize(int((256 / 224) * val_crop_size), int((256 / 224) * val_crop_size)),
A.CenterCrop(val_crop_size, val_crop_size),
A.Normalize(mean=mean, std=std),
ToTensor(),
]
)
train_loader, val_loader, train_eval_loader = get_train_val_loaders(
data_path,
train_transforms=train_transforms,
val_transforms=val_transforms,
batch_size=batch_size,
num_workers=num_workers,
val_batch_size=batch_size,
pin_memory=True,
train_sampler="distributed",
val_sampler="distributed",
)
# Image denormalization function to plot predictions with images
img_denormalize = partial(denormalize, mean=mean, std=std)
# ##############################
# Setup Model
# ##############################
model = resnet50(pretrained=False)
# ##############################
# Setup Solver
# ##############################
num_epochs = 105
criterion = nn.CrossEntropyLoss()
le = len(train_loader)
base_lr = 0.1 * (batch_size * dist.get_world_size() / 256.0)
optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=1e-4)
lr_scheduler = lrs.MultiStepLR(optimizer, milestones=[30 * le, 60 * le, 90 * le, 100 * le], gamma=0.1)