-
Notifications
You must be signed in to change notification settings - Fork 2
/
main_celeba.py
143 lines (133 loc) · 6.95 KB
/
main_celeba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# # built-in modules
import os
import argparse
from pprint import pformat
from collections import OrderedDict
import random
# # Torch modules
import torch
from torchvision import transforms, datasets
# # internal imports
from prelude import save_dicts, startup_folders, get_device, save_results_to_csv
from src.composer import CelebACrop
from src.model import AttentionModel
from src.utils import plot_all, plot_loss_all
from src.utils import build_loaders, get_n_parameters
from src.conductor import AttentionTrain
# reproducibility
torch.manual_seed(1984) # Posner & Cohen "Components of visual orienting." (1984)
random.seed(1984)
parser = argparse.ArgumentParser()
parser.add_argument('-n_epochs', type=int, default=32)
parser.add_argument('-batch_size', type=int, default=128)
parser.add_argument('-lr', type=float, default=1e-4)
parser.add_argument('-l2', type=float, default=5e-4)
parser.add_argument('-n_iter', type=int, default=1)
parser.add_argument('-exase', type=str, default="default")
parser.add_argument('-verbose', type=int, default=1)
argus = parser.parse_args()
data_path = r"./data"
train_params = {
"n_epochs": argus.n_epochs,
"batch_size": argus.batch_size,
"lr": argus.lr,
"l2": argus.l2,
"exase": str(argus.exase),
"dir": r"./results",
"milestones": [16, ],
"gamma": 0.1,
"max_grad_norm": 1.0,
}
model_params = {
"in_dims": (3, 128, 128), # input dimensions (channels, height, width)
"n_classes": 2, # number of classes
"out_dim": 2, # output dimensions (could be larger than n_classes)
"normalize": True, # normalize input images
"softness": 0.5, # softness of the attention (scale)
"channels": (3, 4, 8, 8, 16, 16, 32, 32), # channels in the encoder
"residuals": False, # use residuals in the encoder
"kernels": 3, # kernel size
"strides": 1, # stride
"paddings": 1, # padding
"conv_bias": True, # bias in the convolutions
"conv_norms": None, # normalization in the encoder
"conv_dropouts": 0.0, # dropout in the encoder
"conv_funs": torch.nn.ReLU(), # activation function in the encoder
"deconv_funs": torch.nn.Tanh(), # activation function in the decoder
"deconv_norms": None, # normalization in the decoder
"pools": 2, # pooling in the encoder
"rnn_dims": (32, 8, ), # dimensions of the RNN (First value is not RNN but FC)
"rnn_bias": True, # bias in the RNN
"rnn_dropouts": 0.0, # dropout in the RNN
"rnn_funs": torch.nn.ReLU(), # activation function in the RNN
"n_tasks": 1, # number of tasks
"task_weight": False, # use tasks embeddings for the decoder channels (multiplicative)
"task_bias": False, # use tasks embeddings for the decoder channels (additive)
"task_funs": None, # activation function for the tasks embeddings
"rnn_to_fc": True, # Whether to use the RNN layers or FC
"trans_fun": torch.nn.Identity(), # activation function for the bridge
'norm_mean': [0.5, 0.5, 0.5], # mean for the normalization
'norm_std': [1.0, 1.0, 1.0], # std for the normalization
}
# # tasks include the composer, key, params, datasets, dataloaders, loss weights, loss slices, and has prompt
# # Loss weights are for the Cross-Entropy (CE), MSE for attention, and CE for the last label
# # the first CE loss is for the iterations indicated in the loss slices
# # the second CE loss is for the last label (after applying the last attention map)
# # Loss slices determine which iterations are used for the loss
tasks = OrderedDict({})
tasks["Celeb"] = {
"composer": CelebACrop, # composer (torch Dataset)
"key": 0, # key for the task
"params": {"n_iter": argus.n_iter, "hair_dir": None, "in_dims": model_params["in_dims"], "padding": 0, "noise": 0.25, "which": 0},
"datasets": [],
"dataloaders": [],
"loss_w": (0.0, 0.0, 1.0), # Loss weights (Cross-Entropy (CE), MSE for attention, CE last label)
"loss_s": (0, None), # Loss slices (CE, MSE for attention)
"has_prompt": False, # has prompt or not (only used for top-down Search)
}
model_params["n_tasks"] = len(tasks)
results_folder, logger = startup_folders(r"./results", name=f"{argus.exase}_")
for i, k in enumerate(tasks):
assert tasks[k]["key"] == i, f"Key {tasks[k]['key']} must be equal to index {i}!"
(argus.verbose == 1) and logger.info(f"train_params\n {pformat(train_params)}")
(argus.verbose == 1) and logger.info(f"model_params\n {pformat(model_params)}")
(argus.verbose == 1) and logger.info(f"tasks\n {pformat(tasks)}")
# datasets and dataloaders
train_ds = datasets.CelebA(root=data_path, split='train', transform=transforms.ToTensor())
valid_ds = datasets.CelebA(root=data_path, split='valid', transform=transforms.ToTensor())
test_ds = datasets.CelebA(root=data_path, split='test', transform=transforms.ToTensor())
DeVice, num_workers, pin_memory = get_device()
for o in tasks:
tasks[o]["datasets"].append(tasks[o]["composer"](train_ds, **tasks[o]["params"], kind="train"))
tasks[o]["datasets"].append(tasks[o]["composer"](valid_ds, **tasks[o]["params"], kind="valid"))
tasks[o]["datasets"].append(tasks[o]["composer"](test_ds, **tasks[o]["params"], kind="test"))
tasks[o]["dataloaders"] = build_loaders(tasks[o]["datasets"], batch_size=train_params["batch_size"], num_workers=num_workers, pin_memory=pin_memory)
# model and optimizer...
model = AttentionModel(**model_params)
(argus.verbose == 1) and logger.info(model)
(argus.verbose == 1) and logger.info(f"Model has {get_n_parameters(model):,} parameters!")
optimizer = torch.optim.Adam(model.parameters(), lr=train_params["lr"], weight_decay=train_params["l2"])
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=train_params["milestones"], gamma=train_params["gamma"])
conductor = AttentionTrain(model, optimizer, scheduler, tasks, logger, results_folder, train_params["max_grad_norm"], True)
# training...
plot_all(10, model, tasks, results_folder, "_pre", DeVice, logger, (argus.verbose == 1))
conductor.eval(DeVice)
conductor.train(train_params["n_epochs"], DeVice, True)
plot_loss_all(conductor, results_folder)
conductor.eval(DeVice)
plot_all(10, model, tasks, results_folder, "_post", DeVice, logger, False)
# saving...
(argus.verbose == 1) and logger.info("Saving results...")
save_dicts(tasks, results_folder, "tasks", logger)
save_dicts(train_params, results_folder, "train_params", logger)
save_dicts(model_params, results_folder, "model_params", logger)
torch.save(model.state_dict(), os.path.join(results_folder, "model" + ".pth"))
torch.save(optimizer.state_dict(), os.path.join(results_folder, "optimizer" + ".pth"))
for i, task in enumerate(tasks):
save_results_to_csv(conductor.loss_records[i],
os.path.join(results_folder, f"loss_{task}.csv"),
["labels", "masks", "last_label"], logger)
save_results_to_csv(conductor.valid_records[i],
os.path.join(results_folder, f"valid_{task}.csv"),
["CEi", "CEe", "PixErr", "AttAcc", "ClsAcc"], logger)
(argus.verbose == 1) and logger.info("Done!")