From 9f97617a5ca2751a65368bd1d0b0a7149756f497 Mon Sep 17 00:00:00 2001 From: Henry Hu Date: Mon, 10 Jul 2023 18:57:41 -0700 Subject: [PATCH] Log max diff in AITTestCase (#817) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/817 Reviewed By: colinchan15 Differential Revision: D47283465 fbshipit-source-id: 28a945d8beb99e682f478567e3c944b4958f1187 --- fx2ait/fx2ait/tools/common_fx2ait.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fx2ait/fx2ait/tools/common_fx2ait.py b/fx2ait/fx2ait/tools/common_fx2ait.py index 67d0a9f3c..3d6fc3895 100644 --- a/fx2ait/fx2ait/tools/common_fx2ait.py +++ b/fx2ait/fx2ait/tools/common_fx2ait.py @@ -211,8 +211,12 @@ def run_test( out = map_aggregate( out, lambda output: output.permute(*permute_outputs) ) + out = out.cpu() + if out.numel() != 0: + max_diff = torch.max(torch.abs(out - ref)).item() + logger.info(f"Max diff = {max_diff}") torch.testing.assert_close( - out.cpu(), + out, ref, rtol=rtol, atol=atol,