From 4a510062155e1da675ddc057b91fd3c714d28b3e Mon Sep 17 00:00:00 2001 From: lbluque Date: Wed, 21 Aug 2024 18:16:01 -0700 Subject: [PATCH] write initial and final structure if save_full is false --- .../core/common/relaxation/optimizers/lbfgs_torch.py | 5 ++++- tests/core/common/test_lbfgs_torch.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py b/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py index 47764c513..f9ea84ce4 100644 --- a/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py +++ b/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py @@ -92,6 +92,7 @@ def run(self, fmax, steps): iteration = 0 max_forces = self.optimizable.get_max_forces(apply_constraint=True) logging.info("Step Fmax(eV/A)") + while iteration < steps and not self.optimizable.converged( forces=None, fmax=self.fmax, max_forces=max_forces ): @@ -99,7 +100,9 @@ def run(self, fmax, steps): f"{iteration} " + " ".join(f"{x:0.3f}" for x in max_forces.tolist()) ) - if self.trajectories is not None and self.save_full is True: + if self.trajectories is not None and ( + self.save_full is True or iteration == 0 + ): self.write() self.step(iteration) diff --git a/tests/core/common/test_lbfgs_torch.py b/tests/core/common/test_lbfgs_torch.py index 67810cd00..1ca10d145 100644 --- a/tests/core/common/test_lbfgs_torch.py +++ b/tests/core/common/test_lbfgs_torch.py @@ -53,7 +53,9 @@ def test_lbfgs_write_trajectory(save_full_traj, steps, batch, calculator, tmp_pa traj_files = list(tmp_path.glob("*.traj")) assert len(traj_files) == len(batch) - traj_length = 0 if steps == 0 else steps + 1 if save_full_traj else 1 + traj_length = ( + 0 if steps == 0 else steps + 1 if save_full_traj else 2 + ) # first and final frame for file in traj_files: traj = read(file, ":") assert len(traj) == traj_length