-
Notifications
You must be signed in to change notification settings - Fork 29
/
ProtoNet.py
119 lines (89 loc) · 4.83 KB
/
ProtoNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import higher
import typing
import os
from MLBaseClass import MLBaseClass
from _utils import get_cls_prototypes, euclidean_distance, train_val_split
from CommonModels import CNN, ResNet18
class ProtoNet(MLBaseClass):
def __init__(self, config: dict) -> None:
super().__init__(config=config)
self.hyper_net_class = None # dummy to match with MAML and VAMPIRE
def load_model(self, resume_epoch: int, eps_dataloader: torch.utils.data.DataLoader, **kwargs) -> dict:
"""Initialize or load the protonet and its optimizer
Args:
resume_epoch: the index of the file containing the saved model
Returns: a dictionary consisting of
protonet: the prototypical network
optimizer: the optimizer for the prototypical network
"""
model = dict.fromkeys((["hyper_net", "optimizer"]))
if resume_epoch is None:
resume_epoch = self.config['resume_epoch']
if self.config['network_architecture'] == 'CNN':
model["hyper_net"] = CNN(
dim_output=None,
bn_affine=self.config['batchnorm'],
stride_flag=self.config['strided']
)
elif self.config['network_architecture'] == 'ResNet18':
model["hyper_net"] = ResNet18(
dim_output=None,
bn_affine=self.config['batchnorm']
)
else:
raise NotImplementedError('Network architecture is unknown. Please implement it in the CommonModels.py.')
# ---------------------------------------------------------------
# run a dummy task to initialize lazy modules defined in base_net
# ---------------------------------------------------------------
for eps_data in eps_dataloader:
# split data into train and validation
split_data = train_val_split(eps_data=eps_data, k_shot=self.config['k_shot'])
# run to initialize lazy modules
model["hyper_net"].forward(split_data['x_t'])
break
params = torch.nn.utils.parameters_to_vector(parameters=model["hyper_net"].parameters())
print('Number of parameters of the base network = {0:,}.\n'.format(params.numel()))
# move to device
model["hyper_net"].to(self.config['device'])
# optimizer
model["optimizer"] = torch.optim.Adam(params=model["hyper_net"].parameters(), lr=self.config['meta_lr'])
# load model if there is saved file
if resume_epoch > 0:
# path to the saved file
checkpoint_path = os.path.join(self.config['logdir'], 'Epoch_{0:d}.pt'.format(resume_epoch))
# load file
saved_checkpoint = torch.load(
f=checkpoint_path,
map_location=lambda storage, loc: storage.cuda(self.config['device'].index) if self.config['device'].type == 'cuda' else storage
)
# load state dictionaries
model["hyper_net"].load_state_dict(state_dict=saved_checkpoint['hyper_net_state_dict'])
model["optimizer"].load_state_dict(state_dict=saved_checkpoint['opt_state_dict'])
# update learning rate
for param_group in model["optimizer"].param_groups:
if param_group['lr'] != self.config['meta_lr']:
param_group['lr'] = self.config['meta_lr']
model['f_base_net'] = model['hyper_net']
return model
def adaptation(self, x: torch.Tensor, y: torch.Tensor, model: dict) -> higher.patch._MonkeyPatchBase:
"""Calculate the prototype of each class
"""
z = model["hyper_net"].forward(x) # embed data into the latent space
cls_prototypes = get_cls_prototypes(x=z, y=y)
return cls_prototypes
def prediction(self, x: torch.Tensor, adapted_hyper_net: torch.Tensor, model: dict) -> torch.Tensor:
z = model["hyper_net"].forward(x)
distance_matrix = euclidean_distance(matrixN=z, matrixM=adapted_hyper_net)
logits = -distance_matrix
return logits
def validation_loss(self, x: torch.Tensor, y: torch.Tensor, adapted_hyper_net: torch.Tensor, model: dict) -> torch.Tensor:
logits = self.prediction(x=x, adapted_hyper_net=adapted_hyper_net, model=model)
loss = torch.nn.functional.cross_entropy(input=logits, target=y)
return loss
def evaluation(self, x_t: torch.Tensor, y_t: torch.Tensor, x_v: torch.Tensor, y_v: torch.Tensor, model: dict) -> typing.Tuple[float, float]:
class_prototypes = self.adaptation(x=x_t, y=y_t, model=model)
logits = self.prediction(x=x_v, adapted_hyper_net=class_prototypes, model=model)
loss = torch.nn.functional.cross_entropy(input=logits, target=y_v)
accuracy = (logits.argmax(dim=1) == y_v).float().mean().item()
return loss.item(), accuracy * 100