Skip to content

Commit

Permalink
Fix tests and respond to comments
Browse files Browse the repository at this point in the history
  • Loading branch information
oerc0122 committed Nov 4, 2024
1 parent fe7ad20 commit ce710bb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
26 changes: 24 additions & 2 deletions janus_core/processing/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def compute_vaf(
index: SliceLike = (0, None, 1),
filter_atoms: MaybeSequence[MaybeSequence[Optional[int]]] = ((None),),
time_step: float = 1.0,
) -> NDArray[float64]:
) -> tuple[NDArray[float64], list[NDArray[float64]]]:
"""
Compute the velocity autocorrelation function (VAF) of `data`.
Expand Down Expand Up @@ -219,8 +219,30 @@ def compute_vaf(
Returns
-------
MaybeSequence[NDArray[float64]]
lags : numpy.ndarray
Lags at which the VAFs have been computed.
vafs : list[numpy.ndarray]
Computed VAF(s).
Notes
-----
`filter_atoms` is given as a series of sequences of atoms, where
each element in the series denotes a VAF subset to calculate and
each sequence determines the atoms (by index) to be included in that VAF.
E.g.
.. code-block: Python
# Species indices in cell
na = (1, 3, 5, 7)
cl = (2, 4, 6, 8)
compute_vaf(..., filter_atoms=(na, cl))
Would compute separate VAFs for each species.
By default, one VAF will be computed for all atoms in the structure.
"""
# Ensure if passed scalars they are turned into correct dimensionality
if not isinstance(filter_atoms, Sequence):
Expand Down
11 changes: 7 additions & 4 deletions tests/test_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,21 +179,21 @@ def test_vaf(tmp_path):
vaf_filter = ((3, 4), (1, 2, 3))

data = read(DATA_PATH / "lj-traj.xyz", index=":")
vaf = post_process.compute_vaf(data)
lags, vaf = post_process.compute_vaf(data)
expected = np.loadtxt(DATA_PATH / "vaf-lj.dat")

assert isinstance(vaf, list)
assert len(vaf) == 1
assert isinstance(vaf[0], np.ndarray)
assert vaf[0] == approx(expected, rel=1e-9)

vaf = post_process.compute_vaf(data, fft=True)
lags, vaf = post_process.compute_vaf(data, fft=True)

assert isinstance(vaf, list)
assert len(vaf) == 1
assert isinstance(vaf[0], np.ndarray)

vaf = post_process.compute_vaf(
lags, vaf = post_process.compute_vaf(
data, filter_atoms=vaf_filter, filenames=[tmp_path / name for name in vaf_names]
)

Expand All @@ -205,5 +205,8 @@ def test_vaf(tmp_path):
assert (tmp_path / name).exists()
expected = np.loadtxt(DATA_PATH / name)
written = np.loadtxt(tmp_path / name)
w_lag, w_vaf = written[:, 0], written[:, 1]

assert vaf[i] == approx(expected, rel=1e-9)
assert vaf[i] == approx(written, rel=1e-9)
assert lags == approx(w_lag, rel=1e-9)
assert vaf[i] == approx(w_vaf, rel=1e-9)

0 comments on commit ce710bb

Please sign in to comment.