Skip to content

Commit

Permalink
Add missing specialization of Serial Trsv
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
Yuuichi Asahi committed Dec 6, 2024
1 parent 8def092 commit 84c1d54
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 80 deletions.
251 changes: 209 additions & 42 deletions batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,37 @@
#define KOKKOSBATCHED_TRSV_SERIAL_IMPL_HPP

/// \author Kyungjoo Kim ([email protected])
/// \author Yuuichi Asahi ([email protected])

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Trsv_Serial_Internal.hpp"

namespace KokkosBatched {
namespace Impl {
template <typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int checkTrsvInput([[maybe_unused]] const AViewType &A,
[[maybe_unused]] const bViewType &b) {
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::trsv: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<bViewType>, "KokkosBatched::trsv: bViewType is not a Kokkos::View.");
static_assert(AViewType::rank == 2, "KokkosBatched::trsv: AViewType must have rank 2.");
static_assert(bViewType::rank == 1, "KokkosBatched::trsv: bViewType must have rank 1.");

const int lda = A.extent(0), n = A.extent(1);
if (lda < Kokkos::max(1, n)) {
Kokkos::printf(
"KokkosBatched::trsv: leading dimension of A must not be smaller than "
"max(1, n): "
"lda = %d, n = %d\n",
lda, n);
return 1;
}
#endif
return 0;
}
} // namespace Impl

///
/// Serial Impl
/// ===========

///
/// Implemented:
/// L/NT, U/NT, L/T, U/T
///
/// Not yet implemented
/// L/CT, U/CT

///
/// L/NT
///

//// Lower non-transpose ////
#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__)
template <typename ArgDiag>
Expand Down Expand Up @@ -76,24 +86,31 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::Unblocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
return SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(ArgDiag::use_unit_diag, A.extent(0), alpha, A.data(),
A.stride_0(), A.stride_1(), b.data(), b.stride_0());
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
}
};

template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::Blocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
return SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(ArgDiag::use_unit_diag, A.extent(0), alpha, A.data(),
A.stride_0(), A.stride_1(), b.data(), b.stride_0());
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
}
};

///
/// L/T
///

//// Lower transpose ////
#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__)
template <typename ArgDiag>
Expand Down Expand Up @@ -132,24 +149,94 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::Unblocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
return SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(ArgDiag::use_unit_diag, A.extent(1), alpha, A.data(),
A.stride_1(), A.stride_0(), b.data(), b.stride_0());
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
}
};

template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::Blocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
return SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(ArgDiag::use_unit_diag, A.extent(1), alpha, A.data(),
A.stride_1(), A.stride_0(), b.data(), b.stride_0());
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
}
};

//// Lower conjugate-transpose ////
#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__)
template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
typedef typename bViewType::value_type vector_type;
// typedef typename vector_type::value_type value_type;

const int m = b.extent(0), n = 1;

static_assert(is_vector<vector_type>::value, "value type is not vector type");
static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8,
"AVX, AVX2 and AVX512 is supported");
const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX;

// no error check
int r_val = 0;
if (A.stride_0() == 1) {
mkl_dtrsm_compact(MKL_COL_MAJOR, MKL_LEFT, MKL_LOWER, MKL_CONJTRANS,
ArgDiag::use_unit_diag ? MKL_UNIT : MKL_NONUNIT, m, n, alpha, (const double *)A.data(),
A.stride_0(), (double *)b.data(), b.stride_0(), format, (MKL_INT)vector_type::vector_length);
} else if (A.stride_1() == 1) {
mkl_dtrsm_compact(MKL_ROW_MAJOR, MKL_LEFT, MKL_LOWER, MKL_CONJTRANS,
ArgDiag::use_unit_diag ? MKL_UNIT : MKL_NONUNIT, m, n, alpha, (const double *)A.data(),
A.stride_0(), (double *)b.data(), b.stride_0(), format, (MKL_INT)vector_type::vector_length);
} else {
r_val = -1;
}
return r_val;
}
};
#endif

template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Unblocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
}
};

///
/// U/NT
///
template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Blocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
}
};

//// Upper non-transpose ////
#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__)
template <typename ArgDiag>
Expand Down Expand Up @@ -188,24 +275,31 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::Unblocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
return SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(ArgDiag::use_unit_diag, A.extent(0), alpha, A.data(),
A.stride_0(), A.stride_1(), b.data(), b.stride_0());
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
}
};

template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::Blocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
return SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(ArgDiag::use_unit_diag, A.extent(0), alpha, A.data(),
A.stride_0(), A.stride_1(), b.data(), b.stride_0());
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
}
};

///
/// U/T
///

//// Upper transpose ////
#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__)
template <typename ArgDiag>
Expand Down Expand Up @@ -244,17 +338,90 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::Unblocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
return SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(ArgDiag::use_unit_diag, A.extent(1), alpha, A.data(),
A.stride_1(), A.stride_0(), b.data(), b.stride_0());
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
}
};

template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::Blocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
return SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(ArgDiag::use_unit_diag, A.extent(1), alpha, A.data(),
A.stride_1(), A.stride_0(), b.data(), b.stride_0());
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
}
};

//// Upper conjugate-transpose ////
#if defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL) && defined(KOKKOSBATCHED_IMPL_ENABLE_INTEL_MKL_BATCHED) && \
defined(__KOKKOSBATCHED_ENABLE_INTEL_MKL_COMPACT_BATCHED__)
template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
typedef typename bViewType::value_type vector_type;
// typedef typename vector_type::value_type value_type;

const int m = b.extent(0), n = 1;

static_assert(is_vector<vector_type>::value, "value type is not vector type");
static_assert(vector_type::vector_length == 4 || vector_type::vector_length == 8,
"AVX, AVX2 and AVX512 is supported");
const MKL_COMPACT_PACK format = vector_type::vector_length == 8 ? MKL_COMPACT_AVX512 : MKL_COMPACT_AVX;

// no error check
int r_val = 0;
if (A.stride_0() == 1) {
mkl_dtrsm_compact(MKL_COL_MAJOR, MKL_LEFT, MKL_UPPER, MKL_CONJTRANS,
ArgDiag::use_unit_diag ? MKL_UNIT : MKL_NONUNIT, m, n, alpha, (const double *)A.data(),
A.stride_0(), (double *)b.data(), b.stride_0(), format, (MKL_INT)vector_type::vector_length);
} else if (A.stride_1() == 1) {
mkl_dtrsm_compact(MKL_ROW_MAJOR, MKL_LEFT, MKL_UPPER, MKL_CONJTRANS,
ArgDiag::use_unit_diag ? MKL_UNIT : MKL_NONUNIT, m, n, alpha, (const double *)A.data(),
A.stride_0(), (double *)b.data(), b.stride_0(), format, (MKL_INT)vector_type::vector_length);
} else {
r_val = -1;
}
return r_val;
}
};
#endif

template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Unblocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
}
};

template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Blocked> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
}
};

Expand Down
Loading

0 comments on commit 84c1d54

Please sign in to comment.