Skip to content

Commit

Permalink
Merge pull request #78 from alexfikl/vectorize-connections
Browse files Browse the repository at this point in the history
Vectorize connections
  • Loading branch information
inducer committed Nov 13, 2020
2 parents 90a01d3 + f4a0385 commit 535dbce
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
8 changes: 5 additions & 3 deletions meshmode/discretization/connection/chained.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np

from pytools import Record
from pytools.obj_array import obj_array_vectorized_n_args

import modepy as mp
from meshmode.discretization.connection.direct import \
Expand Down Expand Up @@ -61,11 +62,12 @@ def __init__(self, connections, from_discr=None):

self.connections = connections

def __call__(self, vec):
@obj_array_vectorized_n_args
def __call__(self, ary):
for cnx in self.connections:
vec = cnx(vec)
ary = cnx(ary)

return vec
return ary

# }}}

Expand Down
8 changes: 5 additions & 3 deletions meshmode/discretization/connection/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import loopy as lp
from pytools import memoize_in, keyed_memoize_method
from pytools.obj_array import obj_array_vectorized_n_args
from meshmode.array_context import ArrayContext, make_loopy_program


Expand Down Expand Up @@ -249,11 +250,15 @@ def full_resample_matrix(self, actx):

return make_direct_full_resample_matrix(actx, self)

@obj_array_vectorized_n_args
def __call__(self, ary):
from meshmode.dof_array import DOFArray
if not isinstance(ary, DOFArray):
raise TypeError("non-array passed to discretization connection")

if ary.shape != (len(self.from_discr.groups),):
raise ValueError("invalid shape of incoming resampling data")

actx = ary.array_context

@memoize_in(actx, (DirectDiscretizationConnection, "resample_by_mat_knl"))
Expand Down Expand Up @@ -311,9 +316,6 @@ def pick_knl():
else:
result = self.to_discr.zeros(actx, dtype=ary.entry_dtype)

if ary.shape != (len(self.from_discr.groups),):
raise ValueError("invalid shape of incoming resampling data")

for i_tgrp, (tgrp, cgrp) in enumerate(
zip(self.to_discr.groups, self.groups)):
for i_batch, batch in enumerate(cgrp.batches):
Expand Down
21 changes: 13 additions & 8 deletions meshmode/discretization/connection/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np

from pytools import keyed_memoize_method, memoize_in
from pytools.obj_array import obj_array_vectorized_n_args

import loopy as lp

Expand Down Expand Up @@ -115,11 +116,15 @@ def det(v):

return weights

def __call__(self, vec):
if not isinstance(vec, DOFArray):
@obj_array_vectorized_n_args
def __call__(self, ary):
if not isinstance(ary, DOFArray):
raise TypeError("non-array passed to discretization connection")

actx = vec.array_context
if ary.shape != (len(self.from_discr.groups),):
raise ValueError("invalid shape of incoming resampling data")

actx = ary.array_context

@memoize_in(actx, (L2ProjectionInverseDiscretizationConnection,
"conn_projection_knl"))
Expand All @@ -131,15 +136,15 @@ def kproj():
"""
for iel
<> element_dot = sum(idof_quad,
vec[from_element_indices[iel], idof_quad]
ary[from_element_indices[iel], idof_quad]
* basis[idof_quad] * weights[idof_quad])
result[to_element_indices[iel], ibasis] = \
result[to_element_indices[iel], ibasis] + element_dot
end
""",
[
lp.GlobalArg("vec", None,
lp.GlobalArg("ary", None,
shape=("n_from_elements", "n_from_nodes")),
lp.GlobalArg("result", None,
shape=("n_to_elements", "n_to_nodes")),
Expand Down Expand Up @@ -178,7 +183,7 @@ def keval():
weights = self._batch_weights(actx)

# perform dot product (on reference element) to get basis coefficients
c = self.to_discr.zeros(actx, dtype=vec.entry_dtype)
c = self.to_discr.zeros(actx, dtype=ary.entry_dtype)

for igrp, (tgrp, cgrp) in enumerate(
zip(self.to_discr.groups, self.conn.groups)):
Expand All @@ -195,15 +200,15 @@ def keval():
# saves on recreating the connection groups and batches.
actx.call_loopy(kproj(),
ibasis=ibasis,
vec=vec[sgrp.index],
ary=ary[sgrp.index],
basis=basis,
weights=weights[igrp, ibatch],
result=c[igrp],
from_element_indices=batch.to_element_indices,
to_element_indices=batch.from_element_indices)

# evaluate at unit_nodes to get the vector on to_discr
result = self.to_discr.zeros(actx, dtype=vec.entry_dtype)
result = self.to_discr.zeros(actx, dtype=ary.entry_dtype)
for igrp, grp in enumerate(self.to_discr.groups):
for ibasis, basis_fn in enumerate(grp.basis()):
basis = actx.from_numpy(
Expand Down

0 comments on commit 535dbce

Please sign in to comment.