Skip to content

Commit

Permalink
Merge branch 'feature/dsl_patterns'
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed Aug 13, 2024
2 parents 1fbc72a + 42d5e9b commit 6ec19a3
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 2 deletions.
4 changes: 2 additions & 2 deletions dsl_patterns/Do__get_top_of_the_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
"""

from gt4py.cartesian.gtscript import computation, interval, PARALLEL, FORWARD
from ndsl.boilerplate import get_factories_single_tile_orchestrated_cpu
from ndsl.boilerplate import get_factories_single_tile_numpy
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.typing import FloatField, FloatFieldIJ
from ndsl import StencilFactory, QuantityFactory, orchestrate
import numpy as np

domain = (3, 3, 4)

stcil_fctry, ijk_qty_fctry = get_factories_single_tile_orchestrated_cpu(
stcil_fctry, ijk_qty_fctry = get_factories_single_tile_numpy(
domain[0], domain[1], domain[2], 0
)

Expand Down
62 changes: 62 additions & 0 deletions dsl_patterns/Do__while_in_gt_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
""" While loops in stencil functions
Last update: 2024/07/24
Description: gt4py allows while loop patterns in stencils. This pattern is also allowed to be
used in `gtscript.function`.
"""

from gt4py.cartesian.gtscript import computation, interval, PARALLEL, function
from ndsl.boilerplate import get_factories_single_tile_numpy
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.typing import FloatField
from ndsl import StencilFactory, QuantityFactory, orchestrate
import numpy as np

domain = (3, 3, 4)

stcil_fctry, ijk_qty_fctry = get_factories_single_tile_numpy(
domain[0], domain[1], domain[2], 0
)


@function
def while_in_function(field: FloatField):
lev = 0
while field[0, 0, lev] < 4:
lev += 1
return lev


def stencil(in_field: FloatField, out_field: FloatField):
with computation(PARALLEL), interval(...):
out_field = while_in_function(in_field)


class Code:
def __init__(
self,
stencil_factory: StencilFactory,
qty_fctry: QuantityFactory,
):
orchestrate(obj=self, config=stencil_factory.config.dace_config)
self._tmp = qty_fctry.zeros([X_DIM, Y_DIM], "n/a")
self.stencil = stcil_fctry.from_dims_halo(
func=stencil,
compute_dims=[X_DIM, Y_DIM, Z_DIM],
)

def __call__(self, I: FloatField, O: FloatField):
self.stencil(I, O)


if __name__ == "__main__":
I = np.ones(domain[0] * domain[1] * domain[2], dtype=np.float64).reshape(domain)
I[:, :, domain[2] - 1] = 42
O = np.zeros(domain)

code = Code(stcil_fctry, ijk_qty_fctry)
code(I, O)

print(f"Input:\n{I}\n")
print(f"Output:\n{O}\n")
assert (O[0, 0, :] == [3.0, 2.0, 1.0, 0.0]).all()
90 changes: 90 additions & 0 deletions dsl_patterns/WIP__hybrid_index_2dout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
""" Access K as an absolute index, while retaining I/J as relative index and output to 2D.
Last update: 2024/07/25
Description: gt4py doesn't allow for direct indexing in K while doing relative indexing in the
other cartesian dimensions. We use a masking technique in stencil to copy the relative data.
Additionally a 2D output means we are barred to use PARALLEL and have to use FORWARD (a potential harsh limiter)
WARNING: this code returns O as a 2D Field
Fortran equivalent:
```fortran
do L=1,LM
do J=1,JM
do I=1,IM
...
PLmb(i,j, KLCL(I,J))
...
```
as seen in https://github.com/GEOS-ESM/GEOSgcm_GridComp/blob/db55c301840d98b788b0e17045510af726c0f555/GEOSagcm_GridComp/GEOSphysics_GridComp/GEOSmoist_GridComp/GEOS_GFDL_1M_InterfaceMod.F90#L589
"""

from gt4py.cartesian.gtscript import computation, interval, FORWARD
from ndsl.boilerplate import get_factories_single_tile_numpy
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.typing import FloatField, FloatFieldIJ, IntFieldIJ, IntField
from ndsl import StencilFactory, QuantityFactory, orchestrate
import numpy as np

domain = (3, 3, 4)
nhalo = 0
stcil_fctry, ijk_qty_fctry = get_factories_single_tile_numpy(
domain[0], domain[1], domain[2], nhalo
)


def stencil(
data_field: FloatField,
k_mask: FloatField,
k_index_desired: FloatFieldIJ,
out_field: FloatFieldIJ,
):
with computation(FORWARD), interval(...):
if k_mask == k_index_desired:
out_field = data_field


class Code:
def __init__(
self,
stencil_factory: StencilFactory,
qty_fctry: QuantityFactory,
):
orchestrate(obj=self, config=stencil_factory.config.dace_config)
self.stencil = stcil_fctry.from_dims_halo(
func=stencil,
compute_dims=[X_DIM, Y_DIM, Z_DIM],
)
self.O = qty_fctry.zeros([X_DIM, Y_DIM], "n/a")

def __call__(
self,
# out_field: FloatField,
k_mask: IntField,
k_index_desired: IntFieldIJ,
data_field: IntField,
):
self.stencil(data_field, k_mask, k_index_desired, self.O)


if __name__ == "__main__":
k_mask = ijk_qty_fctry.zeros([X_DIM, Y_DIM, Z_DIM], "n/a")

k_index_desired = ijk_qty_fctry.zeros([X_DIM, Y_DIM], "n/a")
k_index_desired.view[:, :] = np.random.randint(
0, domain[2], size=(domain[0], domain[1])
)

input_to_sample_from = ijk_qty_fctry.zeros([X_DIM, Y_DIM, Z_DIM], "n/a")
input_to_sample_from.view[:, :, :] = np.random.randint(800, 900, size=domain)

for i in range(0, domain[0]):
for j in range(0, domain[1]):
for k in range(0, domain[2]):
k_mask.view[i, j, k] = k

code = Code(stcil_fctry, ijk_qty_fctry)
code(k_mask, k_index_desired, input_to_sample_from)

print(f"Mask:\n{k_mask.view[:,:,:]}\n")
print(f"K Level Desired for each column:\n{k_index_desired.view[:,:]}\n")
print(f"Input to sample from:\n{input_to_sample_from.view[:,:,:]}\n")
print(f"Output:\n{code.O.view[:,:]}\n")

0 comments on commit 6ec19a3

Please sign in to comment.