From fa39a8f3a4a24c8ad3b993f3028c2ad51700ec6a Mon Sep 17 00:00:00 2001 From: Adeesh Kolluru <43401571+AdeeshKolluru@users.noreply.github.com> Date: Fri, 22 Mar 2024 13:12:44 -0400 Subject: [PATCH] run_relaxation minor fixes (#636) --- ocpmodels/common/utils.py | 2 +- ocpmodels/trainers/ocp_trainer.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index 53b497e32..8c47b0ab7 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -969,7 +969,7 @@ def check_traj_files(batch, traj_dir) -> bool: if traj_dir is None: return False traj_dir = Path(traj_dir) - traj_files = [traj_dir / f"{id}.traj" for id in batch[0].sid.tolist()] + traj_files = [traj_dir / f"{id}.traj" for id in batch.sid.tolist()] return all(fl.exists() for fl in traj_files) diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index 1ef82baf5..86bab1a9d 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -316,6 +316,7 @@ def _compute_loss(self, out, batch): target = batch[target_name] pred = out[target_name] + natoms = batch.natoms natoms = torch.repeat_interleave(natoms, natoms) @@ -580,7 +581,7 @@ def run_relaxations(self, split="val"): if check_traj_files( batch, self.config["task"]["relax_opt"].get("traj_dir", None) ): - logging.info(f"Skipping batch: {batch[0].sid.tolist()}") + logging.info(f"Skipping batch: {batch.sid.tolist()}") continue relaxed_batch = ml_relax(