diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index 71c7c01cd..0d3f79a80 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -455,6 +455,8 @@ set(fbgemm_gpu_sources_static_cpu codegen/training/backward/embedding_backward_dense_host_cpu.cpp codegen/utils/embedding_bounds_check_host_cpu.cpp src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp + src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp + src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp @@ -547,6 +549,7 @@ if(NOT FBGEMM_CPU_ONLY) src/metric_ops/metric_ops.cu src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu + src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu src/quantize_ops/quantize_bfloat16.cu src/quantize_ops/quantize_fp8_rowwise.cu src/quantize_ops/quantize_fused_8bit_rowwise.cu diff --git a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py index a06214b98..e56bd343e 100644 --- a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py +++ b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py @@ -19,16 +19,25 @@ torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu" ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu" + ) try: torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu" ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu" + ) except OSError: # This is for forward compatibility (new torch.package + old backend) # We should be able to remove it after this diff is picked up by all backend torch.ops.load_library( "//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu_cuda" ) + torch.ops.load_library( + "//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu_cuda" + ) except OSError: pass diff --git a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h new file mode 100644 index 000000000..fbda97d47 --- /dev/null +++ b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include "fbgemm_gpu/dispatch_macros.h" +#include "fbgemm_gpu/ops_utils.h" +#include "fbgemm_gpu/sparse_ops_utils.h" + +namespace fbgemm_gpu { + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +class PermuteMultiEmbeddingOp + : public torch::autograd::Function { + public: + static variable_list forward( + AutogradContext* ctx, + const at::TensorList& pooled_embs, + const Tensor& permutes, + const Tensor& in_shapes, + const Tensor& out_shapes, + const std::vector& out_lengths); + + static variable_list backward( + AutogradContext* ctx, + variable_list grad_output); +}; + +std::vector permute_multi_embedding_cpu( + const at::TensorList& pooled_embs, + const Tensor& permutes, + const Tensor& in_shapes, + const Tensor& out_shapes, + const std::vector& out_lengths, + const bool& reverse_permute); + +std::vector permute_multi_embedding_meta( + const at::TensorList& pooled_embs, + const Tensor& permutes, + const Tensor& in_shapes, + const Tensor& out_shapes, + const std::vector& out_lengths, + const bool& reverse_permute); + +std::vector permute_multi_embedding_gpu( + const at::TensorList& pooled_embs, + const Tensor& permutes, + const Tensor& in_shapes, + const Tensor& out_shapes, + const std::vector& out_lengths, + const bool& reverse_permute); +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp new file mode 100644 index 000000000..2d3a2ed65 --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/permute_multi_embedding_function.h" +#include +#include + +namespace fbgemm_gpu { + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +variable_list PermuteMultiEmbeddingOp::forward( + AutogradContext* ctx, + const at::TensorList& pooled_embs, + const Tensor& permutes, + const Tensor& in_shapes, + const Tensor& out_shapes, + const std::vector& out_lengths) { + ctx->saved_data["permutes"] = permutes; + ctx->saved_data["in_shapes"] = in_shapes; + ctx->saved_data["out_shapes"] = out_shapes; + + std::vector in_lengths; + in_lengths.reserve(pooled_embs.size()); + for (auto i : c10::irange(pooled_embs.size())) { + in_lengths.push_back(pooled_embs[i].size(1)); + } + ctx->saved_data["in_lengths"] = in_lengths; + + /* + select the correct dispatched (cpu/gpu) forward function + the cpu/gup function needs to be registered in the dispatcher, + e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc. + */ + const auto permute_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "") + .typed(); + + return permute_op.call( + pooled_embs, permutes, in_shapes, out_shapes, out_lengths, false); +} + +variable_list PermuteMultiEmbeddingOp::backward( + AutogradContext* ctx, + variable_list grad_output) { + const auto permutes = ctx->saved_data["permutes"].toTensor(); + const auto in_shapes = ctx->saved_data["in_shapes"].toTensor(); + const auto out_shapes = ctx->saved_data["out_shapes"].toTensor(); + const auto in_lengths = ctx->saved_data["in_lengths"].toIntVector(); + + /* + select the correct dispatched (cpu/gpu) backward function + the cpu/gup function needs to be registered in the dispatcher, + e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc. + */ + const auto permute_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "") + .typed(); + auto grad_input = permute_op.call( + grad_output, permutes, out_shapes, in_shapes, in_lengths, true); + grad_input.push_back(torch::autograd::Variable()); // permutes + grad_input.push_back(torch::autograd::Variable()); // in_shapes + grad_input.push_back(torch::autograd::Variable()); // out_shapes + grad_input.push_back(torch::autograd::Variable()); // out_lengths + return grad_input; +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu new file mode 100644 index 000000000..c5ea34f5e --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu @@ -0,0 +1,230 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fbgemm_gpu/fbgemm_cuda_utils.cuh" +#include "fbgemm_gpu/fbgemm_tensor_accessor.h" +#include "fbgemm_gpu/permute_multi_embedding_function.h" + +using Tensor = at::Tensor; + +namespace fbgemm_gpu { + +// Kernerl for permute pooled embedding op. +// This kernel is moving D elements per warp. +template +__global__ void permute_multi_embs_kernel( + const scalar_t** __restrict__ inputs, + const scalar_t** __restrict__ outputs, + const pta::PackedTensorAccessor32 + permutes, + const pta::PackedTensorAccessor32 + in_lengths, + const pta::PackedTensorAccessor32 + out_lengths, + const int32_t batch_size, + const int32_t permute_size) { + // workers in a warp handle exact one permute (of a feature/key) + const int32_t worker_id = threadIdx.x; + const int32_t permute_id = threadIdx.y + blockIdx.x * blockDim.x; + const int32_t batch_id = blockIdx.y + gridDim.y * blockIdx.z; + if (batch_id >= batch_size) { + return; + } + if (permute_id >= permute_size) { + return; + } + + // parse permutes + int32_t in_tensor, out_tensor, in_start, out_start, length, next; + if (reverse_permute) { + out_tensor = permutes[permute_id][0]; + in_tensor = permutes[permute_id][1]; + out_start = permutes[permute_id][2]; + in_start = permutes[permute_id][3]; + } else { + in_tensor = permutes[permute_id][0]; + out_tensor = permutes[permute_id][1]; + in_start = permutes[permute_id][2]; + out_start = permutes[permute_id][3]; + } + length = permutes[permute_id][4]; + next = permutes[permute_id][5]; + + if (worker_id >= length) { + return; + } + if (reverse_permute && next < 0) { + return; + } + + // locate the batch_id + int32_t in_length = in_lengths[in_tensor]; + scalar_t* input_ptr = (scalar_t*)inputs[in_tensor]; + input_ptr += batch_id * in_length; + + int32_t out_length = out_lengths[out_tensor]; + scalar_t* output_ptr = (scalar_t*)outputs[out_tensor]; + output_ptr += batch_id * out_length; + + if (fbgemm_gpu::is_aligned>( + &output_ptr[out_start]) && + fbgemm_gpu::is_aligned>( + &input_ptr[in_start])) { + constexpr int32_t vec_size = 4; + const int32_t loop_end = round_down(length, vec_size); + for (int32_t i = worker_id * vec_size; i < loop_end; + i += blockDim.x * vec_size) { + fbgemm_gpu::Vec4T::copy( + &input_ptr[in_start + i], &output_ptr[out_start + i]); + } + // Use elementwise access for the last incomplete vector. + for (int32_t i = loop_end + worker_id; i < length; i += blockDim.x) { + output_ptr[out_start + i] = input_ptr[in_start + i]; + } + } else { // Fallback if not aligned. + for (int32_t i = worker_id; i < length; i += blockDim.x) { + output_ptr[out_start + i] = input_ptr[in_start + i]; + } + } + + // for reverse_permute (backward) with next + while (reverse_permute && next > 0 && next < permute_size) { + in_tensor = permutes[next][1]; + in_start = permutes[next][3]; + length = permutes[next][4]; + next = -permutes[next][5]; + + int32_t in_length = in_lengths[in_tensor]; + scalar_t* input_ptr = (scalar_t*)inputs[in_tensor]; + input_ptr += batch_id * in_length; + + for (int32_t i = worker_id; i < length; i += blockDim.x) { + output_ptr[out_start + i] += input_ptr[in_start + i]; + } + } +} + +template +Tensor from_vec(const std::vector input) { + const auto int_opts = + torch::TensorOptions().dtype(torch::kInt32).pinned_memory(true); + Tensor output = at::empty({static_cast(input.size())}, int_opts); + // Ensure that output is contiguous + TORCH_CHECK(output.is_contiguous()); + std::memcpy( + output.data_ptr(), input.data(), input.size() * sizeof(index_t)); + return output; +} + +template +Tensor tensors_ptr(const at::TensorList& tensors) { + auto size = tensors.size(); + Tensor ptr_tensor = at::empty( + {static_cast(size * sizeof(scalar_t*))}, + at::TensorOptions().dtype(tensors[0].scalar_type()).pinned_memory(true)); + + // Ensure that ptr_tensor is contiguous + TORCH_CHECK(ptr_tensor.is_contiguous()); + auto tp = reinterpret_cast(ptr_tensor.data_ptr()); + for (int32_t i = 0; i < tensors.size(); i++) { + tp[i] = tensors[i].data_ptr(); + } + // Ensure that ptr_tensor is contiguous + TORCH_CHECK(ptr_tensor.is_contiguous()); + return ptr_tensor; +} + +std::vector permute_multi_embedding_gpu( + const at::TensorList& pooled_embs, + const Tensor& permutes, + const Tensor& in_shapes, + const Tensor& out_shapes, + const std::vector& out_lengths, + const bool& reverse_permute) { + int32_t num_of_input_tensors = in_shapes.size(0); + int32_t num_of_output_tensors = out_lengths.size(); + int32_t batch_size = pooled_embs[0].size(0); + int32_t permute_size = permutes.size(0); + + // check input tensors + std::vector inputs; + inputs.reserve(pooled_embs.size()); + for (int32_t i = 0; i < num_of_input_tensors; i++) { + Tensor cont_tensor = pooled_embs[i].contiguous(); + inputs.push_back(cont_tensor); + TENSORS_ON_SAME_DEVICE(cont_tensor, pooled_embs[i]); + TENSORS_ON_SAME_DEVICE(pooled_embs[i], pooled_embs[0]); + CUDA_DEVICE_GUARD(cont_tensor); + } + TORCH_CHECK(in_shapes.is_contiguous()); + TORCH_CHECK(out_shapes.is_contiguous()); + + // initiate output tensors + std::vector outputs; + outputs.reserve(num_of_output_tensors); + for (int32_t i = 0; i < num_of_output_tensors; i++) { + Tensor output = + at::empty({batch_size, out_lengths[i]}, pooled_embs[0].options()); + outputs.push_back(output); + } + auto device = pooled_embs[0].device(); + + // This kernel is moving one feature/key per warp. + // We are launching ( permute_size//warp_per_block, batch_size, ?) + // blocks. The grid z dimension is also used by batch_size in case it's + // greater than 65535. + const int32_t warp_per_block = + fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize; + const int32_t max_grid_dim = 32768; // The CUDA maximum is 65535, not 1<(batch_size), max_grid_dim), + (batch_size + max_grid_dim - 1) / max_grid_dim); + + FBGEMM_DISPATCH_FLOATING_TYPES( + pooled_embs[0].scalar_type(), "permute_multi_embedding", [&] { + Tensor in_ptr = tensors_ptr(inputs); + Tensor out_ptr = tensors_ptr(outputs); + in_ptr = in_ptr.to(device, /*non_blocking=*/true); + out_ptr = out_ptr.to(device, /*non_blocking=*/true); + const auto permute_kernel = reverse_permute + ? permute_multi_embs_kernel + : permute_multi_embs_kernel; + const auto stream = at::cuda::getCurrentCUDAStream(); +#ifdef FBGEMM_GPU_MEMCHECK + const char* func_name = "permute_multi_embs_kernel"; +#endif + permute_kernel<<>>( + reinterpret_cast(in_ptr.data_ptr()), + reinterpret_cast(out_ptr.data_ptr()), + MAKE_PTA_WITH_NAME(func_name, permutes, int32_t, 2, 32), + MAKE_PTA_WITH_NAME(func_name, in_shapes, int32_t, 1, 32), + MAKE_PTA_WITH_NAME(func_name, out_shapes, int32_t, 1, 32), + batch_size, + permute_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + return outputs; +} + +} // namespace fbgemm_gpu + +FBGEMM_OP_DISPATCH( + CUDA, + "permute_multi_embedding_function", + fbgemm_gpu::permute_multi_embedding_gpu); diff --git a/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp new file mode 100644 index 000000000..a0225e743 --- /dev/null +++ b/fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp @@ -0,0 +1,186 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/permute_multi_embedding_function.h" + +namespace fbgemm_gpu { + +using Tensor = at::Tensor; +using torch::autograd::AutogradContext; +using torch::autograd::variable_list; + +std::vector permute_multi_embedding_cpu( + const at::TensorList& pooled_embs, + const Tensor& permutes, + const Tensor& /* in_shapes */, + const Tensor& /* out_shapes */, + const std::vector& out_lengths, + const bool& reverse_permute) { + std::vector inputs; + inputs.reserve(pooled_embs.size()); + for (auto i : c10::irange(pooled_embs.size())) { + Tensor cont_tensor = pooled_embs[i].contiguous(); + inputs.push_back(cont_tensor); + TENSORS_ON_SAME_DEVICE(cont_tensor, pooled_embs[i]); + TENSORS_ON_SAME_DEVICE(pooled_embs[i], pooled_embs[0]); + } + + int32_t B = pooled_embs[0].size(0); + std::vector outputs; + outputs.reserve(out_lengths.size()); + for (const auto i : c10::irange(out_lengths.size())) { + outputs.push_back(at::empty({B, out_lengths[i]}, pooled_embs[0].options())); + TORCH_CHECK(outputs[i].is_contiguous()); + } + int32_t in_tensor, out_tensor, in_start, out_start, length, jump; + for (const auto i : c10::irange(permutes.size(0))) { + if (reverse_permute) { + out_tensor = permutes[i][0].item(); + in_tensor = permutes[i][1].item(); + out_start = permutes[i][2].item(); + in_start = permutes[i][3].item(); + jump = permutes[i][5].item(); + } else { + in_tensor = permutes[i][0].item(); + out_tensor = permutes[i][1].item(); + in_start = permutes[i][2].item(); + out_start = permutes[i][3].item(); + } + length = permutes[i][4].item(); + if (reverse_permute && jump < 0) { + for (auto b : c10::irange(B)) { + for (const auto j : c10::irange(length)) { + outputs[out_tensor][b][j + out_start] += + inputs[in_tensor][b][j + in_start]; + } + } + } else { + for (auto b : c10::irange(B)) { + auto outp = outputs[out_tensor][b].data_ptr() + out_start; + auto inp = inputs[in_tensor][b].data_ptr() + in_start; + std::memcpy(outp, inp, length * pooled_embs[0].itemsize()); + } + } + } + return outputs; +} + +std::vector permute_multi_embedding_meta( + const at::TensorList& pooled_embs, + const Tensor& /* permutes */, + const Tensor& /* in_shapes */, + const Tensor& /* out_shapes */, + const std::vector& out_lengths, + const bool& /* reverse_permute */) { + int32_t batch_size = pooled_embs[0].size(0); + + std::vector outputs; + outputs.reserve(out_lengths.size()); + for (const auto i : c10::irange(out_lengths.size())) { + outputs.push_back( + at::empty({batch_size, out_lengths[i]}, pooled_embs[0].options())); + } + return outputs; +} + +/// @ingroup permute pooled embedding function group +/// +/// @brief permute and regroup keyed tensors +/// +/// We often need to regroup keyed tensors (KTs) in a batch. For example, we +/// have two KTs A and B, where A contains the pooled embeddings of two features +/// (keys) F1 and F2, and B contains the pooled embeddings of two features +/// (keys) F3 and F4. Both KTs have the same batch size. +/// +/// We want to permute and regroup the KTs so that in the new KTs, F1 and F3 are +/// grouped together, and F2 and F4 are grouped together. +/// +/// **Example:** +/// ```python +/// # input arguments +/// keys = [["F1", "F2"], ["F3", "F4"]] +/// lengths = [[128, 128], [64, 32]] +/// batch_size = 1024 +/// values = [torch.randn(batch_size, 256), torch.randn(batch_size, 96)] +/// +/// # target output KTs +/// groups = [["F1", "F3"], ["F2", "F4"]] +/// +/// # generate permutes +/// permutes, in_shapes, out_shapes, out_lengths = kt_regroup_permutes(keys, +/// lengths, groups) +/// +/// # permute and regroup +/// permuted_values = permute_multi_embedding(values, permutes, in_shapes, +/// out_shapes, lengths) +/// ``` +/// +/// +/// @param pooled_embs list of tensors that from KTs' values +/// @param permutes a 2D tensor with each row representing a permute operation. +/// a permute operation is about how to move/copy a feature from the input KT to +/// the output KT. the first column is the input tensor index, and the second +/// column is the output tensor index. the third column is the feature's offset +/// of input tensor, and the fourth column is the feature's offset of output +/// tensor. the fifth column is the length of the feature in a permute, and the +/// last column is a jump flag. +/// @param in_shapes a 1D tensor with each element representing the length of an +/// input KT. +/// @param out_shapes a 1D tensor with each element representing the length of +/// an output KT. +/// @param out_lengths a 1D vector with each element representing the length of +/// an output KT. +/// +/// @return the values of the output KTs. +/// +/// +/// @note This operator supports autograd, and duplications in the output KTs +/// are supported, such as [["F1", "F3"], ["F2", "F4"], ["F1", "F3"]] +/// +/// @warning when a feature is omitted from the output KTs, the gradient of the +/// feature won't be set to 0. +/// +std::vector permute_multi_embedding( + const at::TensorList& pooled_embs, + const Tensor& permutes, + const Tensor& in_shapes, + const Tensor& out_shapes, + const std::vector& out_lengths) { + return PermuteMultiEmbeddingOp::apply( + pooled_embs, permutes, in_shapes, out_shapes, out_lengths); +} + +} // namespace fbgemm_gpu + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + // register the forward function for internal (autograd) usage + m.def( + "permute_multi_embedding_function(Tensor[] pooled_embs, Tensor permutes, Tensor in_shapes, Tensor out_shapes, SymInt[] out_lengths, bool reverse=False) -> Tensor[]"); + + // register the main function for external usage + m.def( + "permute_multi_embedding(Tensor[] pooled_embs,Tensor permutes, Tensor in_shapes, Tensor out_shapes, SymInt[] out_lengths) -> Tensor[]"); + + // dispatch the forward function to CPU for internal (autograd) usage + DISPATCH_TO_CPU( + "permute_multi_embedding_function", + fbgemm_gpu::permute_multi_embedding_cpu); + + // dispatch the forward function to CPU for internal (autograd) usage + DISPATCH_TO_META( + "permute_multi_embedding_function", + fbgemm_gpu::permute_multi_embedding_meta); + + // dispath the main function to Autograd for external usage + DISPATCH_TO_AUTOGRAD( + "permute_multi_embedding", fbgemm_gpu::permute_multi_embedding); + + // dispath the main function to Autograd for external usage + DISPATCH_TO_CUDA( + "permute_multi_embedding", fbgemm_gpu::permute_multi_embedding); +}