Skip to content

Commit

Permalink
Add support for duplicate in permutations for permute_pooled_embs_spl…
Browse files Browse the repository at this point in the history
…it (pytorch#1940)

Summary:
Pull Request resolved: pytorch#1940

This diff builds ontop of the pervious diffs and adds support for duplicates to the permute_pooled_embs_split op.

Background
Currently permute_pooled_embs_split does not support duplicates in a permutation, this poses a problem with passing the same embeddings to multiple modules. This doc proposes a solution to allow duplicate subsets in the resultant permutation.
Details
The required implementation of permute_pooled_embs_split should support a subset being repeated. This is represented by having duplicates in the permute list. This also results in the output list size being greater than the input list.
Input: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
Offset_dims: [0,  2,  5,  6, 10]
Permute: [3, 0, 2, 1, 3]
Output:  [6, 7, 8, 9, 0, 1, 5, 2, 3, 4, 6, 7, 8, 9]

Reviewed By: sryap

Differential Revision: D48305847

fbshipit-source-id: 4c82683b725592cad458e83596617a14f4c6e988
  • Loading branch information
AGZain authored and facebook-github-bot committed Sep 18, 2023
1 parent eb1103b commit f030bbc
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 7 deletions.
32 changes: 32 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embedding_ops_split.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,28 @@ at::Tensor permute_pooled_embs_split_cpu(
const at::Tensor& inv_offset_dim_list,
const at::Tensor& inv_permute_list);

// Implementation of permute_pooled_embs_split for GPU. This supports both the
// duplicate and non-duplicate cases with the allow_duplicates flag.
///@ingroup permute-pooled-embs-gpu-impl
at::Tensor permute_pooled_embs_split_gpu_impl(
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
const at::Tensor& offset_dim_list,
const at::Tensor& permute_list,
const at::Tensor& inv_offset_dim_list,
const at::Tensor& inv_permute_list,
const bool& allow_duplicates);

// Implementation of permute_pooled_embs_split for GPU for the duplicate
// permutations use case. This calls the permute_pooled_embs_split_gpu_impl
// function.
///@ingroup permute-duplicate-pooled-embs-gpu
at::Tensor permute_duplicate_pooled_embs_split_gpu(
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
const at::Tensor& offset_dim_list,
const at::Tensor& permute_list,
const at::Tensor& inv_offset_dim_list,
const at::Tensor& inv_permute_list);

///@ingroup permute-pooled-embs-gpu
at::Tensor permute_pooled_embs_split_gpu(
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
Expand All @@ -38,6 +60,16 @@ at::Tensor permute_pooled_embs_auto_grad_split_cpu(
const at::Tensor& inv_offset_dim_list,
const at::Tensor& inv_permute_list);

// Implementation of permute_pooled_embs_auto_grad_split for GPU for the
// duplicate permutations use case.
///@ingroup permute-duplicate-pooled-embs-gpu
at::Tensor permute_duplicate_pooled_embs_auto_grad_split_gpu(
const at::Tensor& pooled_embs,
const at::Tensor& offset_dim_list,
const at::Tensor& permute_list,
const at::Tensor& inv_offset_dim_list,
const at::Tensor& inv_permute_list);

///@ingroup permute-pooled-embs-gpu
at::Tensor permute_pooled_embs_auto_grad_split_gpu(
const at::Tensor& pooled_embs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Tensor permute_pooled_embs_gpu(
inv_permute_list,
false);
}

Tensor permute_pooled_embs_gpu_impl(
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
const Tensor& offset_dim_list,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,43 @@ Tensor permute_pooled_embs_split_gpu(
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
TORCH_CHECK(offset_dim_list.numel() == permute_list.numel() + 1);
TORCH_CHECK(offset_dim_list.numel() == inv_offset_dim_list.numel());

return permute_pooled_embs_split_gpu_impl(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list,
false);
}

Tensor permute_duplicate_pooled_embs_split_gpu(
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
TORCH_CHECK(offset_dim_list.numel() > 0);
TORCH_CHECK(inv_offset_dim_list.numel() > 0);

return permute_pooled_embs_split_gpu_impl(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list,
true);
}

Tensor permute_pooled_embs_split_gpu_impl(
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list,
const bool& allow_duplicates) {
// inv_permute_list is not being used so it's not checked here.
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list);
Expand All @@ -45,9 +82,14 @@ Tensor permute_pooled_embs_split_gpu(
TENSORS_ON_SAME_DEVICE(pooled_embs_contiguous, offset_dim_list);
TENSORS_ON_SAME_DEVICE(pooled_embs_contiguous, permute_list);
TENSORS_ON_SAME_DEVICE(pooled_embs_contiguous, inv_offset_dim_list);
TORCH_CHECK(offset_dim_list.numel() == permute_list.numel() + 1);
TORCH_CHECK(offset_dim_list.numel() == inv_offset_dim_list.numel());
Tensor permuted_pooled_embs = at::empty_like(pooled_embs_contiguous);

// Last index in inv_offset_dim_list contains the size of output.
// This will cause a D->H sync.
const int64_t permuted_embs_dim_sum =
allow_duplicates ? inv_offset_dim_list[-1].item<int64_t>() : dim_sum;
Tensor permuted_pooled_embs = at::empty(
{pooled_embs_contiguous.size(0), permuted_embs_dim_sum},
pooled_embs_contiguous.options());

// This kernel is moving D elements per warp.
// We are launching ( div_round_up(T, warp_per_block), B ) blocks.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

Tensor permute_pooled_embs_split_cpu(
Tensor permute_pooled_embs_split_cpu_impl(
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
const Tensor& inv_permute_list,
const bool& allow_duplicates) {
TORCH_CHECK(
offset_dim_list.scalar_type() == at::ScalarType::Long,
"offset_dim_list needs to have long/int64 type")
Expand All @@ -38,9 +39,10 @@ Tensor permute_pooled_embs_split_cpu(
"permute_list needs to have long/int64 type")
auto permute = permute_list.data_ptr<int64_t>();
const auto n = permute_list.numel();
const auto dims_size = allow_duplicates ? offset_dim_list.numel() : n;
std::vector<int64_t> dims;
dims.reserve(n - 1);
for (const auto i : c10::irange(1, n)) {
dims.reserve(dims_size - 1);
for (const auto i : c10::irange(1, dims_size)) {
dims.push_back(offset_dim_list[i].item<int64_t>());
}
auto ts = pooled_embs.tensor_split(dims, 1);
Expand All @@ -52,6 +54,36 @@ Tensor permute_pooled_embs_split_cpu(
return at::cat(permuted_ts, 1);
}

Tensor permute_pooled_embs_split_cpu(
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
return permute_pooled_embs_split_cpu_impl(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list,
false);
}

Tensor permute_duplicate_pooled_embs_split_cpu(
const Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
return permute_pooled_embs_split_cpu_impl(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list,
true);
}

Tensor permute_pooled_embs_auto_grad_split_cpu(
const Tensor& pooled_embs,
const Tensor& offset_dim_list,
Expand All @@ -65,16 +97,42 @@ Tensor permute_pooled_embs_auto_grad_split_cpu(
inv_offset_dim_list,
inv_permute_list);
}

Tensor permute_duplicate_pooled_embs_auto_grad_split_cpu(
const Tensor& pooled_embs,
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
return PermutePooledEmbsFunctionSplit<
permute_duplicate_pooled_embs_split_cpu>::
apply(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list);
}
} // namespace fbgemm_gpu

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"permute_pooled_embs_split(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
DISPATCH_TO_CPU(
"permute_pooled_embs_split", fbgemm_gpu::permute_pooled_embs_split_cpu);
m.def(
"permute_duplicate_pooled_embs_split(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
DISPATCH_TO_CPU(
"permute_duplicate_pooled_embs_split",
fbgemm_gpu::permute_duplicate_pooled_embs_split_cpu);
m.def(
"permute_pooled_embs_auto_grad_split(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
DISPATCH_TO_CPU(
"permute_pooled_embs_auto_grad_split",
fbgemm_gpu::permute_pooled_embs_auto_grad_split_cpu);
m.def(
"permute_duplicate_pooled_embs_auto_grad_split(Tensor pooled_embs, Tensor offset_dim_list, Tensor permute_list, Tensor inv_offset_dim_list, Tensor inv_permute_list) -> Tensor");
DISPATCH_TO_CPU(
"permute_duplicate_pooled_embs_auto_grad_split",
fbgemm_gpu::permute_duplicate_pooled_embs_auto_grad_split_cpu);
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,33 @@ Tensor permute_pooled_embs_auto_grad_split_gpu(
inv_permute_list);
}

Tensor permute_duplicate_pooled_embs_auto_grad_split_gpu(
const Tensor& pooled_embs,
const Tensor& offset_dim_list,
const Tensor& permute_list,
const Tensor& inv_offset_dim_list,
const Tensor& inv_permute_list) {
return PermutePooledEmbsFunctionSplit<
permute_duplicate_pooled_embs_split_gpu>::
apply(
pooled_embs,
offset_dim_list,
permute_list,
inv_offset_dim_list,
inv_permute_list);
}
} // namespace fbgemm_gpu

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
DISPATCH_TO_CUDA(
"permute_pooled_embs_split", fbgemm_gpu::permute_pooled_embs_split_gpu);
DISPATCH_TO_CUDA(
"permute_duplicate_pooled_embs_split",
fbgemm_gpu::permute_duplicate_pooled_embs_split_gpu);
DISPATCH_TO_CUDA(
"permute_pooled_embs_auto_grad_split",
fbgemm_gpu::permute_pooled_embs_auto_grad_split_gpu);
DISPATCH_TO_CUDA(
"permute_duplicate_pooled_embs_auto_grad_split",
fbgemm_gpu::permute_duplicate_pooled_embs_auto_grad_split_gpu);
}
88 changes: 88 additions & 0 deletions fbgemm_gpu/test/permute_pooled_embedding_split_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env fbpython
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from itertools import accumulate
from typing import List, Tuple

import torch
import torch._dynamo

try:
# pyre-ignore[21]
from fbgemm_gpu import open_source # noqa: F401

# pyre-ignore[21]
from test_utils import gpu_unavailable
except Exception:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_split_gpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_split_cpu"
)
from fbgemm_gpu.test.test_utils import gpu_unavailable

typed_gpu_unavailable: Tuple[bool, str] = gpu_unavailable


class PermutePooledEmbeddingSplitTest(unittest.TestCase):
def setUp(self) -> None:
super().setUp()
self.device = "cuda"

@unittest.skipIf(*typed_gpu_unavailable)
def test_duplicate_permutations(self) -> None:
# self.device = "cuda"
embs_dims = [2, 3, 1, 4]
permute = [3, 0, 2, 0, 1, 3]
expected_result = [6, 7, 8, 9, 0, 1, 5, 0, 1, 2, 3, 4, 6, 7, 8, 9]
input = torch.Tensor([range(10)]).to(device="cuda")

_permute = torch.tensor(permute, device=self.device, dtype=torch.int64)
_offset_dim_list = torch.tensor(
[0] + list(accumulate(embs_dims)), device=self.device, dtype=torch.int64
)
inv_permute: List[int] = [0] * len(permute)
for i, p in enumerate(permute):
inv_permute[p] = i
_inv_permute = torch.tensor(inv_permute, device=self.device, dtype=torch.int64)
inv_embs_dims = [embs_dims[i] for i in permute]
_inv_offset_dim_list = torch.tensor(
[0] + list(accumulate(inv_embs_dims)),
device=self.device,
dtype=torch.int64,
)

result = torch.ops.fbgemm.permute_duplicate_pooled_embs_auto_grad_split(
input,
_offset_dim_list.to(device=input.device),
_permute.to(device=input.device),
_inv_offset_dim_list.to(device=input.device),
_inv_permute.to(device=input.device),
)
self.assertEqual(
result.view(16).tolist(),
expected_result,
)

input = input.to(device="cpu")
result = torch.ops.fbgemm.permute_duplicate_pooled_embs_auto_grad_split(
input,
_offset_dim_list.to(device=input.device),
_permute.to(device=input.device),
_inv_offset_dim_list.to(device=input.device),
_inv_permute.to(device=input.device),
)
self.assertEqual(
result.view(16).tolist(),
expected_result,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit f030bbc

Please sign in to comment.