From 9accffdcab6e0266f5e926a9a7c8779fc8c0dbe9 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Fri, 6 Dec 2024 22:28:02 +0900 Subject: [PATCH 1/8] Add missing specialization of Serial Trsv Signed-off-by: Yuuichi Asahi --- .../impl/KokkosBatched_Trsv_Serial_Impl.hpp | 251 +++++++++++++++--- .../KokkosBatched_Trsv_Serial_Internal.hpp | 81 +++--- batched/dense/src/KokkosBatched_Trsv_Decl.hpp | 1 + 3 files changed, 253 insertions(+), 80 deletions(-) diff --git a/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp index d7db47375a..87d32258b8 100644 --- a/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp @@ -17,27 +17,37 @@ #define KOKKOSBATCHED_TRSV_SERIAL_IMPL_HPP /// \author Kyungjoo Kim (kyukim@sandia.gov) +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) #include "KokkosBatched_Util.hpp" #include "KokkosBatched_Trsv_Serial_Internal.hpp" namespace KokkosBatched { +namespace Impl { +template +KOKKOS_INLINE_FUNCTION static int checkTrsvInput([[maybe_unused]] const AViewType &A, + [[maybe_unused]] const bViewType &b) { +#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) + static_assert(Kokkos::is_view_v, "KokkosBatched::trsv: AViewType is not a Kokkos::View."); + static_assert(Kokkos::is_view_v, "KokkosBatched::trsv: bViewType is not a Kokkos::View."); + static_assert(AViewType::rank == 2, "KokkosBatched::trsv: AViewType must have rank 2."); + static_assert(bViewType::rank == 1, "KokkosBatched::trsv: bViewType must have rank 1."); + + const int lda = A.extent(0), n = A.extent(1); + if (lda < Kokkos::max(1, n)) { + Kokkos::printf( + "KokkosBatched::trsv: leading dimension of A must not be smaller than " + "max(1, n): " + "lda = %d, n = %d\n", + lda, n); + return 1; + } +#endif + return 0; +} +} // namespace Impl -/// -/// Serial Impl -/// =========== - -/// -/// Implemented: -/// L/NT, U/NT, L/T, U/T -/// -/// Not yet implemented -/// L/CT, U/CT - -/// -/// L/NT -/// - +//// Lower non-transpose //// #if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) template @@ -76,8 +86,13 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - return SerialTrsvInternalLower::invoke(ArgDiag::use_unit_diag, A.extent(0), alpha, A.data(), - A.stride_0(), A.stride_1(), b.data(), b.stride_0()); + // Quick return if possible + if (A.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkTrsvInput(A, b); + if (info) return info; + return KokkosBatched::Impl::SerialTrsvInternalLower::invoke( + ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0)); } }; @@ -85,15 +100,17 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - return SerialTrsvInternalLower::invoke(ArgDiag::use_unit_diag, A.extent(0), alpha, A.data(), - A.stride_0(), A.stride_1(), b.data(), b.stride_0()); + // Quick return if possible + if (A.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkTrsvInput(A, b); + if (info) return info; + return KokkosBatched::Impl::SerialTrsvInternalLower::invoke( + ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0)); } }; -/// -/// L/T -/// - +//// Lower transpose //// #if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) template @@ -132,8 +149,13 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - return SerialTrsvInternalUpper::invoke(ArgDiag::use_unit_diag, A.extent(1), alpha, A.data(), - A.stride_1(), A.stride_0(), b.data(), b.stride_0()); + // Quick return if possible + if (A.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkTrsvInput(A, b); + if (info) return info; + return KokkosBatched::Impl::SerialTrsvInternalUpper::invoke( + ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); } }; @@ -141,15 +163,80 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - return SerialTrsvInternalUpper::invoke(ArgDiag::use_unit_diag, A.extent(1), alpha, A.data(), - A.stride_1(), A.stride_0(), b.data(), b.stride_0()); + // Quick return if possible + if (A.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkTrsvInput(A, b); + if (info) return info; + return KokkosBatched::Impl::SerialTrsvInternalUpper::invoke( + ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); + } +}; + +//// Lower conjugate-transpose //// +#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ + defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) +template +struct SerialTrsv { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { + typedef typename bViewType::value_type vector_type; + // typedef typename vector_type::value_type value_type; + + const int m = b.extent(0), n = 1; + + static_assert(is_vector::value, "value type is not vector type"); + static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, + "AVX, AVX2 and AVX512 is supported"); + const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; + + // no error check + int r_val = 0; + if (A.stride_0() == 1) { + mkl_dtrsm_compact(MKL_COL_MAJOR, MKL_LEFT, MKL_LOWER, MKL_CONJTRANS, + ArgDiag::use_unit_diag ? MKL_UNIT : MKL_NONUNIT, m, n, alpha, (const double *)A.data(), + A.stride_0(), (double *)b.data(), b.stride_0(), format, (MKL_INT)vector_type::vector_length); + } else if (A.stride_1() == 1) { + mkl_dtrsm_compact(MKL_ROW_MAJOR, MKL_LEFT, MKL_LOWER, MKL_CONJTRANS, + ArgDiag::use_unit_diag ? MKL_UNIT : MKL_NONUNIT, m, n, alpha, (const double *)A.data(), + A.stride_0(), (double *)b.data(), b.stride_0(), format, (MKL_INT)vector_type::vector_length); + } else { + r_val = -1; + } + return r_val; + } +}; +#endif + +template +struct SerialTrsv { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { + // Quick return if possible + if (A.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkTrsvInput(A, b); + if (info) return info; + return KokkosBatched::Impl::SerialTrsvInternalUpper::invoke( + ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); } }; -/// -/// U/NT -/// +template +struct SerialTrsv { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { + // Quick return if possible + if (A.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkTrsvInput(A, b); + if (info) return info; + return KokkosBatched::Impl::SerialTrsvInternalUpper::invoke( + ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); + } +}; +//// Upper non-transpose //// #if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) template @@ -188,8 +275,13 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - return SerialTrsvInternalUpper::invoke(ArgDiag::use_unit_diag, A.extent(0), alpha, A.data(), - A.stride_0(), A.stride_1(), b.data(), b.stride_0()); + // Quick return if possible + if (A.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkTrsvInput(A, b); + if (info) return info; + return KokkosBatched::Impl::SerialTrsvInternalUpper::invoke( + ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0)); } }; @@ -197,15 +289,17 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - return SerialTrsvInternalUpper::invoke(ArgDiag::use_unit_diag, A.extent(0), alpha, A.data(), - A.stride_0(), A.stride_1(), b.data(), b.stride_0()); + // Quick return if possible + if (A.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkTrsvInput(A, b); + if (info) return info; + return KokkosBatched::Impl::SerialTrsvInternalUpper::invoke( + ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0)); } }; -/// -/// U/T -/// - +//// Upper transpose //// #if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) template @@ -244,8 +338,13 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - return SerialTrsvInternalLower::invoke(ArgDiag::use_unit_diag, A.extent(1), alpha, A.data(), - A.stride_1(), A.stride_0(), b.data(), b.stride_0()); + // Quick return if possible + if (A.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkTrsvInput(A, b); + if (info) return info; + return KokkosBatched::Impl::SerialTrsvInternalLower::invoke( + ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); } }; @@ -253,8 +352,76 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - return SerialTrsvInternalLower::invoke(ArgDiag::use_unit_diag, A.extent(1), alpha, A.data(), - A.stride_1(), A.stride_0(), b.data(), b.stride_0()); + // Quick return if possible + if (A.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkTrsvInput(A, b); + if (info) return info; + return KokkosBatched::Impl::SerialTrsvInternalLower::invoke( + ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); + } +}; + +//// Upper conjugate-transpose //// +#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \ + defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__) +template +struct SerialTrsv { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { + typedef typename bViewType::value_type vector_type; + // typedef typename vector_type::value_type value_type; + + const int m = b.extent(0), n = 1; + + static_assert(is_vector::value, "value type is not vector type"); + static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, + "AVX, AVX2 and AVX512 is supported"); + const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; + + // no error check + int r_val = 0; + if (A.stride_0() == 1) { + mkl_dtrsm_compact(MKL_COL_MAJOR, MKL_LEFT, MKL_UPPER, MKL_CONJTRANS, + ArgDiag::use_unit_diag ? MKL_UNIT : MKL_NONUNIT, m, n, alpha, (const double *)A.data(), + A.stride_0(), (double *)b.data(), b.stride_0(), format, (MKL_INT)vector_type::vector_length); + } else if (A.stride_1() == 1) { + mkl_dtrsm_compact(MKL_ROW_MAJOR, MKL_LEFT, MKL_UPPER, MKL_CONJTRANS, + ArgDiag::use_unit_diag ? MKL_UNIT : MKL_NONUNIT, m, n, alpha, (const double *)A.data(), + A.stride_0(), (double *)b.data(), b.stride_0(), format, (MKL_INT)vector_type::vector_length); + } else { + r_val = -1; + } + return r_val; + } +}; +#endif + +template +struct SerialTrsv { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { + // Quick return if possible + if (A.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkTrsvInput(A, b); + if (info) return info; + return KokkosBatched::Impl::SerialTrsvInternalLower::invoke( + ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); + } +}; + +template +struct SerialTrsv { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { + // Quick return if possible + if (A.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkTrsvInput(A, b); + if (info) return info; + return KokkosBatched::Impl::SerialTrsvInternalLower::invoke( + ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0)); } }; diff --git a/batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp index 861c72eec9..424f1680e1 100644 --- a/batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp @@ -17,6 +17,7 @@ #define KOKKOSBATCHED_TRSV_SERIAL_INTERNAL_HPP /// \author Kyungjoo Kim (kyukim@sandia.gov) +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) #include "KokkosBatched_Util.hpp" @@ -26,6 +27,7 @@ #include "KokkosBlas2_serial_gemv_internal.hpp" namespace KokkosBatched { +namespace Impl { /// /// Serial Internal Impl @@ -38,19 +40,18 @@ namespace KokkosBatched { template struct SerialTrsvInternalLower { template - KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const int m, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, + KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int m, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, /**/ ValueType *KOKKOS_RESTRICT b, const int bs0); }; template <> template -KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invoke(const bool use_unit_diag, const int m, - const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, - const int as0, const int as1, - /**/ ValueType *KOKKOS_RESTRICT b, - const int bs0) { +KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invoke( + const bool use_unit_diag, const bool do_conj, const int m, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, + /**/ ValueType *KOKKOS_RESTRICT b, const int bs0) { const ScalarType one(1.0), zero(0.0); if (alpha == zero) @@ -62,16 +63,20 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invok for (int p = 0; p < m; ++p) { const int iend = m - p - 1; - const ValueType *KOKKOS_RESTRICT a21 = iend ? A + (p + 1) * as0 + p * as1 : NULL; + const ValueType *KOKKOS_RESTRICT a21 = iend ? A + (p + 1) * as0 + p * as1 : nullptr; - ValueType *KOKKOS_RESTRICT beta1 = b + p * bs0, *KOKKOS_RESTRICT b2 = iend ? beta1 + bs0 : NULL; + ValueType *KOKKOS_RESTRICT beta1 = b + p * bs0, *KOKKOS_RESTRICT b2 = iend ? beta1 + bs0 : nullptr; // with KOKKOS_RESTRICT a compiler assumes that the pointer is not // accessed by others op(/=) uses this pointer and changes the associated // values, which brings a compiler problem - if (!use_unit_diag) *beta1 = *beta1 / A[p * as0 + p * as1]; + if (!use_unit_diag) + *beta1 = (do_conj ? *beta1 / Kokkos::ArithTraits::conj(A[p * as0 + p * as1]) + : *beta1 / A[p * as0 + p * as1]); - for (int i = 0; i < iend; ++i) b2[i * bs0] -= a21[i * as0] * (*beta1); + for (int i = 0; i < iend; ++i) + b2[i * bs0] -= + (do_conj ? Kokkos::ArithTraits::conj(a21[i * as0]) * (*beta1) : a21[i * as0] * (*beta1)); } } return 0; @@ -79,12 +84,10 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invok template <> template -KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invoke(const bool use_unit_diag, const int m, - const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, - const int as0, const int as1, - /**/ ValueType *KOKKOS_RESTRICT b, - const int bs0) { +KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invoke( + const bool use_unit_diag, const bool /*do_conj*/, const int m, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, + /**/ ValueType *KOKKOS_RESTRICT b, const int bs0) { const ScalarType one(1.0), zero(0.0), minus_one(-1.0); constexpr int mbAlgo = Algo::Trsv::Blocked::mb(); @@ -105,7 +108,7 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invoke( // trsm update const ValueType *KOKKOS_RESTRICT Ap = A + p * as0 + p * as1; - /**/ ValueType *KOKKOS_RESTRICT bp = b + p * bs0; + /**/ ValueType *KOKKOS_RESTRICT bp = b + p * bs0; if (use_unit_diag) trsm_u.serial_invoke(Ap, pb, 1, bp); @@ -127,19 +130,18 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invoke( template struct SerialTrsvInternalUpper { template - KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const int m, const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, + KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int m, + const ScalarType alpha, const ValueType *KOKKOS_RESTRICT A, const int as0, + const int as1, /**/ ValueType *KOKKOS_RESTRICT b, const int bs0); }; template <> template -KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invoke(const bool use_unit_diag, const int m, - const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, - const int as0, const int as1, - /**/ ValueType *KOKKOS_RESTRICT b, - const int bs0) { +KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invoke( + const bool use_unit_diag, const bool do_conj, const int m, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, + /**/ ValueType *KOKKOS_RESTRICT b, const int bs0) { const ScalarType one(1.0), zero(0.0); if (alpha == zero) @@ -152,15 +154,19 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invok for (int p = (m - 1); p >= 0; --p) { const int iend = p; - const ValueType *KOKKOS_RESTRICT a01 = A + p * as1; - /**/ ValueType *KOKKOS_RESTRICT beta1 = b + p * bs0; + const ValueType *KOKKOS_RESTRICT a01 = A + p * as1; + /**/ ValueType *KOKKOS_RESTRICT beta1 = b + p * bs0; // with KOKKOS_RESTRICT a compiler assumes that the pointer is not // accessed by others op(/=) uses this pointer and changes the associated // values, which brings a compiler problem - if (!use_unit_diag) *beta1 = *beta1 / A[p * as0 + p * as1]; + if (!use_unit_diag) + *beta1 = (do_conj ? *beta1 / Kokkos::ArithTraits::conj(A[p * as0 + p * as1]) + : *beta1 / A[p * as0 + p * as1]); - for (int i = 0; i < iend; ++i) b0[i * bs0] -= a01[i * as0] * (*beta1); + for (int i = 0; i < iend; ++i) + b0[i * bs0] -= + (do_conj ? Kokkos::ArithTraits::conj(a01[i * as0]) * (*beta1) : a01[i * as0] * (*beta1)); } } return 0; @@ -168,12 +174,10 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invok template <> template -KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invoke(const bool use_unit_diag, const int m, - const ScalarType alpha, - const ValueType *KOKKOS_RESTRICT A, - const int as0, const int as1, - /**/ ValueType *KOKKOS_RESTRICT b, - const int bs0) { +KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invoke( + const bool use_unit_diag, const bool /*do_conj*/, const int m, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1, + /**/ ValueType *KOKKOS_RESTRICT b, const int bs0) { const ScalarType one(1.0), zero(0.0), minus_one(-1.0); constexpr int mbAlgo = Algo::Trsm::Blocked::mb(); @@ -194,7 +198,7 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invoke( // trsm update const ValueType *KOKKOS_RESTRICT Ap = A + p * as0 + p * as1; - /**/ ValueType *KOKKOS_RESTRICT bp = b + p * bs0; + /**/ ValueType *KOKKOS_RESTRICT bp = b + p * bs0; if (use_unit_diag) trsm_u.serial_invoke(Ap, pb, 1, bp); @@ -209,6 +213,7 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invoke( return 0; } +} // namespace Impl } // namespace KokkosBatched #endif diff --git a/batched/dense/src/KokkosBatched_Trsv_Decl.hpp b/batched/dense/src/KokkosBatched_Trsv_Decl.hpp index d711085c9b..74a5f11a8f 100644 --- a/batched/dense/src/KokkosBatched_Trsv_Decl.hpp +++ b/batched/dense/src/KokkosBatched_Trsv_Decl.hpp @@ -17,6 +17,7 @@ #define KOKKOSBATCHED_TRSV_DECL_HPP /// \author Kyungjoo Kim (kyukim@sandia.gov) +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) #include "KokkosBatched_Util.hpp" #include "KokkosBatched_Vector.hpp" From 7507db4b12051e3eb19d8ff452d5ae3d45d73071 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Fri, 6 Dec 2024 22:29:56 +0900 Subject: [PATCH 2/8] Improve testing of serial trsv Signed-off-by: Yuuichi Asahi --- .../unit_test/Test_Batched_SerialTrsv.hpp | 435 ++++++++++++++---- .../Test_Batched_SerialTrsv_Complex.hpp | 140 ++++-- .../Test_Batched_SerialTrsv_Real.hpp | 72 ++- 3 files changed, 501 insertions(+), 146 deletions(-) diff --git a/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp b/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp index c0a7de9e99..ddc965f1d4 100644 --- a/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp @@ -14,15 +14,16 @@ // //@HEADER /// \author Kyungjoo Kim (kyukim@sandia.gov) +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) #include "gtest/gtest.h" #include "Kokkos_Core.hpp" #include "Kokkos_Random.hpp" #include "KokkosBatched_Util.hpp" +#include "KokkosBlas2_gemv.hpp" #include "KokkosBatched_Trsv_Decl.hpp" - -// #include "KokkosKernels_TestUtils.hpp" +#include "Test_Batched_DenseUtils.hpp" using namespace KokkosBatched; @@ -31,57 +32,94 @@ namespace Trsv { template struct ParamTag { - typedef U uplo; - typedef T trans; - typedef D diag; + using uplo = U; + using trans = T; + using diag = D; }; -template -struct Functor_TestBatchedSerialTrsv { +template +struct Functor_BatchedSerialTrsv { using execution_space = typename DeviceType::execution_space; - ViewType _a, _b; + AViewType m_a; + BViewType m_b; - ScalarType _alpha; + ScalarType m_alpha; KOKKOS_INLINE_FUNCTION - Functor_TestBatchedSerialTrsv(const ScalarType alpha, const ViewType &a, const ViewType &b) - : _a(a), _b(b), _alpha(alpha) {} + Functor_BatchedSerialTrsv(const ScalarType alpha, const AViewType &a, const BViewType &b) + : m_a(a), m_b(b), m_alpha(alpha) {} KOKKOS_INLINE_FUNCTION void operator()(const ParamTagType &, const int k) const { - auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); - auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), 0); + auto aa = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto bb = Kokkos::subview(m_b, k, Kokkos::ALL()); SerialTrsv::invoke(_alpha, aa, bb); + AlgoTagType>::invoke(m_alpha, aa, bb); } inline void run() { - typedef typename ViewType::value_type value_type; + using value_type = typename AViewType::non_const_value_type; std::string name_region("KokkosBatched::Test::SerialTrsv"); const std::string name_value_type = Test::value_type_name(); std::string name = name_region + name_value_type; Kokkos::Profiling::pushRegion(name.c_str()); - Kokkos::RangePolicy policy(0, _b.extent(0)); + Kokkos::RangePolicy policy(0, m_b.extent(0)); Kokkos::parallel_for(name.c_str(), policy, *this); Kokkos::Profiling::popRegion(); } }; -template -void impl_test_batched_trsv(const int N, const int BlkSize) { - typedef typename ViewType::value_type value_type; - typedef Kokkos::ArithTraits ats; +template +struct Functor_BatchedSerialGemv { + using execution_space = typename DeviceType::execution_space; + AViewType m_a; + xViewType m_x; + yViewType m_y; + ScalarType m_alpha, m_beta; + + KOKKOS_INLINE_FUNCTION + Functor_BatchedSerialGemv(const ScalarType alpha, const AViewType &a, const xViewType &x, const ScalarType beta, + const yViewType &y) + : m_a(a), m_x(x), m_y(y), m_alpha(alpha), m_beta(beta) {} + + KOKKOS_INLINE_FUNCTION + void operator()(const ParamTagType &, const int k) const { + auto aa = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto xx = Kokkos::subview(m_x, k, Kokkos::ALL()); + auto yy = Kokkos::subview(m_y, k, Kokkos::ALL()); + + KokkosBlas::SerialGemv::invoke(m_alpha, aa, xx, m_beta, yy); + } + + inline void run() { + using value_type = typename AViewType::non_const_value_type; + std::string name_region("KokkosBatched::Test::SerialGemv"); + const std::string name_value_type = Test::value_type_name(); + std::string name = name_region + name_value_type; + Kokkos::RangePolicy policy(0, m_x.extent(0)); + Kokkos::parallel_for(name.c_str(), policy, *this); + } +}; + +template +void impl_test_batched_trsv_blocking(const int N, const int BlkSize) { + using ats = Kokkos::ArithTraits; + using View2DType = Kokkos::View; + using View3DType = Kokkos::View; /// randomized input testing views ScalarType alpha(1.5); - ViewType a0("a0", N, BlkSize, BlkSize), a1("a1", N, BlkSize, BlkSize), b0("b0", N, BlkSize, 1), - b1("b1", N, BlkSize, 1); + View3DType a0("a0", N, BlkSize, BlkSize), a1("a1", N, BlkSize, BlkSize); + View2DType b0("b0", N, BlkSize), b1("b1", N, BlkSize); Kokkos::Random_XorShift64_Pool random(13718); - Kokkos::fill_random(a0, random, value_type(1.0)); - Kokkos::fill_random(b0, random, value_type(1.0)); + Kokkos::fill_random(a0, random, ValueType(1.0)); + Kokkos::fill_random(b0, random, ValueType(1.0)); Kokkos::fence(); @@ -90,82 +128,253 @@ void impl_test_batched_trsv(const int N, const int BlkSize) { Kokkos::deep_copy(a1, a0); Kokkos::deep_copy(b1, b0); - Functor_TestBatchedSerialTrsv(alpha, a0, b0) + Functor_BatchedSerialTrsv(alpha, + a0, b0) + .run(); + Functor_BatchedSerialTrsv(alpha, a1, b1) .run(); - Functor_TestBatchedSerialTrsv(alpha, a1, b1).run(); Kokkos::fence(); /// for comparison send it to host - typename ViewType::HostMirror a0_host = Kokkos::create_mirror_view(a0); - typename ViewType::HostMirror b0_host = Kokkos::create_mirror_view(b0); - typename ViewType::HostMirror b1_host = Kokkos::create_mirror_view(b1); - - Kokkos::deep_copy(a0_host, a0); - Kokkos::deep_copy(b0_host, b0); - Kokkos::deep_copy(b1_host, b1); + auto b0_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), b0); + auto b1_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), b1); /// this eps is about 10^-14 - typedef typename ats::mag_type mag_type; + using mag_type = typename ats::mag_type; mag_type sum(1), diff(0); const mag_type eps = 1.0e3 * ats::epsilon(); - /// check b0 and b1 are correct - const value_type one(1); - const bool is_unit_diag = std::is_same::value; - for (int k = 0; k < N; ++k) { - if (std::is_same::value) { - if (std::is_same::value) { - for (int i = 0; i < BlkSize; ++i) { - value_type tmp(0); - for (int j = 0; j <= i; ++j) { - const value_type aval = (i == j && is_unit_diag ? one : a0_host(k, i, j)); - const value_type bval = b0_host(k, j, 0); - tmp += aval * bval; - } - EXPECT_NEAR(ats::abs(tmp), ats::abs(alpha), eps); - } - for (int i = 0; i < BlkSize; ++i) { - value_type tmp(0); - for (int j = 0; j <= i; ++j) { - const value_type aval = (i == j && is_unit_diag ? one : a0_host(k, i, j)); - const value_type bval = b1_host(k, j, 0); - tmp += aval * bval; - } - EXPECT_NEAR(ats::abs(tmp), ats::abs(alpha), eps); - } - } else if (std::is_same::value) { - for (int i = 0; i < BlkSize; ++i) { - value_type tmp(0); - for (int j = i; j < BlkSize; ++j) { - const value_type aval = (i == j && is_unit_diag ? one : a0_host(k, i, j)); - const value_type bval = b0_host(k, j, 0); - tmp += aval * bval; - } - EXPECT_NEAR(ats::abs(tmp), ats::abs(alpha), eps); + /// check b0 = b1 ; + for (int k = 0; k < N; ++k) + for (int i = 0; i < BlkSize; ++i) { + sum += ats::abs(b0_host(k, i)); + diff += ats::abs(b0_host(k, i) - b1_host(k, i)); + } + EXPECT_NEAR(diff / sum, 0.0, eps); +} + +/// \brief Implementation details of batched trsv analytical test +/// Confirm A * x = b, where +/// A: [[1, 1, 1], +/// [1, 2, 2], +/// [1, 2, 3]] +/// b: [1, 1, 1] +/// +/// Upper and Non-transpose +/// x0 + x1 + x2 = 1 +/// +2x1 + 2x2 = 1 +/// 3x2 = 1 +/// x = [1/2, 1/6, 1/3] +/// +/// Upper and Transpose +/// x0 = 1 +/// x0 +2x1 = 1 +/// x0 +2x1 + 3x2 = 1 +/// x = [1, 0, 0] +/// +/// Lower, Non-transpose +/// x0 = 1 +/// x0 +2x1 = 1 +/// x0 +2x1 + 3x2 = 1 +/// x = [1, 0, 0] +/// +/// Lower, Transpose +/// x0 + x1 + x2 = 1 +/// +2x1 + 2x2 = 1 +/// 3x2 = 1 +/// x = [1/2, 1/6, 1/3] +/// \param N [in] Batch size of matrices and RHS +template +void impl_test_batched_trsv_analytical(const std::size_t N) { + using ats = typename Kokkos::ArithTraits; + using RealType = typename ats::mag_type; + + using View2DType = Kokkos::View; + using StridedView2DType = Kokkos::View; + using View3DType = Kokkos::View; + using PivView2DType = Kokkos::View; + + constexpr std::size_t BlkSize = 3, incx = 2; + View3DType A("A", N, BlkSize, BlkSize); + View2DType x("x", N, BlkSize), x_ref("x_ref", N, BlkSize); // Solutions + + // Testing incx argument with strided views + Kokkos::LayoutStride layout{N, incx, BlkSize, N * incx}; + StridedView2DType x_s("x_s", layout); // Solutions + + auto h_A = Kokkos::create_mirror_view(A); + auto h_x_ref = Kokkos::create_mirror_view(x_ref); + for (std::size_t ib = 0; ib < N; ib++) { + h_A(ib, 0, 0) = 1.0; + h_A(ib, 0, 1) = 1.0; + h_A(ib, 0, 2) = 1.0; + h_A(ib, 1, 0) = 1.0; + h_A(ib, 1, 1) = 2.0; + h_A(ib, 1, 2) = 2.0; + h_A(ib, 2, 0) = 1.0; + h_A(ib, 2, 1) = 2.0; + h_A(ib, 2, 2) = 3.0; + + if (std::is_same_v) { + if (std::is_same_v) { + if (std::is_same_v) { + h_x_ref(ib, 0) = 1.0 / 2.0; + h_x_ref(ib, 1) = 1.0 / 6.0; + h_x_ref(ib, 2) = 1.0 / 3.0; + } else { + h_x_ref(ib, 0) = 1.0; + h_x_ref(ib, 1) = -1.0; + h_x_ref(ib, 2) = 1.0; } - for (int i = 0; i < BlkSize; ++i) { - value_type tmp(0); - for (int j = i; j < BlkSize; ++j) { - const value_type aval = (i == j && is_unit_diag ? one : a0_host(k, i, j)); - const value_type bval = b1_host(k, j, 0); - tmp += aval * bval; - } - EXPECT_NEAR(ats::abs(tmp), ats::abs(alpha), eps); + } else { + // Diag::NonUnit does not matter + h_x_ref(ib, 0) = 1.0; + h_x_ref(ib, 1) = 0.0; + h_x_ref(ib, 2) = 0.0; + } + } else { + if (std::is_same_v) { + // Diag::NonUnit does not matter + h_x_ref(ib, 0) = 1.0; + h_x_ref(ib, 1) = 0.0; + h_x_ref(ib, 2) = 0.0; + } else { + if (std::is_same_v) { + h_x_ref(ib, 0) = 1.0 / 2.0; + h_x_ref(ib, 1) = 1.0 / 6.0; + h_x_ref(ib, 2) = 1.0 / 3.0; + } else { + h_x_ref(ib, 0) = 1.0; + h_x_ref(ib, 1) = -1.0; + h_x_ref(ib, 2) = 1.0; } } } } - /// check b0 = b1 ; - for (int k = 0; k < N; ++k) - for (int i = 0; i < BlkSize; ++i) - for (int j = 0; j < 1; ++j) { - sum += ats::abs(b0_host(k, i, j)); - diff += ats::abs(b0_host(k, i, j) - b1_host(k, i, j)); + Kokkos::deep_copy(A, h_A); + + // Set RHS as [1.0, 1.0, 1.0] + Kokkos::deep_copy(x, 1.0); + Kokkos::deep_copy(x_s, x); + + // trsv to solve U * x = b or L * x = b + Functor_BatchedSerialTrsv(1.0, A, + x) + .run(); + Functor_BatchedSerialTrsv( + 1.0, A, x_s) + .run(); + + Kokkos::fence(); + + // Check x = x_ref + RealType eps = 1.0e1 * ats::epsilon(); + auto h_x = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), x); + for (std::size_t ib = 0; ib < N; ib++) { + for (std::size_t i = 0; i < BlkSize; i++) { + EXPECT_NEAR_KK(h_x(ib, i), h_x_ref(ib, i), eps); + } + } + + // Testing for strided views, reusing x + Kokkos::deep_copy(x, x_s); + Kokkos::deep_copy(h_x, x); + for (std::size_t ib = 0; ib < N; ib++) { + for (std::size_t i = 0; i < BlkSize; i++) { + EXPECT_NEAR_KK(h_x(ib, i), h_x_ref(ib, i), eps); + } + } +} + +/// \brief Implementation details of batched trsv test +/// \param N [in] Batch size of matrices and RHS +/// \param BlkSize [in] Block size of matrix A +template +void impl_test_batched_trsv(const std::size_t N, const std::size_t BlkSize) { + using ats = typename Kokkos::ArithTraits; + using RealType = typename ats::mag_type; + + using View2DType = Kokkos::View; + using StridedView2DType = Kokkos::View; + using View3DType = Kokkos::View; + using PivView2DType = Kokkos::View; + + constexpr std::size_t incx = 2; + View3DType A("A", N, BlkSize, BlkSize), AT("AT", N, BlkSize, BlkSize); + View2DType x("x", N, BlkSize), y("y", N, BlkSize), x_ref("x_ref", N, BlkSize); // Solutions + + // Testing incx argument with strided views + Kokkos::LayoutStride layout{N, incx, BlkSize, N * incx}; + StridedView2DType x_s("x_s", layout), y_s("y_s", layout); // Solutions + + using execution_space = typename DeviceType::execution_space; + Kokkos::Random_XorShift64_Pool rand_pool(13718); + ScalarType randStart, randEnd; + KokkosKernels::Impl::getRandomBounds(1.0, randStart, randEnd); + Kokkos::fill_random(A, rand_pool, randStart, randEnd); + Kokkos::fill_random(x, rand_pool, randStart, randEnd); + Kokkos::deep_copy(x_ref, x); // Keep reference solution + Kokkos::deep_copy(x_s, x); + + // Create triangluar matrix + create_triangular_matrix(A, AT); + + // trsv to solve U * x = b or L * x = b + Functor_BatchedSerialTrsv(1.0, + AT, x) + .run(); + + Functor_BatchedSerialTrsv( + 1.0, AT, x_s) + .run(); + + Kokkos::fence(); + + if constexpr (std::is_same_v) { + auto h_AT = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), AT); + for (std::size_t ib = 0; ib < N; ib++) { + for (std::size_t i = 0; i < BlkSize; i++) { + h_AT(ib, i, i) = 1.0; } - EXPECT_NEAR(diff / sum, 0.0, eps); + } + Kokkos::deep_copy(AT, h_AT); + } + + // Compute A * x by gemv + // Gemv to compute A*x, this should be identical to b + Functor_BatchedSerialGemv(1.0, AT, x, 0.0, + y) + .run(); + + // Gemv to compute A*x, this should be identical to b + Functor_BatchedSerialGemv( + 1.0, AT, x_s, 0.0, y_s) + .run(); + + // Check A*x = x_ref + RealType eps = 1.0e3 * ats::epsilon(); + auto h_y = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), y); + auto h_x_ref = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), x_ref); + for (std::size_t ib = 0; ib < N; ib++) { + for (std::size_t i = 0; i < BlkSize; i++) { + EXPECT_NEAR_KK(h_y(ib, i), h_x_ref(ib, i), eps); + } + } + + // Testing for strided views, reusing y + Kokkos::deep_copy(y, y_s); + Kokkos::deep_copy(h_y, y); + for (std::size_t ib = 0; ib < N; ib++) { + for (std::size_t i = 0; i < BlkSize; i++) { + EXPECT_NEAR_KK(h_y(ib, i), h_x_ref(ib, i), eps); + } + } } + } // namespace Trsv } // namespace Test @@ -173,33 +382,55 @@ template ViewType; - Test::Trsv::impl_test_batched_trsv(0, 10); + using LayoutType = Kokkos::LayoutLeft; + Test::Trsv::impl_test_batched_trsv_analytical(1); + Test::Trsv::impl_test_batched_trsv_analytical(2); + + // FIXME: ConjTranspose with blocking is not implemented yet + if constexpr (!std::is_same_v) { + Test::Trsv::impl_test_batched_trsv_blocking(0, 10); + } + for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutLeft, Blksize %d, Uplo %d, Trans %d, Diag - // %d\n", - // i, - // std::is_same::value, - // std::is_same::value, std::is_same::value); - Test::Trsv::impl_test_batched_trsv(1, i); + // FIXME: ConjTranspose with blocking is not implemented yet + if constexpr (!std::is_same_v) { + Test::Trsv::impl_test_batched_trsv_blocking(1, i); + } + Test::Trsv::impl_test_batched_trsv(1, + i); + Test::Trsv::impl_test_batched_trsv(2, + i); } } #endif #if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) { - typedef Kokkos::View ViewType; - Test::Trsv::impl_test_batched_trsv(0, 10); + using LayoutType = Kokkos::LayoutLeft; + Test::Trsv::impl_test_batched_trsv_analytical(1); + Test::Trsv::impl_test_batched_trsv_analytical(2); + + // FIXME: ConjTranspose with blocking is not implemented yet + if constexpr (!std::is_same_v) { + Test::Trsv::impl_test_batched_trsv_blocking(0, 10); + } + for (int i = 0; i < 10; ++i) { - // printf("Testing: LayoutRight, Blksize %d, Uplo %d, Trans %d, Diag - // %d\n", - // i, - // std::is_same::value, - // std::is_same::value, std::is_same::value); - Test::Trsv::impl_test_batched_trsv(1, i); + // FIXME: ConjTranspose with blocking is not implemented yet + if constexpr (!std::is_same_v) { + Test::Trsv::impl_test_batched_trsv_blocking(1, i); + } + Test::Trsv::impl_test_batched_trsv(1, + i); + Test::Trsv::impl_test_batched_trsv(2, + i); } } #endif diff --git a/batched/dense/unit_test/Test_Batched_SerialTrsv_Complex.hpp b/batched/dense/unit_test/Test_Batched_SerialTrsv_Complex.hpp index 73f0e65ed9..470bba3320 100644 --- a/batched/dense/unit_test/Test_Batched_SerialTrsv_Complex.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialTrsv_Complex.hpp @@ -14,46 +14,130 @@ // //@HEADER +#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) +TEST_F(TestCategory, batched_scalar_serial_trsv_l_nt_u_fcomplex_fcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_l_nt_n_fcomplex_fcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_l_t_u_fcomplex_fcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_l_t_n_fcomplex_fcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_l_c_u_fcomplex_fcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_l_c_n_fcomplex_fcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} + +TEST_F(TestCategory, batched_scalar_serial_trsv_u_nt_u_fcomplex_fcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_u_nt_n_fcomplex_fcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_u_t_u_fcomplex_fcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_u_t_n_fcomplex_fcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_u_c_u_fcomplex_fcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_u_c_n_fcomplex_fcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +#endif + #if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) TEST_F(TestCategory, batched_scalar_serial_trsv_l_nt_u_dcomplex_dcomplex) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } TEST_F(TestCategory, batched_scalar_serial_trsv_l_nt_n_dcomplex_dcomplex) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_l_t_u_dcomplex_dcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_l_t_n_dcomplex_dcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_l_c_u_dcomplex_dcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_l_c_n_dcomplex_dcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } + TEST_F(TestCategory, batched_scalar_serial_trsv_u_nt_u_dcomplex_dcomplex) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } TEST_F(TestCategory, batched_scalar_serial_trsv_u_nt_n_dcomplex_dcomplex) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } - -TEST_F(TestCategory, batched_scalar_serial_trsv_l_nt_u_dcomplex_double) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; - test_batched_trsv, double, param_tag_type, algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_serial_trsv_l_nt_n_dcomplex_double) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; - test_batched_trsv, double, param_tag_type, algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_serial_trsv_u_nt_u_dcomplex_double) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; - test_batched_trsv, double, param_tag_type, algo_tag_type>(); -} -TEST_F(TestCategory, batched_scalar_serial_trsv_u_nt_n_dcomplex_double) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; - test_batched_trsv, double, param_tag_type, algo_tag_type>(); +TEST_F(TestCategory, batched_scalar_serial_trsv_u_t_u_dcomplex_dcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_u_t_n_dcomplex_dcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_u_c_u_dcomplex_dcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_u_c_n_dcomplex_dcomplex) { + using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } #endif diff --git a/batched/dense/unit_test/Test_Batched_SerialTrsv_Real.hpp b/batched/dense/unit_test/Test_Batched_SerialTrsv_Real.hpp index 5998232605..ebb9d8e1c3 100644 --- a/batched/dense/unit_test/Test_Batched_SerialTrsv_Real.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialTrsv_Real.hpp @@ -16,46 +16,86 @@ #if defined(KOKKOSKERNELS_INST_FLOAT) TEST_F(TestCategory, batched_scalar_serial_trsv_l_nt_u_float_float) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv(); } TEST_F(TestCategory, batched_scalar_serial_trsv_l_nt_n_float_float) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_l_t_u_float_float) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_l_t_n_float_float) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv(); } TEST_F(TestCategory, batched_scalar_serial_trsv_u_nt_u_float_float) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv(); } TEST_F(TestCategory, batched_scalar_serial_trsv_u_nt_n_float_float) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_u_t_u_float_float) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_u_t_n_float_float) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv(); } #endif #if defined(KOKKOSKERNELS_INST_DOUBLE) TEST_F(TestCategory, batched_scalar_serial_trsv_l_nt_u_double_double) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv(); } TEST_F(TestCategory, batched_scalar_serial_trsv_l_nt_n_double_double) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_l_t_u_double_double) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_l_t_n_double_double) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv(); } TEST_F(TestCategory, batched_scalar_serial_trsv_u_nt_u_double_double) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv(); } TEST_F(TestCategory, batched_scalar_serial_trsv_u_nt_n_double_double) { - typedef ::Test::Trsv::ParamTag param_tag_type; - typedef Algo::Trsv::Blocked algo_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_u_t_u_double_double) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; + test_batched_trsv(); +} +TEST_F(TestCategory, batched_scalar_serial_trsv_u_t_n_double_double) { + using param_tag_type = ::Test::Trsv::ParamTag; + using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv(); } #endif From 651371fa7d3e80d57477f0f4c195c4ff0fb0a1ca Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Fri, 6 Dec 2024 22:39:27 +0900 Subject: [PATCH 3/8] format Signed-off-by: Yuuichi Asahi --- batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp index 424f1680e1..d01b11810e 100644 --- a/batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Trsv_Serial_Internal.hpp @@ -108,7 +108,7 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalLower::invoke( // trsm update const ValueType *KOKKOS_RESTRICT Ap = A + p * as0 + p * as1; - /**/ ValueType *KOKKOS_RESTRICT bp = b + p * bs0; + /**/ ValueType *KOKKOS_RESTRICT bp = b + p * bs0; if (use_unit_diag) trsm_u.serial_invoke(Ap, pb, 1, bp); @@ -154,8 +154,8 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invok for (int p = (m - 1); p >= 0; --p) { const int iend = p; - const ValueType *KOKKOS_RESTRICT a01 = A + p * as1; - /**/ ValueType *KOKKOS_RESTRICT beta1 = b + p * bs0; + const ValueType *KOKKOS_RESTRICT a01 = A + p * as1; + /**/ ValueType *KOKKOS_RESTRICT beta1 = b + p * bs0; // with KOKKOS_RESTRICT a compiler assumes that the pointer is not // accessed by others op(/=) uses this pointer and changes the associated @@ -198,7 +198,7 @@ KOKKOS_INLINE_FUNCTION int SerialTrsvInternalUpper::invoke( // trsm update const ValueType *KOKKOS_RESTRICT Ap = A + p * as0 + p * as1; - /**/ ValueType *KOKKOS_RESTRICT bp = b + p * bs0; + /**/ ValueType *KOKKOS_RESTRICT bp = b + p * bs0; if (use_unit_diag) trsm_u.serial_invoke(Ap, pb, 1, bp); From 87b2be85bdfd80f47341fc8ddd0bea4057d239a0 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Fri, 6 Dec 2024 22:57:05 +0900 Subject: [PATCH 4/8] fix code QR Signed-off-by: Yuuichi Asahi --- .../unit_test/Test_Batched_SerialTrsv.hpp | 2 -- .../Test_Batched_SerialTrsv_Complex.hpp | 24 +++++++++---------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp b/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp index ddc965f1d4..62e6e87601 100644 --- a/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp @@ -194,7 +194,6 @@ void impl_test_batched_trsv_analytical(const std::size_t N) { using View2DType = Kokkos::View; using StridedView2DType = Kokkos::View; using View3DType = Kokkos::View; - using PivView2DType = Kokkos::View; constexpr std::size_t BlkSize = 3, incx = 2; View3DType A("A", N, BlkSize, BlkSize); @@ -301,7 +300,6 @@ void impl_test_batched_trsv(const std::size_t N, const std::size_t BlkSize) { using View2DType = Kokkos::View; using StridedView2DType = Kokkos::View; using View3DType = Kokkos::View; - using PivView2DType = Kokkos::View; constexpr std::size_t incx = 2; View3DType A("A", N, BlkSize, BlkSize), AT("AT", N, BlkSize, BlkSize); diff --git a/batched/dense/unit_test/Test_Batched_SerialTrsv_Complex.hpp b/batched/dense/unit_test/Test_Batched_SerialTrsv_Complex.hpp index 470bba3320..7dfc77b07c 100644 --- a/batched/dense/unit_test/Test_Batched_SerialTrsv_Complex.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialTrsv_Complex.hpp @@ -80,63 +80,63 @@ TEST_F(TestCategory, batched_scalar_serial_trsv_u_c_n_fcomplex_fcomplex) { #if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) TEST_F(TestCategory, batched_scalar_serial_trsv_l_nt_u_dcomplex_dcomplex) { - using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } TEST_F(TestCategory, batched_scalar_serial_trsv_l_nt_n_dcomplex_dcomplex) { - using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } TEST_F(TestCategory, batched_scalar_serial_trsv_l_t_u_dcomplex_dcomplex) { - using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } TEST_F(TestCategory, batched_scalar_serial_trsv_l_t_n_dcomplex_dcomplex) { - using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } TEST_F(TestCategory, batched_scalar_serial_trsv_l_c_u_dcomplex_dcomplex) { - using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } TEST_F(TestCategory, batched_scalar_serial_trsv_l_c_n_dcomplex_dcomplex) { - using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } TEST_F(TestCategory, batched_scalar_serial_trsv_u_nt_u_dcomplex_dcomplex) { - using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } TEST_F(TestCategory, batched_scalar_serial_trsv_u_nt_n_dcomplex_dcomplex) { - using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } TEST_F(TestCategory, batched_scalar_serial_trsv_u_t_u_dcomplex_dcomplex) { - using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } TEST_F(TestCategory, batched_scalar_serial_trsv_u_t_n_dcomplex_dcomplex) { - using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } TEST_F(TestCategory, batched_scalar_serial_trsv_u_c_u_dcomplex_dcomplex) { - using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } TEST_F(TestCategory, batched_scalar_serial_trsv_u_c_n_dcomplex_dcomplex) { - using param_tag_type = ::Test::Trsv::ParamTag param_tag_type; + using param_tag_type = ::Test::Trsv::ParamTag; using algo_tag_type = Algo::Trsv::Unblocked; test_batched_trsv, Kokkos::complex, param_tag_type, algo_tag_type>(); } From 2a62d0819e39220d4822ee3774951c17d183faba Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Sun, 8 Dec 2024 04:06:45 +0900 Subject: [PATCH 5/8] simplify tests for serial trsv Signed-off-by: Yuuichi Asahi --- .../dense/unit_test/Test_Batched_SerialTrsv.hpp | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp b/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp index 62e6e87601..e8b501373c 100644 --- a/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp @@ -319,29 +319,19 @@ void impl_test_batched_trsv(const std::size_t N, const std::size_t BlkSize) { Kokkos::deep_copy(x_s, x); // Create triangluar matrix - create_triangular_matrix(A, AT); + create_triangular_matrix(A, AT); // trsv to solve U * x = b or L * x = b Functor_BatchedSerialTrsv(1.0, - AT, x) + A, x) .run(); Functor_BatchedSerialTrsv( - 1.0, AT, x_s) + 1.0, A, x_s) .run(); Kokkos::fence(); - if constexpr (std::is_same_v) { - auto h_AT = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), AT); - for (std::size_t ib = 0; ib < N; ib++) { - for (std::size_t i = 0; i < BlkSize; i++) { - h_AT(ib, i, i) = 1.0; - } - } - Kokkos::deep_copy(AT, h_AT); - } - // Compute A * x by gemv // Gemv to compute A*x, this should be identical to b Functor_BatchedSerialGemv(1.0, AT, x, 0.0, From bc344df74d34e5e7023d7b049aafad1feedf0762 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Sun, 8 Dec 2024 04:10:58 +0900 Subject: [PATCH 6/8] format Signed-off-by: Yuuichi Asahi --- batched/dense/unit_test/Test_Batched_SerialTrsv.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp b/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp index e8b501373c..8b8edc65a5 100644 --- a/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialTrsv.hpp @@ -322,8 +322,8 @@ void impl_test_batched_trsv(const std::size_t N, const std::size_t BlkSize) { create_triangular_matrix(A, AT); // trsv to solve U * x = b or L * x = b - Functor_BatchedSerialTrsv(1.0, - A, x) + Functor_BatchedSerialTrsv(1.0, A, + x) .run(); Functor_BatchedSerialTrsv( From 83c3d7144434e1d01b825fcedca147ec99886ba5 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 10 Dec 2024 10:02:39 +0900 Subject: [PATCH 7/8] static-assertion check outside debug mode Signed-off-by: Yuuichi Asahi --- batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp index 87d32258b8..724d8ecd2a 100644 --- a/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp @@ -27,12 +27,11 @@ namespace Impl { template KOKKOS_INLINE_FUNCTION static int checkTrsvInput([[maybe_unused]] const AViewType &A, [[maybe_unused]] const bViewType &b) { -#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) static_assert(Kokkos::is_view_v, "KokkosBatched::trsv: AViewType is not a Kokkos::View."); static_assert(Kokkos::is_view_v, "KokkosBatched::trsv: bViewType is not a Kokkos::View."); static_assert(AViewType::rank == 2, "KokkosBatched::trsv: AViewType must have rank 2."); static_assert(bViewType::rank == 1, "KokkosBatched::trsv: bViewType must have rank 1."); - +#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) const int lda = A.extent(0), n = A.extent(1); if (lda < Kokkos::max(1, n)) { Kokkos::printf( From 78a206608cfb02dd76447e05bc6f29fee9848a77 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Tue, 10 Dec 2024 13:10:26 +0900 Subject: [PATCH 8/8] disallow vector_length=8 if AVX512 is unavailable Signed-off-by: Yuuichi Asahi --- .../impl/KokkosBatched_Trsv_Serial_Impl.hpp | 48 ++++++++++++------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp index 724d8ecd2a..c542294657 100644 --- a/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp +++ b/batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp @@ -53,14 +53,16 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - typedef typename bViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - + using vector_type = typename bViewType::value_type; const int m = b.extent(0), n = 1; static_assert(is_vector::value, "value type is not vector type"); +#if defined(KOKKOS_ARCH_AVX512XEON) static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, "AVX, AVX2 and AVX512 is supported"); +#else + static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported"); +#endif const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; // no error check @@ -116,14 +118,16 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - typedef typename bViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - + using vector_type = typename bViewType::value_type; const int m = b.extent(0), n = 1; static_assert(is_vector::value, "value type is not vector type"); +#if defined(KOKKOS_ARCH_AVX512XEON) static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, "AVX, AVX2 and AVX512 is supported"); +#else + static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported"); +#endif const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; // no error check @@ -179,14 +183,16 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - typedef typename bViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - + using vector_type = typename bViewType::value_type; const int m = b.extent(0), n = 1; static_assert(is_vector::value, "value type is not vector type"); +#if defined(KOKKOS_ARCH_AVX512XEON) static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, "AVX, AVX2 and AVX512 is supported"); +#else + static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported"); +#endif const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; // no error check @@ -242,14 +248,16 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - typedef typename bViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - + using vector_type = typename bViewType::value_type; const int m = b.extent(0), n = 1; static_assert(is_vector::value, "value type is not vector type"); +#if defined(KOKKOS_ARCH_AVX512XEON) static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, "AVX, AVX2 and AVX512 is supported"); +#else + static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported"); +#endif const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; // no error check @@ -305,14 +313,16 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - typedef typename bViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - + using vector_type = typename bViewType::value_type; const int m = b.extent(0), n = 1; static_assert(is_vector::value, "value type is not vector type"); +#if defined(KOKKOS_ARCH_AVX512XEON) static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, "AVX, AVX2 and AVX512 is supported"); +#else + static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported"); +#endif const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; // no error check @@ -368,14 +378,16 @@ template struct SerialTrsv { template KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) { - typedef typename bViewType::value_type vector_type; - // typedef typename vector_type::value_type value_type; - + using vector_type = typename bViewType::value_type; const int m = b.extent(0), n = 1; static_assert(is_vector::value, "value type is not vector type"); +#if defined(KOKKOS_ARCH_AVX512XEON) static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8, "AVX, AVX2 and AVX512 is supported"); +#else + static_assert(vector_type::vector_length == 4, "AVX and AVX2 is supported"); +#endif const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX; // no error check