diff --git a/fx2ait/fx2ait/tools/common_fx2ait.py b/fx2ait/fx2ait/tools/common_fx2ait.py index 898a14234..e205e9097 100644 --- a/fx2ait/fx2ait/tools/common_fx2ait.py +++ b/fx2ait/fx2ait/tools/common_fx2ait.py @@ -216,8 +216,12 @@ def run_test( out = map_aggregate( out, lambda output: output.permute(*permute_outputs) ) + out = out.cpu() + if out.numel() != 0: + max_diff = torch.max(torch.abs(out - ref)) + logger.info(f"Max diff = {max_diff}") torch.testing.assert_close( - out.cpu(), + out, ref, rtol=rtol, atol=atol,