Skip to content

Commit

Permalink
KJT custom op for 1d lengths input (pytorch#2774)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/torchrec#2163

Pull Request resolved: pytorch#2774

# context
* move the `tensor.view(-1, stride)` from python into the operator (c++)
* make the PT2 complier happy
* reference: D58948987

# notes
* not sure if we should directly change the op call in the jagged_tensor
* tested on CPU and GPU
* backward/autograd not tested

Reviewed By: IvanKobzarev

Differential Revision: D58956327

fbshipit-source-id: 0b07fd96c704251358cf3f58b11f842f6add32ba
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jun 26, 2024
1 parent 5fd8182 commit e47a82c
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 1 deletion.
72 changes: 71 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

import math
from typing import Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple

Check failure on line 10 in fbgemm_gpu/fbgemm_gpu/sparse_ops.py

View workflow job for this annotation

GitHub Actions / run-lint (3.11)

F401 'typing.Any' imported but unused

import torch

Expand Down Expand Up @@ -71,6 +71,72 @@ def wrapper(f: Callable) -> Callable:
return wrapper


def permute_2D_sparse_data_input1D_meta(
permute: Tensor,
lengths: Tensor,
values: Tensor,
stride: int,
weights: Optional[Tensor] = None,
permuted_lengths_sum: Optional[int] = None,
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
torch._check(
lengths.dim() == 1, lambda: f"expected lengths.dim() == 1, got {lengths.dim()}"
)
T = permute.numel()
B = stride
indices = values
permuted_lengths = lengths.new_empty([T * B])
permuted_indices_size = 0
if permuted_lengths_sum is not None:
permuted_indices_size = permuted_lengths_sum
else:
ctx = torch.library.get_ctx()
permuted_indices_size = ctx.new_dynamic_size()
# pyre-fixme
permuted_indices = indices.new_empty(permuted_indices_size)
permuted_weights = None
if weights is not None:
# pyre-fixme
permuted_weights = weights.new_empty(permuted_indices_size)
return permuted_lengths, permuted_indices, permuted_weights


# pyre-ignore
def permute_2D_sparse_data_input1D_setup_context(ctx, inputs, output):
permute, lengths, values, stride, weights, permuted_lengths_sum = inputs
permuted_lengths, permuted_values, permuted_weights = output
ctx.permute = permute
ctx.permuted_lengths = permuted_lengths
ctx.stride = stride


def permute_2D_sparse_data_input1D_backward(
ctx, # pyre-ignore
grad_lengths: torch.Tensor,
grad_values: torch.Tensor,
grad_weights: torch.Tensor,
) -> Tuple[None, Tensor, Tensor, None, Tensor, None]:
inv_permute = torch.ops.fbgemm.invert_permute(ctx.permute)
permuted_grad_lengths, permuted_grad_values, permuted_grad_weights = (
torch.ops.fbgemm.permute_2D_sparse_data_input1D(
inv_permute,
ctx.permuted_lengths,
grad_values,
ctx.stride,
grad_weights,
None,
)
)
return (
None,
permuted_grad_lengths,
permuted_grad_values,
None,
permuted_grad_weights,
None,
)


def permute_2D_sparse_data_meta(
permute: Tensor,
lengths: Tensor,
Expand Down Expand Up @@ -919,6 +985,10 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
)

impl_abstract("fbgemm::permute_2D_sparse_data", permute_2D_sparse_data_meta)
impl_abstract(
"fbgemm::permute_2D_sparse_data_input1D",
permute_2D_sparse_data_input1D_meta,
)
impl_abstract("fbgemm::invert_permute", invert_permute_abstract)
impl_abstract("fbgemm::permute_1D_sparse_data", permute_1D_sparse_data_meta)
impl_abstract("fbgemm::masked_select_jagged_1d", masked_select_jagged_1d)
Expand Down
18 changes: 18 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,24 @@ at::Tensor segment_sum_csr_cpu(
///@see You can find more info <a
/// href="https://www.doxygen.nl/manual/commands.html#cmdlink">here</a>

std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>>
permute_2D_sparse_data_input1D_cpu(
const at::Tensor& permute,
const at::Tensor& lengths,
const at::Tensor& indices,
const int64_t& stride,
const std::optional<at::Tensor>& weights,
const std::optional<int64_t>& permuted_lengths_sum);

std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>>
permute_2D_sparse_data_input1D_cuda(
const at::Tensor& permute,
const at::Tensor& lengths,
const at::Tensor& indices,
const int64_t& stride,
const std::optional<at::Tensor>& weights,
const std::optional<int64_t>& permuted_lengths_sum);

std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>>
permute_2D_sparse_data_cuda(
const at::Tensor& permute,
Expand Down
23 changes: 23 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,24 @@ void _bucketize_sparse_features_cpu(
}
}

std::tuple<Tensor, Tensor, std::optional<Tensor>>
permute_2D_sparse_data_input1D_cpu(
const Tensor& permute,
const Tensor& lengths,
const Tensor& indices,
const int64_t& stride,
const std::optional<Tensor>& weights,
const std::optional<int64_t>& permuted_lengths_sum) {
auto [permuted_lengths, permuted_indices, permuted_weights] =
permute_2D_sparse_data_cpu(
permute,
lengths.view({-1, stride}),
indices,
weights,
permuted_lengths_sum);
return {permuted_lengths.view(-1), permuted_indices, permuted_weights};
}

std::tuple<Tensor, Tensor, std::optional<Tensor>> permute_2D_sparse_data_cpu(
const Tensor& permute,
const Tensor& lengths,
Expand Down Expand Up @@ -3032,6 +3050,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)",
{PT2_COMPLIANT_TAG});
m.def(
"permute_2D_sparse_data_input1D(Tensor permute, Tensor lengths, Tensor values, SymInt stride, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.def(
"permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)",
{PT2_COMPLIANT_TAG});
Expand Down Expand Up @@ -3142,6 +3162,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
"permute_sparse_data", fbgemm_gpu::permute_2D_sparse_data_cpu);
DISPATCH_TO_CPU(
"permute_2D_sparse_data", fbgemm_gpu::permute_2D_sparse_data_cpu);
DISPATCH_TO_CPU(
"permute_2D_sparse_data_input1D",
fbgemm_gpu::permute_2D_sparse_data_input1D_cpu);
DISPATCH_TO_CPU(
"permute_1D_sparse_data", fbgemm_gpu::permute_1D_sparse_data_cpu);
DISPATCH_TO_CPU("invert_permute", fbgemm_gpu::invert_permute_cpu);
Expand Down
22 changes: 22 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,24 @@ permute_sparse_features_cuda(
return {permuted_lengths, permuted_indices, permuted_weights};
}
DLL_PUBLIC std::tuple<Tensor, Tensor, std::optional<Tensor>>
permute_2D_sparse_data_input1D_cuda(
const Tensor& permute,
const Tensor& lengths,
const Tensor& indices,
const int64_t& stride,
const std::optional<Tensor>& weights,
const std::optional<int64_t>& permuted_lengths_sum) {
auto [permuted_lengths, permuted_indices, permuted_weights] =
permute_2D_sparse_data_cuda(
permute,
lengths.view({-1, stride}),
indices,
weights,
permuted_lengths_sum);
return {permuted_lengths.view(-1), permuted_indices, permuted_weights};
}
} // namespace fbgemm_gpu
FBGEMM_OP_DISPATCH(
Expand All @@ -359,6 +377,10 @@ FBGEMM_OP_DISPATCH(
CUDA,
"permute_2D_sparse_data",
fbgemm_gpu::permute_2D_sparse_data_cuda);
FBGEMM_OP_DISPATCH(
CUDA,
"permute_2D_sparse_data_input1D",
fbgemm_gpu::permute_2D_sparse_data_input1D_cuda);
FBGEMM_OP_DISPATCH(
CUDA,
"permute_sparse_features",
Expand Down

0 comments on commit e47a82c

Please sign in to comment.