Skip to content

Commit

Permalink
Simplify spmatrix _set_many (#42)
Browse files Browse the repository at this point in the history
* Simplify spmatrix _set_many
  • Loading branch information
loganbvh authored Oct 2, 2023
1 parent 49dfc15 commit d6d0440
Showing 1 changed file with 7 additions and 25 deletions.
32 changes: 7 additions & 25 deletions tdgl/finite_volume/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import scipy.sparse as sp
from scipy.sparse._sparsetools import csr_sample_offsets

try:
import cupy # type: ignore
Expand All @@ -16,25 +15,6 @@
from .mesh import Mesh


def _get_spmatrix_offsets(
spmatrix: sp.spmatrix, i: np.ndarray, j: np.ndarray, n_samples: int
) -> np.ndarray:
"""Calculates the sparse matrix offsets for a set of rows ``i`` and columns ``j``."""
# See _set_many() at
# https://github.com/scipy/scipy/blob/3f9a8c80e281e746225092621935b88c1ce68040/scipy/sparse/_compressed.py#L901
i, j, M, N = spmatrix._prepare_indices(i, j)
offsets = np.empty(n_samples, dtype=spmatrix.indices.dtype)
ret = csr_sample_offsets(
M, N, spmatrix.indptr, spmatrix.indices, n_samples, i, j, offsets
)
if ret == 1:
spmatrix.sum_duplicates()
csr_sample_offsets(
M, N, spmatrix.indptr, spmatrix.indices, n_samples, i, j, offsets
)
return offsets, (i, j, M, N)


def _get_spmatrix_offsets_cupy(spmatrix, i, j):
"""Calculates the sparse matrix offsets for a set of rows ``i`` and columns ``j``."""
# See _set_many() at
Expand All @@ -55,14 +35,16 @@ def _get_spmatrix_offsets_cupy(spmatrix, i, j):
def _spmatrix_set_many(spmatrix, i, j, x):
"""spmatrix.__setitem__()"""
if sp.issparse(spmatrix):
i, j = spmatrix._swap((i, j))
offsets, (i, j, M, N) = _get_spmatrix_offsets(spmatrix, i, j, len(x))
else:
i, j = spmatrix._swap(i, j)
offsets, (i, j, M, N) = _get_spmatrix_offsets_cupy(spmatrix, i, j)
spmatrix[i, j] = x
return

i, j = spmatrix._swap(i, j)
offsets, (i, j, M, N) = _get_spmatrix_offsets_cupy(spmatrix, i, j)

mask = offsets > -1
# update where possible
spmatrix.data[offsets[mask]] = x[mask]

if not mask.all():
# only insertions remain
mask = ~mask
Expand Down

0 comments on commit d6d0440

Please sign in to comment.