forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for duplicate in permutations for permute_pooled_embs_spl…
…it (pytorch#1940) Summary: Pull Request resolved: pytorch#1940 This diff builds ontop of the pervious diffs and adds support for duplicates to the permute_pooled_embs_split op. Background Currently permute_pooled_embs_split does not support duplicates in a permutation, this poses a problem with passing the same embeddings to multiple modules. This doc proposes a solution to allow duplicate subsets in the resultant permutation. Details The required implementation of permute_pooled_embs_split should support a subset being repeated. This is represented by having duplicates in the permute list. This also results in the output list size being greater than the input list. Input: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] Offset_dims: [0, 2, 5, 6, 10] Permute: [3, 0, 2, 1, 3] Output: [6, 7, 8, 9, 0, 1, 5, 2, 3, 4, 6, 7, 8, 9] Reviewed By: sryap Differential Revision: D48305847 fbshipit-source-id: 4c82683b725592cad458e83596617a14f4c6e988
- Loading branch information
1 parent
eb1103b
commit f030bbc
Showing
6 changed files
with
249 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#!/usr/bin/env fbpython | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
from itertools import accumulate | ||
from typing import List, Tuple | ||
|
||
import torch | ||
import torch._dynamo | ||
|
||
try: | ||
# pyre-ignore[21] | ||
from fbgemm_gpu import open_source # noqa: F401 | ||
|
||
# pyre-ignore[21] | ||
from test_utils import gpu_unavailable | ||
except Exception: | ||
torch.ops.load_library( | ||
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_split_gpu" | ||
) | ||
torch.ops.load_library( | ||
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_split_cpu" | ||
) | ||
from fbgemm_gpu.test.test_utils import gpu_unavailable | ||
|
||
typed_gpu_unavailable: Tuple[bool, str] = gpu_unavailable | ||
|
||
|
||
class PermutePooledEmbeddingSplitTest(unittest.TestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
self.device = "cuda" | ||
|
||
@unittest.skipIf(*typed_gpu_unavailable) | ||
def test_duplicate_permutations(self) -> None: | ||
# self.device = "cuda" | ||
embs_dims = [2, 3, 1, 4] | ||
permute = [3, 0, 2, 0, 1, 3] | ||
expected_result = [6, 7, 8, 9, 0, 1, 5, 0, 1, 2, 3, 4, 6, 7, 8, 9] | ||
input = torch.Tensor([range(10)]).to(device="cuda") | ||
|
||
_permute = torch.tensor(permute, device=self.device, dtype=torch.int64) | ||
_offset_dim_list = torch.tensor( | ||
[0] + list(accumulate(embs_dims)), device=self.device, dtype=torch.int64 | ||
) | ||
inv_permute: List[int] = [0] * len(permute) | ||
for i, p in enumerate(permute): | ||
inv_permute[p] = i | ||
_inv_permute = torch.tensor(inv_permute, device=self.device, dtype=torch.int64) | ||
inv_embs_dims = [embs_dims[i] for i in permute] | ||
_inv_offset_dim_list = torch.tensor( | ||
[0] + list(accumulate(inv_embs_dims)), | ||
device=self.device, | ||
dtype=torch.int64, | ||
) | ||
|
||
result = torch.ops.fbgemm.permute_duplicate_pooled_embs_auto_grad_split( | ||
input, | ||
_offset_dim_list.to(device=input.device), | ||
_permute.to(device=input.device), | ||
_inv_offset_dim_list.to(device=input.device), | ||
_inv_permute.to(device=input.device), | ||
) | ||
self.assertEqual( | ||
result.view(16).tolist(), | ||
expected_result, | ||
) | ||
|
||
input = input.to(device="cpu") | ||
result = torch.ops.fbgemm.permute_duplicate_pooled_embs_auto_grad_split( | ||
input, | ||
_offset_dim_list.to(device=input.device), | ||
_permute.to(device=input.device), | ||
_inv_offset_dim_list.to(device=input.device), | ||
_inv_permute.to(device=input.device), | ||
) | ||
self.assertEqual( | ||
result.view(16).tolist(), | ||
expected_result, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |