Skip to content

Commit

Permalink
2024-09-24 nightly release (012a658)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Sep 24, 2024
1 parent d6be525 commit 0d0aace
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,14 @@ static const std::unordered_map<
fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
{{128, 7168, 8192},
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{1024, 7168, 8192},
fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5},
{{2048, 7168, 8192},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{4096, 7168, 8192},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{{8192, 7168, 8192},
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
// Support for decode across batch sizes for [8192, 3584]
{{16, 8192, 3584},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
Expand All @@ -83,6 +89,14 @@ static const std::unordered_map<
fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
{{128, 8192, 3584},
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{1024, 8192, 3584},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{2048, 8192, 3584},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{4096, 8192, 3584},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{{8192, 8192, 3584},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
// Llama 405B Decode Shapes.
// Support for decode across batch sizes for [13312, 6656].
{{16, 13312, 6656},
Expand All @@ -102,8 +116,14 @@ static const std::unordered_map<
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{128, 13312, 16384},
fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{2048, 13312, 16384},
{{1024, 13312, 16384},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{2048, 13312, 16384},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{{4096, 13312, 16384},
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{{8192, 13312, 16384},
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
// Support for decode across batch sizes for [16384, 6656].
{{16, 16384, 6656},
fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1},
Expand All @@ -113,8 +133,14 @@ static const std::unordered_map<
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{128, 16384, 6656},
fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{1024, 16384, 6656},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{2048, 16384, 6656},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{4096, 16384, 6656},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{{8192, 16384, 6656},
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
// Support for decode across batch sizes for [16384, 16384].
{{16, 16384, 16384},
fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2},
Expand Down
134 changes: 118 additions & 16 deletions fbgemm_gpu/src/input_combine_ops/input_combine_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,65 @@ namespace fbgemm_gpu {
void _cat_int_tensors_out(
Tensor& combined_tensors,
const std::vector<Tensor>& tensor_list,
int64_t total_num) {
int64_t total_num,
bool to_trim_padding = false,
const std::vector<int64_t>& indices_terminating_idx =
std::vector<int64_t>()) {
if (to_trim_padding) {
// We need to define the teminating idx for each indices tensor
TORCH_CHECK(indices_terminating_idx.size() == tensor_list.size());
}
at::native::resize_(combined_tensors, {total_num});
auto* combined_tensors_data_ptr =
combined_tensors.mutable_data_ptr<int32_t>();
size_t idx = 0;

for (const auto& tensor : tensor_list) {
// Let's keep the original paddings and later pad them in the end
std::vector<int64_t> paddings;
paddings.reserve(total_num);

for (size_t i = 0; i < tensor_list.size(); ++i) {
const auto& tensor = tensor_list[i];
AT_DISPATCH_INDEX_TYPES(tensor.scalar_type(), "tbe_cat_inputs_", [&] {
// Necessary to use data_ptr. Checked in caller, but let's
// be safe in case somebody changes that.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tensor.is_contiguous());
auto* indices_data_ptr = tensor.const_data_ptr<index_t>();
const auto numel = tensor.numel();
for (auto j = 0; j < numel; j++) {
auto numel = tensor.numel();
if (to_trim_padding) {
const auto terminating_idx = indices_terminating_idx.at(i);
numel = terminating_idx > 0 && terminating_idx < numel ? terminating_idx
: numel;
}
size_t j = 0;
for (; j < numel; j++) {
combined_tensors_data_ptr[idx++] =
static_cast<int32_t>(indices_data_ptr[j]);
}
for (; j < tensor.numel(); j++) {
paddings.push_back(indices_data_ptr[j]);
}
});
}

// Pad the original paddings in the end
int i = 0;
while (idx < total_num) {
if (i < paddings.size()) [[likely]] {
combined_tensors_data_ptr[idx++] = paddings[i++];
} else {
combined_tensors_data_ptr[idx++] = 0;
}
}
}

Tensor _cat_int_tensors(
const std::vector<Tensor>& tensor_list,
int64_t total_num,
bool use_pin_memory) {
bool use_pin_memory,
bool to_trim_padding = false,
const std::vector<int64_t>& indices_terminating_idx =
std::vector<int64_t>()) {
// Using int type to maintain original behavior
// in https://fburl.com/code/h2lwews2
auto combined_tensors = at::empty(
Expand All @@ -63,7 +97,12 @@ Tensor _cat_int_tensors(
.device(tensor_list[0].device())
.pinned_memory(use_pin_memory));

_cat_int_tensors_out(combined_tensors, tensor_list, total_num);
_cat_int_tensors_out(
combined_tensors,
tensor_list,
total_num,
to_trim_padding,
indices_terminating_idx);
return combined_tensors;
}

Expand Down Expand Up @@ -104,37 +143,61 @@ void _cat_per_sample_weights_list_out(
Tensor& out,
const std::vector<Tensor>& per_sample_weights,
const std::vector<Tensor>& indices_list,
int64_t total_num) {
int64_t total_num,
bool to_trim_padding = false,
const std::vector<int64_t>& indices_terminating_idx =
std::vector<int64_t>()) {
if (to_trim_padding) {
// We need to define the teminating idx for each indices tensor
TORCH_CHECK(indices_terminating_idx.size() == indices_list.size());
}
at::native::resize_(out, {total_num});
out.fill_(1.);

auto* out_weights_ptr = out.mutable_data_ptr<float>();

for (size_t i = 0; i < per_sample_weights.size(); i++) {
auto element_size = per_sample_weights[i].numel();
auto actual_indices_size = indices_list[i].numel();
if (to_trim_padding) {
element_size = element_size > indices_terminating_idx.at(i)
? indices_terminating_idx.at(i)
: element_size;
actual_indices_size = actual_indices_size > indices_terminating_idx.at(i)
? indices_terminating_idx.at(i)
: actual_indices_size;
}
if (element_size != 0) {
memcpy(
out_weights_ptr,
per_sample_weights[i].data_ptr<float>(),
element_size * sizeof(float));
}
out_weights_ptr += indices_list[i].numel();
out_weights_ptr += actual_indices_size;
}
}

Tensor _cat_per_sample_weights_list(
const std::vector<Tensor>& per_sample_weights,
const std::vector<Tensor>& indices_list,
int64_t total_num,
bool use_pin_memory) {
bool use_pin_memory,
bool to_trim_padding = false,
const std::vector<int64_t>& indices_terminating_idx =
std::vector<int64_t>()) {
auto combined_weights = at::empty(
{0},
at::TensorOptions()
.dtype(c10::kFloat)
.device(per_sample_weights[0].device())
.pinned_memory(use_pin_memory));
_cat_per_sample_weights_list_out(
combined_weights, per_sample_weights, indices_list, total_num);
combined_weights,
per_sample_weights,
indices_list,
total_num,
to_trim_padding,
indices_terminating_idx);
return combined_weights;
}

Expand All @@ -155,6 +218,24 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_cpu(
bool need_weights = false;
bool pin_memory = false;

// We only enable this feature when all the elements of `include_last_offsets`
// are True and there is at least one indices tensor has paddings
bool indices_tensor_has_padding = false;
for (size_t i = 0; i < indices_list.size(); i++) {
if (indices_list[i].numel() > offsets_list[i][-1].item().toLong()) {
indices_tensor_has_padding = true;
break;
}
}
auto to_trim_padding =
indices_tensor_has_padding && include_last_offsets.all().item<bool>();
// In case of index tensors have padding, we need to determine the boundary
// i.e. the terminating idx, to properly combine the TBE inputs
// `indices_terminating_idx` is a list of the terminating idx for each index
// tensor
std::vector<int64_t> indices_terminating_idx;
indices_terminating_idx.reserve(indices_list.size());

for (size_t i = 0; i < indices_list.size(); i++) {
TORCH_CHECK(
indices_list[i].dtype() == c10::kInt ||
Expand All @@ -167,6 +248,14 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_cpu(
TORCH_CHECK(indices_list[i].is_contiguous());
TORCH_CHECK(offsets_list[i].is_contiguous());
total_indices += indices_list[i].numel();
if (to_trim_padding) {
// When the offsets tensor has last offset, we respect this value
// And the last offset value should be less than (in case there are
// paddings) or equal to the number of elements in the indices tensor
TORCH_CHECK_LE(
offsets_list[i][-1].item().toLong(), indices_list[i].numel());
indices_terminating_idx.push_back(offsets_list[i][-1].item().toLong());
}
auto num_offset =
offsets_list[i].numel() - (include_last_offsets_acc[i] ? 1 : 0);
total_offsets += num_offset == 0 ? 1 : num_offset;
Expand All @@ -179,8 +268,12 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_cpu(
}
}

auto combined_indices =
_cat_int_tensors(indices_list, total_indices, pin_memory);
auto combined_indices = _cat_int_tensors(
indices_list,
total_indices,
pin_memory,
to_trim_padding,
indices_terminating_idx);

auto combined_offsets = at::empty(
{total_offsets},
Expand Down Expand Up @@ -208,17 +301,26 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_cpu(
offset + static_cast<int32_t>(offsets_data_ptr[j]);
}

offset += static_cast<int32_t>(indices_list[i].numel());
if (to_trim_padding) {
offset += static_cast<int32_t>(offsets_list[i][-1].item().toInt());
} else {
offset += static_cast<int32_t>(indices_list[i].numel());
}
combined_offsets_data_ptr[offsets_acc_idx++] = offset;
});
}

if (need_weights) {
return {
std::move(combined_indices),
std::move(combined_offsets),
combined_indices,
combined_offsets,
_cat_per_sample_weights_list(
per_sample_weights, indices_list, total_indices, pin_memory)};
per_sample_weights,
indices_list,
total_indices,
pin_memory,
to_trim_padding,
indices_terminating_idx)};
}
return {combined_indices, combined_offsets, at::empty({0})};
}
Expand Down
10 changes: 6 additions & 4 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2843,8 +2843,9 @@ Tensor pack_segments_forward_cpu(
TORCH_CHECK(
t_in.dtype() == at::ScalarType::Float ||
t_in.dtype() == at::ScalarType::Double ||
t_in.dtype() == at::ScalarType::Half,
"t_in must be of type float or double or half");
t_in.dtype() == at::ScalarType::Half ||
t_in.dtype() == at::ScalarType::BFloat16,
"t_in must be of type float, double, half, or bfloat16");
TORCH_CHECK_GT(max_length, 0);

const auto t_in_cont = t_in.expect_contiguous();
Expand Down Expand Up @@ -2911,8 +2912,9 @@ Tensor pack_segments_backward_cpu(
TORCH_CHECK(
data.dtype() == at::ScalarType::Float ||
data.dtype() == at::ScalarType::Double ||
data.dtype() == at::ScalarType::Half,
"data must be of type float or double or half");
data.dtype() == at::ScalarType::Half ||
data.dtype() == at::ScalarType::BFloat16,
"data must be of type float, double, half, or bfloat16");
TORCH_CHECK(
max_length == data.sizes()[1],
"max_length should be equal to the second dimension of the packed segments");
Expand Down
Loading

0 comments on commit 0d0aace

Please sign in to comment.