You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
set_temperature() will not only change the temperature setting (desired),
but also the parameters (running_mean, running_mean) of every BatchNorm in the network (undesired).
I think the whole network should be frozen except the temperature scaling.
how to fix
set BatchNorm momentum to 0 before set_temperature()
(snippet tested on another codebase, not this one)
def freeze_batch_norm(self):
for module in self.modules():
if isinstance(module, torch.nn.BatchNorm2d):
module.momentum = 0
In case it could be usefull to someone: my script to diff 2 model checkpoints
auto generated with LLMs,
tested on this codebase
import torch
def load_checkpoint(filepath):
"""Load a PyTorch model checkpoint."""
return torch.load(filepath, map_location=torch.device('cpu'))
def normalize_key(key):
"""Normalize dictionary keys by removing specific prefixes."""
prefix = "model."
if key.startswith(prefix):
return key[len(prefix):]
return key
def compare_tensors(tensor1, tensor2, prefix, changed_layers):
"""Helper function to compare two tensors."""
if not torch.equal(tensor1, tensor2):
changed_layers.append(prefix)
def compare_dicts(dict1, dict2, prefix, changed_layers):
"""Recursively compare dictionaries that may contain tensors."""
# Normalize the keys in both dictionaries
normalized_dict1 = {normalize_key(k): v for k, v in dict1.items()}
normalized_dict2 = {normalize_key(k): v for k, v in dict2.items()}
for key in normalized_dict1.keys():
if key in normalized_dict2:
if isinstance(normalized_dict1[key], torch.Tensor) and isinstance(normalized_dict2[key], torch.Tensor):
compare_tensors(normalized_dict1[key], normalized_dict2[key], f"{prefix}.{key}" if prefix else key, changed_layers)
elif isinstance(normalized_dict1[key], dict) and isinstance(normalized_dict2[key], dict):
compare_dicts(normalized_dict1[key], normalized_dict2[key], f"{prefix}.{key}" if prefix else key, changed_layers)
else:
print(f"Type mismatch at {prefix}.{key}")
else:
print(f"Key {key} found in the first model but not in the second at {prefix}")
for key in normalized_dict2.keys():
if key not in normalized_dict1:
print(f"Key {key} found in the second model but not in the first at {prefix}")
def compare_models(checkpoint1, checkpoint2):
"""Compare two model checkpoints."""
model1 = load_checkpoint(checkpoint1)
model2 = load_checkpoint(checkpoint2)
changed_layers = []
compare_dicts(model1, model2, "", changed_layers)
return changed_layers
# Paths to your model checkpoints
checkpoint_path1 = '/home/<yourpath>/model_with_temperature.pth'
checkpoint_path2 = '/home/<yourpath>/model_without_temp.pth'
# Compare the models
changed_layers = compare_models(checkpoint_path1, checkpoint_path2)
if changed_layers:
print("Changed layers:")
for layer in changed_layers:
print(layer)
else:
print("No changes found between the models.")
summary
set_temperature()
will not only change the temperature setting (desired),but also the parameters (running_mean, running_mean) of every BatchNorm in the network (undesired).
I think the whole network should be frozen except the temperature scaling.
how to fix
set BatchNorm momentum to 0 before
set_temperature()
(snippet tested on another codebase, not this one)
In case it could be usefull to someone: my script to diff 2 model checkpoints
auto generated with LLMs,
tested on this codebase
output of above script
The text was updated successfully, but these errors were encountered: