Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using the Learner object for my project, Loss not behaving at its best #72

Open
Metabloggism opened this issue Feb 9, 2023 · 0 comments

Comments

@Metabloggism
Copy link

I am writing a blog (I already presented it in this subreddit) and in my last post, I did a performance analysis of MAML. I ran several experiments, basically trying at the Meta-Learning level both SGD and Adam and different (Meta-)LR's, but summarizing when I try Adam with LR=10^-4 the training is too unstable. At the same time, if I modify it to LR=10^-5 the curve is better but it doesn't improve much (basically the Loss function depends way more on the initialization). Do you have ideas on how to overcome this issue? I think I could apply some Batch Normalization but in Meta-Learning samples are problems, and I'm not sure about if Batch Normalization will work in Meta-Learning.

I'll add images from the last Loss function (raw, smoothed and smoothed+zoomed).

Raw

Smoothed

Smoothed+scaled

My code (also in the post and not necessary to read for the issue, just for support):

import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, RandomSampler, SubsetRandomSampler, BatchSampler
import torchvision
import matplotlib.pyplot as plt

omniglot_raw = torchvision.datasets.Omniglot(root="./dataset/omniglot", download=True, transform=torchvision.transforms.ToTensor())


alphabets = omniglot_raw._alphabets
characters = omniglot_raw._characters


num_alphabets = len(alphabets)
num_characters = len(characters)

class MetaSplit:
  def __init__(self, ratio, total_num_characters):
    self.alphabets = []
    self.num_characters = 0
    self.min_num_characters = total_num_characters * ratio
    self.num_problems = None

metasplits = {'metatrain': MetaSplit(0.7, num_characters),
              'metaval': MetaSplit(0.15, num_characters),
              'metatest': MetaSplit(0.15, num_characters)}

chars_per_alphabet = {alph: [char.split('/')[0] for char in characters].count(alph) for alph in alphabets}

random.shuffle(alphabets)

current_metasplit = 'metatrain'
switch_metasplit_from = {'metatrain': 'metaval', 'metaval': 'metatest'}

for alphabet in alphabets:
  if not metasplits[current_metasplit].num_characters < metasplits[current_metasplit].min_num_characters:
    current_metasplit = switch_metasplit_from[current_metasplit]
  metasplits[current_metasplit].alphabets.append(alphabet)
  metasplits[current_metasplit].num_characters += chars_per_alphabet[alphabet]

for metasplit in metasplits:
  metasplits[metasplit].num_problems = 1/2 * sum([chars_per_alphabet[alph]**2 - chars_per_alphabet[alph] for alph in metasplits[metasplit].alphabets])

metabatch_size = 8
num_metabatches = int(metasplits['metatrain'].num_problems / metabatch_size)

class MetaLoader():
    """
    """
    def __init__(self, base_dataset, metabatch_size, batch_sizes, 
                 chars_per_alphabet, problem_ratios):
        self.base_dataset = base_dataset
        self.metabatch_size = metabatch_size
        self.batch_sizes = batch_sizes
        self.chars_per_alph = chars_per_alphabet
        self.problem_ratios = [0.75, 0.15, 0.1]
        self.problems_per_alph = {}
        self.num_problems = 0
        self.__load_quantitative_info__()
        self.metasampler = BatchSampler(RandomSampler(range(self.num_problems)), 
                                        batch_size=self.metabatch_size, 
                                        drop_last=True)
    
    def __load_quantitative_info__(self):
        for alphb in self.chars_per_alph:
            self.problems_per_alph[alphb] = int((self.chars_per_alph[alphb]**2 - 
                                                self.chars_per_alph[alphb]) / 2)
            self.num_problems += self.problems_per_alph[alphb]
    
    def __has_reached__(self, idx, ctr, current):
        return ctr + current > idx
    
    def __problem_idx_to_samples_idx__(self, problem_idx, alphb, 
                                       prbs_on_prev_alphabets, 
                                       chars_on_prev_alphabets):
        pb_idx_in_alph = problem_idx - prbs_on_prev_alphabets
        ichars_in_alphabet = (int(pb_idx_in_alph / self.chars_per_alph[alphb]), 
                                pb_idx_in_alph % self.chars_per_alph[alphb])
        ichars = tuple([ich + chars_on_prev_alphabets \
                        for ich in ichars_in_alphabet])
        return [sample_idx for charidx in ichars 
                for sample_idx in range(charidx * 20, (charidx + 1) * 20)]
    
    def __build_problem_loader_from_samples__(self, samples_idx):

        random.shuffle(samples_idx)

        train_val_frontier = int(len(samples_idx) * self.problem_ratios[0])
        val_test_frontier = int(train_val_frontier + 
                                len(samples_idx) * self.problem_ratios[1])
        
        samples_idx_train = samples_idx[:train_val_frontier]
        samples_idx_val = samples_idx[train_val_frontier:val_test_frontier]
        samples_idx_test = samples_idx[val_test_frontier:]

        train_sampler = BatchSampler(SubsetRandomSampler(samples_idx_train), 
                                     batch_size=self.batch_sizes['train'], 
                                     drop_last=True)
        val_sampler = BatchSampler(SubsetRandomSampler(samples_idx_val), 
                                   batch_size=self.batch_sizes['val'], 
                                   drop_last=True)
        test_sampler = BatchSampler(SubsetRandomSampler(samples_idx_test), 
                                    batch_size=self.batch_sizes['test'], 
                                    drop_last=True)
        loaders = {'train': DataLoader(dataset=self.base_dataset, 
                                       batch_sampler=train_sampler),
                   'val': DataLoader(dataset=self.base_dataset, 
                                       batch_sampler=val_sampler),
                   'test': DataLoader(dataset=self.base_dataset, 
                                       batch_sampler=test_sampler)}
        return loaders

        
    def __get_problem_loader__(self, problem_idx):
        pbs_ctr = 0
        chars_ctr = 0
        for alphb in self.chars_per_alph:
            if not self.__has_reached__(problem_idx, pbs_ctr, 
                                        self.problems_per_alph[alphb]):
                pbs_ctr += self.problems_per_alph[alphb]
                chars_ctr += self.chars_per_alph[alphb]
            else:
                problem_samples_idx = self.__problem_idx_to_samples_idx__(
                    problem_idx, alphb, pbs_ctr, chars_ctr)
                return self.__build_problem_loader_from_samples__(
                    problem_samples_idx)

    def  __iter__(self):
        for imetabatch, metabatch in enumerate(self.metasampler):
            problem_loaders = []
            for problem_idx in metabatch:
                problem_loaders.append(self.__get_problem_loader__(problem_idx))
            yield problem_loaders

chars_per_alphabet = {split: {alph: [char.split('/')[0] for char in characters].count(alph) for alph in metasplits[split].alphabets} for split in metasplits}

metatrain_loader = MetaLoader(base_dataset=omniglot_raw, metabatch_size=metabatch_size, batch_sizes={'train': 8, 'val': 1, 'test': 1}, chars_per_alphabet=chars_per_alphabet['metatrain'], problem_ratios = [0.75, 0.15, 0.1])
metaval_loader = MetaLoader(base_dataset=omniglot_raw, metabatch_size=metabatch_size, batch_sizes={'train': 8, 'val': 1, 'test': 1}, chars_per_alphabet=chars_per_alphabet['metaval'], problem_ratios = [0.75, 0.15, 0.1])
metatest_loader = MetaLoader(base_dataset=omniglot_raw, metabatch_size=1, batch_sizes={'train': 8, 'val': 1, 'test': 1}, chars_per_alphabet=chars_per_alphabet['metatest'], problem_ratios = [0.75, 0.15, 0.1])

n_epochs = 15

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 10, 5)
        self.conv3 = nn.Conv2d(10, 12, 5)
        self.conv4 = nn.Conv2d(12, 16, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 2 * 2, 10)
        self.fc2 = nn.Linear(10, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.sigmoid(self.fc2(x))
        x = x.squeeze()
        return x


def process_labels(labels_raw, ref_label):
  return (labels_raw == ref_label).float()

def preprocess_inputs(inputs):
    return (1- inputs) * 255

def make_step(model, outputs, labels, update_lr, in_weights):
    loss = criterion(outputs, labels)
    grads = torch.autograd.grad(loss, model.parameters())
    out_weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grads, in_weights)))
    accuracy = (((1 - outputs) < outputs).float() == labels).sum() / outputs.shape[0]
    return out_weights, loss, accuracy

def update_model(model, new_weights, param_keys):
    for param, param_key in zip(new_weights, param_keys):
        model._modules[param_key[0]]._parameters[param_key[1]] = param

toy_metabatch = next(iter(metatrain_loader))
toy_problem_loader = toy_metabatch[0]['train']
toy_problem_loader_val = toy_metabatch[0]['val']
toy_problem_loader_test = toy_metabatch[0]['test']

# Commented out IPython magic to ensure Python compatibility.
class Learner(nn.Module):
    """

    """

    def __init__(self, config, imgc, imgsz):
        """

        :param config: network config file, type:list of (string, list)
        :param imgc: 1 or 3
        :param imgsz:  28 or 84
        """
        super(Learner, self).__init__()


        self.config = config

        # this dict contains all tensors needed to be optimized
        self.vars = nn.ParameterList()
        # running_mean and running_var
        self.vars_bn = nn.ParameterList()

        for i, (name, param) in enumerate(self.config):
            if name is 'conv2d':
                # [ch_out, ch_in, kernelsz, kernelsz]
                w = nn.Parameter(torch.ones(*param[:4]))
                # gain=1 according to cbfin's implementation
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name is 'convt2d':
                # [ch_in, ch_out, kernelsz, kernelsz, stride, padding]
                w = nn.Parameter(torch.ones(*param[:4]))
                # gain=1 according to cbfin's implementation
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                # [ch_in, ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[1])))

            elif name is 'linear':
                # [ch_out, ch_in]
                w = nn.Parameter(torch.ones(*param))
                # gain=1 according to cbfinn's implementation
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

            elif name is 'bn':
                # [ch_out]
                w = nn.Parameter(torch.ones(param[0]))
                self.vars.append(w)
                # [ch_out]
                self.vars.append(nn.Parameter(torch.zeros(param[0])))

                # must set requires_grad=False
                running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
                running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
                self.vars_bn.extend([running_mean, running_var])


            elif name in ['tanh', 'relu', 'upsample', 'avg_pool2d', 'max_pool2d',
                          'flatten', 'reshape', 'leakyrelu', 'sigmoid']:
                continue
            else:
                raise NotImplementedError






    def extra_repr(self):
        info = ''

        for name, param in self.config:
            if name is 'conv2d':
                tmp = 'conv2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)'\
#                       %(param[1], param[0], param[2], param[3], param[4], param[5],)
                info += tmp + '\n'

            elif name is 'convt2d':
                tmp = 'convTranspose2d:(ch_in:%d, ch_out:%d, k:%dx%d, stride:%d, padding:%d)'\
#                       %(param[0], param[1], param[2], param[3], param[4], param[5],)
                info += tmp + '\n'

            elif name is 'linear':
                tmp = 'linear:(in:%d, out:%d)'%(param[1], param[0])
                info += tmp + '\n'

            elif name is 'leakyrelu':
                tmp = 'leakyrelu:(slope:%f)'%(param[0])
                info += tmp + '\n'


            elif name is 'avg_pool2d':
                tmp = 'avg_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2])
                info += tmp + '\n'
            elif name is 'max_pool2d':
                tmp = 'max_pool2d:(k:%d, stride:%d, padding:%d)'%(param[0], param[1], param[2])
                info += tmp + '\n'
            elif name in ['flatten', 'tanh', 'relu', 'upsample', 'reshape', 'sigmoid', 'use_logits', 'bn']:
                tmp = name + ':' + str(tuple(param))
                info += tmp + '\n'
            else:
                raise NotImplementedError

        return info



    def forward(self, x, vars=None, bn_training=True):
        """
        This function can be called by finetunning, however, in finetunning, we dont wish to update
        running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
        Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
        but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
        :param x: [b, 1, 28, 28]
        :param vars:
        :param bn_training: set False to not update
        :return: x, loss, likelihood, kld
        """

        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0

        for name, param in self.config:
            if name is 'conv2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'convt2d':
                w, b = vars[idx], vars[idx + 1]
                # remember to keep synchrozied of forward_encoder and forward_decoder!
                x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
                # print(name, param, '\tout:', x.shape)
            elif name is 'linear':
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
                # print('forward:', idx, x.norm().item())
            elif name is 'bn':
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1]
                x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
                idx += 2
                bn_idx += 2

            elif name is 'flatten':
                # print(x.shape)
                x = x.view(x.size(0), -1)
            elif name is 'reshape':
                # [b, 8] => [b, 2, 2, 2]
                x = x.view(x.size(0), *param)
            elif name is 'relu':
                x = F.relu(x, inplace=param[0])
            elif name is 'leakyrelu':
                x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
            elif name is 'tanh':
                x = F.tanh(x)
            elif name is 'sigmoid':
                x = torch.sigmoid(x)
            elif name is 'upsample':
                x = F.upsample_nearest(x, scale_factor=param[0])
            elif name is 'max_pool2d':
                x = F.max_pool2d(x, param[0], param[1], param[2])
            elif name is 'avg_pool2d':
                x = F.avg_pool2d(x, param[0], param[1], param[2])

            else:
                raise NotImplementedError

        # make sure variable is used properly
        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)


        return x


    def zero_grad(self, vars=None):
        """

        :param vars:
        :return:
        """
        with torch.no_grad():
            if vars is None:
                for p in self.vars:
                    if p.grad is not None:
                        p.grad.zero_()
            else:
                for p in vars:
                    if p.grad is not None:
                        p.grad.zero_()

    def parameters(self):
        """
        override this function since initial parameters will return with a generator.
        :return:
        """
        return self.vars

net_config = [
        ('conv2d', [6, 1, 5, 5, 1, 0]),
        ('relu', [True]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [10, 6, 5, 5, 1, 0]),
        ('relu', [True]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [12, 10, 5, 5, 1, 0]),
        ('relu', [True]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [16, 12, 5, 5, 1, 0]),
        ('relu', [True]),
        ('max_pool2d', [2, 2, 0]),
        ('flatten', []),
        ('linear', [10, 64]),
        ('relu', [True]),
        ('linear', [1, 10]),
        ('sigmoid', []),
        ('reshape', [])
    ]

printlines = []

model = Learner(net_config, imgc=1, imgsz=28)
criterion = nn.BCEWithLogitsLoss()
update_lr = 0.01
meta_lr = 0.00001
n_epochs = 15
n_metaepochs = 2

metaoptimizer = optim.Adam(model.parameters(), lr=meta_lr)

for metaepoch in range(n_metaepochs):

    printlines.append('===============================')
    printlines.append(f'//           Meta-Epoch {metaepoch + 1}       //')    
    printlines.append('===============================')
    print('===============================')
    print(f'//           Meta-Epoch {metaepoch + 1}       //')    
    print('===============================')

    for mi, metabatch in enumerate(metatrain_loader, 0):  #  Meta-step
        print(mi)
        printlines.append(f'{mi} updates at Meta-Level')
        print(f'{mi} updates at Meta-Level')

        running_loss = 0.0  #  At each meta-step, the loss is reset

        # No need to store initial weights

        for pi, problem_loaders in enumerate(metabatch, 0):  #  Problem in the meta-batch

            printlines.append(f'- Problem {pi + 1} -')
            print(f'- Problem {pi + 1} -')

            problem_loader = problem_loaders['train']
            problem_loader_val = problem_loaders['val']
            ref_label = None

            new_weights = model.parameters()

            for epoch in range(n_epochs):  #  Epoch in the problem training

                printlines.append(f'Epoch {epoch + 1}')
                print(f'Epoch {epoch + 1}')

                val_loss = 0.0
                val_accuracy = 0.0

                for i, data in enumerate(problem_loader, 0):  #  Step in the problem

                    inputs_raw, labels_raw = data
                    inputs = 1 - inputs_raw
                    outputs = model(inputs, new_weights)
                    if ref_label is None:
                        ref_label = labels_raw[0]   #  On a new problem (1st step) adjust label mapping
                    labels = process_labels(labels_raw, ref_label)

                    new_weights, loss, accuracy = make_step(model, outputs, labels, update_lr, new_weights)

                    #  As the prediction is intrinsically done with the new weights, no need to actually update the model at the Learning Level

                    printlines.append(f'Epoch {epoch + 1}, step {i + 1:5d}], Loss: {loss.item()}, Accuracy: {accuracy}')
                    print(f'Epoch {epoch + 1}, step {i + 1:5d}], Loss: {loss.item()}, Accuracy: {accuracy}')

                for iv, datav in enumerate(problem_loader_val):  #  At the end of the training process in an epoch of a problem we compute a whole validation

                    inputs_rawv, labels_rawv = datav
                    inputsv = 1 - inputs_rawv
                    outputsv = model(inputsv, new_weights)
                    labelsv = process_labels(labels_rawv, ref_label)

                    lossv = criterion(outputsv, labelsv)  #  Loss in a validation batch
                    val_loss += lossv.item()
                    val_accuracy += (((1 - outputsv) < outputsv).float() == labelsv).sum()

                printlines.append(f'Epoch {epoch + 1}, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}')  #  Loss and accuracy averaged for all validation batches in the problem, displayed after whole validation
                print(f'Epoch {epoch + 1}, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}')  #  Loss and accuracy averaged for all validation batches in the problem, displayed after whole validation

            running_loss += lossv  #  After all epochs (all training process) in a single problem the validation loss is added

            # Again, no need to update the model to the initial weights 
        
        metastep_loss = running_loss / metabatch_size  #  The added validation losses of all problems in the metabatch are averaged

        metaoptimizer.zero_grad()  #  We perform gradient descent at the Meta-Level over the averaged validation loss
        metastep_loss.backward()
        metaoptimizer.step()

        if (mi + 1) % 1000 == 0:  #  Meta-validation performed every 1000 meta-steps

            printlines.append('META-VALIDATION STEP:')
            print('META-VALIDATION STEP:')

            for mbvi, metabatch_val in enumerate(metaval_loader):  #  Meta-validation meta-step

                if (mbvi + 1) % 10 == 0:

                    printlines.append(f'Validation step {mbvi + 1}')
                    print(f'Validation step {mbvi + 1}')

                for problem_loaders in metabatch_val:  #  Problem in the meta-validation meta-batch

                    problem_loader = problem_loaders['train']
                    problem_loader_val = problem_loaders['val']
                    ref_label = None
                    new_weights = model.parameters()

                    for epoch in range(n_epochs):  #  Epoch in the problem training

                        val_loss = 0.0
                        val_accuracy = 0.0

                        for i, data in enumerate(problem_loader, 0):  #  Step in the problem
                            
                            inputs_raw, labels_raw = data
                            inputs = 1 - inputs_raw
                            outputs = model(inputs)
                            if ref_label is None:
                                ref_label = labels_raw[0]
                            labels = process_labels(labels_raw, ref_label)

                            new_weights, loss, accuracy = make_step(model, outputs, labels, update_lr, new_weights)

                        #    printlines.append(f'Epoch {epoch + 1}, step {i + 1:5d}], Loss: {loss.item()}, Accuracy: {accuracy}')

                        for iv, datav in enumerate(problem_loader_val):  #  At the end of the training process in an epoch of a problem we compute a whole validation, as in Meta-Train

                            inputs_rawv, labels_rawv = datav
                            inputsv = 1 - inputs_rawv
                            outputsv = model(inputsv)
                            labelsv = process_labels(labels_rawv, ref_label)
                            
                            lossv = criterion(outputsv, labelsv)
                            val_loss += lossv.item()
                            val_accuracy += (((1 - outputsv) < outputsv).float() == labelsv).sum()

                    
                    if (mbvi + 1) % 10 == 0:

                        printlines.append(f'Last epoch, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}')  # The Meta-Validation only runs for informative matters, so our goal is to have this at the end of each problem (every 10 steps)
                        print(f'Last epoch, VALIDATION], Loss: {val_loss / (iv + 1)}, Accuracy: {val_accuracy / (iv + 1)}')  # The Meta-Validation only runs for informative matters, so our goal is to have this at the end of each problem (every 10 steps)

            printlines.append('END OF META-VALIDATION STEP')
            print('END OF META-VALIDATION STEP')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant