diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index f523773d2..7415e5c30 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -6543,6 +6543,164 @@ def check_weight_momentum(v: int) -> None: check_weight_momentum(0) + @unittest.skipIf(*gpu_unavailable) + @given( + T=st.integers(min_value=1, max_value=10), + D=st.integers(min_value=2, max_value=128), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=0, max_value=20), + ) + @settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None) + def test_unique_lxu_cache_lookup( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + ) -> None: + E = int(10**log_E) + + indices = to_device( + torch.randint(low=0, high=E, size=(T * L * B,)), + use_cpu=False, + ).long() + offsets = to_device( + torch.tensor([0] + list(accumulate([L] * (T * L)))), + use_cpu=False, + ).long() + + def unique_lookup( + indices: Tensor, + offsets: Tensor, + cache_hash_size_cumsum: Tensor, + total_cache_hash_size: int, + ) -> Tuple[Tensor, Tensor]: + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + cache_hash_size_cumsum, + indices, + offsets, + ) + + uniq_indices, uniq_indices_length, _ = torch.ops.fbgemm.get_unique_indices( + linear_cache_indices, total_cache_hash_size, compute_count=False + ) + + uniq_lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( + uniq_indices, + lxu_cache_state, + total_cache_hash_size, + gather_cache_stats=False, + num_uniq_cache_indices=uniq_indices_length, + ) + + return uniq_lxu_cache_locations, uniq_indices_length + + def duplicate_lookup( + indices: Tensor, + offsets: Tensor, + cache_hash_size_cumsum: Tensor, + total_cache_hash_size: int, + ) -> Tensor: + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + cache_hash_size_cumsum, + indices, + offsets, + ) + + lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup( + linear_cache_indices, + lxu_cache_state, + total_cache_hash_size, + gather_cache_stats=False, + ) + return lxu_cache_locations + + cache_sets = int((E * T) * 0.2) + lxu_cache_state = torch.zeros( + cache_sets, + DEFAULT_ASSOC, + device="cuda", + dtype=torch.int64, + ).fill_(-1) + + hash_sizes = torch.tensor([E] * T, dtype=torch.long, device="cuda") + cache_hash_size_cumsum = torch.ops.fbgemm.asynchronous_complete_cumsum( + hash_sizes + ) + total_cache_hash_size = cache_hash_size_cumsum[-1].item() + + linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices( + cache_hash_size_cumsum, + indices, + offsets, + ) + + # Emulate cache population + uniq_indices_cpu = linear_cache_indices.unique().cpu() + index_cache_set_map = uniq_indices_cpu.clone() + index_cache_set_map.apply_( + lambda x: torch.ops.fbgemm.lxu_cache_slot(x, cache_sets) + ) + index_cache_set_map = index_cache_set_map.tolist() + uniq_indices_cpu = uniq_indices_cpu.tolist() + + slots = {} + for idx, c in zip(uniq_indices_cpu, index_cache_set_map): + if c not in slots: + slots[c] = 0 + slot = slots[c] + if slot < DEFAULT_ASSOC: + lxu_cache_state[c][slot] = idx + slots[c] = slot + 1 + + # Run unique lookup + uniq_lookup_output, uniq_indices_length = unique_lookup( + indices, offsets, cache_hash_size_cumsum, total_cache_hash_size + ) + + # Run duplicate lookup + duplicate_lookup_output = duplicate_lookup( + indices, offsets, cache_hash_size_cumsum, total_cache_hash_size + ) + + # Start running validation + + # Compute unique indices using PyTorch ops + sorted_linear_cache_indices, inverse_sorted_cache_indices = torch.sort( + linear_cache_indices + ) + ref_uniq_cache_indices, cache_indices_counts = torch.unique_consecutive( + sorted_linear_cache_indices, return_inverse=False, return_counts=True + ) + + # Convert to lists + cache_indices_counts = cache_indices_counts.cpu().tolist() + uniq_lookup_output = uniq_lookup_output.cpu().tolist() + + # Validate the number of unique cache indices + ref_num_uniq_indices = ref_uniq_cache_indices.numel() + assert ref_num_uniq_indices == uniq_indices_length.item() + + # Expand + reshaped_uniq_lookup_output = uniq_lookup_output[:ref_num_uniq_indices] + sorted_lxu_cache_locations = to_device( + torch.tensor( + np.repeat(reshaped_uniq_lookup_output, cache_indices_counts), + dtype=duplicate_lookup_output.dtype, + ), + use_cpu=False, + ) + + _, cache_location_indices = torch.sort(inverse_sorted_cache_indices) + + expanded_lxu_cache_locations = torch.index_select( + sorted_lxu_cache_locations, 0, cache_location_indices + ) + + assert torch.equal(expanded_lxu_cache_locations, duplicate_lookup_output) + if __name__ == "__main__": unittest.main()