From 95697dcf50e1025213bb4053fce22e8c472ef71b Mon Sep 17 00:00:00 2001 From: Quinn Zhu Date: Tue, 21 May 2024 21:13:24 -0700 Subject: [PATCH] Use torch.all to check tensor equality (#1008) Summary: Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/1008 When the tensor contains 0 or >=2 elements, this check will fail. Reviewed By: muchulee8 Differential Revision: D57636627 fbshipit-source-id: ee668e63fc1944f89a6a85a9cb883fe1588ac4a4 --- fx2ait/fx2ait/fx2ait.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fx2ait/fx2ait/fx2ait.py b/fx2ait/fx2ait/fx2ait.py index 94d25c531..76d06de00 100644 --- a/fx2ait/fx2ait/fx2ait.py +++ b/fx2ait/fx2ait/fx2ait.py @@ -354,7 +354,7 @@ def get_attr(self, target, args, kwargs): if ait_friendly_name in self._loaded_params: existing_tensor = self._loaded_params[ait_friendly_name] assert existing_tensor._attrs["dtype"] == ait_dtype - assert existing_tensor._attrs["data"].tensor == ait_val + assert torch.all(existing_tensor._attrs["data"].tensor == ait_val) return existing_tensor data = _TorchConstantTensorData(ait_val)