Skip to content

Commit

Permalink
Add unit test for unique cache lookup (pytorch#2160)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2160

As title

Differential Revision: D51600570

fbshipit-source-id: 3004dad5cfecd72c28067e6f9185bc131952bb01
  • Loading branch information
sryap authored and facebook-github-bot committed Nov 28, 2023
1 parent 18af2b2 commit ca1da75
Showing 1 changed file with 158 additions and 0 deletions.
158 changes: 158 additions & 0 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6543,6 +6543,164 @@ def check_weight_momentum(v: int) -> None:

check_weight_momentum(0)

@unittest.skipIf(*gpu_unavailable)
@given(
T=st.integers(min_value=1, max_value=10),
D=st.integers(min_value=2, max_value=128),
B=st.integers(min_value=1, max_value=128),
log_E=st.integers(min_value=3, max_value=5),
L=st.integers(min_value=0, max_value=20),
)
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
def test_unique_lxu_cache_lookup(
self,
T: int,
D: int,
B: int,
log_E: int,
L: int,
) -> None:
E = int(10**log_E)

indices = to_device(
torch.randint(low=0, high=E, size=(T * L * B,)),
use_cpu=False,
).long()
offsets = to_device(
torch.tensor([0] + list(accumulate([L] * (T * L)))),
use_cpu=False,
).long()

def unique_lookup(
indices: Tensor,
offsets: Tensor,
cache_hash_size_cumsum: Tensor,
total_cache_hash_size: int,
) -> Tuple[Tensor, Tensor]:
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum,
indices,
offsets,
)

uniq_indices, uniq_indices_length, _ = torch.ops.fbgemm.get_unique_indices(
linear_cache_indices, total_cache_hash_size, compute_count=False
)

uniq_lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
uniq_indices,
lxu_cache_state,
total_cache_hash_size,
gather_cache_stats=False,
num_uniq_cache_indices=uniq_indices_length,
)

return uniq_lxu_cache_locations, uniq_indices_length

def duplicate_lookup(
indices: Tensor,
offsets: Tensor,
cache_hash_size_cumsum: Tensor,
total_cache_hash_size: int,
) -> Tensor:
linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum,
indices,
offsets,
)

lxu_cache_locations = torch.ops.fbgemm.lxu_cache_lookup(
linear_cache_indices,
lxu_cache_state,
total_cache_hash_size,
gather_cache_stats=False,
)
return lxu_cache_locations

cache_sets = int((E * T) * 0.2)
lxu_cache_state = torch.zeros(
cache_sets,
DEFAULT_ASSOC,
device="cuda",
dtype=torch.int64,
).fill_(-1)

hash_sizes = torch.tensor([E] * T, dtype=torch.long, device="cuda")
cache_hash_size_cumsum = torch.ops.fbgemm.asynchronous_complete_cumsum(
hash_sizes
)
total_cache_hash_size = cache_hash_size_cumsum[-1].item()

linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
cache_hash_size_cumsum,
indices,
offsets,
)

# Emulate cache population
uniq_indices_cpu = linear_cache_indices.unique().cpu()
index_cache_set_map = uniq_indices_cpu.clone()
index_cache_set_map.apply_(
lambda x: torch.ops.fbgemm.lxu_cache_slot(x, cache_sets)
)
index_cache_set_map = index_cache_set_map.tolist()
uniq_indices_cpu = uniq_indices_cpu.tolist()

slots = {}
for idx, c in zip(uniq_indices_cpu, index_cache_set_map):
if c not in slots:
slots[c] = 0
slot = slots[c]
if slot < DEFAULT_ASSOC:
lxu_cache_state[c][slot] = idx
slots[c] = slot + 1

# Run unique lookup
uniq_lookup_output, uniq_indices_length = unique_lookup(
indices, offsets, cache_hash_size_cumsum, total_cache_hash_size
)

# Run duplicate lookup
duplicate_lookup_output = duplicate_lookup(
indices, offsets, cache_hash_size_cumsum, total_cache_hash_size
)

# Start running validation

# Compute unique indices using PyTorch ops
sorted_linear_cache_indices, inverse_sorted_cache_indices = torch.sort(
linear_cache_indices
)
ref_uniq_cache_indices, cache_indices_counts = torch.unique_consecutive(
sorted_linear_cache_indices, return_inverse=False, return_counts=True
)

# Convert to lists
cache_indices_counts = cache_indices_counts.cpu().tolist()
uniq_lookup_output = uniq_lookup_output.cpu().tolist()

# Validate the number of unique cache indices
ref_num_uniq_indices = ref_uniq_cache_indices.numel()
assert ref_num_uniq_indices == uniq_indices_length.item()

# Expand
reshaped_uniq_lookup_output = uniq_lookup_output[:ref_num_uniq_indices]
sorted_lxu_cache_locations = to_device(
torch.tensor(
np.repeat(reshaped_uniq_lookup_output, cache_indices_counts),
dtype=duplicate_lookup_output.dtype,
),
use_cpu=False,
)

_, cache_location_indices = torch.sort(inverse_sorted_cache_indices)

expanded_lxu_cache_locations = torch.index_select(
sorted_lxu_cache_locations, 0, cache_location_indices
)

assert torch.equal(expanded_lxu_cache_locations, duplicate_lookup_output)


if __name__ == "__main__":
unittest.main()

0 comments on commit ca1da75

Please sign in to comment.