From 17d17fbea1126cc3c6e32f761a1598d47741dde8 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 6 Mar 2024 15:48:50 -0500 Subject: [PATCH 01/11] Added basic_operations.py from pyFV3 to ndsl/stencils --- ndsl/stencils/basic_operations.py | 46 +++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 ndsl/stencils/basic_operations.py diff --git a/ndsl/stencils/basic_operations.py b/ndsl/stencils/basic_operations.py new file mode 100644 index 00000000..55307c5c --- /dev/null +++ b/ndsl/stencils/basic_operations.py @@ -0,0 +1,46 @@ +import gt4py.cartesian.gtscript as gtscript +from gt4py.cartesian.gtscript import PARALLEL, computation, interval + +from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ + + +def copy_defn(q_in: FloatField, q_out: FloatField): + """Copy q_in to q_out. + + Args: + q_in: input field + q_out: output field + """ + with computation(PARALLEL), interval(...): + q_out = q_in + + +def adjustmentfactor_stencil_defn(adjustment: FloatFieldIJ, q_out: FloatField): + with computation(PARALLEL), interval(...): + q_out = q_out * adjustment + + +def set_value_defn(q_out: FloatField, value: Float): + with computation(PARALLEL), interval(...): + q_out = value + + +def adjust_divide_stencil(adjustment: FloatField, q_out: FloatField): + with computation(PARALLEL), interval(...): + q_out = q_out / adjustment + + +@gtscript.function +def sign(a, b): + asignb = abs(a) + if b > 0: + asignb = asignb + else: + asignb = -asignb + return asignb + + +@gtscript.function +def dim(a, b): + diff = a - b if a - b > 0 else 0 + return diff From 9f9e25c4df67fcacce435dc40663f91c2eba47d4 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 20 Mar 2024 16:07:22 -0400 Subject: [PATCH 02/11] Added rough draft of unit tests --- tests/test_basic_operations.py | 205 +++++++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) create mode 100644 tests/test_basic_operations.py diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py new file mode 100644 index 00000000..83b9f817 --- /dev/null +++ b/tests/test_basic_operations.py @@ -0,0 +1,205 @@ +from gt4py.storage import full, ones, zeros + +from ndsl import ( + CompilationConfig, + DaceConfig, + DaCeOrchestration, + GridIndexing, + RunMode, + StencilConfig, + StencilFactory, +) +from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ +from ndsl.stencils import basic_operations as basic + + +nx = 20 +ny = 20 +nz = 79 +nhalo = 3 +backend = "numpy" + +dace_config = DaceConfig( + communicator=None, backend=backend, orchestration=DaCeOrchestration.Python +) + +compilation_config = CompilationConfig( + backend=backend, + rebuild=True, + validate_args=True, + format_source=False, + device_sync=False, + run_mode=RunMode.BuildAndRun, + use_minimal_caching=False, +) + +stencil_config = StencilConfig( + compare_to_numpy=False, + compilation_config=compilation_config, + dace_config=dace_config, +) + +grid_indexing = GridIndexing( + domain=(nx, ny, nz), + n_halo=nhalo, + south_edge=True, + north_edge=True, + west_edge=True, + east_edge=True, +) + +stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing) + + +class Copy: + def __init__(self, stencil_factory: StencilFactory): + grid_indexing = stencil_factory.grid_indexing + self._copy_stencil = stencil_factory.from_origin_domain( + basic.copy_defn, + origin=grid_indexing.origin_compute(), + domain=grid_indexing.domain_compute(), + ) + + def __call__( + self, + f_in: FloatField, + f_out: FloatField, + ): + self._copy_stencil(f_in, f_out) + + +class AdjustmentFactor: + def __init__(self, stencil_factory: StencilFactory): + grid_indexing = stencil_factory.grid_indexing + self._adjustmentfactor_stencil = stencil_factory.from_origin_domain( + basic.adjustmentfactor_stencil_defn, + origin=grid_indexing.origin_compute(), + domain=grid_indexing.domain_compute(), + ) + + def __call__( + self, + factor: FloatFieldIJ, + f_out: FloatField, + ): + self._adjustmentfactor_stencil(factor, f_out) + + +class SetValue: + def __init__(self, stencil_factory: StencilFactory): + grid_indexing = stencil_factory.grid_indexing + self._set_value_stencil = stencil_factory.from_origin_domain( + basic.set_value_defn, + origin=grid_indexing.origin_compute(), + domain=grid_indexing.domain_compute(), + ) + + def __call__( + self, + f_out: FloatField, + value: Float, + ): + self._set_value_stencil(f_out, value) + + +class AdjustDivide: + def __init__(self, stencil_factory: StencilFactory): + grid_indexing = stencil_factory.grid_indexing + self._adjust_divide_stencil = stencil_factory.from_origin_domain( + basic.adjust_divide_stencil, + origin=grid_indexing.origin_compute(), + domain=grid_indexing.domain_compute(), + ) + + def __call__( + self, + factor: FloatField, + f_out: FloatField, + ): + self._adjust_divide_stencil(factor, f_out) + + +def test_copy(): + copy = Copy(stencil_factory) + + infield = zeros( + backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) + ) + + outfield = ones( + backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) + ) + + copy(f_in=infield, f_out=outfield) + + assert infield.all() == outfield.all() + + +def test_adjustmentfactor(): + adfact = AdjustmentFactor(stencil_factory) + + factorfield = ones( + backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo) + ) + + outfield = ones( + backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) + ) + + testfield = full( + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo), + fill_value=26.0, + ) + + adfact(factor=factorfield, f_out=outfield) + assert outfield.any() == testfield.any() + + +def test_setvalue(): + setvalue = SetValue(stencil_factory) + + outfield = zeros( + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + ) + + testfield = full( + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + fill_value=2.0, + ) + + setvalue(f_out=outfield, value=2.0) + + assert outfield.any() == testfield.any() + + +def test_adjustdivide(): + addiv = AdjustDivide(stencil_factory) + + factorfield = full( + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + fill_value=2.0, + ) + + outfield = ones( + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + ) + + testfield = full( + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo), + fill_value=13.0, + ) + + addiv(factor=factorfield, f_out=outfield) + assert outfield.any() == testfield.any() From a15accc632a0150aa1f38bb2f724ca874d4e1aa1 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 27 Mar 2024 15:17:35 -0400 Subject: [PATCH 03/11] Updated unit test --- tests/test_basic_operations.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py index 83b9f817..f90d7942 100644 --- a/tests/test_basic_operations.py +++ b/tests/test_basic_operations.py @@ -13,10 +13,10 @@ from ndsl.stencils import basic_operations as basic -nx = 20 -ny = 20 -nz = 79 -nhalo = 3 +nx = 5 +ny = 5 +nz = 1 +nhalo = 0 backend = "numpy" dace_config = DaceConfig( @@ -122,35 +122,35 @@ def __call__( def test_copy(): copy = Copy(stencil_factory) - infield = zeros( + infield = ones( backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) ) - outfield = ones( + outfield = zeros( backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) ) copy(f_in=infield, f_out=outfield) - assert infield.all() == outfield.all() + assert infield.any() == outfield.any() def test_adjustmentfactor(): adfact = AdjustmentFactor(stencil_factory) - factorfield = ones( - backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo) + factorfield = full( + backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo), fill_value=2.0 ) - outfield = ones( - backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) + outfield = full( + backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), fill_value=2.0 ) testfield = full( backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo), - fill_value=26.0, + fill_value=4.0, ) adfact(factor=factorfield, f_out=outfield) @@ -188,18 +188,20 @@ def test_adjustdivide(): fill_value=2.0, ) - outfield = ones( + outfield = full( backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + fill_value=4.0, ) testfield = full( backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo), - fill_value=13.0, + fill_value=2.0, ) - + addiv(factor=factorfield, f_out=outfield) + assert outfield.any() == testfield.any() From 330e52d974cd5fe449fc4c693a8228a840994408 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 27 Mar 2024 15:19:38 -0400 Subject: [PATCH 04/11] Linting --- tests/test_basic_operations.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py index f90d7942..1ca1dd43 100644 --- a/tests/test_basic_operations.py +++ b/tests/test_basic_operations.py @@ -139,11 +139,17 @@ def test_adjustmentfactor(): adfact = AdjustmentFactor(stencil_factory) factorfield = full( - backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo), fill_value=2.0 + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo), + fill_value=2.0, ) outfield = full( - backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), fill_value=2.0 + backend=backend, + dtype=Float, + shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + fill_value=2.0, ) testfield = full( @@ -201,7 +207,7 @@ def test_adjustdivide(): shape=(nx + 2 * nhalo, ny + 2 * nhalo), fill_value=2.0, ) - + addiv(factor=factorfield, f_out=outfield) assert outfield.any() == testfield.any() From 5e4daf4a59b463c6e112244bc02c2d56fc5972f8 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Thu, 28 Mar 2024 15:35:33 -0400 Subject: [PATCH 05/11] Updated docstrings of basic_operations and using Quantity instead of gtstorage for testing --- ndsl/stencils/basic_operations.py | 52 ++++++++++++++- tests/test_basic_operations.py | 103 +++++++++++++++--------------- 2 files changed, 102 insertions(+), 53 deletions(-) diff --git a/ndsl/stencils/basic_operations.py b/ndsl/stencils/basic_operations.py index 55307c5c..b46123a3 100644 --- a/ndsl/stencils/basic_operations.py +++ b/ndsl/stencils/basic_operations.py @@ -5,7 +5,8 @@ def copy_defn(q_in: FloatField, q_out: FloatField): - """Copy q_in to q_out. + """ + Copy q_in to q_out. Args: q_in: input field @@ -16,22 +17,62 @@ def copy_defn(q_in: FloatField, q_out: FloatField): def adjustmentfactor_stencil_defn(adjustment: FloatFieldIJ, q_out: FloatField): + """ + Multiplies every element of q_out + by every element of the adjustment + field over the interval, replacing + the elements of q_out by the result + of the multiplication. + + Args: + adjustment: adjustment field + q_out: output field + """ with computation(PARALLEL), interval(...): q_out = q_out * adjustment def set_value_defn(q_out: FloatField, value: Float): + """ + Sets every element of q_out to the + value specified by value argument. + + Args: + q_out: output field + value: NDSL Float type + """ with computation(PARALLEL), interval(...): q_out = value def adjust_divide_stencil(adjustment: FloatField, q_out: FloatField): + """ + Divides every element of q_out + by every element of the adjustment + field over the interval, replacing + the elements of q_out by the result + of the multiplication. + + Args: + adjustment: adjustment field + q_out: output field + """ with computation(PARALLEL), interval(...): q_out = q_out / adjustment @gtscript.function def sign(a, b): + """ + Defines asignb as the absolute value + of a, and checks if b is positive + or negative, assigning the analogus + sign value to asignb. asignb is returned + + Args: + a: A number + b: A number + """ asignb = abs(a) if b > 0: asignb = asignb @@ -42,5 +83,14 @@ def sign(a, b): @gtscript.function def dim(a, b): + """ + Performs a check on the difference + between the values in arguments + a and b. The variable diff is set + to the difference between a and b + when the difference is positive, + otherwise it is set to zero. The + function returns the diff variable. + """ diff = a - b if a - b > 0 else 0 return diff diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py index 1ca1dd43..8fd87ad4 100644 --- a/tests/test_basic_operations.py +++ b/tests/test_basic_operations.py @@ -1,14 +1,16 @@ -from gt4py.storage import full, ones, zeros +import numpy as np from ndsl import ( CompilationConfig, DaceConfig, DaCeOrchestration, GridIndexing, + Quantity, RunMode, StencilConfig, StencilFactory, ) +from ndsl.constants import X_DIM, Y_DIM, Z_DIM from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ from ndsl.stencils import basic_operations as basic @@ -122,92 +124,89 @@ def __call__( def test_copy(): copy = Copy(stencil_factory) - infield = ones( - backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) + infield = Quantity( + data=np.zeros([5, 5, 1]), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - outfield = zeros( - backend=backend, dtype=Float, shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz) + outfield = Quantity( + data=np.ones([5, 5, 1]), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - copy(f_in=infield, f_out=outfield) + copy(f_in=infield.data, f_out=outfield.data) - assert infield.any() == outfield.any() + assert infield.data.any() == outfield.data.any() def test_adjustmentfactor(): adfact = AdjustmentFactor(stencil_factory) - factorfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo), - fill_value=2.0, + factorfield = Quantity( + data=np.full(shape=[5, 5], fill_value=2.0), + dims=[X_DIM, Y_DIM], + units="m", ) - outfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), - fill_value=2.0, + outfield = Quantity( + data=np.full(shape=[5, 5, 1], fill_value=2.0), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - testfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo), - fill_value=4.0, + testfield = Quantity( + data=np.full(shape=[5, 5, 1], fill_value=4.0), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - adfact(factor=factorfield, f_out=outfield) - assert outfield.any() == testfield.any() + adfact(factor=factorfield.data, f_out=outfield.data) + assert outfield.data.any() == testfield.data.any() def test_setvalue(): setvalue = SetValue(stencil_factory) - outfield = zeros( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), + outfield = Quantity( + data=np.zeros(shape=[5, 5, 1]), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - testfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), - fill_value=2.0, + testfield = Quantity( + data=np.full(shape=[5, 5, 1], fill_value=2.0), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - setvalue(f_out=outfield, value=2.0) + setvalue(f_out=outfield.data, value=2.0) - assert outfield.any() == testfield.any() + assert outfield.data.any() == testfield.data.any() def test_adjustdivide(): addiv = AdjustDivide(stencil_factory) - factorfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), - fill_value=2.0, + factorfield = Quantity( + data=np.full(shape=[5, 5, 1], fill_value=2.0), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - outfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo, nz), - fill_value=4.0, + outfield = Quantity( + data=np.full(shape=[5, 5, 1], fill_value=2.0), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - testfield = full( - backend=backend, - dtype=Float, - shape=(nx + 2 * nhalo, ny + 2 * nhalo), - fill_value=2.0, + testfield = Quantity( + data=np.full(shape=[5, 5, 1], fill_value=1.0), + dims=[X_DIM, Y_DIM, Z_DIM], + units="m", ) - addiv(factor=factorfield, f_out=outfield) + addiv(factor=factorfield.data, f_out=outfield.data) - assert outfield.any() == testfield.any() + assert outfield.data.any() == testfield.data.any() From 7f0d0f0c260eb15f353f970098803c33651d133a Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Fri, 29 Mar 2024 13:46:32 -0400 Subject: [PATCH 06/11] Updated logic of unit tests in test_basic_operations.py --- tests/test_basic_operations.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py index 8fd87ad4..e761b8b5 100644 --- a/tests/test_basic_operations.py +++ b/tests/test_basic_operations.py @@ -15,9 +15,9 @@ from ndsl.stencils import basic_operations as basic -nx = 5 -ny = 5 -nz = 1 +nx = 20 +ny = 20 +nz = 79 nhalo = 0 backend = "numpy" @@ -125,88 +125,88 @@ def test_copy(): copy = Copy(stencil_factory) infield = Quantity( - data=np.zeros([5, 5, 1]), + data=np.zeros([20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) outfield = Quantity( - data=np.ones([5, 5, 1]), + data=np.ones([20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) copy(f_in=infield.data, f_out=outfield.data) - assert infield.data.any() == outfield.data.any() + assert (infield.data == outfield.data).any() def test_adjustmentfactor(): adfact = AdjustmentFactor(stencil_factory) factorfield = Quantity( - data=np.full(shape=[5, 5], fill_value=2.0), + data=np.full(shape=[20, 20], fill_value=2.0), dims=[X_DIM, Y_DIM], units="m", ) outfield = Quantity( - data=np.full(shape=[5, 5, 1], fill_value=2.0), + data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) testfield = Quantity( - data=np.full(shape=[5, 5, 1], fill_value=4.0), + data=np.full(shape=[20, 20, 79], fill_value=4.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) adfact(factor=factorfield.data, f_out=outfield.data) - assert outfield.data.any() == testfield.data.any() + assert (outfield.data == testfield.data).any() def test_setvalue(): setvalue = SetValue(stencil_factory) outfield = Quantity( - data=np.zeros(shape=[5, 5, 1]), + data=np.zeros(shape=[20, 20, 79]), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) testfield = Quantity( - data=np.full(shape=[5, 5, 1], fill_value=2.0), + data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) setvalue(f_out=outfield.data, value=2.0) - assert outfield.data.any() == testfield.data.any() + assert (outfield.data == testfield.data).any() def test_adjustdivide(): addiv = AdjustDivide(stencil_factory) factorfield = Quantity( - data=np.full(shape=[5, 5, 1], fill_value=2.0), + data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) outfield = Quantity( - data=np.full(shape=[5, 5, 1], fill_value=2.0), + data=np.full(shape=[20, 20, 79], fill_value=2.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) testfield = Quantity( - data=np.full(shape=[5, 5, 1], fill_value=1.0), + data=np.full(shape=[20, 20, 79], fill_value=1.0), dims=[X_DIM, Y_DIM, Z_DIM], units="m", ) addiv(factor=factorfield.data, f_out=outfield.data) - assert outfield.data.any() == testfield.data.any() + assert (outfield.data == testfield.data).any() From b08654231a233571c37fd247db1d68cb927168b2 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Fri, 29 Mar 2024 15:05:01 -0400 Subject: [PATCH 07/11] Using .all() instead of .any() --- tests/test_basic_operations.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_basic_operations.py b/tests/test_basic_operations.py index e761b8b5..0d707240 100644 --- a/tests/test_basic_operations.py +++ b/tests/test_basic_operations.py @@ -138,7 +138,7 @@ def test_copy(): copy(f_in=infield.data, f_out=outfield.data) - assert (infield.data == outfield.data).any() + assert (infield.data == outfield.data).all() def test_adjustmentfactor(): @@ -163,7 +163,7 @@ def test_adjustmentfactor(): ) adfact(factor=factorfield.data, f_out=outfield.data) - assert (outfield.data == testfield.data).any() + assert (outfield.data == testfield.data).all() def test_setvalue(): @@ -183,7 +183,7 @@ def test_setvalue(): setvalue(f_out=outfield.data, value=2.0) - assert (outfield.data == testfield.data).any() + assert (outfield.data == testfield.data).all() def test_adjustdivide(): @@ -209,4 +209,4 @@ def test_adjustdivide(): addiv(factor=factorfield.data, f_out=outfield.data) - assert (outfield.data == testfield.data).any() + assert (outfield.data == testfield.data).all() From 3bf2a48bea58580adf72bf7c6a2daf451f46a023 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sat, 13 Apr 2024 13:26:15 -0400 Subject: [PATCH 08/11] Allow pressure coeff to be read from ak & bk data --- ndsl/grid/eta.py | 54 ++++++++++++++++++++++++----------------- ndsl/grid/generation.py | 18 +++++++++++--- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/ndsl/grid/eta.py b/ndsl/grid/eta.py index 19663fde..4d7473e9 100644 --- a/ndsl/grid/eta.py +++ b/ndsl/grid/eta.py @@ -1,6 +1,7 @@ import math import os from dataclasses import dataclass +from typing import Optional, Tuple import numpy as np import xarray as xr @@ -28,8 +29,25 @@ class HybridPressureCoefficients: bk: np.ndarray +def _load_ak_bk_from_file(eta_file: str) -> Tuple[np.ndarray, np.ndarray]: + if eta_file == "None": + raise ValueError("eta file not specified") + if not os.path.isfile(eta_file): + raise ValueError("file " + eta_file + " does not exist") + + # read file into ak, bk arrays + data = xr.open_dataset(eta_file) + ak = data["ak"].values + bk = data["bk"].values + + return ak, bk + + def set_hybrid_pressure_coefficients( - km: int, eta_file: str + km: int, + eta_file: str, + ak_data: Optional[np.ndarray] = None, + bk_data: Optional[np.ndarray] = None, ) -> HybridPressureCoefficients: """ Sets the coefficients describing the hybrid pressure coordinates. @@ -44,25 +62,19 @@ def set_hybrid_pressure_coefficients( Returns: a HybridPressureCoefficients dataclass """ - - if eta_file == "None": - raise ValueError("eta file not specified") - if not os.path.isfile(eta_file): - raise ValueError("file " + eta_file + " does not exist") - - # read file into ak, bk arrays - data = xr.open_dataset(eta_file) - ak = data["ak"].values - bk = data["bk"].values + if ak_data is None or bk_data is None: + ak, bk = _load_ak_bk_from_file(eta_file) + else: + ak, bk = ak_data, bk_data # check size of ak and bk array is km+1 if ak.size - 1 != km: - raise ValueError(f"size of ak array is not equal to km={km}") + raise ValueError(f"size of ak array {ak.size} is not equal to km+1={km+1}") if bk.size - 1 != km: - raise ValueError(f"size of bk array is not equal to km={km}") + raise ValueError(f"size of bk array {ak.size} is not equal to km+1={km+1}") # check that the eta values computed from ak and bk are monotonically increasing - eta, etav = check_eta(ak, bk) + eta, etav = _check_eta(ak, bk) if not np.all(eta[:-1] <= eta[1:]): raise ValueError("ETA values are not monotonically increasing") @@ -75,12 +87,10 @@ def set_hybrid_pressure_coefficients( else: raise ValueError("bk must contain at least one 0.") - pressure_data = HybridPressureCoefficients(ks, ptop, ak, bk) - - return pressure_data + return HybridPressureCoefficients(ks, ptop, ak, bk) -def vertical_coordinate(eta_value): +def _vertical_coordinate(eta_value) -> np.ndarray: """ Equation (1) JRMS2006 computes eta_v, the auxiliary variable vertical coordinate @@ -88,15 +98,15 @@ def vertical_coordinate(eta_value): return (eta_value - ETA_0) * math.pi * 0.5 -def compute_eta(ak, bk): +def _compute_eta(ak, bk) -> Tuple[np.ndarray, np.ndarray]: """ Equation (1) JRMS2006 eta is the vertical coordinate and eta_v is an auxiliary vertical coordinate """ eta = 0.5 * ((ak[:-1] + ak[1:]) / SURFACE_PRESSURE + bk[:-1] + bk[1:]) - eta_v = vertical_coordinate(eta) + eta_v = _vertical_coordinate(eta) return eta, eta_v -def check_eta(ak, bk): - return compute_eta(ak, bk) +def _check_eta(ak, bk): + return _compute_eta(ak, bk) diff --git a/ndsl/grid/generation.py b/ndsl/grid/generation.py index 12275d7d..275db563 100644 --- a/ndsl/grid/generation.py +++ b/ndsl/grid/generation.py @@ -1,7 +1,7 @@ import dataclasses import functools import warnings -from typing import Tuple +from typing import Optional, Tuple import numpy as np @@ -238,6 +238,8 @@ def __init__( deglat: float = 15.0, extdgrid: bool = False, eta_file: str = "None", + ak: Optional[np.ndarray] = None, + bk: Optional[np.ndarray] = None, ): self._grid_type = grid_type self._dx_const = dx_const @@ -300,7 +302,7 @@ def __init__( self._ptop, self._ak, self._bk, - ) = self._set_hybrid_pressure_coefficients(eta_file) + ) = self._set_hybrid_pressure_coefficients(eta_file, ak, bk) self._ec1 = None self._ec2 = None self._ew1 = None @@ -2147,7 +2149,12 @@ def _compute_area_c_cartesian(self): area_cgrid_64.data[:, :] = self._dx_const * self._dy_const return quantity_cast_to_model_float(self.quantity_factory, area_cgrid_64) - def _set_hybrid_pressure_coefficients(self, eta_file): + def _set_hybrid_pressure_coefficients( + self, + eta_file, + ak_data: Optional[np.ndarray] = None, + bk_data: Optional[np.ndarray] = None, + ): ks = self.quantity_factory.zeros( [], "", @@ -2169,7 +2176,10 @@ def _set_hybrid_pressure_coefficients(self, eta_file): dtype=Float, ) pressure_coefficients = eta.set_hybrid_pressure_coefficients( - self._npz, eta_file + self._npz, + eta_file, + ak_data, + bk_data, ) ks = pressure_coefficients.ks ptop = pressure_coefficients.ptop From 759c85272a0f2963dc2021354e71c39fb2d8edc3 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 16 Apr 2024 10:48:12 -0400 Subject: [PATCH 09/11] Freezing mpi4py to 3.1.6 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d81523d1..d54db73d 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ def local_pkg(name: str, relative_path: str) -> str: requirements: List[str] = [ local_pkg("gt4py", "external/gt4py"), local_pkg("dace", "external/dace"), - "mpi4py", + "mpi4py==3.1.5", "cftime", "xarray", "f90nml>=1.1.0", From 6a8b67436a33e0af8b437162f1edeb10432c8539 Mon Sep 17 00:00:00 2001 From: Frank Malatino Date: Wed, 17 Apr 2024 11:30:55 -0400 Subject: [PATCH 10/11] Added reference to RC in setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d54db73d..bb24006a 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,6 @@ def local_pkg(name: str, relative_path: str) -> str: packages=find_namespace_packages(include=["ndsl", "ndsl.*"]), include_package_data=True, url="https://github.com/NOAA-GFDL/NDSL", - version="2024.03.01", + version="2024.04.00-RC", zip_safe=False, ) From 044155b99352921c3fdb204e1d23416409410248 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 17 Apr 2024 12:29:58 -0400 Subject: [PATCH 11/11] Revert public API changes done in PR34 --- ndsl/grid/eta.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ndsl/grid/eta.py b/ndsl/grid/eta.py index 4d7473e9..90db8c4a 100644 --- a/ndsl/grid/eta.py +++ b/ndsl/grid/eta.py @@ -90,7 +90,7 @@ def set_hybrid_pressure_coefficients( return HybridPressureCoefficients(ks, ptop, ak, bk) -def _vertical_coordinate(eta_value) -> np.ndarray: +def vertical_coordinate(eta_value) -> np.ndarray: """ Equation (1) JRMS2006 computes eta_v, the auxiliary variable vertical coordinate @@ -98,15 +98,15 @@ def _vertical_coordinate(eta_value) -> np.ndarray: return (eta_value - ETA_0) * math.pi * 0.5 -def _compute_eta(ak, bk) -> Tuple[np.ndarray, np.ndarray]: +def compute_eta(ak, bk) -> Tuple[np.ndarray, np.ndarray]: """ Equation (1) JRMS2006 eta is the vertical coordinate and eta_v is an auxiliary vertical coordinate """ eta = 0.5 * ((ak[:-1] + ak[1:]) / SURFACE_PRESSURE + bk[:-1] + bk[1:]) - eta_v = _vertical_coordinate(eta) + eta_v = vertical_coordinate(eta) return eta, eta_v def _check_eta(ak, bk): - return _compute_eta(ak, bk) + return compute_eta(ak, bk)