Skip to content

Commit

Permalink
remove symint from tbe_input_combine_with_length_abstract (pytorch#2336)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2336

as titled, the number of offsets is not dynamic, rather it is the sum of the number of elements in the concatenated offset tensors

additionally, make sure new tensors are created on the same device as the indices

Reviewed By: khabinov

Differential Revision: D53862184

fbshipit-source-id: 6d98ac9d47725a325fdd55b9ca877d6f75af3779
  • Loading branch information
bradleyhd authored and facebook-github-bot committed Feb 19, 2024
1 parent 13d6bea commit bc4a73e
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def tbe_input_combine_with_length_abstract(
torch._check(len(indices_list) == len(offsets_list))
torch._check(len(indices_list) == len(per_sample_weights))
total_indices = 0
total_offsets = 0
need_weight = False
for index, offset, weight in zip(indices_list, offsets_list, per_sample_weights):
torch._check(index.dtype == torch.int or index.dtype == torch.long)
Expand All @@ -184,20 +185,20 @@ def tbe_input_combine_with_length_abstract(
torch._check(index.is_contiguous())
torch._check(offset.is_contiguous())
total_indices = total_indices + index.numel()
total_offsets = total_offsets + offset.numel()
if weight.numel() > 0:
torch._check(weight.dim() == 1)
torch._check(weight.numel() == index.numel())
torch._check(weight.is_contiguous())
need_weight = True
total_offsets = torch.library.get_ctx().new_dynamic_size()
combined_indices = indices_list[0].new_empty([total_indices], dtype=torch.int)
combined_offsets = offsets_list[0].new_empty([total_offsets], dtype=torch.int)
if need_weight:
combined_weights = per_sample_weights[0].new_empty(
[total_indices], dtype=torch.float
)
else:
combined_weights = torch.empty(0)
combined_weights = torch.empty(0, device=indices_list[0].device)
return combined_indices, combined_offsets, combined_weights


Expand Down

0 comments on commit bc4a73e

Please sign in to comment.