diff --git a/ndsl/testing/comparison.py b/ndsl/testing/comparison.py index 9812e06..d3ac0d6 100644 --- a/ndsl/testing/comparison.py +++ b/ndsl/testing/comparison.py @@ -39,12 +39,12 @@ def __init__( ): super().__init__(reference_values, computed_values) self.eps = eps + self._calculated_metric = np.empty_like(self.references) self.success = self._compute_errors( ignore_near_zero_errors, near_zero, ) self.check = np.all(self.success) - self._calculated_metric = np.empty_like(self.references) def _compute_errors( self, @@ -52,9 +52,10 @@ def _compute_errors( near_zero, ) -> npt.NDArray[np.bool_]: if self.references.dtype in (np.float64, np.int64, np.float32, np.int32): - denom = np.abs(self.references) + np.abs(self.computed) + denom = self.references + denom[self.references == 0] = self.computed[self.references == 0] self._calculated_metric = np.asarray( - 2.0 * np.abs(self.computed - self.references) / denom + np.abs((self.computed - self.references) / denom) ) self._calculated_metric[denom == 0] = 0.0 elif self.references.dtype in (np.bool_, bool): @@ -123,7 +124,7 @@ def __repr__(self) -> str: f"{reference_failures[b]} {abs_errs[-1]:.3e} {metric_err:.3e}" ) - if np.isnan(metric_err) or (metric_err > worst_metric_err): + if np.isnan(metric_err) or (abs(metric_err) > abs(worst_metric_err)): worst_metric_err = metric_err worst_full_idx = full_index worst_abs_err = abs_errs[-1] @@ -249,7 +250,7 @@ def __repr__(self) -> str: f"All failures ({bad_indices_count}/{full_count}) ({failures_pct}%),\n", f"Index Computed Reference " f"Absolute E(<{self.absolute_eps:.2e}) " - f"Relative E(<{self.relative_fraction*100:.2e}%) " + f"Relative E(<{self.relative_fraction * 100:.2e}%) " f"ULP E(<{self.ulp_threshold})", ] # Summary and worst result