Skip to content

Commit

Permalink
implementation of fbgemm op - permute_multi_embedding (pytorch#2738)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2738

X-link: pytorch/torchrec#2120

# context
* current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward
* it has several limitations:
a) it has to be a full permute, duplicates are not supported;
b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient;
c) the function output a single KT which requires a split operation
* there is some attempt to support duplicated outputs, but the backward doesn't work
* this diff is trying to create a new kernel (named `permute_multi_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support

# notes
* this diff focuses on the implemenation and test of the operator
* performance analysis and benchmark are in the next diff

# operator example usage
* used in python
```
# test inputs: 3 KTs with batch_size=2048
batch_size = 2048
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[96, 256], [512, 128, 768], [1024]]
values = [
    torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
    for lens in lengths
]

# target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]

# accessorial arguments to the op/kernel
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
    keys, lengths, groups
)

# arguments
outputs = torch.ops.fbgemm.permute_multi_embedding_internal_testing(
    values, permutes, in_lengths, out_lengths
)
```
* permutes
```
# each row represents a key (feature) permute move, which consists of the following parameters:
# [input_tensor_idx, output_tensor_idx, input_key_idx, output_key_idx, key_length, magic_jump]
permutes = tensor(
            [
                [0, 0, 0, 0, 3, 4],  # f1
                [1, 0, 0, 3, 5, 0],  # f3
                [0, 1, 3, 0, 4, 0],  # f2
                [1, 2, 5, 0, 6, 0],  # f4
                [0, 2, 0, 6, 3, -6],  # f1
                [2, 2, 0, 9, 8, 0],  # f6
                [0, 3, 0, 0, 3, -8],  # f1
                [1, 3, 11, 3, 7, 0],  # f5
            ]
)
```

# details
1. from the above example usage, we can clearly see that the operatior takes in the following:
a) values: List[torch.Tensor], which represents the input KTs
b) permutes: torch.Tensor, which contains the permute information, will be explained later
c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead
d) in_lengths: torch.Tensor, lengths of input tensors, which is on device
e) out_lengths: torch.Tensor, lengths of output tensors, which is on device
2. the operator returns a list of tensors, which represents the permuted KTs
3. `permute` is the most critical argument in this operator:
a) 2-D tensor
b) each row represents a key (feature) permute move
c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump]
d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors
4. The magic_jump
a) It's only used in the backward computation
b) it's usually 0, means no jump
c) it's non-zero when there is a duplicate in the permute, e.g., the same feature appears more than once in the output
d) the `magic_jump` is the next index of the very same feature in the permute sequence with some modifications
e) modification-1: `magic_jump` is positive when it's the first of its kind [Start]
f) modification-2: `magic_jump` is negative when it's not the first of its kind [Continue]
g) modification-3: `magic_jump` is the negative value of the length of the permute sequence when it's the last of its kind. [Stop]

Reviewed By: sryap

Differential Revision: D57055616

fbshipit-source-id: 16673d3a2eafab93b08d4ff3c43d54366966064a
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jul 9, 2024
1 parent bbdd76c commit 87cfbdf
Show file tree
Hide file tree
Showing 6 changed files with 574 additions and 0 deletions.
3 changes: 3 additions & 0 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,8 @@ set(fbgemm_gpu_sources_static_cpu
codegen/training/backward/embedding_backward_dense_host_cpu.cpp
codegen/utils/embedding_bounds_check_host_cpu.cpp
src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
Expand Down Expand Up @@ -547,6 +549,7 @@ if(NOT FBGEMM_CPU_ONLY)
src/metric_ops/metric_ops.cu
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
src/quantize_ops/quantize_bfloat16.cu
src/quantize_ops/quantize_fp8_rowwise.cu
src/quantize_ops/quantize_fused_8bit_rowwise.cu
Expand Down
9 changes: 9 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,25 @@
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
)
try:
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu"
)
except OSError:
# This is for forward compatibility (new torch.package + old backend)
# We should be able to remove it after this diff is picked up by all backend
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu_cuda"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu_cuda"
)
except OSError:
pass

Expand Down
69 changes: 69 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <torch/csrc/api/include/torch/types.h>
#include <torch/csrc/autograd/custom_function.h>

#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/ops_utils.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

namespace fbgemm_gpu {

using Tensor = at::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::variable_list;

using Tensor = at::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::variable_list;

class PermuteMultiEmbeddingOp
: public torch::autograd::Function<PermuteMultiEmbeddingOp> {
public:
static variable_list forward(
AutogradContext* ctx,
const at::TensorList& pooled_embs,
const Tensor& permutes,
const Tensor& in_shapes,
const Tensor& out_shapes,
const std::vector<int64_t>& out_lengths);

static variable_list backward(
AutogradContext* ctx,
variable_list grad_output);
};

std::vector<Tensor> permute_multi_embedding_cpu(
const at::TensorList& pooled_embs,
const Tensor& permutes,
const Tensor& in_shapes,
const Tensor& out_shapes,
const std::vector<int64_t>& out_lengths,
const bool& reverse_permute);

std::vector<Tensor> permute_multi_embedding_meta(
const at::TensorList& pooled_embs,
const Tensor& permutes,
const Tensor& in_shapes,
const Tensor& out_shapes,
const std::vector<int64_t>& out_lengths,
const bool& reverse_permute);

std::vector<Tensor> permute_multi_embedding_gpu(
const at::TensorList& pooled_embs,
const Tensor& permutes,
const Tensor& in_shapes,
const Tensor& out_shapes,
const std::vector<int64_t>& out_lengths,
const bool& reverse_permute);
} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fbgemm_gpu/permute_multi_embedding_function.h"
#include <cstdint>
#include <iostream>

namespace fbgemm_gpu {

using Tensor = at::Tensor;
using torch::autograd::AutogradContext;
using torch::autograd::variable_list;

variable_list PermuteMultiEmbeddingOp::forward(
AutogradContext* ctx,
const at::TensorList& pooled_embs,
const Tensor& permutes,
const Tensor& in_shapes,
const Tensor& out_shapes,
const std::vector<int64_t>& out_lengths) {
ctx->saved_data["permutes"] = permutes;
ctx->saved_data["in_shapes"] = in_shapes;
ctx->saved_data["out_shapes"] = out_shapes;

std::vector<int64_t> in_lengths;
in_lengths.reserve(pooled_embs.size());
for (auto i : c10::irange(pooled_embs.size())) {
in_lengths.push_back(pooled_embs[i].size(1));
}
ctx->saved_data["in_lengths"] = in_lengths;

/*
select the correct dispatched (cpu/gpu) forward function
the cpu/gup function needs to be registered in the dispatcher,
e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc.
*/
const auto permute_op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "")
.typed<decltype(permute_multi_embedding_cpu)>();

return permute_op.call(
pooled_embs, permutes, in_shapes, out_shapes, out_lengths, false);
}

variable_list PermuteMultiEmbeddingOp::backward(
AutogradContext* ctx,
variable_list grad_output) {
const auto permutes = ctx->saved_data["permutes"].toTensor();
const auto in_shapes = ctx->saved_data["in_shapes"].toTensor();
const auto out_shapes = ctx->saved_data["out_shapes"].toTensor();
const auto in_lengths = ctx->saved_data["in_lengths"].toIntVector();

/*
select the correct dispatched (cpu/gpu) backward function
the cpu/gup function needs to be registered in the dispatcher,
e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc.
*/
const auto permute_op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "")
.typed<decltype(permute_multi_embedding_cpu)>();
auto grad_input = permute_op.call(
grad_output, permutes, out_shapes, in_shapes, in_lengths, true);
grad_input.push_back(torch::autograd::Variable()); // permutes
grad_input.push_back(torch::autograd::Variable()); // in_shapes
grad_input.push_back(torch::autograd::Variable()); // out_shapes
grad_input.push_back(torch::autograd::Variable()); // out_lengths
return grad_input;
}

} // namespace fbgemm_gpu
Loading

0 comments on commit 87cfbdf

Please sign in to comment.