Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implementation of fbgemm op - permute_multi_embedding (pytorch#2738)
Summary: Pull Request resolved: pytorch#2738 X-link: pytorch/torchrec#2120 # context * current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward * it has several limitations: a) it has to be a full permute, duplicates are not supported; b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient; c) the function output a single KT which requires a split operation * there is some attempt to support duplicated outputs, but the backward doesn't work * this diff is trying to create a new kernel (named `permute_multi_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support # notes * this diff focuses on the implemenation and test of the operator * performance analysis and benchmark are in the next diff # operator example usage * used in python ``` # test inputs: 3 KTs with batch_size=2048 batch_size = 2048 keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] lengths = [[96, 256], [512, 128, 768], [1024]] values = [ torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) for lens in lengths ] # target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] # accessorial arguments to the op/kernel permutes, in_lengths, out_lengths = _multi_remap_to_groups( keys, lengths, groups ) # arguments outputs = torch.ops.fbgemm.permute_multi_embedding_internal_testing( values, permutes, in_lengths, out_lengths ) ``` * permutes ``` # each row represents a key (feature) permute move, which consists of the following parameters: # [input_tensor_idx, output_tensor_idx, input_key_idx, output_key_idx, key_length, magic_jump] permutes = tensor( [ [0, 0, 0, 0, 3, 4], # f1 [1, 0, 0, 3, 5, 0], # f3 [0, 1, 3, 0, 4, 0], # f2 [1, 2, 5, 0, 6, 0], # f4 [0, 2, 0, 6, 3, -6], # f1 [2, 2, 0, 9, 8, 0], # f6 [0, 3, 0, 0, 3, -8], # f1 [1, 3, 11, 3, 7, 0], # f5 ] ) ``` # details 1. from the above example usage, we can clearly see that the operatior takes in the following: a) values: List[torch.Tensor], which represents the input KTs b) permutes: torch.Tensor, which contains the permute information, will be explained later c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead d) in_lengths: torch.Tensor, lengths of input tensors, which is on device e) out_lengths: torch.Tensor, lengths of output tensors, which is on device 2. the operator returns a list of tensors, which represents the permuted KTs 3. `permute` is the most critical argument in this operator: a) 2-D tensor b) each row represents a key (feature) permute move c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump] d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors 4. The magic_jump a) It's only used in the backward computation b) it's usually 0, means no jump c) it's non-zero when there is a duplicate in the permute, e.g., the same feature appears more than once in the output d) the `magic_jump` is the next index of the very same feature in the permute sequence with some modifications e) modification-1: `magic_jump` is positive when it's the first of its kind [Start] f) modification-2: `magic_jump` is negative when it's not the first of its kind [Continue] g) modification-3: `magic_jump` is the negative value of the length of the permute sequence when it's the last of its kind. [Stop] Reviewed By: sryap Differential Revision: D57055616 fbshipit-source-id: 16673d3a2eafab93b08d4ff3c43d54366966064a
- Loading branch information