Skip to content

Commit

Permalink
Merge pull request #67 from NOAA-GFDL/feature/translate_test_f32
Browse files Browse the repository at this point in the history
Multi-modal metric for 32-bit float translate tests
  • Loading branch information
FlorianDeconinck committed Sep 4, 2024
2 parents 82e5384 + 47396f8 commit 7f84b32
Show file tree
Hide file tree
Showing 4 changed files with 340 additions and 220 deletions.
42 changes: 8 additions & 34 deletions ndsl/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ndsl.dsl.typing import Float, Index3D, cast_to_index3d
from ndsl.initialization.sizer import GridSizer, SubtileGridSizer
from ndsl.quantity import Quantity
from ndsl.testing import comparison
from ndsl.testing.comparison import LegacyMetric


try:
Expand Down Expand Up @@ -68,40 +68,14 @@ def report_difference(args, kwargs, args_copy, kwargs_copy, function_name, gt_id


def report_diff(arg: np.ndarray, numpy_arg: np.ndarray, label) -> str:
metric_err = comparison.compare_arr(arg, numpy_arg)
nans_match = np.logical_and(np.isnan(arg), np.isnan(numpy_arg))
n_points = np.product(arg.shape)
failures_14 = n_points - np.sum(
np.logical_or(
nans_match,
metric_err < 1e-14,
)
)
failures_10 = n_points - np.sum(
np.logical_or(
nans_match,
metric_err < 1e-10,
)
metric = LegacyMetric(
reference_values=arg,
computed_values=numpy_arg,
eps=1e-13,
ignore_near_zero_errors=False,
near_zero=0,
)
failures_8 = n_points - np.sum(
np.logical_or(
nans_match,
metric_err < 1e-8,
)
)
greatest_error = np.max(metric_err[~np.isnan(metric_err)])
if greatest_error == 0.0 and failures_14 == 0:
report = ""
else:
report = f"\n {label}: "
report += f"max_err={greatest_error}"
if failures_14 > 0:
report += f" 1e-14 failures: {failures_14}"
if failures_10 > 0:
report += f" 1e-10 failures: {failures_10}"
if failures_8 > 0:
report += f" 1e-8 failures: {failures_8}"
return report
return metric.__repr__()


@dataclasses.dataclass
Expand Down
11 changes: 11 additions & 0 deletions ndsl/stencils/testing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ def pytest_addoption(parser):
default="cubed-sphere",
help='Topology of the grid. "cubed-sphere" means a 6-faced grid, "doubly-periodic" means a 1 tile grid. Default to "cubed-sphere".',
)
parser.addoption(
"--multimodal_metric",
action="store_true",
default=False,
help="Use the multi-modal float metric. Default to False.",
)


def pytest_configure(config):
Expand Down Expand Up @@ -389,6 +395,11 @@ def failure_stride(pytestconfig):
return int(pytestconfig.getoption("failure_stride"))


@pytest.fixture()
def multimodal_metric(pytestconfig):
return bool(pytestconfig.getoption("multimodal_metric"))


@pytest.fixture()
def grid(pytestconfig):
return pytestconfig.getoption("grid")
Expand Down
202 changes: 70 additions & 132 deletions ndsl/stencils/testing/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ndsl.quantity import Quantity
from ndsl.restart._legacy_restart import RESTART_PROPERTIES
from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict
from ndsl.testing.comparison import compare_scalar, success, success_array
from ndsl.testing.comparison import LegacyMetric, MultiModalFloatMetric
from ndsl.testing.perturbation import perturb


Expand All @@ -32,92 +32,6 @@ def platform():
return "docker" if in_docker else "metal"


def sample_wherefail(
computed_data,
ref_data,
eps,
print_failures,
failure_stride,
test_name,
ignore_near_zero_errors,
near_zero,
xy_indices=False,
):
found_indices = np.where(
np.logical_not(
success_array(
computed_data, ref_data, eps, ignore_near_zero_errors, near_zero
)
)
)
computed_failures = computed_data[found_indices]
reference_failures = ref_data[found_indices]

# List all errors
return_strings = []
bad_indices_count = len(found_indices[0])
# Determine worst result
worst_metric_err = 0.0
for b in range(bad_indices_count):
full_index = [f[b] for f in found_indices]
metric_err = compare_scalar(computed_failures[b], reference_failures[b])
abs_err = abs(computed_failures[b] - reference_failures[b])
if print_failures and b % failure_stride == 0:
return_strings.append(
f"index: {full_index}, computed {computed_failures[b]}, "
f"reference {reference_failures[b]}, "
f"absolute diff {abs_err:.3e}, "
f"metric diff: {metric_err:.3e}"
)
if np.isnan(metric_err) or (metric_err > worst_metric_err):
worst_metric_err = metric_err
worst_full_idx = full_index
worst_abs_err = abs_err
computed_worst = computed_failures[b]
reference_worst = reference_failures[b]
# Summary and worst result
fullcount = len(ref_data.flatten())
return_strings.append(
f"Failed count: {bad_indices_count}/{fullcount} "
f"({round(100.0 * (bad_indices_count / fullcount), 2)}%),\n"
f"Worst failed index {worst_full_idx}\n"
f"\tcomputed:{computed_worst}\n"
f"\treference: {reference_worst}\n"
f"\tabsolute diff: {worst_abs_err:.3e}\n"
f"\tmetric diff: {worst_metric_err:.3e}\n"
)

if xy_indices:
if len(computed_data.shape) == 3:
axis = 2
any = np.any
elif len(computed_data.shape) == 4:
axis = (2, 3)
any = np.any
else:
axis = None

def any(array, axis):
return array

found_xy_indices = np.where(
any(
np.logical_not(
success_array(
computed_data, ref_data, eps, ignore_near_zero_errors, near_zero
)
),
axis=axis,
)
)

return_strings.append(
"failed horizontal indices:" + str(list(zip(*found_xy_indices)))
)

return "\n".join(return_strings)


def process_override(threshold_overrides, testobj, test_name, backend):
override = threshold_overrides.get(test_name, None)
if override is not None:
Expand Down Expand Up @@ -229,6 +143,7 @@ def test_sequential_savepoint(
subtests,
caplog,
threshold_overrides,
multimodal_metric,
xy_indices=True,
):
if case.testobj is None:
Expand Down Expand Up @@ -257,7 +172,12 @@ def test_sequential_savepoint(
case.testobj.serialnames(case.testobj.in_vars["data_vars"])
+ case.testobj.in_vars["parameters"]
)
input_data = {name: input_data[name] for name in input_names}
try:
input_data = {name: input_data[name] for name in input_names}
except KeyError as e:
raise KeyError(
f"Variable {e} was described in the translate test but cannot be found in the NetCDF"
)
original_input_data = copy.deepcopy(input_data)
# run python version of functionality
output = case.testobj.compute(input_data)
Expand All @@ -273,23 +193,24 @@ def test_sequential_savepoint(
with subtests.test(varname=varname):
failing_names.append(varname)
output_data = gt_utils.asarray(output[varname])
assert success(
output_data,
ref_data,
case.testobj.max_error,
ignore_near_zero,
case.testobj.near_zero,
), sample_wherefail(
output_data,
ref_data,
case.testobj.max_error,
print_failures,
failure_stride,
case.savepoint_name,
ignore_near_zero_errors=ignore_near_zero,
near_zero=case.testobj.near_zero,
xy_indices=xy_indices,
)
if multimodal_metric:
metric = MultiModalFloatMetric(
reference_values=ref_data,
computed_values=output_data,
eps=case.testobj.max_error,
ignore_near_zero_errors=ignore_near_zero,
near_zero=case.testobj.near_zero,
)
else:
metric = LegacyMetric(
reference_values=ref_data,
computed_values=output_data,
eps=case.testobj.max_error,
ignore_near_zero_errors=ignore_near_zero,
near_zero=case.testobj.near_zero,
)
if not metric.check:
pytest.fail(str(metric), pytrace=False)
passing_names.append(failing_names.pop())
ref_data_out[varname] = [ref_data]
if len(failing_names) > 0:
Expand All @@ -307,8 +228,12 @@ def test_sequential_savepoint(
failing_names,
out_filename,
)
assert failing_names == [], f"only the following variables passed: {passing_names}"
assert len(passing_names) > 0, "No tests passed"
if failing_names != []:
pytest.fail(
f"Only the following variables passed: {passing_names}", pytrace=False
)
if len(passing_names) == 0:
pytest.fail("No tests passed")


def state_from_savepoint(serializer, savepoint, name_to_std_name):
Expand Down Expand Up @@ -354,6 +279,7 @@ def test_parallel_savepoint(
caplog,
threshold_overrides,
grid,
multimodal_metric,
xy_indices=True,
):
if MPI.COMM_WORLD.Get_size() % 6 != 0:
Expand Down Expand Up @@ -410,23 +336,24 @@ def test_parallel_savepoint(
with subtests.test(varname=varname):
failing_names.append(varname)
output_data = gt_utils.asarray(output[varname])
assert success(
output_data,
ref_data[varname][0],
case.testobj.max_error,
ignore_near_zero,
case.testobj.near_zero,
), sample_wherefail(
output_data,
ref_data[varname][0],
case.testobj.max_error,
print_failures,
failure_stride,
case.savepoint_name,
ignore_near_zero,
case.testobj.near_zero,
xy_indices,
)
if multimodal_metric:
metric = MultiModalFloatMetric(
reference_values=ref_data[varname][0],
computed_values=output_data,
eps=case.testobj.max_error,
ignore_near_zero_errors=ignore_near_zero,
near_zero=case.testobj.near_zero,
)
else:
metric = LegacyMetric(
reference_values=ref_data[varname][0],
computed_values=output_data,
eps=case.testobj.max_error,
ignore_near_zero_errors=ignore_near_zero,
near_zero=case.testobj.near_zero,
)
if not metric.check:
pytest.fail(str(metric), pytrace=False)
passing_names.append(failing_names.pop())
if len(failing_names) > 0:
os.makedirs(OUTDIR, exist_ok=True)
Expand All @@ -447,8 +374,12 @@ def test_parallel_savepoint(
)
except Exception as error:
print(f"TestParallel SaveNetCDF Error: {error}")
assert failing_names == [], f"only the following variables passed: {passing_names}"
assert len(passing_names) > 0, "No tests passed"
if failing_names != []:
pytest.fail(
f"Only the following variables passed: {passing_names}", pytrace=False
)
if len(passing_names) == 0:
pytest.fail("No tests passed")


def save_netcdf(
Expand All @@ -464,6 +395,7 @@ def save_netcdf(

data_vars = {}
for i, varname in enumerate(failing_names):
# Read in dimensions and attributes
if hasattr(testobj, "outputs"):
dims = [dim_name + f"_{i}" for dim_name in testobj.outputs[varname]["dims"]]
attrs = {"units": testobj.outputs[varname]["units"]}
Expand All @@ -472,27 +404,33 @@ def save_netcdf(
f"dim_{varname}_{j}" for j in range(len(ref_data[varname][0].shape))
]
attrs = {"units": "unknown"}

# Try to save inputs
try:
data_vars[f"{varname}_in"] = xr.DataArray(
data_vars[f"{varname}_input"] = xr.DataArray(
np.stack([in_data[varname] for in_data in inputs_list]),
dims=("rank",) + tuple([f"{d}_in" for d in dims]),
attrs=attrs,
)
except KeyError as error:
print(f"No input data found for {error}")
data_vars[f"{varname}_ref"] = xr.DataArray(

# Reference, computed and error computation
data_vars[f"{varname}_reference"] = xr.DataArray(
np.stack(ref_data[varname]),
dims=("rank",) + tuple([f"{d}_out" for d in dims]),
attrs=attrs,
)
data_vars[f"{varname}_out"] = xr.DataArray(
data_vars[f"{varname}_computed"] = xr.DataArray(
np.stack([output[varname] for output in output_list]),
dims=("rank",) + tuple([f"{d}_out" for d in dims]),
attrs=attrs,
)
data_vars[f"{varname}_error"] = (
data_vars[f"{varname}_ref"] - data_vars[f"{varname}_out"]
absolute_errors = (
data_vars[f"{varname}_reference"] - data_vars[f"{varname}_computed"]
)
data_vars[f"{varname}_error"].attrs = attrs
data_vars[f"{varname}_absolute_error"] = absolute_errors
data_vars[f"{varname}_absolute_error"].attrs = attrs

print(f"File saved to {out_filename}")
xr.Dataset(data_vars=data_vars).to_netcdf(out_filename)
Loading

0 comments on commit 7f84b32

Please sign in to comment.