Skip to content

Commit

Permalink
Added more sanity checks with verbose messages to the slice_scatter k…
Browse files Browse the repository at this point in the history
…ernel (facebookincubator#825)

Summary:
This would help debug runtime failures.

Pull Request resolved: facebookincubator#825

Reviewed By: yinghai, wushirong

Differential Revision: D47395920

Pulled By: chenyang78

fbshipit-source-id: 434498fc5688fa74d8a60a1c6ef15b35b84e9f98
  • Loading branch information
chenyang78 authored and facebook-github-bot committed Jul 14, 2023
1 parent b779b01 commit 17e20b4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 9 deletions.
68 changes: 59 additions & 9 deletions python/aitemplate/backend/common/tensor/slice_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,11 @@
break;
}
}
// We have a full slice for the entire input
if (flatten_index == -1) {
flatten_index = 0;
}
int64_t input_start_offset =
compute_input_linear_index<Rank>(input_strides,
slice_start_indices,
Expand Down Expand Up @@ -348,15 +353,36 @@
slice_meta_data.num_elems[input_idx] = 1;
for ({{index_type}} i = 0; i < Rank; i++) {
assert(slice_start_indices[i] >= 0 &&
slice_start_indices[i] <= input_shape[i]);
assert(slice_end_indices[i] >= 0 && slice_end_indices[i] <= input_shape[i]);
assert(slice_start_indices[i] <= slice_end_indices[i]);
slice_meta_data.num_elems[input_idx] *=
slice_end_indices[i] - slice_start_indices[i];
slice_meta_data.slice_start_indices[input_idx][i] = slice_start_indices[i];
slice_meta_data.slice_end_indices[input_idx][i] = slice_end_indices[i];
int64_t slice_start_idx = slice_start_indices[i];
int64_t slice_end_idx = slice_end_indices[i];
int64_t input_dim = input_shape[i];
if (!(slice_start_idx >= 0 && slice_start_idx <= input_dim)) {
throw std::runtime_error("invalid slice_start_idx: " +
std::to_string(slice_start_idx) +
", input_dim: " +
std::to_string(input_dim) +
", i: " + std::to_string(i));
}
if (!(slice_end_idx >= 0 && slice_end_idx <= input_dim)) {
throw std::runtime_error("invalid slice_end_idx: " +
std::to_string(slice_end_idx) +
", input_dim: " +
std::to_string(input_dim) +
", i: " + std::to_string(i));
}
if (slice_start_idx > slice_end_idx) {
throw std::runtime_error(
"expected slice_start_idx <= slice_end_idx but got slice_start_idx: " +
std::to_string(slice_start_idx) +
" and slice_end_idx: " +
std::to_string(slice_end_idx) +
", i: " + std::to_string(i));
}
slice_meta_data.num_elems[input_idx] *= slice_end_idx - slice_start_idx;
slice_meta_data.slice_start_indices[input_idx][i] = slice_start_idx;
slice_meta_data.slice_end_indices[input_idx][i] = slice_end_idx;
}
slice_meta_data.dim_sizes[input_idx] =
Expand All @@ -383,10 +409,20 @@
// meta data for placing sliced output
scatter_meta_data.output_strides[Rank-1] = 1;
if (output_shape[Rank-1] < 0) {
throw std::runtime_error("invalid output_shape[Rank-1]: " +
std::to_string(output_shape[Rank-1]) +
", Rank: " + std::to_string(Rank));
}
scatter_meta_data.output_shape[Rank-1] = output_shape[Rank-1];
for ({{index_type}} i = Rank - 2; i >= 0; i--) {
scatter_meta_data.output_strides[i] =
scatter_meta_data.output_strides[i+1] * output_shape[i+1];
if (output_shape[i] < 0) {
throw std::runtime_error("invalid output_shape[i]: " +
std::to_string(output_shape[i]) +
", i: " + std::to_string(i));
}
scatter_meta_data.output_shape[i] = output_shape[i];
}
Expand Down Expand Up @@ -423,6 +459,11 @@
}
}
if (max_num_elems <= 0) {
throw std::runtime_error("invalid max_num_elems: " +
std::to_string(max_num_elems));
}
{{index_type}} m = max_num_elems % (ThreadsPerBlock * ElemsPerThread) != 0;
{{index_type}} num_blocks_x =
(max_num_elems / (ThreadsPerBlock * ElemsPerThread)) + m;
Expand Down Expand Up @@ -465,6 +506,12 @@
std::vector<int64_t> slice_start_indices(rank);
std::vector<int64_t> slice_end_indices(rank);
for ({{index_type}} i = 0; i < rank; i++) {
if (input_shape[i] < 0) {
throw std::runtime_error("invalid input_shape: " +
std::to_string(input_shape[i]) +
", i: " +
std::to_string(i));
}
slice_start_indices[i] = orig_slice_start_indices[i] < 0 ?
input_shape[i] + orig_slice_start_indices[i]:
orig_slice_start_indices[i];
Expand Down Expand Up @@ -549,6 +596,9 @@
if (scatter_dim >= rank) {
throw std::runtime_error("scatter_dim must < rank!");
}
if (num_inputs < 1) {
throw std::runtime_error("num_inputs must be larger than 0!");
}
// clip slip start and end indices
std::vector<std::vector<int64_t>> slice_start_indices(num_inputs);
Expand Down
11 changes: 11 additions & 0 deletions tests/unittest/ops/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ def _run_dynamic_slice(
self.test_count += 1

def test_dynamic_slice(self):
self._run_dynamic_slice(
input_shape=[10, 13],
start_indices=[None, None],
end_indices=[None, None],
)
self._run_dynamic_slice(
input_shape=[1],
start_indices=[0],
Expand Down Expand Up @@ -224,6 +229,12 @@ def _run_batch_dynamic_slice(
self.test_count += 1

def test_batch_dynamic_slice(self):
self._run_batch_dynamic_slice(
batch_sizes=[5, 20],
input_shape=[2, 3, 4],
start_indices=[None, None, None, None],
end_indices=[None, None, None, None],
)
self._run_batch_dynamic_slice(
batch_sizes=[1, 1],
input_shape=[1],
Expand Down

0 comments on commit 17e20b4

Please sign in to comment.