Skip to content

Commit

Permalink
Fix non-contiguous tensor problem in jagged_index_select (pytorch#2060)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2060

Before this diff, `jagged_index_select` kernels take raw pointers as
arguments.  This requires the input tensors to be contiguous.
However, the `jagged_index_select` operator did not make sure that the
tensors are contiguous before extracting and passing the raw pointers
to the kernels causing the correctness issue.  This diff replaces the
raw pointer arguments with PyTorch's `PackedTensorAccessor` which
handles non-contiguous tensor accesses automatically.  For some
tensors that their raw pointers are still being used, the operator
makes sure that the tensors are contiguous before using them.

Reviewed By: choudharydhruv

Differential Revision: D49937274

fbshipit-source-id: dcdc751191ae17e3697b99d30145c67ab470a218
  • Loading branch information
sryap authored and facebook-github-bot committed Oct 5, 2023
1 parent 2b04682 commit 8f7d8c7
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 46 deletions.
55 changes: 32 additions & 23 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_index_add_2d_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,24 @@ namespace fbgemm_gpu {

template <typename index_t, typename offset_t, typename scalar_t>
__global__ __launch_bounds__(kMaxThreads) void jagged_index_add_2d_kernel(
scalar_t* output,
const scalar_t* values,
const offset_t* input_offsets,
const index_t* indices,
const offset_t* output_offsets,
const int64_t num_input_rows,
const int64_t num_dense_input_rows,
const int64_t num_cols) {
at::PackedTensorAccessor64<scalar_t, 2, at::RestrictPtrTraits> output,
const at::PackedTensorAccessor64<scalar_t, 2, at::RestrictPtrTraits> values,
const at::PackedTensorAccessor32<offset_t, 1, at::RestrictPtrTraits>
input_offsets,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices,
const at::PackedTensorAccessor32<offset_t, 1, at::RestrictPtrTraits>
output_offsets,
const int64_t num_dense_input_rows) {
__shared__ int smem[1];
for (offset_t dense_input_offset = blockIdx.x;
dense_input_offset < num_dense_input_rows;
dense_input_offset += gridDim.x) {
// Binary search
// TODO: use multiple threads to do bin search to reduce number of steps
if (threadIdx.x == 0) {
const auto num_input_rows = indices.size(0);
binary_search_range(
smem, input_offsets, dense_input_offset, num_input_rows);
smem, &input_offsets[0], dense_input_offset, num_input_rows);
}
__syncthreads();

Expand All @@ -46,14 +47,11 @@ __global__ __launch_bounds__(kMaxThreads) void jagged_index_add_2d_kernel(
const offset_t output_offset =
(index == 0 ? 0 : output_offsets[index - 1]) + rel_index;

// Shift buffers
const scalar_t* values_ = values + dense_input_offset * num_cols;
scalar_t* output_ = output + output_offset * num_cols;

// TODO: Avoid using atoimcAdd (because it could lead to the numerical
// indeterminism issue)
const auto num_cols = output.size(1);
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
gpuAtomicAdd(&output_[i], values_[i]);
gpuAtomicAdd(&output[output_offset][i], values[dense_input_offset][i]);
}
}
}
Expand Down Expand Up @@ -85,7 +83,6 @@ Tensor jagged_index_add_2d_forward_cuda(
device_guard.set_index(values.get_device());

auto num_cols = values.size(1);
const int64_t num_input_rows = indices.numel();

const int64_t max_num_blocks = 1024; // Arbitrarily set to this number of now
const int64_t max_num_threads = kMaxThreads;
Expand All @@ -94,6 +91,9 @@ Tensor jagged_index_add_2d_forward_cuda(
Tensor output = at::zeros({num_output_rows, num_cols}, values.options());

if (num_blocks > 0) {
// input_offsets has to be contiguous since it is passed to
// binary_search_range which accepts raw pointers
const auto input_offsets_contig = input_offsets.expect_contiguous();
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
Expand All @@ -109,14 +109,23 @@ Tensor jagged_index_add_2d_forward_cuda(
dim3(num_cols),
0,
at::cuda::getCurrentCUDAStream()>>>(
output.data_ptr<scalar_t>(),
values.data_ptr<scalar_t>(),
input_offsets.data_ptr<int64_t>(),
indices.data_ptr<index_t>(),
output_offsets.data_ptr<int64_t>(),
num_input_rows,
num_dense_input_rows,
num_cols);
output.packed_accessor64<
scalar_t,
2,
at::RestrictPtrTraits>(),
values.packed_accessor64<
scalar_t,
2,
at::RestrictPtrTraits>(),
input_offsets_contig->packed_accessor32<
int64_t,
1,
at::RestrictPtrTraits>(),
indices
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
output_offsets
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
num_dense_input_rows);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down
55 changes: 32 additions & 23 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_index_select_2d_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,24 @@ namespace fbgemm_gpu {

template <typename index_t, typename offset_t, typename scalar_t>
__global__ __launch_bounds__(kMaxThreads) void jagged_index_select_2d_kernel(
scalar_t* output,
const scalar_t* input,
const offset_t* input_offsets,
const index_t* indices,
const offset_t* output_offsets,
const int64_t num_output_rows,
const int64_t num_dense_output_rows,
const int64_t num_cols) {
at::PackedTensorAccessor64<scalar_t, 2, at::RestrictPtrTraits> output,
const at::PackedTensorAccessor64<scalar_t, 2, at::RestrictPtrTraits> input,
const at::PackedTensorAccessor32<offset_t, 1, at::RestrictPtrTraits>
input_offsets,
const at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices,
const at::PackedTensorAccessor32<offset_t, 1, at::RestrictPtrTraits>
output_offsets,
const int64_t num_dense_output_rows) {
__shared__ int smem[1];
for (offset_t dense_output_offset = blockIdx.x;
dense_output_offset < num_dense_output_rows;
dense_output_offset += gridDim.x) {
// Binary search
// TODO: use multiple threads to do bin search to reduce number of steps
if (threadIdx.x == 0) {
const auto num_output_rows = indices.size(0);
binary_search_range(
smem, output_offsets, dense_output_offset, num_output_rows);
smem, &output_offsets[0], dense_output_offset, num_output_rows);
}
__syncthreads();

Expand All @@ -46,12 +47,9 @@ __global__ __launch_bounds__(kMaxThreads) void jagged_index_select_2d_kernel(
const offset_t input_offset =
(index == 0 ? 0 : input_offsets[index - 1]) + rel_index;

// Shift buffers
scalar_t* output_ = output + dense_output_offset * num_cols;
const scalar_t* input_ = input + input_offset * num_cols;

const auto num_cols = input.size(1);
for (int i = threadIdx.x; i < num_cols; i += blockDim.x) {
output_[i] = input_[i];
output[dense_output_offset][i] = input[input_offset][i];
}
}
}
Expand Down Expand Up @@ -81,7 +79,6 @@ Tensor jagged_index_select_2d_forward_cuda(
device_guard.set_index(values.get_device());

auto num_cols = values.size(1);
const int64_t num_output_rows = indices.numel();

const int64_t max_num_blocks = 1024; // Arbitrarily set to this number of now
const int64_t max_num_threads = kMaxThreads;
Expand All @@ -91,6 +88,9 @@ Tensor jagged_index_select_2d_forward_cuda(
at::empty({num_dense_output_rows, num_cols}, values.options());

if (num_blocks > 0) {
// output_offsets has to be contiguous since it is passed to
// binary_search_range which accepts raw pointers
const auto output_offsets_contig = output_offsets.expect_contiguous();
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
Expand All @@ -106,14 +106,23 @@ Tensor jagged_index_select_2d_forward_cuda(
dim3(num_cols),
0,
at::cuda::getCurrentCUDAStream()>>>(
output.data_ptr<scalar_t>(),
values.data_ptr<scalar_t>(),
input_offsets.data_ptr<int64_t>(),
indices.data_ptr<index_t>(),
output_offsets.data_ptr<int64_t>(),
num_output_rows,
num_dense_output_rows,
num_cols);
output.packed_accessor64<
scalar_t,
2,
at::RestrictPtrTraits>(),
values.packed_accessor64<
scalar_t,
2,
at::RestrictPtrTraits>(),
input_offsets
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
indices
.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
output_offsets_contig->packed_accessor32<
int64_t,
1,
at::RestrictPtrTraits>(),
num_dense_output_rows);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
Expand Down
10 changes: 10 additions & 0 deletions fbgemm_gpu/test/jagged_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,6 +1827,7 @@ def jagged_index_select_2d_ref(
else st.just(False)
if (gpu_available and TEST_WITH_ROCM)
else st.just(True),
check_non_contiguous=st.booleans(),
)
@settings(max_examples=20, deadline=None, verbosity=Verbosity.verbose)
def test_jagged_index_select_2d(
Expand All @@ -1838,6 +1839,7 @@ def test_jagged_index_select_2d(
index_dtype: torch.dtype,
jagged_tensor_dtype: torch.dtype,
use_cpu: bool,
check_non_contiguous: bool,
) -> None:
device = torch.device("cpu" if use_cpu else "cuda")
is_float = jagged_tensor_dtype in [torch.float, torch.half, torch.bfloat16]
Expand Down Expand Up @@ -1873,6 +1875,10 @@ def test_jagged_index_select_2d(
)
values_ref = values.detach().clone()

if check_non_contiguous:
values = values.as_strided(values.shape, (1, values.shape[0]))
values_ref = values_ref.as_strided(values.shape, (1, values.shape[0]))

# Only float tensors can require grad
if is_float:
values.requires_grad = True
Expand All @@ -1891,6 +1897,10 @@ def test_jagged_index_select_2d(
grad = torch.rand_like(output)
grad_ref = grad.detach().clone()

if check_non_contiguous:
grad = grad.as_strided(grad.shape, (1, grad.shape[0]))
grad_ref = grad_ref.as_strided(grad.shape, (1, grad.shape[0]))

output.backward(grad)
output_ref.backward(grad_ref)

Expand Down

0 comments on commit 8f7d8c7

Please sign in to comment.