You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am writing a blog (I already presented it in this subreddit) and in my last post, I did a performance analysis of MAML. I ran several experiments, basically trying at the Meta-Learning level both SGD and Adam and different (Meta-)LR's, but summarizing when I try Adam with LR=10^-4 the training is too unstable. At the same time, if I modify it to LR=10^-5 the curve is better but it doesn't improve much (basically the Loss function depends way more on the initialization). Do you have ideas on how to overcome this issue? I think I could apply some Batch Normalization but in Meta-Learning samples are problems, and I'm not sure about if Batch Normalization will work in Meta-Learning.
I'll add images from the last Loss function (raw, smoothed and smoothed+zoomed).
My code (also in the post and not necessary to read for the issue, just for support):
import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, RandomSampler, SubsetRandomSampler, BatchSampler
import torchvision
import matplotlib.pyplot as plt
omniglot_raw = torchvision.datasets.Omniglot(root="./dataset/omniglot", download=True, transform=torchvision.transforms.ToTensor())
alphabets = omniglot_raw._alphabets
characters = omniglot_raw._characters
num_alphabets = len(alphabets)
num_characters = len(characters)
class MetaSplit:
def __init__(self, ratio, total_num_characters):
self.alphabets = []
self.num_characters = 0
self.min_num_characters = total_num_characters * ratio
self.num_problems = None
metasplits = {'metatrain': MetaSplit(0.7, num_characters),
'metaval': MetaSplit(0.15, num_characters),
'metatest': MetaSplit(0.15, num_characters)}
chars_per_alphabet = {alph: [char.split('/')[0] for char in characters].count(alph) for alph in alphabets}
random.shuffle(alphabets)
current_metasplit = 'metatrain'
switch_metasplit_from = {'metatrain': 'metaval', 'metaval': 'metatest'}
for alphabet in alphabets:
if not metasplits[current_metasplit].num_characters < metasplits[current_metasplit].min_num_characters:
current_metasplit = switch_metasplit_from[current_metasplit]
metasplits[current_metasplit].alphabets.append(alphabet)
metasplits[current_metasplit].num_characters += chars_per_alphabet[alphabet]
for metasplit in metasplits:
metasplits[metasplit].num_problems = 1/2 * sum([chars_per_alphabet[alph]**2 - chars_per_alphabet[alph] for alph in metasplits[metasplit].alphabets])
metabatch_size = 8
num_metabatches = int(metasplits['metatrain'].num_problems / metabatch_size)
class MetaLoader():
"""
"""
def __init__(self, base_dataset, metabatch_size, batch_sizes,
chars_per_alphabet, problem_ratios):
self.base_dataset = base_dataset
self.metabatch_size = metabatch_size
self.batch_sizes = batch_sizes
self.chars_per_alph = chars_per_alphabet
self.problem_ratios = [0.75, 0.15, 0.1]
self.problems_per_alph = {}
self.num_problems = 0
self.__load_quantitative_info__()
self.metasampler = BatchSampler(RandomSampler(range(self.num_problems)),
batch_size=self.metabatch_size,
drop_last=True)
def __load_quantitative_info__(self):
for alphb in self.chars_per_alph:
self.problems_per_alph[alphb] = int((self.chars_per_alph[alphb]**2 -
self.chars_per_alph[alphb]) / 2)
self.num_problems += self.problems_per_alph[alphb]
def __has_reached__(self, idx, ctr, current):
return ctr + current > idx
def __problem_idx_to_samples_idx__(self, problem_idx, alphb,
prbs_on_prev_alphabets,
chars_on_prev_alphabets):
pb_idx_in_alph = problem_idx - prbs_on_prev_alphabets
ichars_in_alphabet = (int(pb_idx_in_alph / self.chars_per_alph[alphb]),
pb_idx_in_alph % self.chars_per_alph[alphb])
ichars = tuple([ich + chars_on_prev_alphabets \
for ich in ichars_in_alphabet])
return [sample_idx for charidx in ichars
for sample_idx in range(charidx * 20, (charidx + 1) * 20)]
def __build_problem_loader_from_samples__(self, samples_idx):
random.shuffle(samples_idx)
train_val_frontier = int(len(samples_idx) * self.problem_ratios[0])
val_test_frontier = int(train_val_frontier +
len(samples_idx) * self.problem_ratios[1])
samples_idx_train = samples_idx[:train_val_frontier]
samples_idx_val = samples_idx[train_val_frontier:val_test_frontier]
samples_idx_test = samples_idx[val_test_frontier:]
train_sampler = BatchSampler(SubsetRandomSampler(samples_idx_train),
batch_size=self.batch_sizes['train'],
drop_last=True)
val_sampler = BatchSampler(SubsetRandomSampler(samples_idx_val),
batch_size=self.batch_sizes['val'],
drop_last=True)
test_sampler = BatchSampler(SubsetRandomSampler(samples_idx_test),
batch_size=self.batch_sizes['test'],
drop_last=True)
loaders = {'train': DataLoader(dataset=self.base_dataset,
batch_sampler=train_sampler),
'val': DataLoader(dataset=self.base_dataset,
batch_sampler=val_sampler),
'test': DataLoader(dataset=self.base_dataset,
batch_sampler=test_sampler)}
return loaders
def __get_problem_loader__(self, problem_idx):
pbs_ctr = 0
chars_ctr = 0
for alphb in self.chars_per_alph:
if not self.__has_reached__(problem_idx, pbs_ctr,
self.problems_per_alph[alphb]):
pbs_ctr += self.problems_per_alph[alphb]
chars_ctr += self.chars_per_alph[alphb]
else:
problem_samples_idx = self.__problem_idx_to_samples_idx__(
problem_idx, alphb, pbs_ctr, chars_ctr)
return self.__build_problem_loader_from_samples__(
problem_samples_idx)
def __iter__(self):
for imetabatch, metabatch in enumerate(self.metasampler):
problem_loaders = []
for problem_idx in metabatch:
problem_loaders.append(self.__get_problem_loader__(problem_idx))
yield problem_loaders
chars_per_alphabet = {split: {alph: [char.split('/')[0] for char in characters].count(alph) for alph in metasplits[split].alphabets} for split in metasplits}
metatrain_loader = MetaLoader(base_dataset=omniglot_raw, metabatch_size=metabatch_size, batch_sizes={'train': 8, 'val': 1, 'test': 1}, chars_per_alphabet=chars_per_alphabet['metatrain'], problem_ratios = [0.75, 0.15, 0.1])
metaval_loader = MetaLoader(base_dataset=omniglot_raw, metabatch_size=metabatch_size, batch_sizes={'train': 8, 'val': 1, 'test': 1}, chars_per_alphabet=chars_per_alphabet['metaval'], problem_ratios = [0.75, 0.15, 0.1])
metatest_loader = MetaLoader(base_dataset=omniglot_raw, metabatch_size=1, batch_sizes={'train': 8, 'val': 1, 'test': 1}, chars_per_alphabet=chars_per_alphabet['metatest'], problem_ratios = [0.75, 0.15, 0.1])
n_epochs = 15
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 10, 5)
self.conv3 = nn.Conv2d(10, 12, 5)
self.conv4 = nn.Conv2d(12, 16, 5)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(16 * 2 * 2, 10)
self.fc2 = nn.Linear(10, 1)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = self.pool(F.relu(self.conv4(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.sigmoid(self.fc2(x))
x = x.squeeze()
return x
def process_labels(labels_raw, ref_label):
return (labels_raw == ref_label).float()
def preprocess_inputs(inputs):
return (1- inputs) * 255
def make_step(model, outputs, labels, update_lr, in_weights):
loss = criterion(outputs, labels)
grads = torch.autograd.grad(loss, model.parameters())
out_weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grads, in_weights)))
accuracy = (((1 - outputs) < outputs).float() == labels).sum() / outputs.shape[0]
return out_weights, loss, accuracy
def update_model(model, new_weights, param_keys):
for param, param_key in zip(new_weights, param_keys):
model._modules[param_key[0]]._parameters[param_key[1]] = param
toy_metabatch = next(iter(metatrain_loader))
toy_problem_loader = toy_metabatch[0]['train']
toy_problem_loader_val = toy_metabatch[0]['val']
toy_problem_loader_test = toy_metabatch[0]['test']
# Commented out IPython magic to ensure Python compatibility.
class Learner(nn.Module):
"""
"""
def __init__(self, config, imgc, imgsz):
"""
:param config: network config file, type:list of (string, list)
:param imgc: 1 or 3
:param imgsz: 28 or 84
"""
super(Learner, self).__init__()
self.config = config
# this dict contains all tensors needed to be optimized
self.vars = nn.ParameterList()
# running_mean and running_var
self.vars_bn = nn.ParameterList()
for i, (name, param) in enumerate(self.config):
if name is 'conv2d':
# [ch_out, ch_in, kernelsz, kernelsz]
w = nn.Parameter(torch.ones(*param[:4]))
# gain=1 according to cbfin's implementation
torch.nn.init.kaiming_normal_(w)
self.vars.append(w)
# [ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[0])))
elif name is 'convt2d':
# [ch_in, ch_out, kernelsz, kernelsz, stride, padding]
w = nn.Parameter(torch.ones(*param[:4]))
# gain=1 according to cbfin's implementation
torch.nn.init.kaiming_normal_(w)
self.vars.append(w)
# [ch_in, ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[1])))
elif name is 'linear':
# [ch_out, ch_in]
w = nn.Parameter(torch.ones(*param))
# gain=1 according to cbfinn's implementation
torch.nn.init.kaiming_normal_(w)
self.vars.append(w)
# [ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[0])))
elif name is 'bn':
# [ch_out]
w = nn.Parameter(torch.ones(param[0]))
self.vars.append(w)
# [ch_out]
self.vars.append(nn.Parameter(torch.zeros(param[0])))
# must set requires_grad=False
running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
self.vars_bn.extend([running_mean, running_var])
elif name in ['tanh', 'relu', 'upsample', 'avg_pool2d', 'max_pool2d',
'flatten', 'reshape', 'leakyrelu', 'sigmoid']:
continue
else:
raise NotImplementedError
def extra_repr(self):
info = ''
for name, param in self.config:
if name is 'conv2d':
tmp = 'conv2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)'\
# %(param[1], param[0], param[2], param[3], param[4], param[5],)
info += tmp + '\n'
elif name is 'convt2d':
tmp = 'convTranspose2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)'\
# %(param[0], param[1], param[2], param[3], param[4], param[5],)
info += tmp + '\n'
elif name is 'linear':
tmp = 'linear:(in:%d, out:%d)'%(param[1], param[0])
info += tmp + '\n'
elif name is 'leakyrelu':
tmp = 'leakyrelu:(slope:%f)'%(param[0])
info += tmp + '\n'
elif name is 'avg_pool2d':
tmp = 'avg_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2])
info += tmp + '\n'
elif name is 'max_pool2d':
tmp = 'max_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2])
info += tmp + '\n'
elif name in ['flatten', 'tanh', 'relu', 'upsample', 'reshape', 'sigmoid', 'use_logits', 'bn']:
tmp = name + ':' + str(tuple(param))
info += tmp + '\n'
else:
raise NotImplementedError
return info
def forward(self, x, vars=None, bn_training=True):
"""
This function can be called by finetunning, however, in finetunning, we dont wish to update
running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
:param x: [b, 1, 28, 28]
:param vars:
:param bn_training: set False to not update
:return: x, loss, likelihood, kld
"""
if vars is None:
vars = self.vars
idx = 0
bn_idx = 0
for name, param in self.config:
if name is 'conv2d':
w, b = vars[idx], vars[idx + 1]
# remember to keep synchrozied of forward_encoder and forward_decoder!
x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
idx += 2
# print(name, param, '\tout:', x.shape)
elif name is 'convt2d':
w, b = vars[idx], vars[idx + 1]
# remember to keep synchrozied of forward_encoder and forward_decoder!
x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5])
idx += 2
# print(name, param, '\tout:', x.shape)
elif name is 'linear':
w, b = vars[idx], vars[idx + 1]
x = F.linear(x, w, b)
idx += 2
# print('forward:', idx, x.norm().item())
elif name is 'bn':
w, b = vars[idx], vars[idx + 1]
running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1]
x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
idx += 2
bn_idx += 2
elif name is 'flatten':
# print(x.shape)
x = x.view(x.size(0), -1)
elif name is 'reshape':
# [b, 8] => [b, 2, 2, 2]
x = x.view(x.size(0), *param)
elif name is 'relu':
x = F.relu(x, inplace=param[0])
elif name is 'leakyrelu':
x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
elif name is 'tanh':
x = F.tanh(x)
elif name is 'sigmoid':
x = torch.sigmoid(x)
elif name is 'upsample':
x = F.upsample_nearest(x, scale_factor=param[0])
elif name is 'max_pool2d':
x = F.max_pool2d(x, param[0], param[1], param[2])
elif name is 'avg_pool2d':
x = F.avg_pool2d(x, param[0], param[1], param[2])
else:
raise NotImplementedError
# make sure variable is used properly
assert idx == len(vars)
assert bn_idx == len(self.vars_bn)
return x
def zero_grad(self, vars=None):
"""
:param vars:
:return:
"""
with torch.no_grad():
if vars is None:
for p in self.vars:
if p.grad is not None:
p.grad.zero_()
else:
for p in vars:
if p.grad is not None:
p.grad.zero_()
def parameters(self):
"""
override this function since initial parameters will return with a generator.
:return:
"""
return self.vars
net_config = [
('conv2d', [6, 1, 5, 5, 1, 0]),
('relu', [True]),
('max_pool2d', [2, 2, 0]),
('conv2d', [10, 6, 5, 5, 1, 0]),
('relu', [True]),
('max_pool2d', [2, 2, 0]),
('conv2d', [12, 10, 5, 5, 1, 0]),
('relu', [True]),
('max_pool2d', [2, 2, 0]),
('conv2d', [16, 12, 5, 5, 1, 0]),
('relu', [True]),
('max_pool2d', [2, 2, 0]),
('flatten', []),
('linear', [10, 64]),
('relu', [True]),
('linear', [1, 10]),
('sigmoid', []),
('reshape', [])
]
printlines = []
model = Learner(net_config, imgc=1, imgsz=28)
criterion = nn.BCEWithLogitsLoss()
update_lr = 0.01
meta_lr = 0.00001
n_epochs = 15
n_metaepochs = 2
metaoptimizer = optim.Adam(model.parameters(), lr=meta_lr)
for metaepoch in range(n_metaepochs):
printlines.append('===============================')
printlines.append(f'// Meta-Epoch {metaepoch + 1} //')
printlines.append('===============================')
print('===============================')
print(f'// Meta-Epoch {metaepoch + 1} //')
print('===============================')
for mi, metabatch in enumerate(metatrain_loader, 0): # Meta-step
print(mi)
printlines.append(f'{mi} updates at Meta-Level')
print(f'{mi} updates at Meta-Level')
running_loss = 0.0 # At each meta-step, the loss is reset
# No need to store initial weights
for pi, problem_loaders in enumerate(metabatch, 0): # Problem in the meta-batch
printlines.append(f'- Problem {pi + 1} -')
print(f'- Problem {pi + 1} -')
problem_loader = problem_loaders['train']
problem_loader_val = problem_loaders['val']
ref_label = None
new_weights = model.parameters()
for epoch in range(n_epochs): # Epoch in the problem training
printlines.append(f'Epoch {epoch + 1}')
print(f'Epoch {epoch + 1}')
val_loss = 0.0
val_accuracy = 0.0
for i, data in enumerate(problem_loader, 0): # Step in the problem
inputs_raw, labels_raw = data
inputs = 1 - inputs_raw
outputs = model(inputs, new_weights)
if ref_label is None:
ref_label = labels_raw[0] # On a new problem (1st step) adjust label mapping
labels = process_labels(labels_raw, ref_label)
new_weights, loss, accuracy = make_step(model, outputs, labels, update_lr, new_weights)
# As the prediction is intrinsically done with the new weights, no need to actually update the model at the Learning Level
printlines.append(f'Epoch {epoch + 1}, step {i + 1:5d}], Loss: {loss.item()}, Accuracy: {accuracy}')
print(f'Epoch {epoch + 1}, step {i + 1:5d}], Loss: {loss.item()}, Accuracy: {accuracy}')
for iv, datav in enumerate(problem_loader_val): # At the end of the training process in an epoch of a problem we compute a whole validation
inputs_rawv, labels_rawv = datav
inputsv = 1 - inputs_rawv
outputsv = model(inputsv, new_weights)
labelsv = process_labels(labels_rawv, ref_label)
lossv = criterion(outputsv, labelsv) # Loss in a validation batch
val_loss += lossv.item()
val_accuracy += (((1 - outputsv) < outputsv).float() == labelsv).sum()
printlines.append(f'Epoch {epoch + 1}, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}') # Loss and accuracy averaged for all validation batches in the problem, displayed after whole validation
print(f'Epoch {epoch + 1}, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}') # Loss and accuracy averaged for all validation batches in the problem, displayed after whole validation
running_loss += lossv # After all epochs (all training process) in a single problem the validation loss is added
# Again, no need to update the model to the initial weights
metastep_loss = running_loss / metabatch_size # The added validation losses of all problems in the metabatch are averaged
metaoptimizer.zero_grad() # We perform gradient descent at the Meta-Level over the averaged validation loss
metastep_loss.backward()
metaoptimizer.step()
if (mi + 1) % 1000 == 0: # Meta-validation performed every 1000 meta-steps
printlines.append('META-VALIDATION STEP:')
print('META-VALIDATION STEP:')
for mbvi, metabatch_val in enumerate(metaval_loader): # Meta-validation meta-step
if (mbvi + 1) % 10 == 0:
printlines.append(f'Validation step {mbvi + 1}')
print(f'Validation step {mbvi + 1}')
for problem_loaders in metabatch_val: # Problem in the meta-validation meta-batch
problem_loader = problem_loaders['train']
problem_loader_val = problem_loaders['val']
ref_label = None
new_weights = model.parameters()
for epoch in range(n_epochs): # Epoch in the problem training
val_loss = 0.0
val_accuracy = 0.0
for i, data in enumerate(problem_loader, 0): # Step in the problem
inputs_raw, labels_raw = data
inputs = 1 - inputs_raw
outputs = model(inputs)
if ref_label is None:
ref_label = labels_raw[0]
labels = process_labels(labels_raw, ref_label)
new_weights, loss, accuracy = make_step(model, outputs, labels, update_lr, new_weights)
# printlines.append(f'Epoch {epoch + 1}, step {i + 1:5d}], Loss: {loss.item()}, Accuracy: {accuracy}')
for iv, datav in enumerate(problem_loader_val): # At the end of the training process in an epoch of a problem we compute a whole validation, as in Meta-Train
inputs_rawv, labels_rawv = datav
inputsv = 1 - inputs_rawv
outputsv = model(inputsv)
labelsv = process_labels(labels_rawv, ref_label)
lossv = criterion(outputsv, labelsv)
val_loss += lossv.item()
val_accuracy += (((1 - outputsv) < outputsv).float() == labelsv).sum()
if (mbvi + 1) % 10 == 0:
printlines.append(f'Last epoch, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}') # The Meta-Validation only runs for informative matters, so our goal is to have this at the end of each problem (every 10 steps)
print(f'Last epoch, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}') # The Meta-Validation only runs for informative matters, so our goal is to have this at the end of each problem (every 10 steps)
printlines.append('END OF META-VALIDATION STEP')
print('END OF META-VALIDATION STEP')
The text was updated successfully, but these errors were encountered:
I am writing a blog (I already presented it in this subreddit) and in my last post, I did a performance analysis of MAML. I ran several experiments, basically trying at the Meta-Learning level both SGD and Adam and different (Meta-)LR's, but summarizing when I try Adam with LR=10^-4 the training is too unstable. At the same time, if I modify it to LR=10^-5 the curve is better but it doesn't improve much (basically the Loss function depends way more on the initialization). Do you have ideas on how to overcome this issue? I think I could apply some Batch Normalization but in Meta-Learning samples are problems, and I'm not sure about if Batch Normalization will work in Meta-Learning.
I'll add images from the last Loss function (raw, smoothed and smoothed+zoomed).
Raw
Smoothed
Smoothed+scaled
My code (also in the post and not necessary to read for the issue, just for support):
The text was updated successfully, but these errors were encountered: