Skip to content

Commit

Permalink
Allow dyn-rank-view in serial trsv
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
Yuuichi Asahi committed Dec 17, 2024
1 parent ad284f0 commit 84fa11b
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 46 deletions.
93 changes: 77 additions & 16 deletions batched/dense/impl/KokkosBatched_Trsv_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
/// \author Kyungjoo Kim ([email protected])
/// \author Yuuichi Asahi ([email protected])

#include <Kokkos_DynRankView.hpp>
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Trsv_Serial_Internal.hpp"

Expand All @@ -27,11 +28,27 @@ namespace Impl {
template <typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int checkTrsvInput([[maybe_unused]] const AViewType &A,
[[maybe_unused]] const bViewType &b) {
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::trsv: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<bViewType>, "KokkosBatched::trsv: bViewType is not a Kokkos::View.");
static_assert(AViewType::rank == 2, "KokkosBatched::trsv: AViewType must have rank 2.");
static_assert(bViewType::rank == 1, "KokkosBatched::trsv: bViewType must have rank 1.");
static_assert(Kokkos::is_view_v<AViewType> || Kokkos::is_dyn_rank_view_v<AViewType>,
"KokkosBatched::trsv: AViewType must be either a Kokkos::View or a Kokkos::DynRankView.");
static_assert(Kokkos::is_view_v<bViewType> || Kokkos::is_dyn_rank_view_v<bViewType>,
"KokkosBatched::trsv: bViewType must be either a Kokkos::View or a Kokkos::DynRankView.");
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
if (A.rank() != 2) {
Kokkos::printf(
"KokkosBatched::trsv: A must be a rank 2 View."
"A.rank() = %d\n",
A.rank());
return 1;
}

if (b.rank() != 1) {
Kokkos::printf(
"KokkosBatched::trsv: b must be a rank 1 View."
"b.rank() = %d\n",
b.rank());
return 1;
}

const int lda = A.extent(0), n = A.extent(1);
if (lda < Kokkos::max(1, n)) {
Kokkos::printf(
Expand All @@ -53,6 +70,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -93,7 +116,8 @@ struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::Unblocke
auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride_0(), A.stride_1(), b.data(),
b.stride_0());
}
};

Expand All @@ -107,7 +131,8 @@ struct SerialTrsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Trsv::Blocked>
auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride_0(), A.stride_1(), b.data(),
b.stride_0());
}
};

Expand All @@ -118,6 +143,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -158,7 +189,8 @@ struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::Unblocked>
auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(),
b.stride_0());
}
};

Expand All @@ -172,7 +204,8 @@ struct SerialTrsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Trsv::Blocked> {
auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(),
b.stride_0());
}
};

Expand All @@ -183,6 +216,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -223,7 +262,7 @@ struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Unbloc
auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(), b.stride_0());
}
};

Expand All @@ -237,7 +276,7 @@ struct SerialTrsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Blocke
auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(), b.stride_0());
}
};

Expand All @@ -248,6 +287,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -288,7 +333,8 @@ struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::Unblocke
auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride_0(), A.stride_1(), b.data(),
b.stride_0());
}
};

Expand All @@ -302,7 +348,8 @@ struct SerialTrsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Trsv::Blocked>
auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalUpper<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride(0), A.stride(1), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(0), alpha, A.data(), A.stride_0(), A.stride_1(), b.data(),
b.stride_0());
}
};

Expand All @@ -313,6 +360,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -353,7 +406,8 @@ struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::Unblocked>
auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(),
b.stride_0());
}
};

Expand All @@ -367,7 +421,8 @@ struct SerialTrsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Trsv::Blocked> {
auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, false, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(),
b.stride_0());
}
};

Expand All @@ -378,6 +433,12 @@ template <typename ArgDiag>
struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::CompactMKL> {
template <typename ScalarType, typename AViewType, typename bViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const bViewType &b) {
// Quick return if possible
if (A.extent(1) == 0) return 0;

auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;

using vector_type = typename bViewType::value_type;
const int m = b.extent(0), n = 1;

Expand Down Expand Up @@ -418,7 +479,7 @@ struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Unbloc
auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(), b.stride_0());
}
};

Expand All @@ -432,7 +493,7 @@ struct SerialTrsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Trsv::Blocke
auto info = KokkosBatched::Impl::checkTrsvInput(A, b);
if (info) return info;
return KokkosBatched::Impl::SerialTrsvInternalLower<Algo::Trsv::Blocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride(1), A.stride(0), b.data(), b.stride(0));
ArgDiag::use_unit_diag, true, A.extent(1), alpha, A.data(), A.stride_1(), A.stride_0(), b.data(), b.stride_0());
}
};

Expand Down
Loading

0 comments on commit 84fa11b

Please sign in to comment.