Skip to content

Commit

Permalink
Fix k_start + utest
Browse files Browse the repository at this point in the history
  • Loading branch information
Florian Deconinck committed Oct 7, 2024
1 parent ed4ddd4 commit 5b09a67
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
14 changes: 7 additions & 7 deletions ndsl/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,14 @@ def __init__(
):
unblock_waiting_tiles(MPI.COMM_WORLD)

self._timing_collector.build_info[
_stencil_object_name(self.stencil_object)
] = build_info
self._timing_collector.build_info[_stencil_object_name(self.stencil_object)] = (
build_info
)
field_info = self.stencil_object.field_info

self._field_origins: Dict[
str, Tuple[int, ...]
] = FrozenStencil._compute_field_origins(field_info, self.origin)
self._field_origins: Dict[str, Tuple[int, ...]] = (
FrozenStencil._compute_field_origins(field_info, self.origin)
)
"""mapping from field names to field origins"""

self._stencil_run_kwargs: Dict[str, Any] = {
Expand Down Expand Up @@ -737,7 +737,7 @@ def axis_offsets(
"local_js": gtscript.J[0] + self.jsc - origin[1],
"j_end": j_end,
"local_je": gtscript.J[-1] + self.jec - origin[1] - domain[1] + 1,
"k_start": self.origin[2],
"k_start": origin[2],
"k_end": self.origin[2] + domain[2] - 1,
}

Expand Down
17 changes: 17 additions & 0 deletions tests/dsl/test_stencil_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,23 @@ def test_get_stencils_with_varied_bounds_and_regions(backend: str):
np.testing.assert_array_equal(q_orig.data, q_ref.data)


def test_stencil_vertical_bounds(backend: str):
factory = get_stencil_factory(backend)
origins = [(3, 3, 0), (2, 2, 1)]
domains = [(1, 1, 3), (2, 2, 4)]
stencils = get_stencils_with_varied_bounds(
add_1_in_region_stencil,
origins,
domains,
stencil_factory=factory,
)

assert "k_start" in stencils[0].externals and stencils[0].externals["k_start"] == 0
assert "k_end" in stencils[0].externals and stencils[0].externals["k_end"] == 2
assert "k_start" in stencils[1].externals and stencils[1].externals["k_start"] == 1
assert "k_end" in stencils[1].externals and stencils[1].externals["k_end"] == 3


@pytest.mark.parametrize("enabled", [True, False])
def test_stencil_factory_numpy_comparison_from_dims_halo(enabled: bool):
backend = "numpy"
Expand Down

0 comments on commit 5b09a67

Please sign in to comment.