Skip to content

Commit

Permalink
Add lags to vafs (#318)
Browse files Browse the repository at this point in the history
* Add lags to vafs

---------

Co-authored-by: ElliottKasoar <[email protected]>
  • Loading branch information
oerc0122 and ElliottKasoar authored Nov 4, 2024
1 parent 15a6b49 commit 8678f94
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 13 deletions.
51 changes: 42 additions & 9 deletions janus_core/processing/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ def compute_vaf(
fft: bool = False,
index: SliceLike = (0, None, 1),
filter_atoms: MaybeSequence[MaybeSequence[Optional[int]]] = ((None),),
) -> NDArray[float64]:
time_step: float = 1.0,
) -> tuple[NDArray[float64], list[NDArray[float64]]]:
"""
Compute the velocity autocorrelation function (VAF) of `data`.
Expand All @@ -212,11 +213,36 @@ def compute_vaf(
filter_atoms : MaybeSequence[MaybeSequence[Optional[int]]]
Compute the VAF averaged over subsets of the system.
Default is all atoms.
time_step : float
Time step for scaling lags to align with input data.
Default is 1 (i.e. no scaling).
Returns
-------
MaybeSequence[NDArray[float64]]
lags : numpy.ndarray
Lags at which the VAFs have been computed.
vafs : list[numpy.ndarray]
Computed VAF(s).
Notes
-----
`filter_atoms` is given as a series of sequences of atoms, where
each element in the series denotes a VAF subset to calculate and
each sequence determines the atoms (by index) to be included in that VAF.
E.g.
.. code-block: Python
# Species indices in cell
na = (1, 3, 5, 7)
cl = (2, 4, 6, 8)
compute_vaf(..., filter_atoms=(na, cl))
Would compute separate VAFs for each species.
By default, one VAF will be computed for all atoms in the structure.
"""
# Ensure if passed scalars they are turned into correct dimensionality
if not isinstance(filter_atoms, Sequence):
Expand Down Expand Up @@ -270,17 +296,24 @@ def compute_vaf(

vafs /= n_steps - np.arange(n_steps)

lags = np.arange(n_steps) * time_step

if fft:
vafs = np.fft.fft(vafs, axis=0)

vafs = [
np.average([vafs[used_atoms[i]] for i in atoms], axis=0)
for atoms in filter_atoms
]
lags = np.fft.fftfreq(n_steps, time_step)

vafs = (
lags,
[
np.average([vafs[used_atoms[i]] for i in atoms], axis=0)
for atoms in filter_atoms
],
)

if filenames:
for filename, vaf in zip(filenames, vafs):
for vaf, filename in zip(vafs[1], filenames):
with open(filename, "w", encoding="utf-8") as out_file:
print(*vaf, file=out_file, sep="\n")
for lag, dat in zip(lags, vaf):
print(lag, dat, file=out_file)

return vafs
11 changes: 7 additions & 4 deletions tests/test_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,21 +179,21 @@ def test_vaf(tmp_path):
vaf_filter = ((3, 4), (1, 2, 3))

data = read(DATA_PATH / "lj-traj.xyz", index=":")
vaf = post_process.compute_vaf(data)
lags, vaf = post_process.compute_vaf(data)
expected = np.loadtxt(DATA_PATH / "vaf-lj.dat")

assert isinstance(vaf, list)
assert len(vaf) == 1
assert isinstance(vaf[0], np.ndarray)
assert vaf[0] == approx(expected, rel=1e-9)

vaf = post_process.compute_vaf(data, fft=True)
lags, vaf = post_process.compute_vaf(data, fft=True)

assert isinstance(vaf, list)
assert len(vaf) == 1
assert isinstance(vaf[0], np.ndarray)

vaf = post_process.compute_vaf(
lags, vaf = post_process.compute_vaf(
data, filter_atoms=vaf_filter, filenames=[tmp_path / name for name in vaf_names]
)

Expand All @@ -205,5 +205,8 @@ def test_vaf(tmp_path):
assert (tmp_path / name).exists()
expected = np.loadtxt(DATA_PATH / name)
written = np.loadtxt(tmp_path / name)
w_lag, w_vaf = written[:, 0], written[:, 1]

assert vaf[i] == approx(expected, rel=1e-9)
assert vaf[i] == approx(written, rel=1e-9)
assert lags == approx(w_lag, rel=1e-9)
assert vaf[i] == approx(w_vaf, rel=1e-9)

0 comments on commit 8678f94

Please sign in to comment.