From 17e20b449c09be9941315a78703d72f9383070e0 Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Fri, 14 Jul 2023 09:54:58 -0700 Subject: [PATCH] Added more sanity checks with verbose messages to the slice_scatter kernel (#825) Summary: This would help debug runtime failures. Pull Request resolved: https://github.com/facebookincubator/AITemplate/pull/825 Reviewed By: yinghai, wushirong Differential Revision: D47395920 Pulled By: chenyang78 fbshipit-source-id: 434498fc5688fa74d8a60a1c6ef15b35b84e9f98 --- .../backend/common/tensor/slice_common.py | 68 ++++++++++++++++--- tests/unittest/ops/test_slice.py | 11 +++ 2 files changed, 70 insertions(+), 9 deletions(-) diff --git a/python/aitemplate/backend/common/tensor/slice_common.py b/python/aitemplate/backend/common/tensor/slice_common.py index a20f7dd3e..49961916a 100644 --- a/python/aitemplate/backend/common/tensor/slice_common.py +++ b/python/aitemplate/backend/common/tensor/slice_common.py @@ -291,6 +291,11 @@ break; } } + // We have a full slice for the entire input + if (flatten_index == -1) { + flatten_index = 0; + } + int64_t input_start_offset = compute_input_linear_index(input_strides, slice_start_indices, @@ -348,15 +353,36 @@ slice_meta_data.num_elems[input_idx] = 1; for ({{index_type}} i = 0; i < Rank; i++) { - assert(slice_start_indices[i] >= 0 && - slice_start_indices[i] <= input_shape[i]); - assert(slice_end_indices[i] >= 0 && slice_end_indices[i] <= input_shape[i]); - assert(slice_start_indices[i] <= slice_end_indices[i]); - - slice_meta_data.num_elems[input_idx] *= - slice_end_indices[i] - slice_start_indices[i]; - slice_meta_data.slice_start_indices[input_idx][i] = slice_start_indices[i]; - slice_meta_data.slice_end_indices[input_idx][i] = slice_end_indices[i]; + int64_t slice_start_idx = slice_start_indices[i]; + int64_t slice_end_idx = slice_end_indices[i]; + int64_t input_dim = input_shape[i]; + + if (!(slice_start_idx >= 0 && slice_start_idx <= input_dim)) { + throw std::runtime_error("invalid slice_start_idx: " + + std::to_string(slice_start_idx) + + ", input_dim: " + + std::to_string(input_dim) + + ", i: " + std::to_string(i)); + } + if (!(slice_end_idx >= 0 && slice_end_idx <= input_dim)) { + throw std::runtime_error("invalid slice_end_idx: " + + std::to_string(slice_end_idx) + + ", input_dim: " + + std::to_string(input_dim) + + ", i: " + std::to_string(i)); + } + if (slice_start_idx > slice_end_idx) { + throw std::runtime_error( + "expected slice_start_idx <= slice_end_idx but got slice_start_idx: " + + std::to_string(slice_start_idx) + + " and slice_end_idx: " + + std::to_string(slice_end_idx) + + ", i: " + std::to_string(i)); + } + + slice_meta_data.num_elems[input_idx] *= slice_end_idx - slice_start_idx; + slice_meta_data.slice_start_indices[input_idx][i] = slice_start_idx; + slice_meta_data.slice_end_indices[input_idx][i] = slice_end_idx; } slice_meta_data.dim_sizes[input_idx] = @@ -383,10 +409,20 @@ // meta data for placing sliced output scatter_meta_data.output_strides[Rank-1] = 1; + if (output_shape[Rank-1] < 0) { + throw std::runtime_error("invalid output_shape[Rank-1]: " + + std::to_string(output_shape[Rank-1]) + + ", Rank: " + std::to_string(Rank)); + } scatter_meta_data.output_shape[Rank-1] = output_shape[Rank-1]; for ({{index_type}} i = Rank - 2; i >= 0; i--) { scatter_meta_data.output_strides[i] = scatter_meta_data.output_strides[i+1] * output_shape[i+1]; + if (output_shape[i] < 0) { + throw std::runtime_error("invalid output_shape[i]: " + + std::to_string(output_shape[i]) + + ", i: " + std::to_string(i)); + } scatter_meta_data.output_shape[i] = output_shape[i]; } @@ -423,6 +459,11 @@ } } + if (max_num_elems <= 0) { + throw std::runtime_error("invalid max_num_elems: " + + std::to_string(max_num_elems)); + } + {{index_type}} m = max_num_elems % (ThreadsPerBlock * ElemsPerThread) != 0; {{index_type}} num_blocks_x = (max_num_elems / (ThreadsPerBlock * ElemsPerThread)) + m; @@ -465,6 +506,12 @@ std::vector slice_start_indices(rank); std::vector slice_end_indices(rank); for ({{index_type}} i = 0; i < rank; i++) { + if (input_shape[i] < 0) { + throw std::runtime_error("invalid input_shape: " + + std::to_string(input_shape[i]) + + ", i: " + + std::to_string(i)); + } slice_start_indices[i] = orig_slice_start_indices[i] < 0 ? input_shape[i] + orig_slice_start_indices[i]: orig_slice_start_indices[i]; @@ -549,6 +596,9 @@ if (scatter_dim >= rank) { throw std::runtime_error("scatter_dim must < rank!"); } + if (num_inputs < 1) { + throw std::runtime_error("num_inputs must be larger than 0!"); + } // clip slip start and end indices std::vector> slice_start_indices(num_inputs); diff --git a/tests/unittest/ops/test_slice.py b/tests/unittest/ops/test_slice.py index b3d5789c3..09ebb3775 100644 --- a/tests/unittest/ops/test_slice.py +++ b/tests/unittest/ops/test_slice.py @@ -75,6 +75,11 @@ def _run_dynamic_slice( self.test_count += 1 def test_dynamic_slice(self): + self._run_dynamic_slice( + input_shape=[10, 13], + start_indices=[None, None], + end_indices=[None, None], + ) self._run_dynamic_slice( input_shape=[1], start_indices=[0], @@ -224,6 +229,12 @@ def _run_batch_dynamic_slice( self.test_count += 1 def test_batch_dynamic_slice(self): + self._run_batch_dynamic_slice( + batch_sizes=[5, 20], + input_shape=[2, 3, 4], + start_indices=[None, None, None, None], + end_indices=[None, None, None, None], + ) self._run_batch_dynamic_slice( batch_sizes=[1, 1], input_shape=[1],