Skip to content

Commit

Permalink
permute_2D_sparse_data Autograd formula (pytorch#2629)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2629

Reland of D57625720 without dependency on torchrec

Adding permute_2D_sparse_data python formula for Inductor compilation.

Reviewed By: ezyang

Differential Revision: D57773001

fbshipit-source-id: bf73e5b79c0450cdcc123eb4630941d668beb1f9
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed May 24, 2024
1 parent 63ca6dc commit ab05ca9
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 26 deletions.
37 changes: 37 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,43 @@ def permute_2D_sparse_data_meta(
return permuted_lengths, permuted_indices, permuted_weights


@impl_abstract("fbgemm::invert_permute")
def invert_permute_abstract(permute: Tensor) -> Tensor:
return torch.empty_like(permute)


# pyre-ignore
def permute_2D_sparse_data_setup_context(ctx, inputs, output):
permute, lengths, values, weights, permuted_lengths_sum = inputs
permuted_lengths, permuted_values, permuted_weights = output
ctx.permute = permute
ctx.permuted_lengths = permuted_lengths


# pyre-ignore
def permute_2D_sparse_data_backward(ctx, grad_lengths, grad_values, grad_weights):
inv_permute = torch.ops.fbgemm.invert_permute(ctx.permute)
permuted_grad_lengths, permuted_grad_values, permuted_grad_weights = (
torch.ops.fbgemm.permute_2D_sparse_data(
inv_permute, ctx.permuted_lengths, grad_values, grad_weights
)
)
return (
None,
permuted_grad_lengths,
permuted_grad_values,
permuted_grad_weights,
None,
)


torch.library.register_autograd(
"fbgemm::permute_2D_sparse_data",
permute_2D_sparse_data_backward,
setup_context=permute_2D_sparse_data_setup_context,
)


@impl_abstract("fbgemm::permute_1D_sparse_data")
def permute_1D_sparse_data_meta(
permute: Tensor,
Expand Down
6 changes: 0 additions & 6 deletions fbgemm_gpu/test/sparse/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,6 @@ def extend_test_class(

additional_decorators = {
**(additional_decorators or {}),
**{
"test_pt2_compliant_tag_fbgemm_permute_2D_sparse_data": [
# This operator has been grandfathered in. We need to fix this test failure.
unittest.expectedFailure,
]
},
}

# Only generate tests for PyTorch 2.2+
Expand Down
22 changes: 2 additions & 20 deletions fbgemm_gpu/test/sparse/failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,7 @@
"status": "xfail"
}
},
"fbgemm::invert_permute": {
"MiscOpsTest.test_aot_dispatch_dynamic__test_invert_permute": {
"comment": "",
"status": "xfail"
},
"MiscOpsTest.test_faketensor__test_invert_permute": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::invert_permute": {},
"fbgemm::pack_segments": {},
"fbgemm::permute102_baddbmm_permute102": {
"MiscOpsTest.test_aot_dispatch_dynamic__test_permute102_baddbmm_permute102": {
Expand All @@ -166,16 +157,7 @@
}
},
"fbgemm::permute_1D_sparse_data": {},
"fbgemm::permute_2D_sparse_data": {
"PermuteEmbeddingsTest.test_aot_dispatch_dynamic__test_permute_embeddings": {
"comment": "",
"status": "xfail"
},
"PermuteIndicesTest.test_aot_dispatch_dynamic__test_permute_indices": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::permute_2D_sparse_data": {},
"fbgemm::permute_sequence_embeddings": {
"PermuteEmbeddingsTest.test_aot_dispatch_dynamic__test_permute_embeddings": {
"comment": "",
Expand Down

0 comments on commit ab05ca9

Please sign in to comment.