Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Velocity, ShearStress, StressDiagonal, observables #285

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a719731
Adds component Mixin and Velocity osbervable
harveydevereux Aug 8, 2024
33ec8f5
Support SliceLike as atoms in Velocity
harveydevereux Sep 13, 2024
71afe8f
Move Observable into observables.py
harveydevereux Oct 17, 2024
626fbe2
Remove getter
harveydevereux Oct 22, 2024
d706c92
USe annotations
harveydevereux Oct 22, 2024
772f198
Create SliceLike utils
harveydevereux Oct 22, 2024
d0367cd
Clarify values
harveydevereux Oct 23, 2024
e5ce55d
Remove args and kwargs
harveydevereux Nov 1, 2024
a1a889b
typo
harveydevereux Nov 1, 2024
0726349
Manually set __module__
harveydevereux Nov 1, 2024
c30d923
move to fix pre-commit
harveydevereux Nov 1, 2024
ffeca39
Apply suggestions from code review
harveydevereux Nov 1, 2024
518736d
abstract method, fix averaging
harveydevereux Nov 1, 2024
3c486b5
fix import
harveydevereux Nov 4, 2024
143d082
fix import
harveydevereux Nov 4, 2024
34f57eb
ignore unfound refs
harveydevereux Nov 4, 2024
b41034f
move import
harveydevereux Nov 4, 2024
2e82eba
rebase for vaf lags
harveydevereux Nov 4, 2024
773598c
Apply suggestions from code review
harveydevereux Nov 4, 2024
e2a46a6
Rename builtins, multi-line error msg
harveydevereux Nov 4, 2024
514d81b
remove uneeded property
harveydevereux Nov 4, 2024
18e5f35
restore spacing
harveydevereux Nov 5, 2024
aba3aea
Update developer guide
harveydevereux Nov 5, 2024
8c0e9c5
fix error msg
harveydevereux Nov 5, 2024
0f5b648
fix len_for
harveydevereux Nov 5, 2024
97444de
CorrelationKwargs import Observable directly
harveydevereux Nov 5, 2024
ef9d3bd
Simplify Stress __call__
harveydevereux Nov 5, 2024
2e7a95d
Fix typing, use | in janus_types
harveydevereux Nov 6, 2024
bc3b6ad
fix atoms_slice, parse in Observable
harveydevereux Nov 6, 2024
4f1c2c2
split slc_len in test_selector_len
harveydevereux Nov 6, 2024
03ce113
minimal exclusions, fix typing
harveydevereux Nov 6, 2024
5c36b3d
expand on doc
harveydevereux Nov 6, 2024
874d849
Add slicelike validator
harveydevereux Nov 8, 2024
098a4da
Apply suggestions from code review
harveydevereux Nov 8, 2024
93f75ef
Remove value_count, update dev guide
harveydevereux Nov 8, 2024
d1d7e47
Check atoms is Atoms, clearer exception
harveydevereux Nov 8, 2024
6c30712
remove uneeded
harveydevereux Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
("py:class", "Architectures"),
("py:class", "Devices"),
("py:class", "MaybeSequence"),
("py:class", "SliceLike"),
("py:class", "PathLike"),
("py:class", "Atoms"),
("py:class", "Calculator"),
Expand Down
81 changes: 74 additions & 7 deletions docs/source/developer_guide/tutorial.rst
ElliottKasoar marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,87 @@ Alternatively, using ``tox``::
Adding a new Observable
=======================

Additional built-in observable quantities may be added for use by the ``janus_core.processing.correlator.Correlation`` class. These should conform to the ``__call__`` signature of ``janus_core.helpers.janus_types.Observable``. For a user this can be accomplished by writing a function, or class also implementing a commensurate ``__call__``.
A ``janus_core.processing.observables.Observable`` abstracts obtaining a quantity derived from ``Atoms``. They may be used as kernels for input into analysis such as a correlation.

Built-in observables are collected within the ``janus_core.processing.observables`` module. For example the ``janus_core.processing.observables.Stress`` observable allows a user to quickly setup a given correlation of stress tensor components (with and without the ideal gas contribution). An observable for the ``xy`` component is obtained without the ideal gas contribution as:
Additional built-in observable quantities may be added for use by the ``janus_core.processing.correlator.Correlation`` class. These should extend ``janus_core.processing.observables.Observable`` and are implemented within the ``janus_core.processing.observables`` module.

The abstract method ``__call__`` should be implemented to obtain the values of the observed quantity from an ``Atoms`` object. When used as part of a ``janus_core.processing.correlator.Correlation``, each value will be correlated and the results averaged.

As an example of building a new ``Observable`` consider the ``janus_core.processing.observables.Stress`` built-in. The following steps may be taken:

1. Defining the observable.
---------------------------

The stress tensor may be computed on an atoms object using ``Atoms.get_stress``. A user may wish to obtain a particular component, or perhaps only compute the stress on some subset of ``Atoms``. For example during a ``janus_core.calculations.md.MolecularDynamics`` run a user may wish to correlate only the off-diagonal components (shear stress), computed across all atoms.

2. Writing the ``__call__`` method.
-----------------------------------

In the call method we can use the base ``janus_core.processing.observables.Observable``'s optional atom selector ``atoms_slice`` to first define the subset of atoms to compute the stress for:

.. code-block:: python

Stress("xy", False)
def __call__(self, atoms: Atoms) -> list[float]:
sliced_atoms = atoms[self.atoms_slice]
# must be re-attached after slicing for get_stress
sliced_atoms.calc = atoms.calc

A new built-in observables can be implemented by a class with the method:
Next the stresses may be obtained from:

.. code-block:: python

def __call__(self, atoms: Atoms, *args, **kwargs) -> float
stresses = (
sliced_atoms.get_stress(
include_ideal_gas=self.include_ideal_gas, voigt=True
)
/ units.GPa
)

Finally, to facilitate handling components in a symbolic way, ``janus_core.processing.observables.ComponentMixin`` exists to parse ``str`` symbolic components to ``int`` indices by defining a suitable mapping. For the stress tensor (and the format of ``Atoms.get_stress``) a suitable mapping is defined in ``janus_core.processing.observables.Stress``'s ``__init__`` method:

The ``__call__`` should contain all the logic for obtaining some ``float`` value from an ``Atoms`` object, alongside optional positional arguments and kwargs. The args and kwargs are set by a user when specifying correlations for a ``janus_core.calculations.md.MolecularDynamics`` run. See also ``janus_core.helpers.janus_types.CorrelationKwargs``. These are set at the instantiation of the ``janus_core.calculations.md.MolecularDynamics`` object and are not modified. These could be used e.g. to specify an observable calculated only from one atom's data.
.. code-block:: python

ComponentMixin.__init__(
self,
components={
"xx": 0,
"yy": 1,
"zz": 2,
"yz": 3,
"zy": 3,
"xz": 4,
"zx": 4,
"xy": 5,
"yx": 5,
},
)

This then concludes the ``__call__`` method for ``janus_core.processing.observables.Stress`` by using ``janus_core.processing.observables.ComponentMixin``'s
pre-calculated indices:

.. code-block:: python

return stesses[self._indices]

The combination of the above means a user may obtain, say, the ``xy`` and ``zy`` stress tensor components over odd-indexed atoms by calling the following observable on an ``Atoms``:

.. code-block:: python

s = Stress(components=["xy", "zy"], atoms_slice=(0, None, 2))


Since usually total system stresses are required we can define two built-ins to handle the shear and hydrostatic stresses like so:

.. code-block:: python

StressHydrostatic = Stress(components=["xx", "yy", "zz"])
StressShear = Stress(components=["xy", "yz", "zx"])

Where by default ``janus_core.processing.observables.Observable``'s ``atoms_slice`` is ``slice(0, None, 1)``, which expands to all atoms in an ``Atoms``.

For comparison the ``janus_core.processing.observables.Velocity`` built-in's ``__call__`` not only returns atom velocity for the requested components, but also returns them for every tracked atom i.e:

.. code-block:: python

``janus_core.processing.observables.Stress`` includes a constructor to take a symbolic component, e.g. ``"xx"`` or ``"yz"``, and determine the index required from ``ase.Atoms.get_stress`` on instantiation for ease of use.
def __call__(self, atoms: Atoms) -> list[float]:
return atoms.get_velocities()[self.atoms_slice, :][:, self._indices].flatten()
5 changes: 4 additions & 1 deletion janus_core/calculations/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,10 @@ def _restart_file(self) -> str:
def _parse_correlations(self) -> None:
"""Parse correlation kwargs into Correlations."""
if self.correlation_kwargs:
self._correlations = [Correlation(**cor) for cor in self.correlation_kwargs]
self._correlations = [
Correlation(n_atoms=self.n_atoms, **cor)
for cor in self.correlation_kwargs
]
else:
self._correlations = ()

Expand Down
37 changes: 6 additions & 31 deletions janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,16 @@
from enum import Enum
import logging
from pathlib import Path, PurePath
from typing import (
IO,
Literal,
Optional,
Protocol,
TypedDict,
TypeVar,
Union,
runtime_checkable,
)
from typing import IO, TYPE_CHECKING, Literal, Optional, TypedDict, TypeVar, Union

from ase import Atoms
from ase.eos import EquationOfState
import numpy as np
from numpy.typing import NDArray

if TYPE_CHECKING:
from janus_core.processing.observables import Observable

# General

T = TypeVar("T")
Expand Down Expand Up @@ -86,32 +80,13 @@ class PostProcessKwargs(TypedDict, total=False):
vaf_output_file: PathLike | None


@runtime_checkable
class Observable(Protocol):
"""Signature for correlation observable getter."""

def __call__(self, atoms: Atoms, *args, **kwargs) -> float:
"""
Call the getter.

Parameters
----------
atoms : Atoms
Atoms object to extract values from.
*args : tuple
Additional positional arguments passed to getter.
**kwargs : dict
Additional kwargs passed getter.
"""


class CorrelationKwargs(TypedDict, total=True):
"""Arguments for on-the-fly correlations <ab>."""

#: observable a in <ab>, with optional args and kwargs
a: Observable | tuple[Observable, tuple, dict]
a: Observable
#: observable b in <ab>, with optional args and kwargs
b: Observable | tuple[Observable, tuple, dict]
b: Observable
#: name used for correlation in output
name: str
#: blocks used in multi-tau algorithm
Expand Down
87 changes: 86 additions & 1 deletion janus_core/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
)
from rich.style import Style

from janus_core.helpers.janus_types import MaybeSequence, PathLike
from janus_core.helpers.janus_types import (
MaybeSequence,
PathLike,
SliceLike,
StartStopStep,
)


class FileNameMixin(ABC): # noqa: B024 (abstract-base-class-without-abstract-method)
Expand Down Expand Up @@ -409,3 +414,83 @@ def track_progress(sequence: Sequence | Iterable, description: str) -> Iterable:

with progress:
yield from progress.track(sequence, description=description)


def validate_slicelike(maybe_slicelike: SliceLike) -> None:
"""
Raise an exception if slc is not a valid SliceLike.

Parameters
----------
maybe_slicelike : SliceLike
Candidate to test.

Raises
------
ValueError
If maybe_slicelike is not SliceLike.
"""
if isinstance(maybe_slicelike, (slice, range, int)):
return
if isinstance(maybe_slicelike, tuple) and len(maybe_slicelike) == 3:
start, stop, step = maybe_slicelike
if (
(start is None or isinstance(start, int))
and (stop is None or isinstance(stop, int))
and isinstance(step, int)
):
return

raise ValueError(f"{maybe_slicelike} is not a valid SliceLike")


def slicelike_to_startstopstep(index: SliceLike) -> StartStopStep:
"""
Standarize `SliceLike`s into tuple of `start`, `stop`, `step`.

Parameters
----------
index : SliceLike
`SliceLike` to standardize.

Returns
-------
StartStopStep
Standardized `SliceLike` as `start`, `stop`, `step` triplet.
"""
validate_slicelike(index)
if isinstance(index, int):
if index == -1:
return (index, None, 1)
return (index, index + 1, 1)
ElliottKasoar marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(index, (slice, range)):
return (index.start, index.stop, index.step)

return index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth checking if index is a StartStopStep (or any other valid input)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean throwing an error if it is not actually a SliceLike?

i.e. these cases are no good but "run"

>>> slicelike_to_startstopstep([1,2])
[1, 2]
>>> slicelike_to_startstopstep((None, None, None))
(None, None, None)

its probably a good move

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do

def validate_slicelike(maybe_slicelike: SliceLike):
    # ...
    if isinstance(maybe_slicelike, (slice, range, int)):
        return
    if isinstance(maybe_slicelike, tuple) and len(maybe_slicelike) == 3:
        start, stop, step = maybe_slicelike
        if (
            isinstance(start, Optional[int])
            and isinstance(stop, Optional[int])
            and isinstance(step, int)
        ):
            return

    raise ValueError(f"{maybe_slicelike} is not a valid SliceLike")

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that sort of thing would be great



def selector_len(slc: SliceLike | list, selectable_length: int) -> int:
"""
Calculate the length of a selector applied to an indexable of a given length.

Parameters
----------
slc : Union[SliceLike, list]
The applied SliceLike or list for selection.
selectable_length : int
The length of the selectable object.

Returns
-------
int
Length of the result of applying slc.
"""
if isinstance(slc, int):
return 1
if isinstance(slc, list):
return len(slc)
start, stop, step = slicelike_to_startstopstep(slc)
if stop is None:
stop = selectable_length
return len(range(start, stop, step))
67 changes: 38 additions & 29 deletions janus_core/processing/correlator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ase import Atoms
import numpy as np

from janus_core.helpers.janus_types import Observable
from janus_core.processing.observables import Observable


class Correlator:
Expand Down Expand Up @@ -179,10 +179,12 @@ class Correlation:

Parameters
----------
a : tuple[Observable, dict]
Getter for a and kwargs.
b : tuple[Observable, dict]
Getter for b and kwargs.
n_atoms : int
Number of possible atoms to track.
a : Observable
Observable for a.
b : Observable
Observable for b.
name : str
Name of correlation.
blocks : int
Expand All @@ -197,8 +199,10 @@ class Correlation:

def __init__(
self,
a: Observable | tuple[Observable, tuple, dict],
b: Observable | tuple[Observable, tuple, dict],
*,
n_atoms: int,
a: Observable,
b: Observable,
name: str,
blocks: int,
points: int,
Expand All @@ -210,10 +214,12 @@ def __init__(

Parameters
----------
a : tuple[Observable, tuple, dict]
Getter for a and kwargs.
b : tuple[Observable, tuple, dict]
Getter for b and kwargs.
n_atoms : int
Number of possible atoms to track.
a : Observable
Observable for a.
b : Observable
Observable for b.
name : str
Name of correlation.
blocks : int
Expand All @@ -226,19 +232,13 @@ def __init__(
Frequency to update the correlation, md steps.
"""
self.name = name
if isinstance(a, tuple):
self._get_a, self._a_args, self._a_kwargs = a
else:
self._get_a = a
self._a_args, self._a_kwargs = (), {}

if isinstance(b, tuple):
self._get_b, self._b_args, self._b_kwargs = b
else:
self._get_b = b
self._b_args, self._b_kwargs = (), {}
self.blocks = blocks
self.points = points
self.averaging = averaging
self._get_a = a
self._get_b = b

self._correlator = Correlator(blocks=blocks, points=points, averaging=averaging)
self._correlators = None
self._update_frequency = update_frequency

@property
Expand All @@ -262,14 +262,20 @@ def update(self, atoms: Atoms) -> None:
atoms : Atoms
Atoms object to observe values from.
"""
self._correlator.update(
self._get_a(atoms, *self._a_args, **self._a_kwargs),
self._get_b(atoms, *self._b_args, **self._b_kwargs),
)
value_pairs = zip(self._get_a(atoms), self._get_b(atoms))
if self._correlators is None:
self._correlators = [
Correlator(
blocks=self.blocks, points=self.points, averaging=self.averaging
)
for _ in range(len(self._get_a(atoms)))
]
for corr, values in zip(self._correlators, value_pairs):
corr.update(*values)

def get(self) -> tuple[Iterable[float], Iterable[float]]:
"""
Get the correlation value and lags.
Get the correlation value and lags, averaging over atoms if applicable.

Returns
-------
Expand All @@ -278,7 +284,10 @@ def get(self) -> tuple[Iterable[float], Iterable[float]]:
lags : Iterable[float]]
The correlation lag times t'.
"""
return self._correlator.get()
if self._correlators:
_, lags = self._correlators[0].get()
return np.mean([cor.get()[0] for cor in self._correlators], axis=0), lags
return [], []

def __str__(self) -> str:
"""
Expand Down
Loading