Skip to content

Commit

Permalink
Use matplotlib tri interpolators.
Browse files Browse the repository at this point in the history
This is much faster than scipy interpolators because we can use the existing Delaunay triangulation
  • Loading branch information
loganbvh committed Aug 3, 2023
1 parent 575fe21 commit 3b134be
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 72 deletions.
13 changes: 12 additions & 1 deletion tdgl/finite_volume/mesh.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union

import h5py
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.tri import Triangulation

from ..geometry import close_curve
from .edge_mesh import EdgeMesh
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
self.dual_sites = dual_sites
self.edge_mesh = edge_mesh
self.voronoi_polygons = voronoi_polygons
self._triangulation: Optional[Triangulation] = None

@property
def x(self) -> np.ndarray:
Expand All @@ -71,6 +73,15 @@ def y(self) -> np.ndarray:
"""The y-coordinates of the mesh sites."""
return self.sites[:, 1]

@property
def triangulation(self) -> Triangulation:
"""Matplotlib triangulation of the mesh."""
if self._triangulation is None:
self._triangulation = Triangulation(
self.sites[:, 0], self.sites[:, 1], self.elements
)
return self._triangulation

def closest_site(self, xy: Tuple[float, float]) -> int:
"""Returns the index of the mesh site closest to ``(x, y)``.
Expand Down
26 changes: 14 additions & 12 deletions tdgl/solution/data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import dataclasses
import os
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union

import h5py
import matplotlib.pyplot as plt
import matplotlib.tri as mtri
import numpy as np
from scipy import interpolate
from tqdm import tqdm

from ..finite_volume.mesh import Mesh
Expand Down Expand Up @@ -447,7 +447,7 @@ def get_current_through_paths(
solution_path: os.PathLike,
paths: Union[np.ndarray, List[np.ndarray]],
dataset: Optional[str] = None,
interp_method: str = "linear",
interp_method: Literal["linear", "cubic"] = "linear",
units: Optional[str] = None,
with_units: bool = True,
progress_bar: bool = True,
Expand Down Expand Up @@ -475,19 +475,18 @@ def get_current_through_paths(
solution = Solution.from_hdf5(solution_path)
device = solution.device
mesh = device.mesh
tri = mesh.triangulation
ureg = device.ureg

valid_methods = ("linear", "cubic")
if interp_method not in valid_methods:
raise ValueError(
f"Interpolation method must be one of {valid_methods} (got {interp_method})."
)
if interp_method == "linear":
interpolator = interpolate.LinearNDInterpolator
interp_kwargs = dict(fill_value=0)
else: # "cubic"
interpolator = interpolate.CloughTocher2DInterpolator
interp_kwargs = dict(fill_value=0)
interp_type = {
"linear": mtri.LinearTriInterpolator,
"cubic": mtri.CubicTriInterpolator,
}[interp_method]

valid_datasets = ("supercurrent", "normal_current", None)
if dataset not in valid_datasets:
Expand Down Expand Up @@ -516,7 +515,6 @@ def get_current_through_paths(

step_min, step_max = solution.data_range
times = solution.times
sites = device.points
raw_currents = [np.zeros_like(times) for _ in paths]
with h5py.File(solution_path, "r") as h5file:
for i in tqdm(
Expand All @@ -528,11 +526,15 @@ def get_current_through_paths(
else:
K = np.array(grp[dataset])
K = mesh.get_quantity_on_site(K)
K_interp = interpolator(sites, K, **interp_kwargs)
Kx_interp = interp_type(tri, K[:, 0])
Ky_interp = interp_type(tri, K[:, 1])
for j, (path, lengths, normals, ix) in enumerate(
zip(paths, edge_lengths, unit_normals, in_device)
):
K_path = K_interp(path)
Kx_path = Kx_interp(path[:, 0], path[:, 1]).data
Ky_path = Ky_interp(path[:, 0], path[:, 1]).data
K_path = np.array([Kx_path, Ky_path]).T
K_path[~np.isfinite(K_path).all(axis=1)] = 0
# Evaluate the sheet current at the edge centers
K_edge = (K_path[:-1] + K_path[1:]) / 2
K_dot_n = (K_edge * normals).sum(axis=1)
Expand Down
20 changes: 9 additions & 11 deletions tdgl/solution/plot_solution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Dict, List, Literal, Optional, Sequence, Tuple, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -77,7 +77,7 @@ def cross_section(
dataset_coords: np.ndarray,
dataset_values: np.ndarray,
cross_section_coords: Union[np.ndarray, Sequence[np.ndarray]],
interp_method: str = "linear",
interp_method: Literal["linear", "cubic"] = "linear",
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
"""Takes a cross-section of the specified dataset values along
a path given by the given dataset coordinates.
Expand All @@ -89,24 +89,22 @@ def cross_section(
cross_section_coords: A shape (m, 2) array of (x, y) coordinates specifying
the cross-section path (or a list of such arrays for multiple
cross sections).
interp_method: The interpolation method to use: "nearest", "linear", "cubic".
interp_method: The interpolation method to use: "linear" or "cubic".
Returns:
A list of coordinate arrays, a list of curvilinear coordinate (path) arrays,
and a list of cross section values.
"""
valid_methods = ("nearest", "linear", "cubic")
valid_methods = ("linear", "cubic")
if interp_method not in valid_methods:
raise ValueError(
f"Interpolation method must be one of {valid_methods} "
f"(got {interp_method})."
)
if interp_method == "nearest":
interpolator = interpolate.NearestNDInterpolator
elif interp_method == "linear":
interpolator = interpolate.LinearNDInterpolator
else: # "cubic"
interpolator = interpolate.CloughTocher2DInterpolator
interpolator = {
"linear": interpolate.LinearNDInterpolator,
"cubic": interpolate.CloughTocher2DInterpolator,
}[interp_method]

if not (isinstance(cross_section_coords, Sequence)):
cross_section_coords = [cross_section_coords]
Expand Down Expand Up @@ -641,7 +639,7 @@ def plot_current_through_paths(
solution_path: os.PathLike,
paths: Union[np.ndarray, List[np.ndarray]],
dataset: Optional[str] = None,
interp_method: str = "linear",
interp_method: Literal["linear", "cubic"] = "linear",
units: Optional[str] = None,
progress_bar: bool = True,
grid: bool = True,
Expand Down
87 changes: 41 additions & 46 deletions tdgl/solution/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import shutil
from contextlib import nullcontext
from datetime import datetime
from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple, Union
from typing import Any, Callable, Dict, Literal, NamedTuple, Optional, Tuple, Union

import cloudpickle
import h5py
import matplotlib.pyplot as plt
import matplotlib.tri as mtri
import numpy as np
import pint
from scipy import interpolate
Expand Down Expand Up @@ -336,7 +337,7 @@ def interp_current_density(
positions: np.ndarray,
*,
dataset: Union[str, None] = None,
method: str = "linear",
method: Literal["linear", "cubic"] = "linear",
units: Union[str, None] = None,
with_units: bool = False,
) -> np.ndarray:
Expand All @@ -352,8 +353,7 @@ def interp_current_density(
dataset: The dataset to interpolate. One of ``"supercurrent"``,
``"normal_current"``, or ``None``. If ``None``, then the total
sheet current density is used.
method: Interpolation method to use, ``"nearest"``, ``"linear"``,
or ``"cubic"``.
method: Interpolation method to use, ``"linear"`` or ``"cubic"``.
units: The desired units for the current density. Defaults to
``self.current_units / self.device.length_units``.
with_units: Whether to return a :class:`pint.Quantity` array
Expand All @@ -363,21 +363,6 @@ def interp_current_density(
The interpolated current density as an array of floats
or a :class:`pint.Quantity` array.
"""
valid_methods = ("nearest", "linear", "cubic")
if method not in valid_methods:
raise ValueError(
f"Interpolation method must be one of {valid_methods} (got {method})."
)
if method == "nearest":
interpolator = interpolate.NearestNDInterpolator
interp_kwargs = dict()
elif method == "linear":
interpolator = interpolate.LinearNDInterpolator
interp_kwargs = dict(fill_value=0)
else: # "cubic"
interpolator = interpolate.CloughTocher2DInterpolator
interp_kwargs = dict(fill_value=0)

if dataset is None:
J = self.current_density
elif dataset == "supercurrent":
Expand All @@ -389,11 +374,26 @@ def interp_current_density(

if units is None:
units = f"{self.current_units} / {self.device.length_units}"

valid_methods = ("linear", "cubic")
if method not in valid_methods:
raise ValueError(

Check warning on line 380 in tdgl/solution/solution.py

View check run for this annotation

Codecov / codecov/patch

tdgl/solution/solution.py#L380

Added line #L380 was not covered by tests
f"Interpolation method must be one of {valid_methods} (got {method})."
)
interp_type = {
"linear": mtri.LinearTriInterpolator,
"cubic": mtri.CubicTriInterpolator,
}[method]

positions = np.atleast_2d(positions)
xy = self.device.points
J_interp = interpolator(xy, J.to(units).magnitude, **interp_kwargs)
J = J_interp(positions)
J[~np.isfinite(J)] = 0
J = J.to(units).magnitude
tri = self.device.mesh.triangulation
Jx_interp = interp_type(tri, J[:, 0])
Jy_interp = interp_type(tri, J[:, 1])
Jx = Jx_interp(positions[:, 0], positions[:, 1]).data
Jy = Jy_interp(positions[:, 0], positions[:, 1]).data
J = np.array([Jx, Jy]).T
J[~np.isfinite(J).all(axis=1)] = 0
J[~self.device.contains_points(positions)] = 0
if with_units:
J = J * self.device.ureg(units)
Expand All @@ -402,43 +402,40 @@ def interp_current_density(
def interp_order_parameter(
self,
positions: np.ndarray,
method: str = "linear",
method: Literal["linear", "cubic"] = "linear",
) -> np.ndarray:
"""Interpolates the order parameter at unstructured coordinates.
Args:
positions: Shape ``(m, 2)`` array of x, y coordinates at which to evaluate
the order parameter.
method: Interpolation method to use, ``"nearest"``, ``"linear"``,
or ``"cubic"``.
method: Interpolation method to use, ``"linear"`` or ``"cubic"``.
Returns:
The interpolated order parameter.
"""
valid_methods = ("nearest", "linear", "cubic")
valid_methods = ("linear", "cubic")
if method not in valid_methods:
raise ValueError(
f"Interpolation method must be one of {valid_methods} (got {method})."
)
if method == "nearest":
interpolator = interpolate.NearestNDInterpolator
interp_kwargs = dict()
elif method == "linear":
interpolator = interpolate.LinearNDInterpolator
interp_kwargs = dict(fill_value=1)
else: # "cubic"
interpolator = interpolate.CloughTocher2DInterpolator
interp_kwargs = dict(fill_value=1)
interp_type = {
"linear": mtri.LinearTriInterpolator,
"cubic": mtri.CubicTriInterpolator,
}[method]
positions = np.atleast_2d(positions)
xy = self.device.points
tri = self.device.mesh.triangulation
psi = self.tdgl_data.psi
psi_interp = interpolator(xy, psi, **interp_kwargs)
return psi_interp(positions)
psi_interp_real = interp_type(tri, psi.real)
psi_interp_imag = interp_type(tri, psi.imag)
psi_real = psi_interp_real(positions[:, 0], positions[:, 1]).data
psi_imag = psi_interp_imag(positions[:, 0], positions[:, 1]).data
return psi_real + 1j * psi_imag

def polygon_fluxoid(
self,
polygon_points: Union[np.ndarray, Polygon],
interp_method: str = "linear",
interp_method: Literal["linear", "cubic"] = "linear",
units: str = "Phi_0",
with_units: bool = True,
) -> Fluxoid:
Expand Down Expand Up @@ -467,8 +464,7 @@ def polygon_fluxoid(
Args:
polygon_points: A shape ``(n, 2)`` array of ``(x, y)`` coordinates of
polygon vertices defining the closed region :math:`S`.
interp_method: Interpolation method to use, ``"nearest"``, ``"linear"``,
or ``"cubic"``.
interp_method: Interpolation method to use, ``"linear"`` or ``"cubic"``.
units: The desired units for the fluxoid.
with_units: Whether to return values as :class:`pint.Quantity` instances
with units attached.
Expand Down Expand Up @@ -526,7 +522,7 @@ def hole_fluxoid(
self,
hole_name: str,
points: Union[np.ndarray, None] = None,
interp_method: str = "linear",
interp_method: Literal["linear", "cubic"] = "linear",
units: str = "Phi_0",
with_units: bool = True,
) -> Fluxoid:
Expand All @@ -541,8 +537,7 @@ def hole_fluxoid(
points: The vertices of the polygon enclosing the hole. If None is given,
a polygon is generated using
:func:`tdgl.make_fluxoid_polygons`.
interp_method: Interpolation method to use, ``"nearest"``, ``"linear"``,
or ``"cubic"``.
interp_method: Interpolation method to use, ``"linear"`` or ``"cubic"``.
units: The desired units for the fluxoid.
with_units: Whether to return values as :class:`pint.Quantity` instances
with units attached.
Expand Down Expand Up @@ -600,7 +595,7 @@ def current_through_path(
self,
path_coords: np.ndarray,
dataset: Union[str, None] = None,
method: str = "linear",
method: Literal["linear", "cubic"] = "linear",
units: Union[str, None] = None,
with_units: bool = True,
) -> Union[float, pint.Quantity]:
Expand Down
1 change: 1 addition & 0 deletions tdgl/solver/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def current_func(t):
validate_terminal_currents(current_func, terminal_info, options)

# Construct finite-volume operators
logger.info("Constructing finite volume operators.")
terminal_psi = options.terminal_psi
operators = MeshOperators(
mesh,
Expand Down
2 changes: 1 addition & 1 deletion tdgl/test/test_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_save_and_load_solution(solution, tempdir):
@pytest.mark.parametrize("hole", ["hole1", "hole2", "invalid"])
@pytest.mark.parametrize("with_units", [False, True])
@pytest.mark.parametrize("units", ["Phi_0", "mT * um**2"])
@pytest.mark.parametrize("interp_method", ["linear", "cubic", "nearest"])
@pytest.mark.parametrize("interp_method", ["linear", "cubic"])
def test_hole_fluxoid(solution, hole, with_units, units, interp_method):
if hole == "invalid":
with pytest.raises(KeyError):
Expand Down
2 changes: 1 addition & 1 deletion tdgl/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version_info__ = (0, 3, 1)
__version_info__ = (0, 4, 0)
__version__ = ".".join(map(str, __version_info__))

0 comments on commit 3b134be

Please sign in to comment.