Skip to content

Commit

Permalink
Feature/doubly_periodic_dycore (#24)
Browse files Browse the repository at this point in the history
* initial commit, first version of d2a2c_vect

* doubly periodic implementation for a2b_ord4

* doubly-periodic implementations of update_dwinds_physics and updatedzc

* fixing domains, initial dp xppm, yppm, xyp, ytp, divergence_corner, c_sw

* d_sw, smag_corner initial doubly periodic config done

* removed asserts, initial doubly periodic grid should be supported?

* c2l and some config cleanup

* maybe this will work for the driver?

* add umax to grid_config

* updating namelist, adding test config for driver init

* debugging driver init with dp grid

* fix varname

* rework grid type to be in grid config

* test fixes

* bugfixes

* fixing dp a2b

* need to disable a2b_ord4 test for gridtype 4, exploring more of d2a2c

* workaround for d2a2c on dp domain

* remove breakpoint

* add attrs to divergence damping

* correcting types

* small cleanup

* changing type enforcement on communicators, mocking single rank exchange for c2l

* prolly not gonna push, making one rank tests work

* why test no work

* undo silly

* reconfigure tests for doubly periodic domains

* fixing replace issue

* linting

* a2b fix

* fixing definition for a2b doubly periodic stencil

* trying explicit dp_a2b in nh_p_grad

* Revert "fixing definition for a2b doubly periodic stencil" --doesnt work

This reverts commit 68a86ec.

* actually reverting changes

* fixing size in a2b

* type fix for delnflux

* messing with corner copies

* didn't work

* re-adding nord round fix

* update history

* fixing physics/dycore interface grid type handling

* update util history

* updating one more call

* initial review cleanup

* try to undo gt4py change again

* undoing stencil changes rq

* update calls in notebooks

* updating logs and documentation

* fixing serialized initialization test

* updating explainer for dpa2b
  • Loading branch information
oelbert authored Oct 11, 2023
1 parent 1b91c76 commit e0e7e90
Show file tree
Hide file tree
Showing 59 changed files with 1,601 additions and 1,078 deletions.
37 changes: 25 additions & 12 deletions driver/pace/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@

# TODO: move update_atmos_state into pace.driver
from pace.stencils import update_atmos_state
from pace.util.communicator import CubedSphereCommunicator
from pace.util.communicator import (
Communicator,
CubedSphereCommunicator,
TileCommunicator,
)
from pace.util.logging import pace_log

from . import diagnostics
Expand Down Expand Up @@ -90,6 +94,7 @@ class DriverConfig:
nz: int
layout: Tuple[int, int]
dt_atmos: float
grid_type: Optional[int] = 0
grid_config: GridInitializerSelector = dataclasses.field(
default_factory=lambda: GridInitializerSelector(
type="generated", config=GeneratedGridConfig()
Expand Down Expand Up @@ -158,7 +163,7 @@ def apply_tendencies(self) -> bool:

def get_grid(
self,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
quantity_factory: Optional[pace.util.QuantityFactory] = None,
) -> Tuple[
pace.util.grid.DampingCoefficients,
Expand Down Expand Up @@ -187,7 +192,7 @@ def get_grid(

def get_driver_state(
self,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand All @@ -213,7 +218,7 @@ def get_driver_state(
if stencil_factory is None:
grid_indexing = (
pace.dsl.stencil.GridIndexing.from_sizer_and_communicator(
sizer=sizer, cube=communicator
sizer=sizer, comm=communicator
)
)
stencil_factory = pace.dsl.StencilFactory(
Expand Down Expand Up @@ -407,11 +412,19 @@ def __init__(
if self.config.performance_config.collect_communication
else None
)
communicator = CubedSphereCommunicator.from_layout(
comm=self.comm,
layout=self.config.layout,
timer=comm_timer,
)
communicator: Communicator
if self.config.grid_type <= 3:
communicator = CubedSphereCommunicator.from_layout(
comm=self.comm,
layout=self.config.layout,
timer=comm_timer,
)
else:
communicator = TileCommunicator.from_layout(
comm=self.comm,
layout=self.config.layout,
timer=comm_timer,
)
self._update_driver_config_with_communicator(communicator)

if self.config.stencil_config.compilation_config.run_mode == RunMode.Build:
Expand Down Expand Up @@ -547,7 +560,7 @@ def exit_instead_of_build(self):
pace_log.info("initialization of the object done")

def _update_driver_config_with_communicator(
self, communicator: CubedSphereCommunicator
self, communicator: Communicator
) -> None:
dace_config = DaceConfig(
communicator=communicator,
Expand Down Expand Up @@ -710,7 +723,7 @@ def log_subtile_location(partitioner: pace.util.TilePartitioner, rank: int):

def _setup_factories(
config: DriverConfig,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
stencil_compare_comm,
) -> Tuple[pace.util.QuantityFactory, pace.dsl.StencilFactory]:
"""
Expand Down Expand Up @@ -738,7 +751,7 @@ def _setup_factories(
)

grid_indexing = pace.dsl.stencil.GridIndexing.from_sizer_and_communicator(
sizer=sizer, cube=communicator
sizer=sizer, comm=communicator
)
quantity_factory = pace.util.QuantityFactory.from_backend(
sizer, backend=config.stencil_config.compilation_config.backend
Expand Down
14 changes: 7 additions & 7 deletions driver/pace/driver/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pace.stencils
import pace.util.grid
from pace.stencils.testing import TranslateGrid
from pace.util import CubedSphereCommunicator, QuantityFactory
from pace.util import Communicator, QuantityFactory
from pace.util.grid import (
DampingCoefficients,
DriverGridData,
Expand All @@ -35,7 +35,7 @@ class GridInitializer(abc.ABC):
def get_grid(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
) -> Tuple[DampingCoefficients, DriverGridData, GridData]:
...

Expand All @@ -62,7 +62,7 @@ def register(cls, type_name):
def get_grid(
self,
quantity_factory: QuantityFactory,
communicator: CubedSphereCommunicator,
communicator: Communicator,
) -> Tuple[DampingCoefficients, DriverGridData, GridData]:
return self.config.get_grid(
quantity_factory=quantity_factory, communicator=communicator
Expand Down Expand Up @@ -103,7 +103,7 @@ class GeneratedGridConfig(GridInitializer):
def get_grid(
self,
quantity_factory: QuantityFactory,
communicator: CubedSphereCommunicator,
communicator: Communicator,
) -> Tuple[DampingCoefficients, DriverGridData, GridData]:
metric_terms = MetricTerms(
quantity_factory=quantity_factory,
Expand Down Expand Up @@ -157,7 +157,7 @@ def _f90_namelist(self) -> f90nml.Namelist:
def _namelist(self) -> Namelist:
return Namelist.from_f90nml(self._f90_namelist)

def _serializer(self, communicator: pace.util.CubedSphereCommunicator):
def _serializer(self, communicator: pace.util.Communicator):
import serialbox

serializer = serialbox.Serializer(
Expand All @@ -169,7 +169,7 @@ def _serializer(self, communicator: pace.util.CubedSphereCommunicator):

def _get_serialized_grid(
self,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
backend: str,
) -> pace.stencils.testing.grid.Grid: # type: ignore
ser = self._serializer(communicator)
Expand All @@ -181,7 +181,7 @@ def _get_serialized_grid(
def get_grid(
self,
quantity_factory: QuantityFactory,
communicator: CubedSphereCommunicator,
communicator: Communicator,
) -> Tuple[DampingCoefficients, DriverGridData, GridData]:
backend = quantity_factory.zeros(
dims=[pace.util.X_DIM, pace.util.Y_DIM], units="unknown"
Expand Down
20 changes: 10 additions & 10 deletions driver/pace/driver/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def start_time(self) -> datetime:
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand Down Expand Up @@ -73,7 +73,7 @@ def start_time(self) -> datetime:
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand Down Expand Up @@ -105,7 +105,7 @@ class AnalyticInit(Initializer):
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand Down Expand Up @@ -148,7 +148,7 @@ class RestartInit(Initializer):
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand Down Expand Up @@ -197,7 +197,7 @@ def start_time(self) -> datetime:
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand Down Expand Up @@ -246,7 +246,7 @@ def _namelist(self) -> Namelist:

def _get_serialized_grid(
self,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
backend: str,
) -> pace.stencils.testing.grid.Grid: # type: ignore
ser = self._serializer(communicator)
Expand All @@ -255,7 +255,7 @@ def _get_serialized_grid(
).python_grid()
return grid

def _serializer(self, communicator: pace.util.CubedSphereCommunicator):
def _serializer(self, communicator: pace.util.Communicator):
import serialbox

serializer = serialbox.Serializer(
Expand All @@ -268,7 +268,7 @@ def _serializer(self, communicator: pace.util.CubedSphereCommunicator):
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand All @@ -295,7 +295,7 @@ def get_driver_state(

def _initialize_dycore_state(
self,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
backend: str,
) -> fv3core.DycoreState:
grid = self._get_serialized_grid(communicator=communicator, backend=backend)
Expand Down Expand Up @@ -345,7 +345,7 @@ class PredefinedStateInit(Initializer):
def get_driver_state(
self,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand Down
4 changes: 2 additions & 2 deletions driver/pace/driver/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def load_state_from_restart(
grid_data: pace.util.grid.GridData,
) -> "DriverState":
comm = driver_config.comm_config.get_comm()
communicator = pace.util.CubedSphereCommunicator.from_layout(
communicator = pace.util.Communicator.from_layout(
comm=comm, layout=driver_config.layout
)
sizer = pace.util.SubtileGridSizer.from_tile_params(
Expand Down Expand Up @@ -172,7 +172,7 @@ def _restart_driver_state(
path: str,
rank: int,
quantity_factory: pace.util.QuantityFactory,
communicator: pace.util.CubedSphereCommunicator,
communicator: pace.util.Communicator,
damping_coefficients: pace.util.grid.DampingCoefficients,
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
Expand Down
4 changes: 2 additions & 2 deletions dsl/pace/dsl/caches/cache_location.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pace.dsl.caches.codepath import FV3CodePath
from pace.util import CubedSpherePartitioner
from pace.util import Partitioner


def identify_code_path(
rank: int,
partitioner: CubedSpherePartitioner,
partitioner: Partitioner,
) -> FV3CodePath:
if partitioner.layout == (1, 1) or partitioner.layout == [1, 1]:
return FV3CodePath.All
Expand Down
8 changes: 4 additions & 4 deletions dsl/pace/dsl/dace/dace_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pace.dsl.caches.codepath import FV3CodePath
from pace.dsl.gt4py_utils import is_gpu_backend
from pace.util._optional_imports import cupy as cp
from pace.util.communicator import CubedSphereCommunicator, CubedSpherePartitioner
from pace.util.communicator import Communicator, Partitioner


# This can be turned on to revert compilation for orchestration
Expand All @@ -19,7 +19,7 @@
DEACTIVATE_DISTRIBUTED_DACE_COMPILE = False


def _is_corner(rank: int, partitioner: CubedSpherePartitioner) -> bool:
def _is_corner(rank: int, partitioner: Partitioner) -> bool:
if partitioner.tile.on_tile_bottom(rank):
if partitioner.tile.on_tile_left(rank):
return True
Expand Down Expand Up @@ -55,7 +55,7 @@ def _smallest_rank_middle(x: int, y: int, layout: Tuple[int, int]):

def _determine_compiling_ranks(
config: "DaceConfig",
partitioner: CubedSpherePartitioner,
partitioner: Partitioner,
) -> bool:
"""
We try to map every layout to a 3x3 layout which MPI ranks
Expand Down Expand Up @@ -149,7 +149,7 @@ def __call__(self):
class DaceConfig:
def __init__(
self,
communicator: Optional[CubedSphereCommunicator],
communicator: Optional[Communicator],
backend: str,
tile_nx: int = 0,
tile_nz: int = 0,
Expand Down
4 changes: 2 additions & 2 deletions dsl/pace/dsl/dace/wrapped_halo_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Optional

from pace.dsl.dace.orchestration import dace_inhibitor
from pace.util.communicator import CubedSphereCommunicator
from pace.util.communicator import Communicator
from pace.util.halo_updater import HaloUpdater


Expand All @@ -21,7 +21,7 @@ def __init__(
state,
qty_x_names: List[str],
qty_y_names: List[str] = None,
comm: Optional[CubedSphereCommunicator] = None,
comm: Optional[Communicator] = None,
) -> None:
self._updater = updater
self._state = state
Expand Down
10 changes: 5 additions & 5 deletions dsl/pace/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,18 +595,18 @@ def domain(self, domain):

@classmethod
def from_sizer_and_communicator(
cls, sizer: pace.util.GridSizer, cube: pace.util.CubedSphereCommunicator
cls, sizer: pace.util.GridSizer, comm: pace.util.Communicator
) -> "GridIndexing":
# TODO: if this class is refactored to split off the *_edge booleans,
# this init routine can be refactored to require only a GridSizer
domain = cast(
Tuple[int, int, int],
sizer.get_extent([pace.util.X_DIM, pace.util.Y_DIM, pace.util.Z_DIM]),
)
south_edge = cube.tile.partitioner.on_tile_bottom(cube.rank)
north_edge = cube.tile.partitioner.on_tile_top(cube.rank)
west_edge = cube.tile.partitioner.on_tile_left(cube.rank)
east_edge = cube.tile.partitioner.on_tile_right(cube.rank)
south_edge = comm.tile.partitioner.on_tile_bottom(comm.rank)
north_edge = comm.tile.partitioner.on_tile_top(comm.rank)
west_edge = comm.tile.partitioner.on_tile_left(comm.rank)
east_edge = comm.tile.partitioner.on_tile_right(comm.rank)
return cls(
domain=domain,
n_halo=sizer.n_halo,
Expand Down
Loading

0 comments on commit e0e7e90

Please sign in to comment.