Skip to content

Commit

Permalink
post_process uses full correlation, fix filter_atoms (#284)
Browse files Browse the repository at this point in the history
* post_process uses full correlation mode

Fix filter_atoms

* Test vaf value and writing

---------

Co-authored-by: Jacob Wilkins <[email protected]>
  • Loading branch information
harveydevereux and oerc0122 authored Aug 23, 2024
1 parent 9a50c37 commit 8175ce7
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 9 deletions.
14 changes: 7 additions & 7 deletions janus_core/helpers/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def compute_vaf(
use_velocities: bool = False,
fft: bool = False,
index: SliceLike = (0, None, 1),
filter_atoms: MaybeSequence[MaybeSequence[int]] = ((),),
filter_atoms: MaybeSequence[MaybeSequence[Optional[int]]] = ((None),),
) -> NDArray[float64]:
"""
Compute the velocity autocorrelation function (VAF) of `data`.
Expand All @@ -209,7 +209,7 @@ def compute_vaf(
index : SliceLike
Images to analyze as `start`, `stop`, `step`.
Default is all images.
filter_atoms : MaybeSequence[MaybeSequence[int]]
filter_atoms : MaybeSequence[MaybeSequence[Optional[int]]]
Compute the VAF averaged over subsets of the system.
Default is all atoms.
Expand Down Expand Up @@ -238,9 +238,7 @@ def compute_vaf(
data = data[slice(*index)]

if use_velocities:
momenta = np.asarray(
[datum.get_momenta() / datum.get_masses() for datum in data]
)
momenta = np.asarray([datum.get_velocities() for datum in data])
else:
momenta = np.asarray([datum.get_momenta() for datum in data])

Expand All @@ -249,7 +247,7 @@ def compute_vaf(

# If filter_atoms not specified use all atoms
filter_atoms = [
atom if atom and atom[0] else range(n_atoms) for atom in filter_atoms
atom if atom[0] is not None else range(n_atoms) for atom in filter_atoms
]

used_atoms = {atom for atoms in filter_atoms for atom in atoms}
Expand All @@ -259,7 +257,9 @@ def compute_vaf(
np.asarray(
[
[
np.correlate(momenta[:, j, i], momenta[:, j, i], "same")
np.correlate(momenta[:, j, i], momenta[:, j, i], "full")[
n_steps - 1 :
]
for i in range(3)
]
for j in used_atoms
Expand Down
11 changes: 11 additions & 0 deletions tests/data/vaf-lj-1-2-3.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
1.0066870750448105
0.6418190953234829
0.3563911120413356
0.3051070730422549
0.23009251840934272
0.26906890724754023
0.24509942703971466
0.15062229045327677
-0.23061920240560663
-0.42160528271693426
-0.20524096931167723
11 changes: 11 additions & 0 deletions tests/data/vaf-lj-3-4.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
1.1263716900873755
0.8820158202526532
0.784347065477728
0.6806780688236092
0.5265986467643831
0.5295303067488114
0.5150663396964206
0.5188617483216265
0.3239983242603667
0.19859295298597962
0.058499204536948346
11 changes: 11 additions & 0 deletions tests/data/vaf-lj.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
0.940392953490707
0.581286162528787
0.3474623957116591
0.19574532179574347
0.08404080467605056
0.02723301143121412
0.03320823810966539
-0.0004815372548760061
-0.05545244283596119
-0.06580258912853633
-0.09743930723331182
19 changes: 17 additions & 2 deletions tests/test_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from ase.io import read
import numpy as np
from pytest import approx
from typer.testing import CliRunner

from janus_core.calculations.md import NVE
Expand Down Expand Up @@ -172,23 +173,37 @@ def test_rdf_by_elements():
assert (np.isclose(expected_peaks[element], rdf[0][peaks])).all()


def test_vaf():
def test_vaf(tmp_path):
"""Test vaf will run."""
vaf_names = ("vaf-lj-3-4.dat", "vaf-lj-1-2-3.dat")
vaf_filter = ((3, 4), (1, 2, 3))

data = read(DATA_PATH / "lj-traj.xyz", index=":")
vaf = post_process.compute_vaf(data)
expected = np.loadtxt(DATA_PATH / "vaf-lj.dat")

assert isinstance(vaf, list)
assert len(vaf) == 1
assert isinstance(vaf[0], np.ndarray)
assert vaf[0] == approx(expected, rel=1e-9)

vaf = post_process.compute_vaf(data, fft=True)

assert isinstance(vaf, list)
assert len(vaf) == 1
assert isinstance(vaf[0], np.ndarray)

vaf = post_process.compute_vaf(data, filter_atoms=((3, 4), (1, 2, 3)))
vaf = post_process.compute_vaf(
data, filter_atoms=vaf_filter, filenames=[tmp_path / name for name in vaf_names]
)

assert isinstance(vaf, list)
assert len(vaf) == 2
assert isinstance(vaf[0], np.ndarray)

for i, name in enumerate(vaf_names):
assert (tmp_path / name).exists()
expected = np.loadtxt(DATA_PATH / name)
written = np.loadtxt(tmp_path / name)
assert vaf[i] == approx(expected, rel=1e-9)
assert vaf[i] == approx(written, rel=1e-9)

0 comments on commit 8175ce7

Please sign in to comment.