Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RAII issue by introducing wrapper classes for backend plans #208

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
8fb15ac
fix: conflicts
Dec 5, 2024
737f3ca
Fix and wrapper for FFTW handle
Dec 3, 2024
95419cb
Wrapper for cufft handle
Dec 3, 2024
9baa59d
fix: conflicts
Dec 5, 2024
8d952e5
Wrapper for rocfft handle
Dec 3, 2024
dd907a8
fix: conflicts
Dec 3, 2024
6b1059b
Cleanup plan class based on the introduction of wrappers
Dec 3, 2024
f46c56d
fix: conflicts
Dec 5, 2024
5116c1c
fix: conflicts
Dec 5, 2024
b93c07c
fix: unused variable
Dec 3, 2024
1b908fe
fix: work buffer allocation
Dec 3, 2024
4010401
remove unused variable
Dec 3, 2024
ac5b34d
remove unused lines
Dec 3, 2024
734a55a
Add missing include header file in KokkosFFT_ROCM_types.hpp
Dec 3, 2024
f161d46
fix: fftwHandle type in SYCL types
Dec 4, 2024
31cd597
Do not return const plan type for fftw
Dec 4, 2024
2030d04
fix: remove const
Dec 4, 2024
a8741bd
fix: fftw plan creation
Dec 4, 2024
a3c94e3
fix: set created
Dec 4, 2024
0d8a616
fix: cleanup
Dec 4, 2024
f1f4f30
fix constructor of fftw wrapper
Dec 4, 2024
2b75678
fix: conflicts
Dec 5, 2024
88310c9
Remove non-default constructors from FFTW wrapper
Dec 5, 2024
995bb4a
Remove non-default constructors from cufft wrapper
Dec 5, 2024
751810c
Remove non-default constructors from hipfft wrapper
Dec 5, 2024
4ba04f0
Remove non-default constructors from rocfft wrapper
Dec 5, 2024
2eac65c
update FFTW wrapper class name
Dec 5, 2024
35afb3d
fix: host plan type
Dec 5, 2024
5c25bce
fix: fftw rapper name in ROCM_types
Dec 5, 2024
c363d3d
update cuda backed based on reviews
Dec 17, 2024
975b2f1
update hip backend based on reviews
Dec 17, 2024
405b36f
update rocm backend based on reviews
Dec 17, 2024
d394672
update host backend based on revies
Dec 17, 2024
7c085eb
fix: Rocm types
Dec 17, 2024
a6ccd56
fix: ROCM types
Dec 17, 2024
701136e
fix: Rocm types
Dec 17, 2024
622ac0e
fix: header files
Dec 17, 2024
1d33d7a
fix: rocm types
Dec 17, 2024
34ed2b3
fix: rocm types
Dec 18, 2024
45c489b
remove unused lines
Dec 18, 2024
3fdd5bf
fix: rocm types
Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 18 additions & 60 deletions fft/src/KokkosFFT_Cuda_plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@ namespace KokkosFFT {
namespace Impl {
// 1D transform
template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType, typename BufferViewType, typename InfoType,
typename OutViewType,
std::enable_if_t<InViewType::rank() == 1 &&
std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
auto create_plan(const ExecutionSpace& exec_space,
std::unique_ptr<PlanType>& plan, const InViewType& in,
const OutViewType& out, BufferViewType&, InfoType&,
Direction /*direction*/, axis_type<1> axes, shape_type<1> s,
bool is_inplace) {
const OutViewType& out, Direction /*direction*/,
axis_type<1> axes, shape_type<1> s, bool is_inplace) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand All @@ -34,37 +33,28 @@ auto create_plan(const ExecutionSpace& exec_space,
"InViewType and OutViewType.");
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes, s, is_inplace);
const int nx = fft_extents.at(0);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftPlan1d(&(*plan), nx, type, howmany);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan1d failed");

cudaStream_t stream = exec_space.cuda_stream();
cufft_rt = cufftSetStream((*plan), stream);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed");
plan = std::make_unique<PlanType>(exec_space, nx, type, howmany);

return fft_size;
}

// 2D transform
template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType, typename BufferViewType, typename InfoType,
typename OutViewType,
std::enable_if_t<InViewType::rank() == 2 &&
std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
auto create_plan(const ExecutionSpace& exec_space,
std::unique_ptr<PlanType>& plan, const InViewType& in,
const OutViewType& out, BufferViewType&, InfoType&,
Direction /*direction*/, axis_type<2> axes, shape_type<2> s,
bool is_inplace) {
const OutViewType& out, Direction /*direction*/,
axis_type<2> axes, shape_type<2> s, bool is_inplace) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand All @@ -75,37 +65,28 @@ auto create_plan(const ExecutionSpace& exec_space,
"InViewType and OutViewType.");
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
[[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] =
KokkosFFT::Impl::get_extents(in, out, axes, s, is_inplace);
const int nx = fft_extents.at(0), ny = fft_extents.at(1);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftPlan2d(&(*plan), nx, ny, type);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan2d failed");

cudaStream_t stream = exec_space.cuda_stream();
cufft_rt = cufftSetStream((*plan), stream);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed");
plan = std::make_unique<PlanType>(exec_space, nx, ny, type);

return fft_size;
}

// 3D transform
template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType, typename BufferViewType, typename InfoType,
typename OutViewType,
std::enable_if_t<InViewType::rank() == 3 &&
std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
auto create_plan(const ExecutionSpace& exec_space,
std::unique_ptr<PlanType>& plan, const InViewType& in,
const OutViewType& out, BufferViewType&, InfoType&,
Direction /*direction*/, axis_type<3> axes, shape_type<3> s,
bool is_inplace) {
const OutViewType& out, Direction /*direction*/,
axis_type<3> axes, shape_type<3> s, bool is_inplace) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand All @@ -116,7 +97,6 @@ auto create_plan(const ExecutionSpace& exec_space,
"InViewType and OutViewType.");
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

auto type = KokkosFFT::Impl::transform_type<ExecutionSpace, in_value_type,
out_value_type>::type();
[[maybe_unused]] auto [in_extents, out_extents, fft_extents, howmany] =
Expand All @@ -126,29 +106,21 @@ auto create_plan(const ExecutionSpace& exec_space,
nz = fft_extents.at(2);
int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1,
std::multiplies<>());

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftPlan3d(&(*plan), nx, ny, nz, type);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlan3d failed");

cudaStream_t stream = exec_space.cuda_stream();
cufft_rt = cufftSetStream((*plan), stream);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed");
plan = std::make_unique<PlanType>(exec_space, nx, ny, nz, type);

return fft_size;
}

// batched transform, over ND Views
template <typename ExecutionSpace, typename PlanType, typename InViewType,
typename OutViewType, typename BufferViewType, typename InfoType,
std::size_t fft_rank = 1,
typename OutViewType, std::size_t fft_rank = 1,
std::enable_if_t<std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
auto create_plan(const ExecutionSpace& exec_space,
std::unique_ptr<PlanType>& plan, const InViewType& in,
const OutViewType& out, BufferViewType&, InfoType&,
Direction /*direction*/, axis_type<fft_rank> axes,
shape_type<fft_rank> s, bool is_inplace) {
const OutViewType& out, Direction /*direction*/,
axis_type<fft_rank> axes, shape_type<fft_rank> s,
bool is_inplace) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
Expand Down Expand Up @@ -179,27 +151,13 @@ auto create_plan(const ExecutionSpace& exec_space,

// For the moment, considering the contiguous layout only
int istride = 1, ostride = 1;

plan = std::make_unique<PlanType>();
cufftResult cufft_rt = cufftPlanMany(
&(*plan), rank, fft_extents.data(), in_extents.data(), istride, idist,
plan = std::make_unique<PlanType>(
exec_space, rank, fft_extents.data(), in_extents.data(), istride, idist,
out_extents.data(), ostride, odist, type, howmany);

KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftPlanMany failed");

cudaStream_t stream = exec_space.cuda_stream();
cufft_rt = cufftSetStream((*plan), stream);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftSetStream failed");

return fft_size;
}

template <typename ExecutionSpace, typename PlanType, typename InfoType,
std::enable_if_t<std::is_same_v<ExecutionSpace, Kokkos::Cuda>,
std::nullptr_t> = nullptr>
void destroy_plan_and_info(std::unique_ptr<PlanType>& plan, InfoType&) {
cufftDestroy(*plan);
}
} // namespace Impl
} // namespace KokkosFFT

Expand Down
50 changes: 26 additions & 24 deletions fft/src/KokkosFFT_Cuda_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,47 @@

namespace KokkosFFT {
namespace Impl {
template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftReal* idata, cufftComplex* odata,
int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecR2C(plan, idata, odata);

struct ScopedCufftPlan;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we should include KokkosFFT_Cuda_type.hpp instead because the code below requires the complete definition of the ScopedCufftPlan ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure


inline void exec_plan(const ScopedCufftPlan& scoped_plan, cufftReal* idata,
cufftComplex* odata, int /*direction*/) {
cufftResult cufft_rt = cufftExecR2C(scoped_plan.plan(), idata, odata);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecR2C failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftDoubleReal* idata,
cufftDoubleComplex* odata, int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecD2Z(plan, idata, odata);
inline void exec_plan(const ScopedCufftPlan& scoped_plan,
cufftDoubleReal* idata, cufftDoubleComplex* odata,
int /*direction*/) {
cufftResult cufft_rt = cufftExecD2Z(scoped_plan.plan(), idata, odata);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecD2Z failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftComplex* idata, cufftReal* odata,
int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecC2R(plan, idata, odata);
inline void exec_plan(const ScopedCufftPlan& scoped_plan, cufftComplex* idata,
cufftReal* odata, int /*direction*/) {
cufftResult cufft_rt = cufftExecC2R(scoped_plan.plan(), idata, odata);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecC2R failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata,
cufftDoubleReal* odata, int /*direction*/, Args...) {
cufftResult cufft_rt = cufftExecZ2D(plan, idata, odata);
inline void exec_plan(const ScopedCufftPlan& scoped_plan,
cufftDoubleComplex* idata, cufftDoubleReal* odata,
int /*direction*/) {
cufftResult cufft_rt = cufftExecZ2D(scoped_plan.plan(), idata, odata);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecZ2D failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftComplex* idata,
cufftComplex* odata, int direction, Args...) {
cufftResult cufft_rt = cufftExecC2C(plan, idata, odata, direction);
inline void exec_plan(const ScopedCufftPlan& scoped_plan, cufftComplex* idata,
cufftComplex* odata, int direction) {
cufftResult cufft_rt =
cufftExecC2C(scoped_plan.plan(), idata, odata, direction);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecC2C failed");
}

template <typename... Args>
inline void exec_plan(cufftHandle& plan, cufftDoubleComplex* idata,
cufftDoubleComplex* odata, int direction, Args...) {
cufftResult cufft_rt = cufftExecZ2Z(plan, idata, odata, direction);
inline void exec_plan(const ScopedCufftPlan& scoped_plan,
cufftDoubleComplex* idata, cufftDoubleComplex* odata,
int direction) {
cufftResult cufft_rt =
cufftExecZ2Z(scoped_plan.plan(), idata, odata, direction);
KOKKOSFFT_THROW_IF(cufft_rt != CUFFT_SUCCESS, "cufftExecZ2Z failed");
}
} // namespace Impl
Expand Down
Loading
Loading