diff --git a/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp index 724d8ecd2a..c542294657 100644 --- a/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp @@ -53,14 +53,16 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - typedef typename bViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - + using vector_type = typename bViewType::value_type; const int m = b.extent(0), n = 1; static_assert(is_vector::value, "value type is not vector type"); +#if defined(KOKKOS_ARCH_AVX512XEON) static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, "AVX, AVX2 and AVX512 is supported"); +#else + static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported"); +#endif const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; // no error check @@ -116,14 +118,16 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - typedef typename bViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - + using vector_type = typename bViewType::value_type; const int m = b.extent(0), n = 1; static_assert(is_vector::value, "value type is not vector type"); +#if defined(KOKKOS_ARCH_AVX512XEON) static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, "AVX, AVX2 and AVX512 is supported"); +#else + static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported"); +#endif const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; // no error check @@ -179,14 +183,16 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - typedef typename bViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - + using vector_type = typename bViewType::value_type; const int m = b.extent(0), n = 1; static_assert(is_vector::value, "value type is not vector type"); +#if defined(KOKKOS_ARCH_AVX512XEON) static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, "AVX, AVX2 and AVX512 is supported"); +#else + static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported"); +#endif const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; // no error check @@ -242,14 +248,16 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - typedef typename bViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - + using vector_type = typename bViewType::value_type; const int m = b.extent(0), n = 1; static_assert(is_vector::value, "value type is not vector type"); +#if defined(KOKKOS_ARCH_AVX512XEON) static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, "AVX, AVX2 and AVX512 is supported"); +#else + static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported"); +#endif const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; // no error check @@ -305,14 +313,16 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - typedef typename bViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - + using vector_type = typename bViewType::value_type; const int m = b.extent(0), n = 1; static_assert(is_vector::value, "value type is not vector type"); +#if defined(KOKKOS_ARCH_AVX512XEON) static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, "AVX, AVX2 and AVX512 is supported"); +#else + static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported"); +#endif const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; // no error check @@ -368,14 +378,16 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - typedef typename bViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - + using vector_type = typename bViewType::value_type; const int m = b.extent(0), n = 1; static_assert(is_vector::value, "value type is not vector type"); +#if defined(KOKKOS_ARCH_AVX512XEON) static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, "AVX, AVX2 and AVX512 is supported"); +#else + static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported"); +#endif const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; // no error check