diff --git a/ocpmodels/modules/scaling/fit.py b/ocpmodels/modules/scaling/fit.py index 95c16f136..b3e1d4124 100644 --- a/ocpmodels/modules/scaling/fit.py +++ b/ocpmodels/modules/scaling/fit.py @@ -10,7 +10,6 @@ import torch.nn as nn from torch.nn.parallel.distributed import DistributedDataParallel -from ocpmodels.common.data_parallel import OCPDataParallel from ocpmodels.common.flags import flags from ocpmodels.common.utils import ( build_config, @@ -78,7 +77,7 @@ def main(*, num_batches: int = 16) -> None: # unwrap module from DP/DDP unwrapped_model = model while isinstance( - unwrapped_model, (DistributedDataParallel, OCPDataParallel) + unwrapped_model, DistributedDataParallel ): unwrapped_model = unwrapped_model.module assert isinstance(