diff --git a/tools/test.py b/tools/test.py index f09e298..a36971e 100644 --- a/tools/test.py +++ b/tools/test.py @@ -140,8 +140,9 @@ def main(): # set cudnn_benchmark if cfg.get("cudnn_benchmark", False): torch.backends.cudnn.benchmark = True - - cfg.model.pretrained = None + # fix issue mentioned in https://github.com/microsoft/SoftTeacher/issues/111 + if "pretrained" in cfg.model: + cfg.model.pretrained = None if cfg.model.get("neck"): if isinstance(cfg.model.neck, list): for neck_cfg in cfg.model.neck: