Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Orchestrated] Signature & empty code issues #70

Open
FlorianDeconinck opened this issue Sep 5, 2024 · 3 comments
Open

[Orchestrated] Signature & empty code issues #70

FlorianDeconinck opened this issue Sep 5, 2024 · 3 comments

Comments

@FlorianDeconinck
Copy link
Collaborator

Description

There's a series of bug related to signature at GT level and declared symbols/scalar/fields at DaCe wrapper SDFG. All of those bugs live in the bridge between gt4py & dace.

Those can be classified in 3 groups:

  • Scalar parameter unused are lost in translation and end up as a "missing symbol" in the SDFG
  • When code culling (because of regions) becomes so efficient that the stencil is removed, the parameters becomes unused. Same bug as above but coming with a slightly different code path.
  • To fix both bugs above, we can turn the parameters to be "Scalar" instead of "Symbol", but this create another bug where unused Field are silently being removed which leads to the argument being passed in the wrong order!

Most of those behavior are linked to the prune_unused_argument pass of GT4Py which is called at the very beginning of GTIR. While this is clearly not the design (passes should be pushed down to OIR or backend IR) this was done to deal with some of those issues. Plain removing the prune pass (which could be done considering it gives little to no performance improvement) does not lead to fixing.

In the comments below we will put down 3 examples that showcase the issues (either plain ndsl or relying on pyfv3) and some patches that fixes some bugs but creates other.

To Reproduce
See comment.

@FlorianDeconinck
Copy link
Collaborator Author

Test file to be dropped in ndsl (when the reference to pyFV3 is removed)

from ndsl.stencils.corners import fill_corners_dgrid_defn
from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu
from ndsl.constants import X_DIM, Y_DIM, Z_DIM, X_INTERFACE_DIM, Y_INTERFACE_DIM
from ndsl.dsl.typing import Float, FloatField
from ndsl import orchestrate, StencilFactory, DaceConfig
from gt4py.cartesian.gtscript import computation, PARALLEL, interval


class OrchestratedCorner:
    def __init__(self, stencil_factory: StencilFactory) -> None:
        orchestrate(
            obj=self,
            config=stencil_factory.config.dace_config
            or DaceConfig(communicator=None, backend=stencil_factory.backend),
        )
        origin, domain = stencil_factory.grid_indexing.get_origin_domain(
            dims=[X_DIM, Y_DIM, Z_DIM]
        )
        axes_offsets = stencil_factory.grid_indexing.axis_offsets(origin, domain)

        self.corner_stencil = stencil_factory.from_origin_domain(
            fill_corners_dgrid_defn,
            externals=axes_offsets,
            origin=origin,
            domain=domain,
        )

    def __call__(self, x, y):
        self.corner_stencil(x, x, y, y, 1.0)


def test_empty_corners():
    stencil_factory, quantity_factory = get_factories_single_tile_orchestrated_cpu(
        12, 12, 5, 0
    )
    # Make the
    stencil_factory.grid_indexing.south_edge = False
    stencil_factory.grid_indexing.north_edge = False
    stencil_factory.grid_indexing.west_edge = False
    stencil_factory.grid_indexing.east_edge = False
    stencil_factory.grid_indexing.axis_offsets

    x = quantity_factory.empty(dims=[X_INTERFACE_DIM, Y_DIM, Z_DIM], units="n/a")
    y = quantity_factory.empty(dims=[X_DIM, Y_INTERFACE_DIM, Z_DIM], units="n/a")

    orch_corner = OrchestratedCorner(stencil_factory)

    orch_corner(x, y)


def unusued_parameter_stencil(
    field: FloatField,  # type: ignore
    result: FloatField,  # type: ignore
    weight: Float,  # type: ignore
):
    with computation(PARALLEL), interval(...):
        result = field[1, 0, 0] + field[0, 1, 0] + field[-1, 0, 0] + field[0, -1, 0]


class OrchestratedUnusedParameter:
    def __init__(self, stencil_factory: StencilFactory):
        orchestrate(
            obj=self,
            config=stencil_factory.config.dace_config
            or DaceConfig(communicator=None, backend=stencil_factory.backend),
        )
        origin, domain = stencil_factory.grid_indexing.get_origin_domain(
            dims=[X_DIM, Y_DIM, Z_DIM]
        )

        self.unused_stencil = stencil_factory.from_origin_domain(
            unusued_parameter_stencil,
            origin=origin,
            domain=domain,
        )

    def __call__(self, x, y):
        self.unused_stencil(x, y, 1.0)


def test_unused_parameters():
    stencil_factory, quantity_factory = get_factories_single_tile_orchestrated_cpu(
        12, 12, 5, 2
    )

    x = quantity_factory.empty(dims=[X_INTERFACE_DIM, Y_DIM, Z_DIM], units="n/a")
    y = quantity_factory.empty(dims=[X_DIM, Y_INTERFACE_DIM, Z_DIM], units="n/a")

    orch_unused = OrchestratedUnusedParameter(stencil_factory)

    orch_unused(x, y)


def unusued_field_stencil(
    field: FloatField,  # type: ignore
    other_field: FloatField,  # type: ignore
    result: FloatField,  # type: ignore
):
    with computation(PARALLEL), interval(...):
        result = field[1, 0, 0] + field[0, 1, 0] + field[-1, 0, 0] + field[0, -1, 0]


class OrchestratedunusedField:
    def __init__(self, stencil_factory: StencilFactory):
        orchestrate(
            obj=self,
            config=stencil_factory.config.dace_config
            or DaceConfig(communicator=None, backend=stencil_factory.backend),
        )
        origin, domain = stencil_factory.grid_indexing.get_origin_domain(
            dims=[X_DIM, Y_DIM, Z_DIM]
        )

        self.unused_stencil = stencil_factory.from_origin_domain(
            unusued_field_stencil,
            origin=origin,
            domain=domain,
        )

    def __call__(self, x, unused_field, y):
        self.unused_stencil(
            x,
            unused_field,
            y,
        )


def test_unused_field():
    stencil_factory, quantity_factory = get_factories_single_tile_orchestrated_cpu(
        12, 12, 5, 2
    )

    x = quantity_factory.empty(dims=[X_INTERFACE_DIM, Y_DIM, Z_DIM], units="n/a")
    x_unused = quantity_factory.empty(dims=[X_INTERFACE_DIM, Y_DIM, Z_DIM], units="n/a")
    y = quantity_factory.empty(dims=[X_DIM, Y_INTERFACE_DIM, Z_DIM], units="n/a")

    orch_unused = OrchestratedunusedField(stencil_factory)

    orch_unused(x, x_unused, y)


if __name__ == "__main__":
    test_unused_parameters()
    test_empty_corners()
    test_unused_field()

@NOAA-GFDL NOAA-GFDL deleted a comment Sep 5, 2024
@FlorianDeconinck
Copy link
Collaborator Author

Patches to be applied

diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py
index 7608fcd5..e4699b3b 100644
--- a/src/gt4py/cartesian/backend/dace_backend.py
+++ b/src/gt4py/cartesian/backend/dace_backend.py
@@ -234,6 +234,24 @@ def _sdfg_add_arrays_and_edges(
                     None,
                     dace.Memlet(name, subset=dace.subsets.Range(ranges)),
                 )
+        elif isinstance(array, dace.data.Scalar):
+            wrapper_sdfg.add_scalar(name, dtype=array.dtype, storage=array.storage)
+            if name in inputs:
+                state.add_edge(
+                    state.add_read(name),
+                    None,
+                    nsdfg,
+                    name,
+                    dace.Memlet(name),
+                )
+            if name in outputs:
+                state.add_edge(
+                    nsdfg,
+                    name,
+                    state.add_write(name),
+                    None,
+                    dace.Memlet(name),
+                )
 
 
 def _sdfg_specialize_symbols(wrapper_sdfg, domain: Tuple[int, ...]):
diff --git a/src/gt4py/cartesian/backend/dace_lazy_stencil.py b/src/gt4py/cartesian/backend/dace_lazy_stencil.py
index 2b3cf6fe..0c614ad8 100644
--- a/src/gt4py/cartesian/backend/dace_lazy_stencil.py
+++ b/src/gt4py/cartesian/backend/dace_lazy_stencil.py
@@ -15,6 +15,7 @@ from gt4py.cartesian.backend.dace_backend import SDFGManager
 from gt4py.cartesian.backend.dace_stencil_object import DaCeStencilObject, add_optional_fields
 from gt4py.cartesian.backend.module_generator import make_args_data_from_gtir
 from gt4py.cartesian.lazy_stencil import LazyStencil
+from gt4py.cartesian.gtc.passes.gtir_prune_unused_parameters import prune_unused_parameters
 
 
 if TYPE_CHECKING:
@@ -26,6 +27,7 @@ class DaCeLazyStencil(LazyStencil, SDFGConvertible):
         if "dace" not in builder.backend.name:
             raise ValueError("Trying to build a DaCeLazyStencil for non-dace backend.")
         super().__init__(builder=builder)
+        self.signature = []
 
     @property
     def field_info(self) -> Dict[str, Any]:
@@ -47,7 +49,8 @@ class DaCeLazyStencil(LazyStencil, SDFGConvertible):
     def __sdfg__(self, *args, **kwargs) -> dace.SDFG:
         sdfg_manager = SDFGManager(self.builder)
         args_data = make_args_data_from_gtir(self.builder.gtir_pipeline)
-        arg_names = [arg.name for arg in self.builder.gtir.api_signature]
+        assert self.signature != []
+        arg_names = self.signature
         assert args_data.domain_info is not None
         norm_kwargs = DaCeStencilObject.normalize_args(
             *args,
@@ -69,5 +72,9 @@ class DaCeLazyStencil(LazyStencil, SDFGConvertible):
         return {}
 
     def __sdfg_signature__(self) -> Tuple[Sequence[str], Sequence[str]]:
-        args = [arg.name for arg in self.builder.gtir.api_signature]
-        return (args, [])
+        if self.signature == []:
+            self.signature = [
+                str(p)
+                for p in self.builder.gtir_pipeline.apply([prune_unused_parameters]).param_names
+            ]
+        return (self.signature, [])
diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py
index dba6c5a7..9dd57290 100644
--- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py
+++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py
@@ -150,7 +150,7 @@ class OirSDFGBuilder(eve.NodeVisitor):
                     debuginfo=dace.DebugInfo(0),
                 )
             else:
-                ctx.sdfg.add_symbol(param.name, stype=data_type_to_dace_typeclass(param.dtype))
+                ctx.sdfg.add_scalar(param.name, dtype=data_type_to_dace_typeclass(param.dtype))
 
         for decl in node.declarations:
             dim_strs = [d for i, d in enumerate("IJK") if decl.dimensions[i]] + [

@FlorianDeconinck
Copy link
Collaborator Author

Working solution seems to be updating the symbol_mapping of the library node pre-expansion (StencilComputation)

Branch under test: https://github.com/FlorianDeconinck/gt4py/tree/cartesian/fix/missing_parameter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant