diff --git a/fx2ait/fx2ait/fx2ait.py b/fx2ait/fx2ait/fx2ait.py index 94d25c531..76d06de00 100644 --- a/fx2ait/fx2ait/fx2ait.py +++ b/fx2ait/fx2ait/fx2ait.py @@ -354,7 +354,7 @@ def get_attr(self, target, args, kwargs): if ait_friendly_name in self._loaded_params: existing_tensor = self._loaded_params[ait_friendly_name] assert existing_tensor._attrs["dtype"] == ait_dtype - assert existing_tensor._attrs["data"].tensor == ait_val + assert torch.all(existing_tensor._attrs["data"].tensor == ait_val) return existing_tensor data = _TorchConstantTensorData(ait_val)