diff --git a/dsl_patterns/Do__get_top_of_the_column.py b/dsl_patterns/Do__get_top_of_the_column.py index c0d2ccf..d0ceab8 100644 --- a/dsl_patterns/Do__get_top_of_the_column.py +++ b/dsl_patterns/Do__get_top_of_the_column.py @@ -17,7 +17,7 @@ """ from gt4py.cartesian.gtscript import computation, interval, PARALLEL, FORWARD -from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu +from ndsl.boilerplate import get_factories_single_tile_numpy from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.typing import FloatField, FloatFieldIJ from ndsl import StencilFactory, QuantityFactory, orchestrate @@ -25,7 +25,7 @@ domain = (3, 3, 4) -stcil_fctry, ijk_qty_fctry = get_factories_single_tile_orchestrated_cpu( +stcil_fctry, ijk_qty_fctry = get_factories_single_tile_numpy( domain[0], domain[1], domain[2], 0 ) diff --git a/dsl_patterns/Do__while_in_gt_functions.py b/dsl_patterns/Do__while_in_gt_functions.py new file mode 100644 index 0000000..696dbfd --- /dev/null +++ b/dsl_patterns/Do__while_in_gt_functions.py @@ -0,0 +1,62 @@ +""" While loops in stencil functions + +Last update: 2024/07/24 +Description: gt4py allows while loop patterns in stencils. This pattern is also allowed to be + used in `gtscript.function`. +""" + +from gt4py.cartesian.gtscript import computation, interval, PARALLEL, function +from ndsl.boilerplate import get_factories_single_tile_numpy +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.typing import FloatField +from ndsl import StencilFactory, QuantityFactory, orchestrate +import numpy as np + +domain = (3, 3, 4) + +stcil_fctry, ijk_qty_fctry = get_factories_single_tile_numpy( + domain[0], domain[1], domain[2], 0 +) + + +@function +def while_in_function(field: FloatField): + lev = 0 + while field[0, 0, lev] < 4: + lev += 1 + return lev + + +def stencil(in_field: FloatField, out_field: FloatField): + with computation(PARALLEL), interval(...): + out_field = while_in_function(in_field) + + +class Code: + def __init__( + self, + stencil_factory: StencilFactory, + qty_fctry: QuantityFactory, + ): + orchestrate(obj=self, config=stencil_factory.config.dace_config) + self._tmp = qty_fctry.zeros([X_DIM, Y_DIM], "n/a") + self.stencil = stcil_fctry.from_dims_halo( + func=stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + + def __call__(self, I: FloatField, O: FloatField): + self.stencil(I, O) + + +if __name__ == "__main__": + I = np.ones(domain[0] * domain[1] * domain[2], dtype=np.float64).reshape(domain) + I[:, :, domain[2] - 1] = 42 + O = np.zeros(domain) + + code = Code(stcil_fctry, ijk_qty_fctry) + code(I, O) + + print(f"Input:\n{I}\n") + print(f"Output:\n{O}\n") + assert (O[0, 0, :] == [3.0, 2.0, 1.0, 0.0]).all() diff --git a/dsl_patterns/WIP__hybrid_index_2dout.py b/dsl_patterns/WIP__hybrid_index_2dout.py new file mode 100644 index 0000000..78d0a23 --- /dev/null +++ b/dsl_patterns/WIP__hybrid_index_2dout.py @@ -0,0 +1,90 @@ +""" Access K as an absolute index, while retaining I/J as relative index and output to 2D. + +Last update: 2024/07/25 +Description: gt4py doesn't allow for direct indexing in K while doing relative indexing in the + other cartesian dimensions. We use a masking technique in stencil to copy the relative data. + Additionally a 2D output means we are barred to use PARALLEL and have to use FORWARD (a potential harsh limiter) + WARNING: this code returns O as a 2D Field +Fortran equivalent: +```fortran +do L=1,LM + do J=1,JM + do I=1,IM + ... + PLmb(i,j, KLCL(I,J)) + ... +``` +as seen in https://github.com/GEOS-ESM/GEOSgcm_GridComp/blob/db55c301840d98b788b0e17045510af726c0f555/GEOSagcm_GridComp/GEOSphysics_GridComp/GEOSmoist_GridComp/GEOS_GFDL_1M_InterfaceMod.F90#L589 +""" + +from gt4py.cartesian.gtscript import computation, interval, FORWARD +from ndsl.boilerplate import get_factories_single_tile_numpy +from ndsl.constants import X_DIM, Y_DIM, Z_DIM +from ndsl.dsl.typing import FloatField, FloatFieldIJ, IntFieldIJ, IntField +from ndsl import StencilFactory, QuantityFactory, orchestrate +import numpy as np + +domain = (3, 3, 4) +nhalo = 0 +stcil_fctry, ijk_qty_fctry = get_factories_single_tile_numpy( + domain[0], domain[1], domain[2], nhalo +) + + +def stencil( + data_field: FloatField, + k_mask: FloatField, + k_index_desired: FloatFieldIJ, + out_field: FloatFieldIJ, +): + with computation(FORWARD), interval(...): + if k_mask == k_index_desired: + out_field = data_field + + +class Code: + def __init__( + self, + stencil_factory: StencilFactory, + qty_fctry: QuantityFactory, + ): + orchestrate(obj=self, config=stencil_factory.config.dace_config) + self.stencil = stcil_fctry.from_dims_halo( + func=stencil, + compute_dims=[X_DIM, Y_DIM, Z_DIM], + ) + self.O = qty_fctry.zeros([X_DIM, Y_DIM], "n/a") + + def __call__( + self, + # out_field: FloatField, + k_mask: IntField, + k_index_desired: IntFieldIJ, + data_field: IntField, + ): + self.stencil(data_field, k_mask, k_index_desired, self.O) + + +if __name__ == "__main__": + k_mask = ijk_qty_fctry.zeros([X_DIM, Y_DIM, Z_DIM], "n/a") + + k_index_desired = ijk_qty_fctry.zeros([X_DIM, Y_DIM], "n/a") + k_index_desired.view[:, :] = np.random.randint( + 0, domain[2], size=(domain[0], domain[1]) + ) + + input_to_sample_from = ijk_qty_fctry.zeros([X_DIM, Y_DIM, Z_DIM], "n/a") + input_to_sample_from.view[:, :, :] = np.random.randint(800, 900, size=domain) + + for i in range(0, domain[0]): + for j in range(0, domain[1]): + for k in range(0, domain[2]): + k_mask.view[i, j, k] = k + + code = Code(stcil_fctry, ijk_qty_fctry) + code(k_mask, k_index_desired, input_to_sample_from) + + print(f"Mask:\n{k_mask.view[:,:,:]}\n") + print(f"K Level Desired for each column:\n{k_index_desired.view[:,:]}\n") + print(f"Input to sample from:\n{input_to_sample_from.view[:,:,:]}\n") + print(f"Output:\n{code.O.view[:,:]}\n")