Skip to content

Commit

Permalink
Fix CPU tbe_input_combine to support cases that value tensors have pa…
Browse files Browse the repository at this point in the history
…dding. (#3162)

Summary:
X-link: facebookresearch/FBGEMM#260

Pull Request resolved: #3162

The current CPU implementation of tbe_input_combine assumes there is no padding in the value tensors. However, in some cases, we do have paddings, and we will see the combined tensor having padding in between the actual values, which would cause failures in downstream Pytorch operators.

Differential Revision: D63034520

fbshipit-source-id: 3ed0704d8222c56ecc1896832b4f5ed5695309d6
  • Loading branch information
Shiguang Wang authored and facebook-github-bot committed Sep 23, 2024
1 parent af8ecb0 commit 3760759
Show file tree
Hide file tree
Showing 2 changed files with 261 additions and 16 deletions.
134 changes: 118 additions & 16 deletions fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,65 @@ namespace fbgemm_gpu {
void _cat_int_tensors_out(
Tensor& combined_tensors,
const std::vector<Tensor>& tensor_list,
int64_t total_num) {
int64_t total_num,
bool to_trim_padding = false,
const std::vector<int64_t>& indices_terminating_idx =
std::vector<int64_t>()) {
if (to_trim_padding) {
// We need to define the teminating idx for each indices tensor
TORCH_CHECK(indices_terminating_idx.size() == tensor_list.size());
}
at::native::resize_(combined_tensors, {total_num});
auto* combined_tensors_data_ptr =
combined_tensors.mutable_data_ptr<int32_t>();
size_t idx = 0;

for (const auto& tensor : tensor_list) {
// Let's keep the original paddings and later pad them in the end
std::vector<int64_t> paddings;
paddings.reserve(total_num);

for (size_t i = 0; i < tensor_list.size(); ++i) {
const auto& tensor = tensor_list[i];
AT_DISPATCH_INDEX_TYPES(tensor.scalar_type(), "tbe_cat_inputs_", [&] {
// Necessary to use data_ptr. Checked in caller, but let's
// be safe in case somebody changes that.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.is_contiguous());
auto* indices_data_ptr = tensor.const_data_ptr<index_t>();
const auto numel = tensor.numel();
for (auto j = 0; j < numel; j++) {
auto numel = tensor.numel();
if (to_trim_padding) {
const auto terminating_idx = indices_terminating_idx.at(i);
numel = terminating_idx > 0 && terminating_idx < numel ? terminating_idx
: numel;
}
size_t j = 0;
for (; j < numel; j++) {
combined_tensors_data_ptr[idx++] =
static_cast<int32_t>(indices_data_ptr[j]);
}
for (; j < tensor.numel(); j++) {
paddings.push_back(indices_data_ptr[j]);
}
});
}

// Pad the original paddings in the end
int i = 0;
while (idx < total_num) {
if (i < paddings.size()) [[likely]] {
combined_tensors_data_ptr[idx++] = paddings[i++];
} else {
combined_tensors_data_ptr[idx++] = 0;
}
}
}

Tensor _cat_int_tensors(
const std::vector<Tensor>& tensor_list,
int64_t total_num,
bool use_pin_memory) {
bool use_pin_memory,
bool to_trim_padding = false,
const std::vector<int64_t>& indices_terminating_idx =
std::vector<int64_t>()) {
// Using int type to maintain original behavior
// in https://fburl.com/code/h2lwews2
auto combined_tensors = at::empty(
Expand All @@ -63,7 +97,12 @@ Tensor _cat_int_tensors(
.device(tensor_list[0].device())
.pinned_memory(use_pin_memory));

_cat_int_tensors_out(combined_tensors, tensor_list, total_num);
_cat_int_tensors_out(
combined_tensors,
tensor_list,
total_num,
to_trim_padding,
indices_terminating_idx);
return combined_tensors;
}

Expand Down Expand Up @@ -104,37 +143,61 @@ void _cat_per_sample_weights_list_out(
Tensor& out,
const std::vector<Tensor>& per_sample_weights,
const std::vector<Tensor>& indices_list,
int64_t total_num) {
int64_t total_num,
bool to_trim_padding = false,
const std::vector<int64_t>& indices_terminating_idx =
std::vector<int64_t>()) {
if (to_trim_padding) {
// We need to define the teminating idx for each indices tensor
TORCH_CHECK(indices_terminating_idx.size() == indices_list.size());
}
at::native::resize_(out, {total_num});
out.fill_(1.);

auto* out_weights_ptr = out.mutable_data_ptr<float>();

for (size_t i = 0; i < per_sample_weights.size(); i++) {
auto element_size = per_sample_weights[i].numel();
auto actual_indices_size = indices_list[i].numel();
if (to_trim_padding) {
element_size = element_size > indices_terminating_idx.at(i)
? indices_terminating_idx.at(i)
: element_size;
actual_indices_size = actual_indices_size > indices_terminating_idx.at(i)
? indices_terminating_idx.at(i)
: actual_indices_size;
}
if (element_size != 0) {
memcpy(
out_weights_ptr,
per_sample_weights[i].data_ptr<float>(),
element_size * sizeof(float));
}
out_weights_ptr += indices_list[i].numel();
out_weights_ptr += actual_indices_size;
}
}

Tensor _cat_per_sample_weights_list(
const std::vector<Tensor>& per_sample_weights,
const std::vector<Tensor>& indices_list,
int64_t total_num,
bool use_pin_memory) {
bool use_pin_memory,
bool to_trim_padding = false,
const std::vector<int64_t>& indices_terminating_idx =
std::vector<int64_t>()) {
auto combined_weights = at::empty(
{0},
at::TensorOptions()
.dtype(c10::kFloat)
.device(per_sample_weights[0].device())
.pinned_memory(use_pin_memory));
_cat_per_sample_weights_list_out(
combined_weights, per_sample_weights, indices_list, total_num);
combined_weights,
per_sample_weights,
indices_list,
total_num,
to_trim_padding,
indices_terminating_idx);
return combined_weights;
}

Expand All @@ -155,6 +218,24 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_cpu(
bool need_weights = false;
bool pin_memory = false;

// We only enable this feature when all the elements of `include_last_offsets`
// are True and there is at least one indices tensor has paddings
bool indices_tensor_has_padding = false;
for (size_t i = 0; i < indices_list.size(); i++) {
if (indices_list[i].numel() > offsets_list[i][-1].item().toLong()) {
indices_tensor_has_padding = true;
break;
}
}
auto to_trim_padding =
indices_tensor_has_padding && include_last_offsets.all().item<bool>();
// In case of index tensors have padding, we need to determine the boundary
// i.e. the terminating idx, to properly combine the TBE inputs
// `indices_terminating_idx` is a list of the terminating idx for each index
// tensor
std::vector<int64_t> indices_terminating_idx;
indices_terminating_idx.reserve(indices_list.size());

for (size_t i = 0; i < indices_list.size(); i++) {
TORCH_CHECK(
indices_list[i].dtype() == c10::kInt ||
Expand All @@ -167,6 +248,14 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_cpu(
TORCH_CHECK(indices_list[i].is_contiguous());
TORCH_CHECK(offsets_list[i].is_contiguous());
total_indices += indices_list[i].numel();
if (to_trim_padding) {
// When the offsets tensor has last offset, we respect this value
// And the last offset value should be less than (in case there are
// paddings) or equal to the number of elements in the indices tensor
TORCH_CHECK_LE(
offsets_list[i][-1].item().toLong(), indices_list[i].numel());
indices_terminating_idx.push_back(offsets_list[i][-1].item().toLong());
}
auto num_offset =
offsets_list[i].numel() - (include_last_offsets_acc[i] ? 1 : 0);
total_offsets += num_offset == 0 ? 1 : num_offset;
Expand All @@ -179,8 +268,12 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_cpu(
}
}

auto combined_indices =
_cat_int_tensors(indices_list, total_indices, pin_memory);
auto combined_indices = _cat_int_tensors(
indices_list,
total_indices,
pin_memory,
to_trim_padding,
indices_terminating_idx);

auto combined_offsets = at::empty(
{total_offsets},
Expand Down Expand Up @@ -208,17 +301,26 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_cpu(
offset + static_cast<int32_t>(offsets_data_ptr[j]);
}

offset += static_cast<int32_t>(indices_list[i].numel());
if (to_trim_padding) {
offset += static_cast<int32_t>(offsets_list[i][-1].item().toInt());
} else {
offset += static_cast<int32_t>(indices_list[i].numel());
}
combined_offsets_data_ptr[offsets_acc_idx++] = offset;
});
}

if (need_weights) {
return {
std::move(combined_indices),
std::move(combined_offsets),
combined_indices,
combined_offsets,
_cat_per_sample_weights_list(
per_sample_weights, indices_list, total_indices, pin_memory)};
per_sample_weights,
indices_list,
total_indices,
pin_memory,
to_trim_padding,
indices_terminating_idx)};
}
return {combined_indices, combined_offsets, at::empty({0})};
}
Expand Down
Loading

0 comments on commit 3760759

Please sign in to comment.