diff --git a/python/aitemplate/backend/common/split_common.py b/python/aitemplate/backend/common/split_common.py index 694b401e4..ded02558f 100644 --- a/python/aitemplate/backend/common/split_common.py +++ b/python/aitemplate/backend/common/split_common.py @@ -386,7 +386,7 @@ throw std::runtime_error("input is NULL!"); } for (int i = 0; i < real_num_splits; i++) { - if (!outputs[i]) { + if (split_sizes[i] && !outputs[i]) { throw std::runtime_error("NULL output found at: " + std::to_string(i)); } } diff --git a/tests/unittest/ops/test_split.py b/tests/unittest/ops/test_split.py index 12daae069..41d8bac1e 100644 --- a/tests/unittest/ops/test_split.py +++ b/tests/unittest/ops/test_split.py @@ -204,6 +204,7 @@ def test_split(self): self._run_split(input_shape=[2, 0, 4], split_size_or_sections=0, dim=-2) self._run_split(input_shape=[2, 0, 4], split_size_or_sections=2, dim=-1) self._run_split(input_shape=[2, 0, 7], split_size_or_sections=[2, 3, 2], dim=-1) + self._run_split(input_shape=[32, 8], split_size_or_sections=[8, 0, 0], dim=-1) def test_split_with_mask(self): self._run_split(