Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batchnorm params are modified, they should not ! #36

Open
joihn opened this issue Jun 12, 2024 · 0 comments
Open

Batchnorm params are modified, they should not ! #36

joihn opened this issue Jun 12, 2024 · 0 comments

Comments

@joihn
Copy link

joihn commented Jun 12, 2024

summary

set_temperature() will not only change the temperature setting (desired),
but also the parameters (running_mean, running_mean) of every BatchNorm in the network (undesired).
I think the whole network should be frozen except the temperature scaling.

how to fix

set BatchNorm momentum to 0 before set_temperature()
(snippet tested on another codebase, not this one)

    def freeze_batch_norm(self):
        for module in self.modules():
            if isinstance(module, torch.nn.BatchNorm2d):
                module.momentum = 0

In case it could be usefull to someone: my script to diff 2 model checkpoints

auto generated with LLMs,
tested on this codebase

import torch

def load_checkpoint(filepath):
    """Load a PyTorch model checkpoint."""
    return torch.load(filepath, map_location=torch.device('cpu'))

def normalize_key(key):
    """Normalize dictionary keys by removing specific prefixes."""
    prefix = "model."
    if key.startswith(prefix):
        return key[len(prefix):]
    return key

def compare_tensors(tensor1, tensor2, prefix, changed_layers):
    """Helper function to compare two tensors."""
    if not torch.equal(tensor1, tensor2):
        changed_layers.append(prefix)

def compare_dicts(dict1, dict2, prefix, changed_layers):
    """Recursively compare dictionaries that may contain tensors."""
    # Normalize the keys in both dictionaries
    normalized_dict1 = {normalize_key(k): v for k, v in dict1.items()}
    normalized_dict2 = {normalize_key(k): v for k, v in dict2.items()}

    for key in normalized_dict1.keys():
        if key in normalized_dict2:
            if isinstance(normalized_dict1[key], torch.Tensor) and isinstance(normalized_dict2[key], torch.Tensor):
                compare_tensors(normalized_dict1[key], normalized_dict2[key], f"{prefix}.{key}" if prefix else key, changed_layers)
            elif isinstance(normalized_dict1[key], dict) and isinstance(normalized_dict2[key], dict):
                compare_dicts(normalized_dict1[key], normalized_dict2[key], f"{prefix}.{key}" if prefix else key, changed_layers)
            else:
                print(f"Type mismatch at {prefix}.{key}")
        else:
            print(f"Key {key} found in the first model but not in the second at {prefix}")

    for key in normalized_dict2.keys():
        if key not in normalized_dict1:
            print(f"Key {key} found in the second model but not in the first at {prefix}")

def compare_models(checkpoint1, checkpoint2):
    """Compare two model checkpoints."""
    model1 = load_checkpoint(checkpoint1)
    model2 = load_checkpoint(checkpoint2)

    changed_layers = []
    compare_dicts(model1, model2, "", changed_layers)

    return changed_layers

# Paths to your model checkpoints
checkpoint_path1 = '/home/<yourpath>/model_with_temperature.pth'
checkpoint_path2 = '/home/<yourpath>/model_without_temp.pth'

# Compare the models
changed_layers = compare_models(checkpoint_path1, checkpoint_path2)
if changed_layers:
    print("Changed layers:")
    for layer in changed_layers:
        print(layer)
else:
    print("No changes found between the models.")

output of above script

Key temperature found in the first model but not in the second at 
Changed layers:
features.denseblock1.denselayer1.norm1.running_mean
features.denseblock1.denselayer1.norm1.running_var
features.denseblock1.denselayer1.norm1.num_batches_tracked
features.denseblock1.denselayer1.norm2.running_mean
features.denseblock1.denselayer1.norm2.running_var
features.denseblock1.denselayer1.norm2.num_batches_tracked
features.denseblock1.denselayer2.norm1.running_mean
features.denseblock1.denselayer2.norm1.running_var
features.denseblock1.denselayer2.norm1.num_batches_tracked
features.denseblock1.denselayer2.norm2.running_mean
features.denseblock1.denselayer2.norm2.running_var
features.denseblock1.denselayer2.norm2.num_batches_tracked
features.denseblock1.denselayer3.norm1.running_mean
features.denseblock1.denselayer3.norm1.running_var
features.denseblock1.denselayer3.norm1.num_batches_tracked
features.denseblock1.denselayer3.norm2.running_mean
features.denseblock1.denselayer3.norm2.running_var
features.denseblock1.denselayer3.norm2.num_batches_tracked
features.denseblock1.denselayer4.norm1.running_mean
features.denseblock1.denselayer4.norm1.running_var
features.denseblock1.denselayer4.norm1.num_batches_tracked
features.denseblock1.denselayer4.norm2.running_mean
features.denseblock1.denselayer4.norm2.running_var
features.denseblock1.denselayer4.norm2.num_batches_tracked
features.denseblock1.denselayer5.norm1.running_mean
features.denseblock1.denselayer5.norm1.running_var
features.denseblock1.denselayer5.norm1.num_batches_tracked
features.denseblock1.denselayer5.norm2.running_mean
features.denseblock1.denselayer5.norm2.running_var
features.denseblock1.denselayer5.norm2.num_batches_tracked
features.denseblock1.denselayer6.norm1.running_mean
features.denseblock1.denselayer6.norm1.running_var
features.denseblock1.denselayer6.norm1.num_batches_tracked
features.denseblock1.denselayer6.norm2.running_mean
features.denseblock1.denselayer6.norm2.running_var
features.denseblock1.denselayer6.norm2.num_batches_tracked
features.transition1.norm.running_mean
features.transition1.norm.running_var
features.transition1.norm.num_batches_tracked
features.denseblock2.denselayer1.norm1.running_mean
features.denseblock2.denselayer1.norm1.running_var
features.denseblock2.denselayer1.norm1.num_batches_tracked
features.denseblock2.denselayer1.norm2.running_mean
features.denseblock2.denselayer1.norm2.running_var
features.denseblock2.denselayer1.norm2.num_batches_tracked
features.denseblock2.denselayer2.norm1.running_mean
features.denseblock2.denselayer2.norm1.running_var
features.denseblock2.denselayer2.norm1.num_batches_tracked
features.denseblock2.denselayer2.norm2.running_mean
features.denseblock2.denselayer2.norm2.running_var
features.denseblock2.denselayer2.norm2.num_batches_tracked
features.denseblock2.denselayer3.norm1.running_mean
features.denseblock2.denselayer3.norm1.running_var
features.denseblock2.denselayer3.norm1.num_batches_tracked
features.denseblock2.denselayer3.norm2.running_mean
features.denseblock2.denselayer3.norm2.running_var
features.denseblock2.denselayer3.norm2.num_batches_tracked
features.denseblock2.denselayer4.norm1.running_mean
features.denseblock2.denselayer4.norm1.running_var
features.denseblock2.denselayer4.norm1.num_batches_tracked
features.denseblock2.denselayer4.norm2.running_mean
features.denseblock2.denselayer4.norm2.running_var
features.denseblock2.denselayer4.norm2.num_batches_tracked
features.denseblock2.denselayer5.norm1.running_mean
features.denseblock2.denselayer5.norm1.running_var
features.denseblock2.denselayer5.norm1.num_batches_tracked
features.denseblock2.denselayer5.norm2.running_mean
features.denseblock2.denselayer5.norm2.running_var
features.denseblock2.denselayer5.norm2.num_batches_tracked
features.denseblock2.denselayer6.norm1.running_mean
features.denseblock2.denselayer6.norm1.running_var
features.denseblock2.denselayer6.norm1.num_batches_tracked
features.denseblock2.denselayer6.norm2.running_mean
features.denseblock2.denselayer6.norm2.running_var
features.denseblock2.denselayer6.norm2.num_batches_tracked
features.transition2.norm.running_mean
features.transition2.norm.running_var
features.transition2.norm.num_batches_tracked
features.denseblock3.denselayer1.norm1.running_mean
features.denseblock3.denselayer1.norm1.running_var
features.denseblock3.denselayer1.norm1.num_batches_tracked
features.denseblock3.denselayer1.norm2.running_mean
features.denseblock3.denselayer1.norm2.running_var
features.denseblock3.denselayer1.norm2.num_batches_tracked
features.denseblock3.denselayer2.norm1.running_mean
features.denseblock3.denselayer2.norm1.running_var
features.denseblock3.denselayer2.norm1.num_batches_tracked
features.denseblock3.denselayer2.norm2.running_mean
features.denseblock3.denselayer2.norm2.running_var
features.denseblock3.denselayer2.norm2.num_batches_tracked
features.denseblock3.denselayer3.norm1.running_mean
features.denseblock3.denselayer3.norm1.running_var
features.denseblock3.denselayer3.norm1.num_batches_tracked
features.denseblock3.denselayer3.norm2.running_mean
features.denseblock3.denselayer3.norm2.running_var
features.denseblock3.denselayer3.norm2.num_batches_tracked
features.denseblock3.denselayer4.norm1.running_mean
features.denseblock3.denselayer4.norm1.running_var
features.denseblock3.denselayer4.norm1.num_batches_tracked
features.denseblock3.denselayer4.norm2.running_mean
features.denseblock3.denselayer4.norm2.running_var
features.denseblock3.denselayer4.norm2.num_batches_tracked
features.denseblock3.denselayer5.norm1.running_mean
features.denseblock3.denselayer5.norm1.running_var
features.denseblock3.denselayer5.norm1.num_batches_tracked
features.denseblock3.denselayer5.norm2.running_mean
features.denseblock3.denselayer5.norm2.running_var
features.denseblock3.denselayer5.norm2.num_batches_tracked
features.denseblock3.denselayer6.norm1.running_mean
features.denseblock3.denselayer6.norm1.running_var
features.denseblock3.denselayer6.norm1.num_batches_tracked
features.denseblock3.denselayer6.norm2.running_mean
features.denseblock3.denselayer6.norm2.running_var
features.denseblock3.denselayer6.norm2.num_batches_tracked
features.norm_final.running_mean
features.norm_final.running_var
features.norm_final.num_batches_tracked
@joihn joihn changed the title Batchnorm are getting modified, they should not ! Batchnorm params are modified, they should not ! Jun 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant