Skip to content

Commit

Permalink
Use numbers types and fix Parameter cache
Browse files Browse the repository at this point in the history
  • Loading branch information
loganbvh committed Apr 12, 2024
1 parent 2bd0830 commit c12258f
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 36 deletions.
7 changes: 4 additions & 3 deletions tdgl/device/device.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import numbers
import os
import time
from contextlib import contextmanager, nullcontext
Expand Down Expand Up @@ -416,7 +417,7 @@ def scale(
if not (
isinstance(origin, tuple)
and len(origin) == 2
and all(isinstance(val, (int, float)) for val in origin)
and all(isinstance(val, numbers.Real) for val in origin)
):
raise TypeError("Origin must be a tuple of floats (x, y).")
self._warn_if_mesh_exist("scale()")
Expand Down Expand Up @@ -447,7 +448,7 @@ def rotate(self, degrees: float, origin: Tuple[float, float] = (0, 0)) -> "Devic
if not (
isinstance(origin, tuple)
and len(origin) == 2
and all(isinstance(val, (int, float)) for val in origin)
and all(isinstance(val, numbers.Real) for val in origin)
):
raise TypeError("Origin must be a tuple of floats (x, y).")
self._warn_if_mesh_exist("rotate()")
Expand Down Expand Up @@ -582,7 +583,7 @@ def _create_dimensionless_mesh(
create_submesh=True,
)

def mesh_stats_dict(self) -> Dict[str, Union[int, float, str]]:
def mesh_stats_dict(self) -> Dict[str, Union[numbers.Real, str]]:
"""Returns a dictionary of information about the mesh."""
edge_lengths = self.edge_lengths
areas = self.areas
Expand Down
58 changes: 34 additions & 24 deletions tdgl/parameter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib
import inspect
import operator
from numbers import Number
from typing import Callable, Optional, Union

import cloudpickle
Expand Down Expand Up @@ -67,9 +68,8 @@ class Parameter:
as a function of position coordinates x, y (and optionally z and time t).
Addition, subtraction, multiplication, and division
between multiple Parameters and/or real numbers (ints and floats)
is supported. The result of any of these operations is a
``CompositeParameter`` object.
between multiple Parameters and/or numbers is supported.
The result of any of these operations is a ``CompositeParameter`` object.
Args:
func: A callable/function that actually calculates the parameter's value.
Expand Down Expand Up @@ -155,11 +155,11 @@ def _to_tuple(items):

def _evaluate(
self,
x: Union[int, float, np.ndarray],
y: Union[int, float, np.ndarray],
z: Optional[Union[int, float, np.ndarray]] = None,
x: Union[Number, np.ndarray],
y: Union[Number, np.ndarray],
z: Optional[Union[Number, np.ndarray]] = None,
t: Optional[float] = None,
) -> Union[int, float, np.ndarray]:
) -> Union[Number, np.ndarray]:
kwargs = self.kwargs.copy()
if t is not None:
kwargs["t"] = t
Expand All @@ -173,18 +173,21 @@ def _evaluate(

def __call__(
self,
x: Union[int, float, np.ndarray],
y: Union[int, float, np.ndarray],
z: Optional[Union[int, float, np.ndarray]] = None,
x: Union[Number, np.ndarray],
y: Union[Number, np.ndarray],
z: Optional[Union[Number, np.ndarray]] = None,
t: Optional[float] = None,
) -> Union[int, float, np.ndarray]:
) -> Union[Number, np.ndarray]:
if self._use_cache:
cache_key = self._hash_args(x, y, z, t)
if cache_key not in self._cache:
self._cache[cache_key] = self._evaluate(x, y, z, t)
return self._cache[cache_key]
return self._evaluate(x, y, z, t)

def _clear_cache(self) -> None:
self._cache.clear()

def _get_argspec(self) -> _FakeArgSpec:
if self.kwargs:
kwargs, kwarg_values = list(zip(*self.kwargs.items()))
Expand Down Expand Up @@ -279,12 +282,11 @@ class CompositeParameter(Parameter):
(i.e. it computes a scalar or vector quantity as a function of
position coordinates x, y, z). A CompositeParameter object is created as
a result of mathematical operations between Parameters, CompositeParameters,
and/or real numbers.
and/or numbers.
Addition, subtraction, multiplication, division, and exponentiation
between Parameters, CompositeParameters and real numbers (ints and floats)
are supported. The result of any of these operations is a new
CompositeParameter object.
between ``Parameters``, ``CompositeParameters`` and numbers are supported.
The result of any of these operations is a new ``CompositeParameter`` object.
Args:
left: The object on the left-hand side of the operator.
Expand All @@ -302,11 +304,11 @@ class CompositeParameter(Parameter):

def __init__(
self,
left: Union[int, float, Parameter, "CompositeParameter"],
right: Union[int, float, Parameter, "CompositeParameter"],
left: Union[Number, Parameter, "CompositeParameter"],
right: Union[Number, Parameter, "CompositeParameter"],
operator_: Union[Callable, str],
):
valid_types = (int, float, complex, Parameter, CompositeParameter)
valid_types = (Number, Parameter, CompositeParameter)
if not isinstance(left, valid_types):
raise TypeError(
f"Left must be a number, Parameter, or CompositeParameter, "
Expand All @@ -317,7 +319,7 @@ def __init__(
f"Right must be a number, Parameter, or CompositeParameter, "
f"not {type(right)!r}."
)
if isinstance(left, (int, float)) and isinstance(right, (int, float)):
if isinstance(left, Number) and isinstance(right, Number):
raise TypeError(
"Either left or right must be a Parameter or CompositeParameter."
)
Expand All @@ -329,6 +331,7 @@ def __init__(
f"Unknown operator, {operator_!r}. "
f"Valid operators are {list(self.VALID_OPERATORS)!r}."
)
self._cache = {}
self.left = left
self.right = right
self.operator = operator_
Expand All @@ -342,13 +345,20 @@ def __init__(
if self.right._use_cache is None:
self.right._use_cache = True

Check warning on line 346 in tdgl/parameter.py

View check run for this annotation

Codecov / codecov/patch

tdgl/parameter.py#L346

Added line #L346 was not covered by tests

def _clear_cache(self) -> None:
self._cache.clear()
if isinstance(self.right._cache, Parameter):
self.right._clear_cache()

Check warning on line 351 in tdgl/parameter.py

View check run for this annotation

Codecov / codecov/patch

tdgl/parameter.py#L351

Added line #L351 was not covered by tests
if isinstance(self.left, Parameter):
self.left._clear_cache()

def __call__(
self,
x: Union[int, float, np.ndarray],
y: Union[int, float, np.ndarray],
z: Optional[Union[int, float, np.ndarray]] = None,
x: Union[Number, np.ndarray],
y: Union[Number, np.ndarray],
z: Union[Number, np.ndarray, None] = None,
t: Optional[float] = None,
) -> Union[int, float, np.ndarray]:
) -> Union[Number, np.ndarray]:
kwargs = dict() if t is None else dict(t=t)
values = []
for operand in (self.left, self.right):
Expand Down Expand Up @@ -413,7 +423,7 @@ def __setstate__(self, state):
class Constant(Parameter):
"""A Parameter whose value doesn't depend on position or time."""

def __init__(self, value: Union[int, float, complex], dimensions: int = 2):
def __init__(self, value: Number, dimensions: int = 2):
if dimensions not in (2, 3):
raise ValueError(f"Dimensions must be 2 or 3, got {dimensions}.")
if dimensions == 2:
Expand Down
5 changes: 3 additions & 2 deletions tdgl/solution/solution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import logging
import numbers
import operator
import os
import shutil
Expand Down Expand Up @@ -720,7 +721,7 @@ def field_at_position(
)
zs = positions[:, 2]
positions = positions[:, :2]
elif isinstance(zs, (int, float, np.generic)):
elif isinstance(zs, numbers.Real):
# constant zs
zs = zs * np.ones(len(positions))
zs = zs.squeeze()
Expand Down Expand Up @@ -820,7 +821,7 @@ def vector_potential_at_position(
)
zs = positions[:, 2]
positions = positions[:, :2]
elif isinstance(zs, (int, float, np.generic)):
elif isinstance(zs, numbers.Real):
# constant zs
zs = zs * np.ones(len(positions))
if not isinstance(zs, np.ndarray):
Expand Down
3 changes: 2 additions & 1 deletion tdgl/solver/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import logging
import numbers
import os
import subprocess
import sys
Expand Down Expand Up @@ -153,7 +154,7 @@ def save_fixed_values(self, fixed_data: Dict[str, np.ndarray]) -> None:

def save_time_step(
self,
state: Dict[str, Union[int, float]],
state: Dict[str, numbers.Real],
data: Dict[str, np.ndarray],
running_state: Union[Dict[str, np.ndarray], None],
) -> None:
Expand Down
11 changes: 6 additions & 5 deletions tdgl/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools
import logging
import math
import numbers
import os
from datetime import datetime
from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -216,9 +217,9 @@ def disorder_epsilon(r):

# Clear the Parameter caches
if isinstance(self.applied_vector_potential, Parameter):
self.applied_vector_potential._cache.clear()
self.applied_vector_potential._clear_cache()
if isinstance(self.disorder_epsilon, Parameter):
self.disorder_epsilon._cache.clear()
self.disorder_epsilon._clear_cache()

Check warning on line 222 in tdgl/solver/solver.py

View check run for this annotation

Codecov / codecov/patch

tdgl/solver/solver.py#L222

Added line #L222 was not covered by tests

# Find the current terminal sites.
self.terminal_info = device.terminal_info()
Expand Down Expand Up @@ -578,7 +579,7 @@ def get_induced_vector_potential(

def update(
self,
state: Dict[str, Union[int, float]],
state: Dict[str, numbers.Real],
running_state: RunningState,
dt: float,
*,
Expand Down Expand Up @@ -807,9 +808,9 @@ def solve(self) -> Optional[Solution]:

# Clear the Parameter caches
if isinstance(self.applied_vector_potential, Parameter):
self.applied_vector_potential._cache.clear()
self.applied_vector_potential._clear_cache()
if isinstance(self.disorder_epsilon, Parameter):
self.disorder_epsilon._cache.clear()
self.disorder_epsilon._clear_cache()

Check warning on line 813 in tdgl/solver/solver.py

View check run for this annotation

Codecov / codecov/patch

tdgl/solver/solver.py#L813

Added line #L813 was not covered by tests

solution = None
if data_was_generated:
Expand Down
3 changes: 2 additions & 1 deletion tdgl/visualization/snapshot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numbers
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union

import h5py
Expand Down Expand Up @@ -48,7 +49,7 @@ def generate_snapshots(
Returns:
The matplotlib figure and axes for each time in ``times``
"""
if isinstance(times, (int, float)):
if isinstance(times, numbers.Real):
times = [times]
if quantities is None:
quantities = Quantity.get_keys()
Expand Down

0 comments on commit c12258f

Please sign in to comment.