Skip to content

Commit

Permalink
update host backend based on revies
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Dec 17, 2024
1 parent b87d799 commit 7a6946c
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 24 deletions.
50 changes: 28 additions & 22 deletions fft/src/KokkosFFT_FFTW_Types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ namespace Impl {
enum class FFTWTransformType { R2C, D2Z, C2R, Z2D, C2C, Z2Z };

// Define fft transform types
template <typename ExecutionSpace, typename T1, typename T2>
template <typename T1, typename T2>
struct fftw_transform_type {
static_assert(std::is_same_v<T1, T2>,
"Real to real transform is unavailable");
};

template <typename ExecutionSpace, typename T1, typename T2>
struct fftw_transform_type<ExecutionSpace, T1, Kokkos::complex<T2>> {
template <typename T1, typename T2>
struct fftw_transform_type<T1, Kokkos::complex<T2>> {
static_assert(std::is_same_v<T1, T2>,
"T1 and T2 should have the same precision");
static constexpr FFTWTransformType m_type = std::is_same_v<T1, float>
Expand All @@ -39,8 +39,8 @@ struct fftw_transform_type<ExecutionSpace, T1, Kokkos::complex<T2>> {
static constexpr FFTWTransformType type() { return m_type; };
};

template <typename ExecutionSpace, typename T1, typename T2>
struct fftw_transform_type<ExecutionSpace, Kokkos::complex<T1>, T2> {
template <typename T1, typename T2>
struct fftw_transform_type<Kokkos::complex<T1>, T2> {
static_assert(std::is_same_v<T1, T2>,
"T1 and T2 should have the same precision");
static constexpr FFTWTransformType m_type = std::is_same_v<T2, float>
Expand All @@ -49,9 +49,8 @@ struct fftw_transform_type<ExecutionSpace, Kokkos::complex<T1>, T2> {
static constexpr FFTWTransformType type() { return m_type; };
};

template <typename ExecutionSpace, typename T1, typename T2>
struct fftw_transform_type<ExecutionSpace, Kokkos::complex<T1>,
Kokkos::complex<T2>> {
template <typename T1, typename T2>
struct fftw_transform_type<Kokkos::complex<T1>, Kokkos::complex<T2>> {
static_assert(std::is_same_v<T1, T2>,
"T1 and T2 should have the same precision");
static constexpr FFTWTransformType m_type = std::is_same_v<T1, float>
Expand All @@ -78,7 +77,7 @@ struct ScopedFFTWPlan {
int idist, OutScalarType *out, const int *onembed, int ostride,
int odist, [[maybe_unused]] int sign, unsigned flags) {
init_threads(exec_space);
constexpr auto type = fftw_transform_type<ExecutionSpace, T1, T2>::type();
constexpr auto type = fftw_transform_type<T1, T2>::type();
if constexpr (type == KokkosFFT::Impl::FFTWTransformType::R2C) {
m_plan =
fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist,
Expand All @@ -105,6 +104,7 @@ struct ScopedFFTWPlan {
}
m_is_created = true;
}

~ScopedFFTWPlan() noexcept {
cleanup_threads();
if constexpr (std::is_same_v<plan_type, fftwf_plan>) {
Expand All @@ -120,29 +120,35 @@ struct ScopedFFTWPlan {
ScopedFFTWPlan &operator=(ScopedFFTWPlan &&) = delete;
ScopedFFTWPlan(ScopedFFTWPlan &&) = delete;

const plan_type &plan() const { return m_plan; }
plan_type plan() const noexcept { return m_plan; }

private:
void init_threads([[maybe_unused]] const ExecutionSpace &exec_space) {
#if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS)
int nthreads = exec_space.concurrency();

if constexpr (std::is_same_v<plan_type, fftwf_plan>) {
fftwf_init_threads();
fftwf_plan_with_nthreads(nthreads);
} else {
fftw_init_threads();
fftw_plan_with_nthreads(nthreads);
if constexpr (std::is_same_v<ExecutionSpace,
Kokkos::DefaultHostExecutionSpace>) {
int nthreads = exec_space.concurrency();

if constexpr (std::is_same_v<plan_type, fftwf_plan>) {
fftwf_init_threads();
fftwf_plan_with_nthreads(nthreads);
} else {
fftw_init_threads();
fftw_plan_with_nthreads(nthreads);
}
}
#endif
}

void cleanup_threads() {
#if defined(KOKKOS_ENABLE_OPENMP) || defined(KOKKOS_ENABLE_THREADS)
if constexpr (std::is_same_v<plan_type, fftwf_plan>) {
fftwf_cleanup_threads();
} else {
fftw_cleanup_threads();
if constexpr (std::is_same_v<ExecutionSpace,
Kokkos::DefaultHostExecutionSpace>) {
if constexpr (std::is_same_v<plan_type, fftwf_plan>) {
fftwf_cleanup_threads();
} else {
fftw_cleanup_threads();
}
}
#endif
}
Expand Down
1 change: 1 addition & 0 deletions fft/src/KokkosFFT_Host_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

namespace KokkosFFT {
namespace Impl {

template <typename ScopedPlanType>
void exec_plan(ScopedPlanType& scoped_plan, float* idata, fftwf_complex* odata,
int /*direction*/) {
Expand Down
2 changes: 1 addition & 1 deletion fft/src/KokkosFFT_Host_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ template <typename ExecutionSpace>
using TransformType = FFTWTransformType;

template <typename ExecutionSpace, typename T1, typename T2>
using transform_type = fftw_transform_type<ExecutionSpace, T1, T2>;
using transform_type = fftw_transform_type<T1, T2>;

template <typename ExecutionSpace, typename T1, typename T2>
struct FFTPlanType {
Expand Down
2 changes: 1 addition & 1 deletion fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class Plan {
direction, axes, s, m_is_inplace);
}

~Plan() {}
~Plan() noexcept = default;

Plan() = delete;
Plan(const Plan&) = delete;
Expand Down

0 comments on commit 7a6946c

Please sign in to comment.