diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 501853b84..f913bc891 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -493,9 +493,12 @@ class GroupIndexSelectDim0GPUOp // 1) Add group_size Variable()'s for indices // c10::irange cannot be used in here as it - // triggers a build error of i being an unused variable + // triggers a build error of i being an unused variable. + // Add empty tensor with zero size here to make __torch_dispatch__ work for + // the backward op. Those empty tensors will be replaced with + // torch::autograd::Variable() outside of the op call. for (auto i = 0; i < group_size; i++) { - outputs.push_back(torch::autograd::Variable()); + outputs.push_back(at::empty({0}, at::TensorOptions().dtype(at::kLong))); } // Allocate Tensor for ptrs of grad output and input, and indices @@ -615,6 +618,11 @@ class GroupIndexSelectDim0GPUOp "fbgemm::group_index_select_dim0_gpu_backward", "") .typed(); auto res = backward_op.call(grad_output_group, output_shape_group); + // 1) Add group_size Variable()'s for indices + // Replace all empty tensors with Variable(). This must be done after the + // op.call to make __torch_dispatch__ work for the backward op. + std::fill( + res.begin(), res.begin() + group_size, torch::autograd::Variable()); // 3) Add 1 Variable() for group_size res.push_back({}); return res;