Skip to content

Commit

Permalink
- Required changes for kernels (#3165)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3165

X-link: facebookresearch/FBGEMM#259

Adding small changes to kernels for CompiledAutograd support.

Adding `static constexpr bool is_traceable = true;` on kernels, making some kernels to use tensors instead of double and unrolling input shapes on GroupIndexSelectDim0GPUOp from vector into the ctx dict to help enablement of CompiledAutograd.

Reviewed By: Microve

Differential Revision: D63151913
  • Loading branch information
flaviotruzzi committed Sep 24, 2024
1 parent 012a658 commit 3972d03
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Tensor batch_index_select_dim0_codegen_backward_cuda(
class BatchIndexSelectDim0GPUOp
: public torch::autograd::Function<BatchIndexSelectDim0GPUOp> {
public:
static constexpr bool is_traceable = true;
static torch::autograd::variable_list forward_impl(
Tensor inputs,
Tensor indices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using torch::autograd::variable_list;
class PermuteMultiEmbeddingOp
: public torch::autograd::Function<PermuteMultiEmbeddingOp> {
public:
static constexpr bool is_traceable = true;
static variable_list forward(
AutogradContext* ctx,
const at::TensorList& pooled_embs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PermutePooledEmbsFunctionSplit
: public torch::autograd::Function<
PermutePooledEmbsFunctionSplit<permute_pooled_embs_op>> {
public:
static constexpr bool is_traceable = true;
static Variable forward(
AutogradContext* ctx,
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ namespace fbgemm_gpu {
// Needed this to support backward pass.
class PackSegments : public torch::autograd::Function<PackSegments> {
public:
static constexpr bool is_traceable = true;
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& t_in,
Expand Down
23 changes: 20 additions & 3 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ group_index_select_dim0_unpack(
class GroupIndexSelectDim0GPUOp
: public torch::autograd::Function<GroupIndexSelectDim0GPUOp> {
public:
static constexpr bool is_traceable = true;
// need to combine input_group and indices_group into one tensor list
// to get this working with autograd.
static torch::autograd::variable_list forward_impl(
Expand Down Expand Up @@ -435,13 +436,17 @@ class GroupIndexSelectDim0GPUOp
input_shape_group.end(), input_shape.begin(), input_shape.end());
}

for (unsigned long i = 0; i < input_shape_group.size(); i++) {
ctx->saved_data["input_shape_group_" + std::to_string(i)] =
input_shape_group[i];
}

// save indices, args_tensor, saved_data
auto saved_tensors = std::vector<Tensor>(indices_group);
saved_tensors.insert(
saved_tensors.end(), result.cbegin() + group_size, result.cend());
saved_tensors.push_back(input_group[0]);
ctx->save_for_backward(saved_tensors);
ctx->saved_data["input_shape_group"] = input_shape_group;

return result;
}
Expand Down Expand Up @@ -609,8 +614,20 @@ class GroupIndexSelectDim0GPUOp

auto saved_tensors = ctx->get_saved_variables();
TORCH_CHECK(saved_tensors.size() == group_size + 3);
auto output_shape_group =
ctx->saved_data["input_shape_group"].toSymIntVector();

std::vector<c10::SymInt> output_shape_group;
int i = 0;
while (true) {
auto el = ctx->saved_data.find("input_shape_group_" + std::to_string(i));

if (el == ctx->saved_data.end()) {
break;
}

output_shape_group.emplace_back(el->second.toSymInt());
i++;
}

grad_output_group.insert(
grad_output_group.end(), saved_tensors.begin(), saved_tensors.end());
static auto backward_op =
Expand Down

0 comments on commit 3972d03

Please sign in to comment.