Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to reflect NDSL import format #13

Merged
merged 9 commits into from
Mar 20, 2024
24 changes: 15 additions & 9 deletions examples/notebook/test_functionality.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,21 @@
"from gt4py.cartesian.gtscript import PARALLEL, computation, interval\n",
"\n",
"from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ\n",
"from ndsl.comm.communicator import Communicator, CubedSphereCommunicator\n",
"from ndsl.dsl.stencil import StencilFactory, GridIndexing\n",
"from ndsl.initialization import SubtileGridSizer\n",
"from ndsl.initialization.allocator import QuantityFactory\n",
"from ndsl.quantity import Quantity\n",
"from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner\n",
"from ndsl.constants import X_DIM, Y_DIM, Z_DIM\n",
"from ndsl.dsl.stencil_config import CompilationConfig, StencilConfig\n",
"from ndsl.dsl.dace.wrapped_halo_exchange import WrappedHaloUpdater"
"from ndsl import (\n",
" CompilationConfig,\n",
" CubedSphereCommunicator,\n",
" CubedSpherePartitioner,\n",
" GridIndexing,\n",
" Quantity,\n",
" QuantityFactory,\n",
" StencilConfig,\n",
" StencilFactory,\n",
" SubtileGridSizer,\n",
" TilePartitioner,\n",
" WrappedHaloUpdater,\n",
")\n",
"from ndsl.typing import Communicator\n",
"from ndsl.constants import X_DIM, Y_DIM, Z_DIM"
]
},
{
Expand Down
42 changes: 25 additions & 17 deletions examples/standalone/runfile/acoustics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,21 @@
import yaml
from timing import collect_data_and_write_to_file

import ndsl.dsl
import ndsl.util as util
from ndsl.comm.null_comm import NullComm
from ndsl.dsl.dace.orchestration import DaceConfig
from ndsl.dsl.stencil import CompilationConfig
from ndsl.stencils.testing.grid import Grid
from pyFV3 import AcousticDynamics, DynamicalCoreConfig, TranslateDynCore
from ndsl import (
CompilationConfig,
CubedSphereCommunicator,
CubedSpherePartitioner,
DaceConfig,
NullComm,
StencilConfig,
StencilFactory,
TilePartitioner,
)
from ndsl.performance import Timer
from ndsl.stencils.testing import Grid
from pyFV3 import DynamicalCoreConfig
from pyFV3.stencils import AcousticDynamics
from pyFV3.testing import TranslateDynCore


try:
Expand Down Expand Up @@ -46,7 +54,7 @@ def initialize_serializer(data_directory: str, rank: int = 0) -> serialbox.Seria
def read_input_data(
grid: Grid,
namelist: DynamicalCoreConfig,
stencil_factory: ndsl.dsl.stencil.StencilFactory,
stencil_factory: StencilFactory,
serializer: serialbox.Serializer,
) -> Dict[str, Any]:
"""Uses the serializer to read the input data from disk"""
Expand Down Expand Up @@ -81,17 +89,17 @@ def get_state_from_input(
def set_up_communicator(
disable_halo_exchange: bool,
layout: Tuple[int, int],
) -> Tuple[Optional[MPI.Comm], Optional[util.CubedSphereCommunicator]]:
partitioner = util.CubedSpherePartitioner(util.TilePartitioner(layout))
) -> Tuple[Optional[MPI.Comm], Optional[CubedSphereCommunicator]]:
partitioner = CubedSpherePartitioner(TilePartitioner(layout))
if MPI is not None:
comm = MPI.COMM_WORLD
else:
comm = None
if not disable_halo_exchange:
assert comm is not None
cube_comm = util.CubedSphereCommunicator(comm, partitioner)
cube_comm = CubedSphereCommunicator(comm, partitioner)
else:
cube_comm = util.CubedSphereCommunicator(NullComm(0, 0), partitioner)
cube_comm = CubedSphereCommunicator(NullComm(0, 0), partitioner)
return comm, cube_comm


Expand All @@ -106,10 +114,10 @@ def get_experiment_name(
)["experiment_name"]


def initialize_timers() -> Tuple[util.Timer, util.Timer, List, List]:
total_timer = util.Timer()
def initialize_timers() -> Tuple[Timer, Timer, List, List]:
total_timer = Timer()
total_timer.start("total")
timestep_timer = util.Timer()
timestep_timer = Timer()
return total_timer, timestep_timer, [], []


Expand Down Expand Up @@ -148,13 +156,13 @@ def driver(
tile_nx=dycore_config.npx,
tile_nz=dycore_config.npz,
)
stencil_config = ndsl.dsl.stencil.StencilConfig(
stencil_config = StencilConfig(
compilation_config=CompilationConfig(
backend=backend, rebuild=False, validate_args=True
),
dace_config=dace_config,
)
stencil_factory = ndsl.dsl.stencil.StencilFactory(
stencil_factory = StencilFactory(
config=stencil_config,
grid_indexing=grid.grid_indexing,
)
Expand Down
3 changes: 1 addition & 2 deletions examples/standalone/runfile/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import f90nml
import gt4py.cartesian.config

import ndsl.dsl.stencil # noqa: F401
from ndsl.comm.null_comm import NullComm
from ndsl import NullComm
from pyFV3 import DynamicalCoreConfig


Expand Down
36 changes: 20 additions & 16 deletions examples/standalone/runfile/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,22 @@
# NOTE: we need to import dsl.stencil prior to
# ndsl.util, otherwise xarray precedes gt4py, causing
# very strange errors on some systems (e.g. daint)
import ndsl.dsl.stencil
import ndsl.util as util
from ndsl.comm.null_comm import NullComm
from ndsl.dsl import StencilFactory
from ndsl.dsl.dace.orchestration import DaceConfig
from ndsl import (
CompilationConfig,
CubedSphereCommunicator,
CubedSpherePartitioner,
DaceConfig,
NullComm,
StencilConfig,
StencilFactory,
TilePartitioner,
)
from ndsl.grid import DampingCoefficients, GridData, MetricTerms
from ndsl.stencils.testing import dataset_to_dict
from ndsl.stencils.testing.grid import Grid
from pyFV3 import DycoreState, DynamicalCore, DynamicalCoreConfig, TranslateFVDynamics
from ndsl.performance import Timer
from ndsl.stencils.testing import Grid, dataset_to_dict
from pyFV3 import DycoreState, DynamicalCore, DynamicalCoreConfig
from pyFV3.initialization.test_cases import init_baroclinic_state
from pyFV3.testing import TranslateFVDynamics


def parse_args() -> Namespace:
Expand Down Expand Up @@ -208,10 +214,8 @@ def setup_dycore(
dycore_config, mpi_comm, backend, is_baroclinic_test_case, data_dir
) -> Tuple[DynamicalCore, DycoreState, StencilFactory]:
# set up grid-dependent helper structures
partitioner = util.CubedSpherePartitioner(
util.TilePartitioner(dycore_config.layout)
)
communicator = util.CubedSphereCommunicator(mpi_comm, partitioner)
partitioner = CubedSpherePartitioner(TilePartitioner(dycore_config.layout))
communicator = CubedSphereCommunicator(mpi_comm, partitioner)
grid = Grid.from_namelist(dycore_config, mpi_comm.rank, backend)

dace_config = DaceConfig(
Expand All @@ -220,8 +224,8 @@ def setup_dycore(
tile_nx=dycore_config.npx,
tile_nz=dycore_config.npz,
)
stencil_config = ndsl.dsl.stencil.StencilConfig(
compilation_config=ndsl.dsl.stencil.CompilationConfig(
stencil_config = StencilConfig(
compilation_config=CompilationConfig(
backend=backend, rebuild=False, validate_args=False
),
dace_config=dace_config,
Expand Down Expand Up @@ -265,7 +269,7 @@ def setup_dycore(


if __name__ == "__main__":
timer = util.Timer()
timer = Timer()
timer.start("total")
with timer.clock("initialization"):
args = parse_args()
Expand Down Expand Up @@ -308,7 +312,7 @@ def setup_dycore(
hits_per_step = []
# we set up a specific timer for each timestep
# that is cleared after so we get individual statistics
timestep_timer = util.Timer()
timestep_timer = Timer()
for i in range(args.time_step - 1):
with timestep_timer.clock("mainloop"):
if rank == 0:
Expand Down
1 change: 0 additions & 1 deletion pyFV3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
DycoreState: Dataclass containing state of the dynamical core
DryConvectiveAdjustment: Sub-grid dry convective adjustment
DynamicalCore: The FV3 dynamical core
GeosDycoreWrapper: Interface to the dycore for the GEOS model
"""

__version__ = "0.2.0"
6 changes: 2 additions & 4 deletions pyFV3/dycore_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import xarray as xr

import ndsl.dsl.gt4py_utils as gt_utils
from ndsl.comm.communicator import Communicator
from ndsl import GridSizer, Quantity, QuantityFactory
from ndsl.constants import (
X_DIM,
X_INTERFACE_DIM,
Expand All @@ -14,10 +14,8 @@
Z_INTERFACE_DIM,
)
from ndsl.dsl.typing import Float
from ndsl.initialization.allocator import QuantityFactory
from ndsl.initialization.sizer import GridSizer
from ndsl.quantity import Quantity
from ndsl.restart._legacy_restart import open_restart
from ndsl.typing import Communicator


@dataclass()
Expand Down
5 changes: 2 additions & 3 deletions pyFV3/initialization/analytic_init.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from enum import Enum

from ndsl.comm.communicator import Communicator, CubedSphereCommunicator
from ndsl import CubedSphereCommunicator, MetaEnumStr, QuantityFactory
from ndsl.grid import GridData
from ndsl.initialization.allocator import QuantityFactory
from ndsl.utils import MetaEnumStr
from ndsl.typing import Communicator
from pyFV3.dycore_state import DycoreState


Expand Down
7 changes: 5 additions & 2 deletions pyFV3/initialization/init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

import ndsl.constants as constants
from ndsl.dsl.typing import Float
from ndsl.grid import lon_lat_midpoint
from ndsl.grid.eta import SURFACE_PRESSURE, compute_eta, vertical_coordinate
from ndsl.grid.gnomonic import get_lonlat_vect, get_unit_vector_direction
from ndsl.grid.gnomonic import (
get_lonlat_vect,
get_unit_vector_direction,
lon_lat_midpoint,
)
from pyFV3.dycore_state import DycoreState


Expand Down
8 changes: 4 additions & 4 deletions pyFV3/initialization/test_cases/initialize_baroclinic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

import ndsl.constants as constants
import ndsl.dsl.gt4py_utils as utils
import pyFV3.initialization.init_utils as init_utils
from ndsl.comm.communicator import CubedSphereCommunicator
from ndsl.grid import GridData, great_circle_distance_lon_lat, lon_lat_midpoint
from ndsl.initialization.allocator import QuantityFactory
from ndsl import CubedSphereCommunicator, QuantityFactory
from ndsl.grid import GridData
from ndsl.grid.gnomonic import great_circle_distance_lon_lat, lon_lat_midpoint
from pyFV3.dycore_state import DycoreState
from pyFV3.initialization import init_utils


# maximum windspeed amplitude - close to windspeed of zonal-mean time-mean
Expand Down
5 changes: 2 additions & 3 deletions pyFV3/initialization/test_cases/initialize_tc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import numpy as np

import ndsl.constants as constants
import pyFV3.initialization.init_utils as init_utils
from ndsl.comm.communicator import CubedSphereCommunicator
from ndsl import CubedSphereCommunicator, QuantityFactory
from ndsl.grid import GridData, great_circle_distance_lon_lat
from ndsl.initialization.allocator import QuantityFactory
from pyFV3.dycore_state import DycoreState
from pyFV3.initialization import init_utils


def _calculate_distance_from_tc_center(pe_v, ps_v, muv, calc, tc_properties):
Expand Down
4 changes: 1 addition & 3 deletions pyFV3/stencils/a2b_ord4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@
sqrt,
)

from ndsl import GridIndexing, QuantityFactory, StencilFactory, orchestrate
from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from ndsl.dsl.dace.orchestration import orchestrate
from ndsl.dsl.stencil import GridIndexing, StencilFactory
from ndsl.dsl.typing import Float, FloatField, FloatFieldI, FloatFieldIJ
from ndsl.grid import GridData
from ndsl.initialization.allocator import QuantityFactory
from pyFV3.stencils.basic_operations import copy_defn


Expand Down
5 changes: 1 addition & 4 deletions pyFV3/stencils/c_sw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@
region,
)

from ndsl import Quantity, QuantityFactory, StencilFactory, orchestrate
from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from ndsl.dsl.dace.orchestration import orchestrate
from ndsl.dsl.stencil import StencilFactory
from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ
from ndsl.grid import GridData
from ndsl.initialization.allocator import QuantityFactory
from ndsl.quantity import Quantity
from ndsl.stencils import corners
from pyFV3.stencils.d2a2c_vect import DGrid2AGrid2CGridVectors

Expand Down
4 changes: 1 addition & 3 deletions pyFV3/stencils/d2a2c_vect.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import gt4py.cartesian.gtscript as gtscript
from gt4py.cartesian.gtscript import PARALLEL, computation, horizontal, interval, region

from ndsl import QuantityFactory, StencilFactory, orchestrate
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.dace.orchestration import orchestrate
from ndsl.dsl.stencil import StencilFactory
from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ
from ndsl.grid import GridData
from ndsl.initialization.allocator import QuantityFactory
from ndsl.stencils import corners
from pyFV3.stencils.a2b_ord4 import a1, a2, lagrange_x_func, lagrange_y_func

Expand Down
7 changes: 2 additions & 5 deletions pyFV3/stencils/d_sw.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@
region,
)

import pyFV3.stencils.delnflux as delnflux
from ndsl import Quantity, QuantityFactory, StencilFactory, orchestrate
from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from ndsl.dsl.dace.orchestration import orchestrate
from ndsl.dsl.stencil import StencilFactory
from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, FloatFieldK
from ndsl.grid import DampingCoefficients, GridData
from ndsl.initialization.allocator import QuantityFactory
from ndsl.quantity import Quantity
from pyFV3._config import DGridShallowWaterLagrangianDynamicsConfig
from pyFV3.stencils import delnflux
from pyFV3.stencils.d2a2c_vect import contravariant
from pyFV3.stencils.delnflux import DelnFluxNoSG
from pyFV3.stencils.divergence_damping import DivergenceDamping
Expand Down
5 changes: 2 additions & 3 deletions pyFV3/stencils/del2cubed.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from gt4py.cartesian.gtscript import PARALLEL, computation, horizontal, interval, region
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of scope of this PR, but some part of me is wondering if we could fold the gt4py.cartesian.gtscript import in a ndsl.gtscript. The main reason would be to keep a single point of entry for users and remove the confusion around the frameworks

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on having ndsl as the single point of entry for pyFV3 and pySHiELD.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logged in NOAA-GFDL/NDSL#29


import ndsl.stencils.corners as corners
from ndsl import QuantityFactory, StencilFactory, orchestrate
from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from ndsl.dsl.dace.orchestration import orchestrate
from ndsl.dsl.stencil import StencilFactory, get_stencils_with_varied_bounds
from ndsl.dsl.stencil import get_stencils_with_varied_bounds
from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, cast_to_index3d
from ndsl.grid import DampingCoefficients
from ndsl.initialization.allocator import QuantityFactory
from pyFV3.stencils.basic_operations import copy_defn


Expand Down
6 changes: 2 additions & 4 deletions pyFV3/stencils/delnflux.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
import gt4py.cartesian.gtscript as gtscript
from gt4py.cartesian.gtscript import PARALLEL, computation, horizontal, interval, region

from ndsl import Quantity, QuantityFactory, StencilFactory, orchestrate
from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from ndsl.dsl.dace.orchestration import orchestrate
from ndsl.dsl.stencil import StencilFactory, get_stencils_with_varied_bounds
from ndsl.dsl.stencil import get_stencils_with_varied_bounds
from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, FloatFieldK
from ndsl.grid import DampingCoefficients
from ndsl.initialization.allocator import QuantityFactory
from ndsl.quantity import Quantity


def calc_damp(damp_c: Quantity, da_min: Float, nord: Quantity) -> Quantity:
Expand Down
5 changes: 2 additions & 3 deletions pyFV3/stencils/divergence_damping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@

import ndsl.stencils.corners as corners
import pyFV3.stencils.basic_operations as basic
from ndsl import Quantity, QuantityFactory, StencilFactory
from ndsl.constants import X_DIM, X_INTERFACE_DIM, Y_DIM, Y_INTERFACE_DIM, Z_DIM
from ndsl.dsl.dace.orchestration import dace_inhibitor, orchestrate
from ndsl.dsl.stencil import StencilFactory, get_stencils_with_varied_bounds
from ndsl.dsl.stencil import get_stencils_with_varied_bounds
from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ, FloatFieldK
from ndsl.grid import DampingCoefficients, GridData
from ndsl.initialization.allocator import QuantityFactory
from ndsl.quantity import Quantity
from pyFV3.stencils.a2b_ord4 import AGrid2BGridFourthOrder, doubly_periodic_a2b_ord4
from pyFV3.stencils.d2a2c_vect import contravariant

Expand Down
Loading
Loading