Skip to content

Commit

Permalink
simplify tests for serial trsv
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
Yuuichi Asahi committed Dec 10, 2024
1 parent 87b2be8 commit 2a62d08
Showing 1 changed file with 3 additions and 13 deletions.
16 changes: 3 additions & 13 deletions batched/dense/unit_test/Test_Batched_SerialTrsv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,29 +319,19 @@ void impl_test_batched_trsv(const std::size_t N, const std::size_t BlkSize) {
Kokkos::deep_copy(x_s, x);

// Create triangluar matrix
create_triangular_matrix<View3DType, View3DType, typename ParamTagType::uplo, KokkosBatched::Diag::NonUnit>(A, AT);
create_triangular_matrix<View3DType, View3DType, typename ParamTagType::uplo, typename ParamTagType::diag>(A, AT);

// trsv to solve U * x = b or L * x = b
Functor_BatchedSerialTrsv<DeviceType, View3DType, View2DType, ScalarType, ParamTagType, Algo::Trsv::Unblocked>(1.0,
AT, x)
A, x)
.run();

Functor_BatchedSerialTrsv<DeviceType, View3DType, StridedView2DType, ScalarType, ParamTagType, Algo::Trsv::Unblocked>(
1.0, AT, x_s)
1.0, A, x_s)
.run();

Kokkos::fence();

if constexpr (std::is_same_v<typename ParamTagType::diag, KokkosBatched::Diag::Unit>) {
auto h_AT = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), AT);
for (std::size_t ib = 0; ib < N; ib++) {
for (std::size_t i = 0; i < BlkSize; i++) {
h_AT(ib, i, i) = 1.0;
}
}
Kokkos::deep_copy(AT, h_AT);
}

// Compute A * x by gemv
// Gemv to compute A*x, this should be identical to b
Functor_BatchedSerialGemv<DeviceType, ScalarType, View3DType, View2DType, View2DType, ParamTagType>(1.0, AT, x, 0.0,
Expand Down

0 comments on commit 2a62d08

Please sign in to comment.