Skip to content

Commit

Permalink
Use torch.all to check tensor equality (facebookincubator#1008)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#1008

When the tensor contains 0 or >=2 elements, this check will fail.

Reviewed By: muchulee8

Differential Revision: D57636627

fbshipit-source-id: ee668e63fc1944f89a6a85a9cb883fe1588ac4a4
  • Loading branch information
22quinn authored and facebook-github-bot committed May 22, 2024
1 parent d594370 commit 95697dc
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion fx2ait/fx2ait/fx2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def get_attr(self, target, args, kwargs):
if ait_friendly_name in self._loaded_params:
existing_tensor = self._loaded_params[ait_friendly_name]
assert existing_tensor._attrs["dtype"] == ait_dtype
assert existing_tensor._attrs["data"].tensor == ait_val
assert torch.all(existing_tensor._attrs["data"].tensor == ait_val)
return existing_tensor

data = _TorchConstantTensorData(ait_val)
Expand Down

0 comments on commit 95697dc

Please sign in to comment.