You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tested your code (temperature scaling) in my code.
my code is based on inception_v3.
But I met some error.
# Code# This function probably should live outside of this class, but whatever#def set_temperature(self, valid_loader):defset_temperature(self, valid_loader):
""" Tune the tempearature of the model (using the validation set). We're going to set it to optimize NLL. valid_loader (DataLoader): validation set loader """self.cuda()
nll_criterion=nn.CrossEntropyLoss().cuda()
ece_criterion=_ECELoss().cuda()
# First: collect all the logits and labels for the validation setlogits_list= []
labels_list= []
withtorch.no_grad():
forinput, labelinvalid_loader:
input=input.cuda()
logits=self.model(input)
logits_list.append(logits)
labels_list.append(label)
logits=torch.cat(logits_list).cuda()
labels=torch.cat(labels_list).cuda()
# Calculate NLL and ECE before temperature scalingbefore_temperature_nll=nll_criterion(logits, labels).item()
before_temperature_ece=ece_criterion(logits, labels).item()
print('Before temperature - NLL: %.3f, ECE: %.3f'% (before_temperature_nll, before_temperature_ece))
# Next: optimize the temperature w.r.t. NLLoptimizer=optim.LBFGS([self.temperature], lr=0.01, max_iter=50)
...
# Code'''RuntimeError: Calculated padded input size per channel: (4 x 4). Kernel size: (5 x 5). Kernel size can't be greater than actual input size'''
I checked error code in my code.
The part of error code is
logits = self.model(input)
So I think that this code is for DenseNet only. ??
Thanks,
The text was updated successfully, but these errors were encountered:
Hello Sir.
I tested your code (temperature scaling) in my code.
my code is based on inception_v3.
But I met some error.
I checked error code in my code.
The part of error code is
logits = self.model(input)
So I think that this code is for DenseNet only. ??
Thanks,
The text was updated successfully, but these errors were encountered: