Skip to content

Commit

Permalink
- Required changes for kernels
Browse files Browse the repository at this point in the history
Summary:
Adding small changes to kernels for CompiledAutograd support.

Adding `static constexpr bool is_traceable = true;` on kernels, making some kernels to use tensors instead of double and unrolling input shapes on GroupIndexSelectDim0GPUOp from vector into the ctx dict to help enablement of CompiledAutograd.

Differential Revision: D63151913
  • Loading branch information
flaviotruzzi committed Sep 23, 2024
1 parent af8ecb0 commit 955855c
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Tensor batch_index_select_dim0_codegen_backward_cuda(
class BatchIndexSelectDim0GPUOp
: public torch::autograd::Function<BatchIndexSelectDim0GPUOp> {
public:
static constexpr bool is_traceable = true;
static torch::autograd::variable_list forward_impl(
Tensor inputs,
Tensor indices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using torch::autograd::variable_list;
class PermuteMultiEmbeddingOp
: public torch::autograd::Function<PermuteMultiEmbeddingOp> {
public:
static constexpr bool is_traceable = true;
static variable_list forward(
AutogradContext* ctx,
const at::TensorList& pooled_embs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PermutePooledEmbsFunctionSplit
: public torch::autograd::Function<
PermutePooledEmbsFunctionSplit<permute_pooled_embs_op>> {
public:
static constexpr bool is_traceable = true;
static Variable forward(
AutogradContext* ctx,
const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)]
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ namespace fbgemm_gpu {
// Needed this to support backward pass.
class PackSegments : public torch::autograd::Function<PackSegments> {
public:
static constexpr bool is_traceable = true;
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& t_in,
Expand Down
24 changes: 21 additions & 3 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ group_index_select_dim0_unpack(
class GroupIndexSelectDim0GPUOp
: public torch::autograd::Function<GroupIndexSelectDim0GPUOp> {
public:
static constexpr bool is_traceable = true;
// need to combine input_group and indices_group into one tensor list
// to get this working with autograd.
static torch::autograd::variable_list forward_impl(
Expand Down Expand Up @@ -435,13 +436,18 @@ class GroupIndexSelectDim0GPUOp
input_shape_group.end(), input_shape.begin(), input_shape.end());
}

for (unsigned long i = 0; i < input_shape_group.size(); i++) {
ctx->saved_data["input_shape_group_" + std::to_string(i)] =
input_shape_group[i];
}

// save indices, args_tensor, saved_data
auto saved_tensors = std::vector<Tensor>(indices_group);
saved_tensors.insert(
saved_tensors.end(), result.cbegin() + group_size, result.cend());
saved_tensors.push_back(input_group[0]);
ctx->save_for_backward(saved_tensors);
ctx->saved_data["input_shape_group"] = input_shape_group;
// ctx->saved_data["input_shape_group"] = input_shape_group;

return result;
}
Expand Down Expand Up @@ -609,8 +615,20 @@ class GroupIndexSelectDim0GPUOp

auto saved_tensors = ctx->get_saved_variables();
TORCH_CHECK(saved_tensors.size() == group_size + 3);
auto output_shape_group =
ctx->saved_data["input_shape_group"].toSymIntVector();

std::vector<c10::SymInt> output_shape_group;
int i = 0;
while (true) {
auto el = ctx->saved_data.find("input_shape_group_" + std::to_string(i));

if (el == ctx->saved_data.end()) {
break;
}

output_shape_group.emplace_back(el->second.toSymInt());
i++;
}

grad_output_group.insert(
grad_output_group.end(), saved_tensors.begin(), saved_tensors.end());
static auto backward_op =
Expand Down

0 comments on commit 955855c

Please sign in to comment.