From af95c079b760c01439e5d03248221d1cc21d51da Mon Sep 17 00:00:00 2001 From: Oliver Elbert Date: Tue, 11 Jun 2024 11:28:46 -0400 Subject: [PATCH 1/4] adding new feature for 2d stencil indexing --- ndsl/dsl/stencil.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index 75ef28e..ff2d9c8 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -787,6 +787,34 @@ def get_origin_domain( domain[i] += 2 * n return tuple(origin), tuple(domain) + def get_2d_origin_domain( + self, + dims: Sequence[str], + halos: Sequence[int] = tuple(), + klevel: int = 0, + ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """ + Get the origin and domain for a computation that occurs on the lowest klevel over a certain grid + configuration (given by dims) and a certain number of halo points. + + Args: + dims: dimension names, using dimension constants from ndsl.constants + halos: number of halo points for each dimension, defaults to zero + klevel: the vertical level of the domain, defaults to zero + + Returns: + origin: origin of the computation + domain: shape of the computation + """ + origin = self._origin_from_dims(dims) + origin[2] = klevel + domain = list(self._sizer.get_extent(dims)) + domain[2] = 1 + for i, n in enumerate(halos): + origin[i] -= n + domain[i] += 2 * n + return tuple(origin), tuple(domain) + def _origin_from_dims(self, dims: Iterable[str]) -> List[int]: return_origin = [] for dim in dims: From bd24acafec818f9c43906ea4eedfc30934c9f31f Mon Sep 17 00:00:00 2001 From: Oliver Elbert Date: Tue, 11 Jun 2024 16:19:02 -0400 Subject: [PATCH 2/4] added test, updated method --- ndsl/dsl/stencil.py | 17 ++++----------- tests/dsl/test_stencil_factory.py | 35 +++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 13 deletions(-) diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index ff2d9c8..0cb025a 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -787,10 +787,8 @@ def get_origin_domain( domain[i] += 2 * n return tuple(origin), tuple(domain) - def get_2d_origin_domain( + def get_2d_compute_origin_domain( self, - dims: Sequence[str], - halos: Sequence[int] = tuple(), klevel: int = 0, ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: """ @@ -798,22 +796,15 @@ def get_2d_origin_domain( configuration (given by dims) and a certain number of halo points. Args: - dims: dimension names, using dimension constants from ndsl.constants - halos: number of halo points for each dimension, defaults to zero klevel: the vertical level of the domain, defaults to zero Returns: origin: origin of the computation domain: shape of the computation """ - origin = self._origin_from_dims(dims) - origin[2] = klevel - domain = list(self._sizer.get_extent(dims)) - domain[2] = 1 - for i, n in enumerate(halos): - origin[i] -= n - domain[i] += 2 * n - return tuple(origin), tuple(domain) + origin = (self.isc, self.jsc, klevel) + domain = (self.iec + 1 - self.isc, self.jec + 1 - self.jsc, 1) + return (origin, domain) def _origin_from_dims(self, dims: Iterable[str]) -> List[int]: return_origin = [] diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index ac189ad..ef0048d 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -173,6 +173,41 @@ def test_stencil_factory_numpy_comparison_from_origin_domain(enabled: bool): assert isinstance(stencil, CompareToNumpyStencil) else: assert isinstance(stencil, FrozenStencil) + + +@pytest.mark.parametrize("enabled", [True, False]) +def test_stencil_factory_numpy_comparison_from_origin_domain_2d(enabled: bool): + backend = "numpy" + dace_config = DaceConfig(communicator=None, backend=backend) + config = StencilConfig( + compilation_config=CompilationConfig( + backend=backend, + rebuild=False, + validate_args=False, + format_source=False, + device_sync=False, + ), + compare_to_numpy=enabled, + dace_config=dace_config, + ) + indexing = GridIndexing( + domain=(12, 12, 79), + n_halo=3, + south_edge=True, + north_edge=True, + west_edge=True, + east_edge=True, + ) + dims = ["X_DIM", "Y_DIM", "Z_DIM"] + origin, domain = indexing.get_2d_compute_origin_domain(klevel=1) + factory = StencilFactory(config=config, grid_indexing=indexing) + stencil = factory.from_origin_domain( + func=copy_stencil, origin=origin, domain=domain + ) + if enabled: + assert isinstance(stencil, CompareToNumpyStencil) + else: + assert isinstance(stencil, FrozenStencil) def test_stencil_factory_numpy_comparison_runs_without_exceptions(): From 81ffb4a0f2882cf6ba7a3257e4e6a362fb4fe02e Mon Sep 17 00:00:00 2001 From: Oliver Elbert Date: Tue, 11 Jun 2024 16:21:10 -0400 Subject: [PATCH 3/4] unused line --- tests/dsl/test_stencil_factory.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index ef0048d..a21d7fd 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -198,7 +198,6 @@ def test_stencil_factory_numpy_comparison_from_origin_domain_2d(enabled: bool): west_edge=True, east_edge=True, ) - dims = ["X_DIM", "Y_DIM", "Z_DIM"] origin, domain = indexing.get_2d_compute_origin_domain(klevel=1) factory = StencilFactory(config=config, grid_indexing=indexing) stencil = factory.from_origin_domain( From 049d56fcc7652cc62060a4038eb835baa6b208f8 Mon Sep 17 00:00:00 2001 From: Oliver Elbert Date: Tue, 11 Jun 2024 16:42:18 -0400 Subject: [PATCH 4/4] lint --- tests/dsl/test_stencil_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index a21d7fd..b281bfb 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -173,7 +173,7 @@ def test_stencil_factory_numpy_comparison_from_origin_domain(enabled: bool): assert isinstance(stencil, CompareToNumpyStencil) else: assert isinstance(stencil, FrozenStencil) - + @pytest.mark.parametrize("enabled", [True, False]) def test_stencil_factory_numpy_comparison_from_origin_domain_2d(enabled: bool):