diff --git a/fx2ait/fx2ait/tools/common_fx2ait.py b/fx2ait/fx2ait/tools/common_fx2ait.py index 67d0a9f3c..3d6fc3895 100644 --- a/fx2ait/fx2ait/tools/common_fx2ait.py +++ b/fx2ait/fx2ait/tools/common_fx2ait.py @@ -211,8 +211,12 @@ def run_test( out = map_aggregate( out, lambda output: output.permute(*permute_outputs) ) + out = out.cpu() + if out.numel() != 0: + max_diff = torch.max(torch.abs(out - ref)).item() + logger.info(f"Max diff = {max_diff}") torch.testing.assert_close( - out.cpu(), + out, ref, rtol=rtol, atol=atol,