From c42ea73c7750d781c96b9322d9b9e3aa4f0945eb Mon Sep 17 00:00:00 2001 From: Junwha Hong Date: Tue, 15 Oct 2024 06:47:56 +0900 Subject: [PATCH] fix: prevent BOF #5734 in the last channel handling Signed-off-by: Junwha Hong --- src/layer/arm/shufflechannel_arm.cpp | 51 ++++++++++++++++++++++++++-- src/layer/x86/shufflechannel_x86.cpp | 51 ++++++++++++++++++++++++++-- 2 files changed, 96 insertions(+), 6 deletions(-) diff --git a/src/layer/arm/shufflechannel_arm.cpp b/src/layer/arm/shufflechannel_arm.cpp index 571e983166fe..096b6741d17d 100644 --- a/src/layer/arm/shufflechannel_arm.cpp +++ b/src/layer/arm/shufflechannel_arm.cpp @@ -117,7 +117,7 @@ int ShuffleChannel_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Opt ptr1 += 2; - for (int i = 0; i < size; i++) + for (int i = 0; i < size - 1; i++) { float32x4_t _p0 = vld1q_f32(ptr0); float32x4_t _p1 = vld1q_f32(ptr1); @@ -130,6 +130,21 @@ int ShuffleChannel_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Opt ptr1 += 4; outptr0 += 4; } + + for (int i = 0; i < 4; i++) + { + if (i % 2) + { + *outptr0 = *ptr1; + ptr1 += 1; + } + else + { + *outptr0 = *ptr0; + ptr0 += 1; + } + outptr0 += 1; + } } return 0; @@ -364,7 +379,7 @@ int ShuffleChannel_arm::forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blo ptr1 += 4; - for (int i = 0; i < size; i++) + for (int i = 0; i < size - 1; i++) { uint16x4_t _p0 = vld1_u16(ptr0); uint16x4_t _p1 = vld1_u16(ptr1); @@ -378,6 +393,21 @@ int ShuffleChannel_arm::forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blo ptr1 += 8; outptr0 += 8; } + + for (int i = 0; i < 8; i++) + { + if (i % 2) + { + *outptr0 = *ptr1; + ptr1 += 1; + } + else + { + *outptr0 = *ptr0; + ptr0 += 1; + } + outptr0 += 1; + } } return 0; @@ -598,7 +628,7 @@ int ShuffleChannel_arm::forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blo ptr1 += 2; - for (int i = 0; i < size; i++) + for (int i = 0; i < size - 1; i++) { uint16x4_t _p0 = vld1_u16(ptr0); uint16x4_t _p1 = vld1_u16(ptr1); @@ -611,6 +641,21 @@ int ShuffleChannel_arm::forward_bf16s_fp16s(const Mat& bottom_blob, Mat& top_blo ptr1 += 4; outptr0 += 4; } + + for (int i = 0; i < 4; i++) + { + if (i % 2) + { + *outptr0 = *ptr1; + ptr1 += 1; + } + else + { + *outptr0 = *ptr0; + ptr0 += 1; + } + outptr0 += 1; + } } return 0; diff --git a/src/layer/x86/shufflechannel_x86.cpp b/src/layer/x86/shufflechannel_x86.cpp index b1bd9b4a7de7..cf124b163905 100644 --- a/src/layer/x86/shufflechannel_x86.cpp +++ b/src/layer/x86/shufflechannel_x86.cpp @@ -116,7 +116,7 @@ int ShuffleChannel_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Opt ptr1 += 8; - for (int i = 0; i < size; i++) + for (int i = 0; i < size - 1; i++) { __m256 _p0 = _mm256_loadu_ps(ptr0); __m256 _p1 = _mm256_loadu_ps(ptr1); @@ -134,6 +134,21 @@ int ShuffleChannel_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Opt ptr1 += 16; outptr += 16; } + + for (int i = 0; i < 16; i++) + { + if (i % 2) + { + *outptr = *ptr1; + ptr1 += 1; + } + else + { + *outptr = *ptr0; + ptr0 += 1; + } + outptr += 1; + } } return 0; @@ -372,7 +387,7 @@ int ShuffleChannel_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Opt ptr1 += 4; - for (int i = 0; i < size; i++) + for (int i = 0; i < size - 1; i++) { __m128 _p0 = _mm_loadu_ps(ptr0); __m128 _p1 = _mm_loadu_ps(ptr1); @@ -387,6 +402,21 @@ int ShuffleChannel_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Opt ptr1 += 8; outptr += 8; } + + for (int i = 0; i < 8; i++) + { + if (i % 2) + { + *outptr = *ptr1; + ptr1 += 1; + } + else + { + *outptr = *ptr0; + ptr0 += 1; + } + outptr += 1; + } } return 0; @@ -607,7 +637,7 @@ int ShuffleChannel_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Opt ptr1 += 2; - for (int i = 0; i < size; i++) + for (int i = 0; i < size - 1; i++) { __m128 _p0 = _mm_loadu_ps(ptr0); __m128 _p1 = _mm_loadu_ps(ptr1); @@ -620,6 +650,21 @@ int ShuffleChannel_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Opt ptr1 += 4; outptr += 4; } + + for (int i = 0; i < 4; i++) + { + if (i % 2) + { + *outptr = *ptr1; + ptr1 += 1; + } + else + { + *outptr = *ptr0; + ptr0 += 1; + } + outptr += 1; + } } return 0;