From 66efb7525784f4e6d3f7f74fd74480fa4c36143e Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Tue, 21 May 2024 09:25:52 -0700 Subject: [PATCH] fix two sparsenn failures (#2613) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2613 there were two failures in the dashboard: (1) `torch._custom_op.impl.get_ctx` appears to not exist (or get imported properly). From some grepping it looks like we pull from `torch.library.get_ctx` in most of the codebase (2) an error about branching on truthy numpy arrays in one of the sparsenn tests: I just coerced the numpy array into a boolean. Reviewed By: williamwen42 Differential Revision: D57050715 fbshipit-source-id: 46eb6f309ce1c3a033463bf06d5b139557bf5da9 --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 0ae295a6c..c06e5aa51 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -88,7 +88,7 @@ def permute_2D_sparse_data_meta( if permuted_lengths_sum is not None: permuted_indices_size = permuted_lengths_sum else: - ctx = torch._custom_op.impl.get_ctx() + ctx = torch.library.get_ctx() permuted_indices_size = ctx.new_dynamic_size() # pyre-fixme permuted_indices = indices.new_empty(permuted_indices_size) @@ -114,7 +114,7 @@ def permute_1D_sparse_data_meta( if permuted_lengths_sum is not None: permuted_indices_size = permuted_lengths_sum else: - ctx = torch._custom_op.impl.get_ctx() + ctx = torch.library.get_ctx() permuted_indices_size = ctx.new_dynamic_size() # pyre-fixme permuted_indices = indices.new_empty(permuted_indices_size)