Skip to content

Commit

Permalink
Update operators.py
Browse files Browse the repository at this point in the history
  • Loading branch information
loganbvh committed Sep 20, 2023
1 parent 6847245 commit ac9aad7
Showing 1 changed file with 60 additions and 2 deletions.
62 changes: 60 additions & 2 deletions tdgl/finite_volume/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import scipy.sparse as sp
from scipy.sparse._sparsetools import csr_sample_offsets

try:
import cupy # type: ignore
Expand All @@ -15,6 +16,61 @@
from .mesh import Mesh


def _get_spmatrix_offsets(
spmatrix: sp.spmatrix, i: np.ndarray, j: np.ndarray, n_samples: int
) -> np.ndarray:
"""Calculates the sparse matrix offsets for a set of rows ``i`` and columns ``j``."""
# See _set_many() at
# https://github.com/scipy/scipy/blob/3f9a8c80e281e746225092621935b88c1ce68040/scipy/sparse/_compressed.py#L901
spmatrix = spmatrix.asformat("csr", copy=False)
i, j, M, N = spmatrix._prepare_indices(i, j)
offsets = np.empty(n_samples, dtype=spmatrix.indices.dtype)
ret = csr_sample_offsets(
M, N, spmatrix.indptr, spmatrix.indices, n_samples, i, j, offsets
)
if ret == 1:
spmatrix.sum_duplicates()
csr_sample_offsets(
M, N, spmatrix.indptr, spmatrix.indices, n_samples, i, j, offsets
)
return offsets, (i, j, M, N)


def _get_spmatrix_offsets_cupy(spmatrix, i, j):
"""Calculates the sparse matrix offsets for a set of rows ``i`` and columns ``j``."""
# See _set_many() at
# https://github.com/cupy/cupy/blob/5c32e40af32f6f9627e09d47ecfeb7e9281ccab2/cupyx/scipy/sparse/_compressed.py#L525
i, j, M, N = spmatrix._prepare_indices(i, j)
new_sp = csr_matrix(
(
cupy.arange(spmatrix.nnz, dtype=cupy.float32),
spmatrix.indices,
spmatrix.indptr,
),
shape=(M, N),
)
offsets = new_sp._get_arrayXarray(i, j, not_found_val=-1).astype(cupy.int32).ravel()
return offsets, (i, j, M, N)


def _spmatrix_set_many(spmatrix, i, j, x):
if sp.issparse(spmatrix):
offsets, (i, j, M, N) = _get_spmatrix_offsets(spmatrix, i, j, len(x))
else:
offsets, (i, j, M, N) = _get_spmatrix_offsets_cupy(spmatrix, i, j)

mask = offsets > -1
spmatrix.data[mask] = x[mask]
if not mask.all():
# only insertions remain
mask = ~mask
i = i[mask]
i[i < 0] += M
j = j[mask]
j[j < 0] += N
spmatrix._insert_many(i, j, x[mask])


def build_divergence(mesh: Mesh) -> sp.csr_array:
"""Build the divergence matrix that takes the divergence of a function living
on the edges onto the sites.
Expand Down Expand Up @@ -317,7 +373,8 @@ def set_link_exponents(self, link_exponents: np.ndarray) -> None:
values = self.gradient_weights * link_variables
rows = self.gradient_link_rows
cols = self.gradient_link_cols
self.psi_gradient[rows, cols] = values
_spmatrix_set_many(self.psi_gradient, rows, cols, values)
# self.psi_gradient[rows, cols] = values
# Update Laplacian for psi
areas = self.areas
weights = self.laplacian_weights
Expand All @@ -336,7 +393,8 @@ def set_link_exponents(self, link_exponents: np.ndarray) -> None:
else:
rows = self.laplacian_link_rows
cols = self.laplacian_link_cols
self.psi_laplacian[rows, cols] = values
_spmatrix_set_many(self.psi_laplacian, rows, cols, values)
# self.psi_laplacian[rows, cols] = values

def get_supercurrent(self, psi: np.ndarray):
"""Compute the supercurrent on the edges.
Expand Down

0 comments on commit ac9aad7

Please sign in to comment.