Skip to content

Commit

Permalink
fix two sparsenn failures (pytorch#2613)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2613

there were two failures in the dashboard:

(1) `torch._custom_op.impl.get_ctx` appears to not exist (or get imported properly). From some grepping it looks like we pull from `torch.library.get_ctx` in most of the codebase

(2) an error about branching on truthy numpy arrays in one of the sparsenn tests: I just coerced the numpy array into a boolean.

Reviewed By: williamwen42

Differential Revision: D57050715

fbshipit-source-id: 46eb6f309ce1c3a033463bf06d5b139557bf5da9
  • Loading branch information
bdhirsh authored and facebook-github-bot committed May 21, 2024
1 parent 221f1c3 commit 66efb75
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def permute_2D_sparse_data_meta(
if permuted_lengths_sum is not None:
permuted_indices_size = permuted_lengths_sum
else:
ctx = torch._custom_op.impl.get_ctx()
ctx = torch.library.get_ctx()
permuted_indices_size = ctx.new_dynamic_size()
# pyre-fixme
permuted_indices = indices.new_empty(permuted_indices_size)
Expand All @@ -114,7 +114,7 @@ def permute_1D_sparse_data_meta(
if permuted_lengths_sum is not None:
permuted_indices_size = permuted_lengths_sum
else:
ctx = torch._custom_op.impl.get_ctx()
ctx = torch.library.get_ctx()
permuted_indices_size = ctx.new_dynamic_size()
# pyre-fixme
permuted_indices = indices.new_empty(permuted_indices_size)
Expand Down

0 comments on commit 66efb75

Please sign in to comment.