Skip to content

Commit

Permalink
Add arg checks to input_combine_with_length
Browse files Browse the repository at this point in the history
Summary: - Add arg checks to `input_combine_with_length`

Reviewed By: sryap

Differential Revision: D54138774

fbshipit-source-id: 26dce6bc8939124594aad0770601915100b11f21
  • Loading branch information
q10 authored and facebook-github-bot committed Feb 26, 2024
1 parent 70a1c96 commit 90ba6e4
Show file tree
Hide file tree
Showing 9 changed files with 350 additions and 220 deletions.
8 changes: 8 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ inline bool torch_tensor_empty_or_on_cpu_check(
" and ", \
(y).numel())

#define TENSOR_NUMEL_IS_GT(ten, num) \
TORCH_CHECK( \
(ten).numel() > (num), \
"Tensor '" #ten "' must have more than " #num \
" element(s). " \
"Found ", \
(ten).numel())

template <typename... Tensors>
std::string tensor_on_same_gpu_if_not_optional_check(
const std::string& var_names_str,
Expand Down
22 changes: 18 additions & 4 deletions fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
* LICENSE file in the root directory of this source tree.
*/

#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/input_combine.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

#include <ATen/ATen.h>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
Expand All @@ -21,6 +17,10 @@
#include <c10/util/Exception.h>
#include <torch/script.h>

#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/input_combine.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

using Tensor = at::Tensor;

namespace fbgemm_gpu {
Expand Down Expand Up @@ -204,9 +204,23 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_with_length_cpu(
const std::vector<Tensor>& indices_list,
const std::vector<Tensor>& lengths_list,
const std::vector<Tensor>& per_sample_weights) {
// The list sizes should be non-zero and match
TORCH_CHECK_GT(indices_list.size(), 0);
TORCH_CHECK_EQ(lengths_list.size(), indices_list.size());
TORCH_CHECK_EQ(per_sample_weights.size(), indices_list.size());

// Either the corresponding weights are provided for all indices tensors, or
// none are provided for any of the indices tensors
{
const auto nonempty_weights = static_cast<size_t>(std::count_if(
per_sample_weights.begin(),
per_sample_weights.end(),
[](const auto& ten) { return ten.numel() > 0; }));
TORCH_CHECK(
nonempty_weights == 0 || nonempty_weights == indices_list.size(),
"Either all weights tensors should be empty, or all should be non-empty");
}

int64_t total_indices = 0;
int64_t total_lengths = 0;
bool need_weights = false;
Expand Down
30 changes: 24 additions & 6 deletions fbgemm_gpu/src/input_combine_ops/input_combine_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
* LICENSE file in the root directory of this source tree.
*/

#include "fbgemm_gpu/input_combine.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>
#include "fbgemm_gpu/input_combine.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

using Tensor = at::Tensor;

Expand Down Expand Up @@ -62,10 +61,24 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_with_length_gpu(
const std::vector<Tensor>& indices_list,
const std::vector<Tensor>& lengths_list,
const std::vector<Tensor>& per_sample_weights) {
// The list sizes should be non-zero and match
TORCH_CHECK_GT(indices_list.size(), 0);
TORCH_CHECK_EQ(lengths_list.size(), indices_list.size());
TORCH_CHECK_EQ(per_sample_weights.size(), indices_list.size());

// Either the corresponding weights are provided for all indices tensors, or
// none are provided for any of the indices tensors
{
const auto nonempty_weights = static_cast<size_t>(std::count_if(
per_sample_weights.begin(),
per_sample_weights.end(),
[](const auto& ten) { return ten.numel() > 0; }));
TORCH_CHECK(
nonempty_weights == 0 || nonempty_weights == indices_list.size(),
"Either all weights tensors should be empty, or all should be non-empty");
}

const auto num_lists = indices_list.size();
TORCH_CHECK_GT(num_lists, 0);
TORCH_CHECK_EQ(lengths_list.size(), num_lists);
TORCH_CHECK_EQ(per_sample_weights.size(), num_lists);
const bool need_weights = std::any_of(
per_sample_weights.begin(), per_sample_weights.end(), [](const auto& x) {
return x.numel() > 0;
Expand Down Expand Up @@ -143,14 +156,19 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_with_length_gpu(
}
const auto& indices = indices_list[i];
const auto& lengths = lengths_list[i];

// Tensors are contiguous, on same device, and have same dtype
TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(indices);
TENSOR_CONTIGUOUS_AND_ON_CUDA_GPU(lengths);
TENSORS_ON_SAME_DEVICE(indices, indices_0);
TENSORS_ON_SAME_DEVICE(lengths, indices_0);
TORCH_CHECK(indices.dtype() == c10::kInt || indices.dtype() == c10::kLong);
TORCH_CHECK(lengths.dtype() == c10::kInt || lengths.dtype() == c10::kLong);
// Dimensions must be 1
TENSOR_NDIM_EQUALS(indices, 1);
TENSOR_NDIM_EQUALS(lengths, 1);
// Indices must be non-empty
TENSOR_NUMEL_IS_GT(indices, 0);

const auto indices_numel = indices.numel();
const auto lengths_numel = lengths.numel();
Expand Down
6 changes: 6 additions & 0 deletions fbgemm_gpu/test/combine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
79 changes: 79 additions & 0 deletions fbgemm_gpu/test/combine/bad_inputs_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from hypothesis import given, settings

from .common import open_source

if open_source:
# pyre-ignore[21]
from test_utils import cpu_and_maybe_gpu, optests
else:
from fbgemm_gpu.test.test_utils import cpu_and_maybe_gpu, optests


@optests.generate_opcheck_tests()
class BadInputsTest(unittest.TestCase):
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@given(device=cpu_and_maybe_gpu())
@settings(deadline=None)
def test_tbe_input_combine_with_length_bad_args(self, device: torch.device) -> None:
arg0_list = [
[88, 55],
[80, 29],
[2, 85],
[39, 51],
[84, 35],
[12, 6],
[94, 43],
[98, 59],
[19, 68],
[97, 89],
]
arg0 = [torch.tensor(t, dtype=torch.int32, device=device) for t in arg0_list]

arg1_list = [
[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2],
[1, 2],
]
arg1 = [torch.tensor(t, dtype=torch.int32, device=device) for t in arg1_list]

arg2_list = [
[],
[],
[],
[],
[3.0, 3.0],
[],
[],
[3.0, 3.0],
[3.0, 3.0],
[],
]
arg2 = [torch.tensor(t, dtype=torch.float, device=device) for t in arg2_list]

with self.assertRaises(RuntimeError):
torch.ops.fbgemm.tbe_input_combine_with_length(
arg0,
arg1,
arg2,
)


if __name__ == "__main__":
unittest.main()
126 changes: 126 additions & 0 deletions fbgemm_gpu/test/combine/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-ignore-all-errors[56]

from typing import List, Optional, Tuple

import fbgemm_gpu
import torch
from fbgemm_gpu import sparse_ops # noqa: F401

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)

if not open_source:
if torch.version.hip:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_hip")
else:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu")


class TBEInputPrepareReference(torch.nn.Module):
def __init__(self, include_last_offsets: List[bool]) -> None:
super().__init__()
self.include_last_offsets = include_last_offsets

def forward( # noqa C901
self,
indices_list: List[torch.Tensor],
offsets_list: List[torch.Tensor],
per_sample_weights_list: List[torch.Tensor],
batch_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
size = 0
assert len(indices_list) > 0
assert len(indices_list) == len(offsets_list)
assert len(indices_list) == len(per_sample_weights_list)
assert len(indices_list) == len(self.include_last_offsets)
for i in range(len(self.include_last_offsets)):
size += indices_list[i].size(0)
assert indices_list[i].dim() == 1
assert offsets_list[i].dim() == 1
if per_sample_weights_list[i].numel() > 0:
assert per_sample_weights_list[i].dim() == 1
assert indices_list[i].numel() == per_sample_weights_list[i].numel()
combined_indices = torch.empty(
size,
dtype=torch.int32,
device=indices_list[0].device,
)
torch.cat(indices_list, out=combined_indices)
offsets_starts = torch.zeros(
[len(offsets_list) + 1],
dtype=offsets_list[0].dtype,
device=offsets_list[0].device,
)
offsets_accs = torch.zeros(
[len(offsets_list) + 1],
dtype=offsets_list[0].dtype,
device=offsets_list[0].device,
)

for i, include_last_offset in enumerate(self.include_last_offsets):
if include_last_offset:
offsets_starts[i + 1] = offsets_starts[i] + offsets_list[i].size(0) - 1
else:
offsets_starts[i + 1] = offsets_starts[i] + offsets_list[i].size(0)
offsets_accs[i + 1] = offsets_accs[i] + indices_list[i].size(0)

assert offsets_accs[-1] == combined_indices.size(0)
combined_offsets_size: List[int] = (
[int(offsets_starts[-1].item()) + 1]
if batch_size is None
else [batch_size * len(offsets_list) + 1]
)
combined_offsets = torch.zeros(
combined_offsets_size,
dtype=torch.int32,
device=offsets_list[0].device,
)
if batch_size is None:
for i in range(len(self.include_last_offsets)):
combined_offsets[offsets_starts[i] : offsets_starts[i + 1]] = (
offsets_list[i][: offsets_starts[i + 1] - offsets_starts[i]]
+ offsets_accs[i]
)
else:
for i in range(len(self.include_last_offsets)):
cur_start = batch_size * i
combined_offsets[
cur_start : cur_start + offsets_starts[i + 1] - offsets_starts[i]
] = (
offsets_list[i][: offsets_starts[i + 1] - offsets_starts[i]]
+ offsets_accs[i]
)
cur_start = cur_start + offsets_starts[i + 1] - offsets_starts[i]
for j in range(batch_size - offsets_starts[i + 1] + offsets_starts[i]):
combined_offsets[cur_start + j] = (
indices_list[i].numel() + offsets_accs[i]
)
combined_offsets[-1] = offsets_accs[-1]
per_sample_weights: Optional[torch.Tensor] = None
for i in range(len(self.include_last_offsets)):
if per_sample_weights_list[i].size(0) > 0:
per_sample_weights = torch.ones(
combined_indices.size(0),
dtype=per_sample_weights_list[i].dtype,
device=per_sample_weights_list[i].device,
)
break
if per_sample_weights is not None:
for i in range(len(self.include_last_offsets)):
if per_sample_weights_list[i].size(0) > 0:
# fmt: off
per_sample_weights[offsets_accs[i] : offsets_accs[i + 1]] = (
per_sample_weights_list[i][:]
)
# fmt: on

# indices and offsets are required to be int32 for TBE
return combined_indices, combined_offsets, per_sample_weights
Loading

0 comments on commit 90ba6e4

Please sign in to comment.