Skip to content

Commit

Permalink
Fix pyre error in _generate_baseline_single_dict_feature (#1441)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1441

Fix the pyre error in `_generate_baseline_single_dict_feature`

The issue was that the device parameter didn't have a type annotation. It appears that the correct type is torch.device.

Also changed the call to `.to()` by adding the `device=` named parameter. If a named parameter is not used in this case, torch may assume this is a dtype (see: https://pytorch.org/docs/stable/generated/torch.Tensor.to.html). Adding the named parameter also clarifies the intent.

Reviewed By: jjuncho

Differential Revision: D65914877

fbshipit-source-id: 5521105bac0e06c4205a4b7e6d2b75847b392416
  • Loading branch information
jsawruk authored and facebook-github-bot committed Nov 14, 2024
1 parent 11fea3d commit 49b1ed4
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions captum/attr/_models/pytext.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ def _generate_baseline_single_word_feature(self, device) -> torch.Tensor:

def _generate_baseline_single_dict_feature(
self,
# pyre-fixme[2]: Parameter `device` has no type specified.
device,
device: torch.device,
) -> Tuple[torch.Tensor, ...]:
r"""Generate dict features based on Assistant's case study by using
sia_transformer:
Expand Down Expand Up @@ -197,7 +196,7 @@ def _generate_baseline_single_dict_feature(
]
)
.unsqueeze(0)
.to(device)
.to(device=device)
)
gazetteer_feat_weights = (
torch.tensor(gazetteer_feat_weights).unsqueeze(0).to(device)
Expand Down

0 comments on commit 49b1ed4

Please sign in to comment.