Skip to content

Commit

Permalink
Added more sanity checks with verbose messages to the slice_scatter k…
Browse files Browse the repository at this point in the history
…ernel
  • Loading branch information
chenyang78 committed Jul 12, 2023
1 parent ecb705a commit 85ec51d
Showing 1 changed file with 54 additions and 9 deletions.
63 changes: 54 additions & 9 deletions python/aitemplate/backend/common/tensor/slice_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,15 +348,36 @@
slice_meta_data.num_elems[input_idx] = 1;
for ({{index_type}} i = 0; i < Rank; i++) {
assert(slice_start_indices[i] >= 0 &&
slice_start_indices[i] <= input_shape[i]);
assert(slice_end_indices[i] >= 0 && slice_end_indices[i] <= input_shape[i]);
assert(slice_start_indices[i] <= slice_end_indices[i]);
slice_meta_data.num_elems[input_idx] *=
slice_end_indices[i] - slice_start_indices[i];
slice_meta_data.slice_start_indices[input_idx][i] = slice_start_indices[i];
slice_meta_data.slice_end_indices[input_idx][i] = slice_end_indices[i];
int64_t slice_start_idx = slice_start_indices[i];
int64_t slice_end_idx = slice_end_indices[i];
int64_t input_dim = input_shape[i];
if (!(slice_start_idx >= 0 && slice_start_idx <= input_dim)) {
throw std::runtime_error("invalid slice_start_idx: " +
std::to_string(slice_start_idx) +
", input_dim: " +
std::to_string(input_dim) +
", i: " + std::to_string(i));
}
if (!(slice_end_idx >= 0 && slice_end_idx <= input_dim)) {
throw std::runtime_error("invalid slice_end_idx: " +
std::to_string(slice_end_idx) +
", input_dim: " +
std::to_string(input_dim) +
", i: " + std::to_string(i));
}
if (slice_start_idx > slice_end_idx) {
throw std::runtime_error(
"expected slice_start_idx <= slice_end_idx but got slice_start_idx: " +
std::to_string(slice_start_idx) +
" and slice_end_idx: " +
std::to_string(slice_end_idx) +
", i: " + std::to_string(i));
}
slice_meta_data.num_elems[input_idx] *= slice_end_idx - slice_start_idx;
slice_meta_data.slice_start_indices[input_idx][i] = slice_start_idx;
slice_meta_data.slice_end_indices[input_idx][i] = slice_end_idx;
}
slice_meta_data.dim_sizes[input_idx] =
Expand All @@ -383,10 +404,20 @@
// meta data for placing sliced output
scatter_meta_data.output_strides[Rank-1] = 1;
if (output_shape[Rank-1] < 0) {
throw std::runtime_error("invalid output_shape[Rank-1]: " +
std::to_string(output_shape[Rank-1]) +
", Rank: " + std::to_string(Rank));
}
scatter_meta_data.output_shape[Rank-1] = output_shape[Rank-1];
for ({{index_type}} i = Rank - 2; i >= 0; i--) {
scatter_meta_data.output_strides[i] =
scatter_meta_data.output_strides[i+1] * output_shape[i+1];
if (output_shape[i] < 0) {
throw std::runtime_error("invalid output_shape[i]: " +
std::to_string(output_shape[i]) +
", i: " + std::to_string(i));
}
scatter_meta_data.output_shape[i] = output_shape[i];
}
Expand Down Expand Up @@ -423,6 +454,11 @@
}
}
if (max_num_elems <= 0) {
throw std::runtime_error("invalid max_num_elems: " +
std::to_string(max_num_elems));
}
{{index_type}} m = max_num_elems % (ThreadsPerBlock * ElemsPerThread) != 0;
{{index_type}} num_blocks_x =
(max_num_elems / (ThreadsPerBlock * ElemsPerThread)) + m;
Expand Down Expand Up @@ -465,6 +501,12 @@
std::vector<int64_t> slice_start_indices(rank);
std::vector<int64_t> slice_end_indices(rank);
for ({{index_type}} i = 0; i < rank; i++) {
if (input_shape[i] < 0) {
throw std::runtime_error("invalid input_shape: " +
std::to_string(input_shape[i]) +
", i: " +
std::to_string(i));
}
slice_start_indices[i] = orig_slice_start_indices[i] < 0 ?
input_shape[i] + orig_slice_start_indices[i]:
orig_slice_start_indices[i];
Expand Down Expand Up @@ -549,6 +591,9 @@
if (scatter_dim >= rank) {
throw std::runtime_error("scatter_dim must < rank!");
}
if (num_inputs < 1) {
throw std::runtime_error("num_inputs must be larger than 0!");
}
// clip slip start and end indices
std::vector<std::vector<int64_t>> slice_start_indices(num_inputs);
Expand Down

0 comments on commit 85ec51d

Please sign in to comment.