diff --git a/janus_core/helpers/post_process.py b/janus_core/helpers/post_process.py index ebcca5d2..a750ad32 100644 --- a/janus_core/helpers/post_process.py +++ b/janus_core/helpers/post_process.py @@ -189,7 +189,7 @@ def compute_vaf( use_velocities: bool = False, fft: bool = False, index: SliceLike = (0, None, 1), - filter_atoms: MaybeSequence[MaybeSequence[int]] = ((),), + filter_atoms: MaybeSequence[MaybeSequence[Optional[int]]] = ((None),), ) -> NDArray[float64]: """ Compute the velocity autocorrelation function (VAF) of `data`. @@ -209,7 +209,7 @@ def compute_vaf( index : SliceLike Images to analyze as `start`, `stop`, `step`. Default is all images. - filter_atoms : MaybeSequence[MaybeSequence[int]] + filter_atoms : MaybeSequence[MaybeSequence[Optional[int]]] Compute the VAF averaged over subsets of the system. Default is all atoms. @@ -238,9 +238,7 @@ def compute_vaf( data = data[slice(*index)] if use_velocities: - momenta = np.asarray( - [datum.get_momenta() / datum.get_masses() for datum in data] - ) + momenta = np.asarray([datum.get_velocities() for datum in data]) else: momenta = np.asarray([datum.get_momenta() for datum in data]) @@ -249,7 +247,7 @@ def compute_vaf( # If filter_atoms not specified use all atoms filter_atoms = [ - atom if atom and atom[0] else range(n_atoms) for atom in filter_atoms + atom if atom[0] is not None else range(n_atoms) for atom in filter_atoms ] used_atoms = {atom for atoms in filter_atoms for atom in atoms} @@ -259,7 +257,9 @@ def compute_vaf( np.asarray( [ [ - np.correlate(momenta[:, j, i], momenta[:, j, i], "same") + np.correlate(momenta[:, j, i], momenta[:, j, i], "full")[ + n_steps - 1 : + ] for i in range(3) ] for j in used_atoms diff --git a/tests/data/vaf-lj-1-2-3.dat b/tests/data/vaf-lj-1-2-3.dat new file mode 100644 index 00000000..81aaba52 --- /dev/null +++ b/tests/data/vaf-lj-1-2-3.dat @@ -0,0 +1,11 @@ +1.0066870750448105 +0.6418190953234829 +0.3563911120413356 +0.3051070730422549 +0.23009251840934272 +0.26906890724754023 +0.24509942703971466 +0.15062229045327677 +-0.23061920240560663 +-0.42160528271693426 +-0.20524096931167723 diff --git a/tests/data/vaf-lj-3-4.dat b/tests/data/vaf-lj-3-4.dat new file mode 100644 index 00000000..03d62c99 --- /dev/null +++ b/tests/data/vaf-lj-3-4.dat @@ -0,0 +1,11 @@ +1.1263716900873755 +0.8820158202526532 +0.784347065477728 +0.6806780688236092 +0.5265986467643831 +0.5295303067488114 +0.5150663396964206 +0.5188617483216265 +0.3239983242603667 +0.19859295298597962 +0.058499204536948346 diff --git a/tests/data/vaf-lj.dat b/tests/data/vaf-lj.dat new file mode 100644 index 00000000..154db564 --- /dev/null +++ b/tests/data/vaf-lj.dat @@ -0,0 +1,11 @@ +0.940392953490707 +0.581286162528787 +0.3474623957116591 +0.19574532179574347 +0.08404080467605056 +0.02723301143121412 +0.03320823810966539 +-0.0004815372548760061 +-0.05545244283596119 +-0.06580258912853633 +-0.09743930723331182 diff --git a/tests/test_post_process.py b/tests/test_post_process.py index bba91146..f394d448 100644 --- a/tests/test_post_process.py +++ b/tests/test_post_process.py @@ -4,6 +4,7 @@ from ase.io import read import numpy as np +from pytest import approx from typer.testing import CliRunner from janus_core.calculations.md import NVE @@ -172,14 +173,19 @@ def test_rdf_by_elements(): assert (np.isclose(expected_peaks[element], rdf[0][peaks])).all() -def test_vaf(): +def test_vaf(tmp_path): """Test vaf will run.""" + vaf_names = ("vaf-lj-3-4.dat", "vaf-lj-1-2-3.dat") + vaf_filter = ((3, 4), (1, 2, 3)) + data = read(DATA_PATH / "lj-traj.xyz", index=":") vaf = post_process.compute_vaf(data) + expected = np.loadtxt(DATA_PATH / "vaf-lj.dat") assert isinstance(vaf, list) assert len(vaf) == 1 assert isinstance(vaf[0], np.ndarray) + assert vaf[0] == approx(expected, rel=1e-9) vaf = post_process.compute_vaf(data, fft=True) @@ -187,8 +193,17 @@ def test_vaf(): assert len(vaf) == 1 assert isinstance(vaf[0], np.ndarray) - vaf = post_process.compute_vaf(data, filter_atoms=((3, 4), (1, 2, 3))) + vaf = post_process.compute_vaf( + data, filter_atoms=vaf_filter, filenames=[tmp_path / name for name in vaf_names] + ) assert isinstance(vaf, list) assert len(vaf) == 2 assert isinstance(vaf[0], np.ndarray) + + for i, name in enumerate(vaf_names): + assert (tmp_path / name).exists() + expected = np.loadtxt(DATA_PATH / name) + written = np.loadtxt(tmp_path / name) + assert vaf[i] == approx(expected, rel=1e-9) + assert vaf[i] == approx(written, rel=1e-9)