diff --git a/mmagic/models/archs/wrapper.py b/mmagic/models/archs/wrapper.py index 5b95a9d649..b985ac231d 100644 --- a/mmagic/models/archs/wrapper.py +++ b/mmagic/models/archs/wrapper.py @@ -183,5 +183,8 @@ def to( torch_device: Optional[Union[str, torch.device]] = None, torch_dtype: Optional[torch.dtype] = None, ): - self.model.to(torch_device, torch_dtype) + if torch_dtype is None: + self.model.to(torch_device) + else: + self.model.to(torch_device, torch_dtype) return self