Skip to content

Commit

Permalink
Merge pull request #183 from st-tech/bugfix/fix_weight_estimator
Browse files Browse the repository at this point in the history
Stop updating weights while training the set matching model.
  • Loading branch information
nocotan authored Nov 24, 2021
2 parents f439653 + bbb3252 commit 7514980
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions benchmarks/set_matching_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(
if pretrained_weight:
with open(pretrained_weight, "rb") as f:
self.weight_estimator.load_state_dict(torch.load(f))
for param in self.weight_estimator.parameters():
param.requires_grad = False

def importance_logit(self, prob):
# prob = p(train|x) = 1 - p(test|x)
Expand Down

0 comments on commit 7514980

Please sign in to comment.