Skip to content

Commit

Permalink
disallow vector_length=8 if AVX512 is unavailable
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
Yuuichi Asahi committed Dec 10, 2024
1 parent 83c3d71 commit 78a2066
Showing 1 changed file with 30 additions and 18 deletions.
48 changes: 30 additions & 18 deletions batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
typedef typename bViewType::value_type vector_type;
// typedef typename vector_type::value_type value_type;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

static_assert(is_vector<vector_type>::value, "value type is not vector type");
#if defined(KOKKOS_ARCH_AVX512XEON)
static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8,
"AVX, AVX2 and AVX512 is supported");
#else
static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported");
#endif
const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX;

// no error check
Expand Down Expand Up @@ -116,14 +118,16 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
typedef typename bViewType::value_type vector_type;
// typedef typename vector_type::value_type value_type;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

static_assert(is_vector<vector_type>::value, "value type is not vector type");
#if defined(KOKKOS_ARCH_AVX512XEON)
static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8,
"AVX, AVX2 and AVX512 is supported");
#else
static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported");
#endif
const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX;

// no error check
Expand Down Expand Up @@ -179,14 +183,16 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
typedef typename bViewType::value_type vector_type;
// typedef typename vector_type::value_type value_type;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

static_assert(is_vector<vector_type>::value, "value type is not vector type");
#if defined(KOKKOS_ARCH_AVX512XEON)
static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8,
"AVX, AVX2 and AVX512 is supported");
#else
static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported");
#endif
const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX;

// no error check
Expand Down Expand Up @@ -242,14 +248,16 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
typedef typename bViewType::value_type vector_type;
// typedef typename vector_type::value_type value_type;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

static_assert(is_vector<vector_type>::value, "value type is not vector type");
#if defined(KOKKOS_ARCH_AVX512XEON)
static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8,
"AVX, AVX2 and AVX512 is supported");
#else
static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported");
#endif
const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX;

// no error check
Expand Down Expand Up @@ -305,14 +313,16 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
typedef typename bViewType::value_type vector_type;
// typedef typename vector_type::value_type value_type;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

static_assert(is_vector<vector_type>::value, "value type is not vector type");
#if defined(KOKKOS_ARCH_AVX512XEON)
static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8,
"AVX, AVX2 and AVX512 is supported");
#else
static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported");
#endif
const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX;

// no error check
Expand Down Expand Up @@ -368,14 +378,16 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
typedef typename bViewType::value_type vector_type;
// typedef typename vector_type::value_type value_type;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

static_assert(is_vector<vector_type>::value, "value type is not vector type");
#if defined(KOKKOS_ARCH_AVX512XEON)
static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8,
"AVX, AVX2 and AVX512 is supported");
#else
static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported");
#endif
const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX;

// no error check
Expand Down

0 comments on commit 78a2066

Please sign in to comment.