Skip to content

Commit

Permalink
Fix non-deterministic temporaries by using zeros everywhere instead…
Browse files Browse the repository at this point in the history
… of `empty`
  • Loading branch information
FlorianDeconinck committed Aug 23, 2023
1 parent 2f9bbe9 commit 8f6ba7c
Show file tree
Hide file tree
Showing 12 changed files with 22 additions and 36 deletions.
4 changes: 1 addition & 3 deletions driver/pace/driver/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def get_grid(
quantity_factory: QuantityFactory,
communicator: CubedSphereCommunicator,
) -> Tuple[DampingCoefficients, DriverGridData, GridData]:

metric_terms = MetricTerms(
quantity_factory=quantity_factory,
communicator=communicator,
Expand Down Expand Up @@ -184,8 +183,7 @@ def get_grid(
quantity_factory: QuantityFactory,
communicator: CubedSphereCommunicator,
) -> Tuple[DampingCoefficients, DriverGridData, GridData]:

backend = quantity_factory.empty(
backend = quantity_factory.zeros(
dims=[pace.util.X_DIM, pace.util.Y_DIM], units="unknown"
).gt4py_backend

Expand Down
5 changes: 1 addition & 4 deletions driver/pace/driver/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def get_driver_state(
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
) -> DriverState:

dycore_state = tc_init.init_tc_state(
grid_data=grid_data,
quantity_factory=quantity_factory,
Expand Down Expand Up @@ -323,7 +322,7 @@ def get_driver_state(
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
) -> DriverState:
backend = quantity_factory.empty(
backend = quantity_factory.zeros(
dims=[pace.util.X_DIM, pace.util.Y_DIM], units="unknown"
).gt4py_backend

Expand All @@ -348,7 +347,6 @@ def _initialize_dycore_state(
communicator: pace.util.CubedSphereCommunicator,
backend: str,
) -> fv3core.DycoreState:

grid = self._get_serialized_grid(communicator=communicator, backend=backend)

ser = self._serializer(communicator)
Expand Down Expand Up @@ -401,7 +399,6 @@ def get_driver_state(
driver_grid_data: pace.util.grid.DriverGridData,
grid_data: pace.util.grid.GridData,
) -> DriverState:

return DriverState(
dycore_state=self.dycore_state,
physics_state=self.physics_state,
Expand Down
4 changes: 2 additions & 2 deletions fv3core/pace/fv3core/testing/translate_dyncore.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def compute_parallel(self, inputs, communicator):
grid_data.ptop = inputs["ptop"]
self._base.make_storage_data_input_vars(inputs)
state = DycoreState.init_zeros(quantity_factory=self.grid.quantity_factory)
wsd: pace.util.Quantity = self.grid.quantity_factory.empty(
wsd: pace.util.Quantity = self.grid.quantity_factory.zeros(
dims=[pace.util.X_DIM, pace.util.Y_DIM],
units="unknown",
)
Expand All @@ -152,7 +152,7 @@ def compute_parallel(self, inputs, communicator):
state[name].data[selection] = value
else:
setattr(state, name, value)
phis: pace.util.Quantity = self.grid.quantity_factory.empty(
phis: pace.util.Quantity = self.grid.quantity_factory.zeros(
dims=[pace.util.X_DIM, pace.util.Y_DIM],
units="m",
)
Expand Down
2 changes: 1 addition & 1 deletion fv3core/tests/savepoint/translate/translate_remapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def compute_from_storage(self, inputs):
inputs["wsd"] = wsd_2d
inputs["q_cld"] = inputs["tracers"]["qcld"]
inputs["last_step"] = bool(inputs["last_step"])
pfull = self.grid.quantity_factory.empty([Z_DIM], units="Pa")
pfull = self.grid.quantity_factory.zeros([Z_DIM], units="Pa")
pfull.data[:] = pfull.np.asarray(inputs.pop("pfull"))
l_to_e_obj = LagrangianToEulerian(
self.stencil_factory,
Expand Down
3 changes: 1 addition & 2 deletions stencils/pace/stencils/testing/parallel_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


class ParallelTranslate:

max_error = TranslateFortranData2Py.max_error
near_zero = TranslateFortranData2Py.near_zero
compute_grid_option = False
Expand Down Expand Up @@ -192,7 +191,7 @@ def state_from_inputs(self, inputs: dict, grid=None) -> dict:
for name, properties in self.inputs.items():
standard_name = properties.get("name", name)
if len(properties["dims"]) > 0:
state[standard_name] = grid.quantity_factory.empty(
state[standard_name] = grid.quantity_factory.zeros(
properties["dims"], properties["units"], dtype=inputs[name].dtype
)
input_slice = _serialize_slice(
Expand Down
7 changes: 3 additions & 4 deletions stencils/pace/stencils/testing/temporaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@ def _assert_same_temporaries(dict1: dict, dict2: dict) -> List[str]:
attr2 = dict2[attr]
if isinstance(attr1, np.ndarray):
try:
np.testing.assert_almost_equal(
attr1, attr2, err_msg=f"{attr} not equal"
)
except AssertionError:
assert np.allclose(attr1, attr2, equal_nan=True)
except AssertionError as e:
print(e)
differences.append(attr)
else:
sub_differences = _assert_same_temporaries(attr1, attr2)
Expand Down
10 changes: 5 additions & 5 deletions util/pace/util/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _get_gather_recv_quantity(
) -> Quantity:
"""Initialize a Quantity for use when receiving global data during gather"""
recv_quantity = Quantity(
send_metadata.np.empty(global_extent, dtype=send_metadata.dtype),
send_metadata.np.zeros(global_extent, dtype=send_metadata.dtype),
dims=send_metadata.dims,
units=send_metadata.units,
origin=tuple([0 for dim in send_metadata.dims]),
Expand All @@ -182,7 +182,7 @@ def _get_scatter_recv_quantity(
) -> Quantity:
"""Initialize a Quantity for use when receiving subtile data during scatter"""
recv_quantity = Quantity(
send_metadata.np.empty(shape, dtype=send_metadata.dtype),
send_metadata.np.zeros(shape, dtype=send_metadata.dtype),
dims=send_metadata.dims,
units=send_metadata.units,
gt4py_backend=send_metadata.gt4py_backend,
Expand All @@ -206,7 +206,7 @@ def gather(
result: Optional[Quantity]
if self.rank == constants.ROOT_RANK:
with array_buffer(
send_quantity.np.empty,
send_quantity.np.zeros,
(self.partitioner.total_ranks,) + tuple(send_quantity.extent),
dtype=send_quantity.data.dtype,
) as recvbuf:
Expand Down Expand Up @@ -745,7 +745,7 @@ def _get_gather_recv_quantity(
# needs to change the quantity dimensions since we add a "tile" dimension,
# unlike for tile scatter/gather which retains the same dimensions
recv_quantity = Quantity(
metadata.np.empty(global_extent, dtype=metadata.dtype),
metadata.np.zeros(global_extent, dtype=metadata.dtype),
dims=(constants.TILE_DIM,) + metadata.dims,
units=metadata.units,
origin=(0,) + tuple([0 for dim in metadata.dims]),
Expand All @@ -767,7 +767,7 @@ def _get_scatter_recv_quantity(
# needs to change the quantity dimensions since we remove a "tile" dimension,
# unlike for tile scatter/gather which retains the same dimensions
recv_quantity = Quantity(
metadata.np.empty(shape, dtype=metadata.dtype),
metadata.np.zeros(shape, dtype=metadata.dtype),
dims=metadata.dims[1:],
units=metadata.units,
gt4py_backend=metadata.gt4py_backend,
Expand Down
5 changes: 1 addition & 4 deletions util/pace/util/grid/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def quantity_cast_to_model_float(
quantity_factory: util.QuantityFactory, qty_64: util.Quantity
) -> util.Quantity:
"""Copy & cast from 64-bit float to model precision if need be"""
qty = quantity_factory.empty(qty_64.dims, qty_64.units, dtype=Float)
qty = quantity_factory.zeros(qty_64.dims, qty_64.units, dtype=Float)
qty.data[:] = qty_64.data[:]
return qty

Expand Down Expand Up @@ -1530,7 +1530,6 @@ def rdyc(self) -> util.Quantity:
)

def _init_dgrid(self):

grid_mirror_ew = self.quantity_factory.zeros(
self._grid_dims,
"radians",
Expand Down Expand Up @@ -1751,7 +1750,6 @@ def _compute_dxdy(self):
return dx, dy

def _compute_dxdy_agrid(self):

dx_agrid_64 = self.quantity_factory.zeros(
[util.X_DIM, util.Y_DIM],
"m",
Expand Down Expand Up @@ -2149,7 +2147,6 @@ def _calculate_more_trig_terms(self, cos_sg, sin_sg):
)

def _init_cell_trigonometry(self):

cosa_u_64 = self.quantity_factory.zeros(
[util.X_INTERFACE_DIM, util.Y_DIM],
"",
Expand Down
6 changes: 3 additions & 3 deletions util/pace/util/grid/gnomonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ def _mirror_latlon(lon1, lat1, lon2, lat2, lon0, lat0, np):
pdot = p0[0] * nb[0] + p0[1] * nb[1] + p0[2] * nb[2]
pp = p0 - np.multiply(2.0, pdot) * nb

lon3 = np.empty((1, 1))
lat3 = np.empty((1, 1))
pp3 = np.empty((3, 1, 1))
lon3 = np.zeros((1, 1))
lat3 = np.zeros((1, 1))
pp3 = np.zeros((3, 1, 1))
pp3[:, 0, 0] = pp
_cart_to_latlon(1, pp3, lon3, lat3, np)

Expand Down
6 changes: 2 additions & 4 deletions util/pace/util/grid/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def from_restart(
but no fv_core.res.nc in restart data file."""
)

ak = quantity_factory.empty([Z_INTERFACE_DIM], units="Pa")
bk = quantity_factory.empty([Z_INTERFACE_DIM], units="")
ak = quantity_factory.zeros([Z_INTERFACE_DIM], units="Pa")
bk = quantity_factory.zeros([Z_INTERFACE_DIM], units="")
with fs.open(ak_bk_data_file, "rb") as f:
ds = xr.open_dataset(f).isel(Time=0).drop_vars("Time")
ak.view[:] = ds["ak"].values
Expand Down Expand Up @@ -322,7 +322,6 @@ def __init__(

@classmethod
def new_from_metric_terms(cls, metric_terms: MetricTerms):

horizontal_data = HorizontalGridData.new_from_metric_terms(metric_terms)
vertical_data = VerticalGridData.new_from_metric_terms(metric_terms)
contravariant_data = ContravariantGridData.new_from_metric_terms(metric_terms)
Expand Down Expand Up @@ -701,7 +700,6 @@ def new_from_grid_variables(
es1: pace.util.Quantity,
ew2: pace.util.Quantity,
) -> "DriverGridData":

try:
vlon1, vlon2, vlon3 = split_quantity_along_last_dim(vlon)
vlat1, vlat2, vlat3 = split_quantity_along_last_dim(vlat)
Expand Down
4 changes: 1 addition & 3 deletions util/pace/util/halo_data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _build_flatten_indices(
"""

# Have to go down to numpy to leverage indices calculation
arr_indices = np.empty(shape, dtype=np.int32, order="C")[slices]
arr_indices = np.zeros(shape, dtype=np.int32, order="C")[slices]

# Get offset from first index
offset_dims = []
Expand Down Expand Up @@ -875,7 +875,6 @@ def _opt_unpack_scalar(self, quantities: List[Quantity]):

# Use private stream
with self._get_stream(cu_kernel_args.stream):

# Launch kernel
blocks = 128
grid_x = (info_x._unpack_buffer_size // blocks) + 1
Expand Down Expand Up @@ -942,7 +941,6 @@ def _opt_unpack_vector(

# Use private stream
with self._get_stream(cu_kernel_args.stream):

# Buffer sizes
edge_size = info_x._unpack_buffer_size + info_y._unpack_buffer_size

Expand Down
2 changes: 1 addition & 1 deletion util/pace/util/initialization/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def from_array(
That numpy array must correspond to the correct shape and extent
for the given dims.
"""
base = self.empty(
base = self.zeros(
dims=dims,
units=units,
dtype=data.dtype,
Expand Down

0 comments on commit 8f6ba7c

Please sign in to comment.