diff --git a/model.py b/model.py index 8aba484..41f0c80 100644 --- a/model.py +++ b/model.py @@ -57,7 +57,7 @@ def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, linea classifier = nn.Sequential(*classifier) classifier.apply(weights_init_classifier) - classifier.linear_num = linear + self.linear_num = linear self.add_block = add_block self.classifier = classifier def forward(self, x):