From bbb325216a254abc8193333ca3a2a6b07c37c19e Mon Sep 17 00:00:00 2001 From: tn1031 Date: Fri, 19 Nov 2021 10:13:56 +0900 Subject: [PATCH] fix --- benchmarks/set_matching_pytorch/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benchmarks/set_matching_pytorch/model.py b/benchmarks/set_matching_pytorch/model.py index 0982e2a..6b28e84 100644 --- a/benchmarks/set_matching_pytorch/model.py +++ b/benchmarks/set_matching_pytorch/model.py @@ -69,6 +69,8 @@ def __init__( if pretrained_weight: with open(pretrained_weight, "rb") as f: self.weight_estimator.load_state_dict(torch.load(f)) + for param in self.weight_estimator.parameters(): + param.requires_grad = False def importance_logit(self, prob): # prob = p(train|x) = 1 - p(test|x)