diff --git a/meshmode/discretization/connection/chained.py b/meshmode/discretization/connection/chained.py index 54bbd98ed..de9602db1 100644 --- a/meshmode/discretization/connection/chained.py +++ b/meshmode/discretization/connection/chained.py @@ -23,6 +23,7 @@ import numpy as np from pytools import Record +from pytools.obj_array import obj_array_vectorized_n_args import modepy as mp from meshmode.discretization.connection.direct import \ @@ -61,11 +62,12 @@ def __init__(self, connections, from_discr=None): self.connections = connections - def __call__(self, vec): + @obj_array_vectorized_n_args + def __call__(self, ary): for cnx in self.connections: - vec = cnx(vec) + ary = cnx(ary) - return vec + return ary # }}} diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index 55820fe03..44c9727d5 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -25,6 +25,7 @@ import loopy as lp from pytools import memoize_in, keyed_memoize_method +from pytools.obj_array import obj_array_vectorized_n_args from meshmode.array_context import ArrayContext, make_loopy_program @@ -249,11 +250,15 @@ def full_resample_matrix(self, actx): return make_direct_full_resample_matrix(actx, self) + @obj_array_vectorized_n_args def __call__(self, ary): from meshmode.dof_array import DOFArray if not isinstance(ary, DOFArray): raise TypeError("non-array passed to discretization connection") + if ary.shape != (len(self.from_discr.groups),): + raise ValueError("invalid shape of incoming resampling data") + actx = ary.array_context @memoize_in(actx, (DirectDiscretizationConnection, "resample_by_mat_knl")) @@ -311,9 +316,6 @@ def pick_knl(): else: result = self.to_discr.zeros(actx, dtype=ary.entry_dtype) - if ary.shape != (len(self.from_discr.groups),): - raise ValueError("invalid shape of incoming resampling data") - for i_tgrp, (tgrp, cgrp) in enumerate( zip(self.to_discr.groups, self.groups)): for i_batch, batch in enumerate(cgrp.batches): diff --git a/meshmode/discretization/connection/projection.py b/meshmode/discretization/connection/projection.py index fd3d96e51..85fe1e376 100644 --- a/meshmode/discretization/connection/projection.py +++ b/meshmode/discretization/connection/projection.py @@ -23,6 +23,7 @@ import numpy as np from pytools import keyed_memoize_method, memoize_in +from pytools.obj_array import obj_array_vectorized_n_args import loopy as lp @@ -115,11 +116,15 @@ def det(v): return weights - def __call__(self, vec): - if not isinstance(vec, DOFArray): + @obj_array_vectorized_n_args + def __call__(self, ary): + if not isinstance(ary, DOFArray): raise TypeError("non-array passed to discretization connection") - actx = vec.array_context + if ary.shape != (len(self.from_discr.groups),): + raise ValueError("invalid shape of incoming resampling data") + + actx = ary.array_context @memoize_in(actx, (L2ProjectionInverseDiscretizationConnection, "conn_projection_knl")) @@ -131,7 +136,7 @@ def kproj(): """ for iel <> element_dot = sum(idof_quad, - vec[from_element_indices[iel], idof_quad] + ary[from_element_indices[iel], idof_quad] * basis[idof_quad] * weights[idof_quad]) result[to_element_indices[iel], ibasis] = \ @@ -139,7 +144,7 @@ def kproj(): end """, [ - lp.GlobalArg("vec", None, + lp.GlobalArg("ary", None, shape=("n_from_elements", "n_from_nodes")), lp.GlobalArg("result", None, shape=("n_to_elements", "n_to_nodes")), @@ -178,7 +183,7 @@ def keval(): weights = self._batch_weights(actx) # perform dot product (on reference element) to get basis coefficients - c = self.to_discr.zeros(actx, dtype=vec.entry_dtype) + c = self.to_discr.zeros(actx, dtype=ary.entry_dtype) for igrp, (tgrp, cgrp) in enumerate( zip(self.to_discr.groups, self.conn.groups)): @@ -195,7 +200,7 @@ def keval(): # saves on recreating the connection groups and batches. actx.call_loopy(kproj(), ibasis=ibasis, - vec=vec[sgrp.index], + ary=ary[sgrp.index], basis=basis, weights=weights[igrp, ibatch], result=c[igrp], @@ -203,7 +208,7 @@ def keval(): to_element_indices=batch.from_element_indices) # evaluate at unit_nodes to get the vector on to_discr - result = self.to_discr.zeros(actx, dtype=vec.entry_dtype) + result = self.to_discr.zeros(actx, dtype=ary.entry_dtype) for igrp, grp in enumerate(self.to_discr.groups): for ibasis, basis_fn in enumerate(grp.basis()): basis = actx.from_numpy(