Skip to content

Commit

Permalink
Only allow float and double to repreesnt real values (#118)
Browse files Browse the repository at this point in the history
Co-authored-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Jul 15, 2024
1 parent 0e29048 commit aff3416
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 26 deletions.
4 changes: 2 additions & 2 deletions common/src/KokkosFFT_Helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ namespace KokkosFFT {
template <typename ExecutionSpace, typename RealType>
auto fftfreq(const ExecutionSpace&, const std::size_t n,
const RealType d = 1.0) {
static_assert(std::is_floating_point<RealType>::value,
static_assert(KokkosFFT::Impl::is_real_v<RealType>,
"fftfreq: d must be float or double");
using ViewType = Kokkos::View<RealType*, ExecutionSpace>;
ViewType freq("freq", n);
Expand Down Expand Up @@ -216,7 +216,7 @@ auto fftfreq(const ExecutionSpace&, const std::size_t n,
template <typename ExecutionSpace, typename RealType>
auto rfftfreq(const ExecutionSpace&, const std::size_t n,
const RealType d = 1.0) {
static_assert(std::is_floating_point<RealType>::value,
static_assert(KokkosFFT::Impl::is_real_v<RealType>,
"fftfreq: d must be float or double");
using ViewType = Kokkos::View<RealType*, ExecutionSpace>;

Expand Down
8 changes: 4 additions & 4 deletions common/src/KokkosFFT_layouts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ auto get_extents(const InViewType& in, const OutViewType& out,
_fft_extents.push_back(fft_extent);
}

if (std::is_floating_point<in_value_type>::value) {
if (is_real_v<in_value_type>) {
// Then R2C
if (is_complex<out_value_type>::value) {
if (is_complex_v<out_value_type>) {
assert(_out_extents.at(inner_most_axis) ==
_in_extents.at(inner_most_axis) / 2 + 1);
} else {
Expand All @@ -75,9 +75,9 @@ auto get_extents(const InViewType& in, const OutViewType& out,
}
}

if (std::is_floating_point<out_value_type>::value) {
if (is_real_v<out_value_type>) {
// Then C2R
if (is_complex<in_value_type>::value) {
if (is_complex_v<in_value_type>) {
assert(_in_extents.at(inner_most_axis) ==
_out_extents.at(inner_most_axis) / 2 + 1);
} else {
Expand Down
3 changes: 1 addition & 2 deletions common/src/KokkosFFT_padding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ auto get_modified_shape(const InViewType in, const OutViewType /* out */,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

bool is_C2R = is_complex<in_value_type>::value &&
std::is_floating_point_v<out_value_type>;
bool is_C2R = is_complex_v<in_value_type> && is_real_v<out_value_type>;

if (is_C2R) {
int reshaped_axis = positive_axes.back();
Expand Down
4 changes: 2 additions & 2 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,14 @@ class Plan {
ExecutionSpace, typename OutViewType::memory_space>::accessible,
"Plan::Plan: execution_space cannot access data in OutViewType");

if (std::is_floating_point<in_value_type>::value &&
if (KokkosFFT::Impl::is_real_v<in_value_type> &&
m_direction != KokkosFFT::Direction::forward) {
throw std::runtime_error(
"Plan::Plan: real to complex transform is constrcuted with backward "
"direction.");
}

if (std::is_floating_point<out_value_type>::value &&
if (KokkosFFT::Impl::is_real_v<out_value_type> &&
m_direction != KokkosFFT::Direction::backward) {
throw std::runtime_error(
"Plan::Plan: complex to real transform is constrcuted with forward "
Expand Down
32 changes: 16 additions & 16 deletions fft/src/KokkosFFT_Transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,9 @@ void rfft(const ExecutionSpace& exec_space, const InViewType& in,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(std::is_floating_point<in_value_type>::value,
static_assert(KokkosFFT::Impl::is_real_v<in_value_type>,
"rfft: InViewType must be real");
static_assert(KokkosFFT::Impl::is_complex<out_value_type>::value,
static_assert(KokkosFFT::Impl::is_complex_v<out_value_type>,
"rfft: OutViewType must be complex");

fft(exec_space, in, out, norm, axis, n);
Expand Down Expand Up @@ -341,9 +341,9 @@ void irfft(const ExecutionSpace& exec_space, const InViewType& in,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(KokkosFFT::Impl::is_complex<in_value_type>::value,
static_assert(KokkosFFT::Impl::is_complex_v<in_value_type>,
"irfft: InViewType must be complex");
static_assert(std::is_floating_point<out_value_type>::value,
static_assert(KokkosFFT::Impl::is_real_v<out_value_type>,
"irfft: OutViewType must be real");
ifft(exec_space, in, out, norm, axis, n);
}
Expand Down Expand Up @@ -391,9 +391,9 @@ void hfft(const ExecutionSpace& exec_space, const InViewType& in,
// type
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
static_assert(KokkosFFT::Impl::is_complex<in_value_type>::value,
static_assert(KokkosFFT::Impl::is_complex_v<in_value_type>,
"hfft: InViewType must be complex");
static_assert(std::is_floating_point<out_value_type>::value,
static_assert(KokkosFFT::Impl::is_real_v<out_value_type>,
"hfft: OutViewType must be real");
auto new_norm = KokkosFFT::Impl::swap_direction(norm);
// using ComplexViewType = typename
Expand Down Expand Up @@ -444,9 +444,9 @@ void ihfft(const ExecutionSpace& exec_space, const InViewType& in,

using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
static_assert(std::is_floating_point<in_value_type>::value,
static_assert(KokkosFFT::Impl::is_real_v<in_value_type>,
"ihfft: InViewType must be real");
static_assert(KokkosFFT::Impl::is_complex<out_value_type>::value,
static_assert(KokkosFFT::Impl::is_complex_v<out_value_type>,
"ihfft: OutViewType must be complex");

auto new_norm = KokkosFFT::Impl::swap_direction(norm);
Expand Down Expand Up @@ -585,9 +585,9 @@ void rfft2(const ExecutionSpace& exec_space, const InViewType& in,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(std::is_floating_point<in_value_type>::value,
static_assert(KokkosFFT::Impl::is_real_v<in_value_type>,
"rfft2: InViewType must be real");
static_assert(KokkosFFT::Impl::is_complex<out_value_type>::value,
static_assert(KokkosFFT::Impl::is_complex_v<out_value_type>,
"rfft2: OutViewType must be complex");

fft2(exec_space, in, out, norm, axes, s);
Expand Down Expand Up @@ -635,9 +635,9 @@ void irfft2(const ExecutionSpace& exec_space, const InViewType& in,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(KokkosFFT::Impl::is_complex<in_value_type>::value,
static_assert(KokkosFFT::Impl::is_complex_v<in_value_type>,
"irfft2: InViewType must be complex");
static_assert(std::is_floating_point<out_value_type>::value,
static_assert(KokkosFFT::Impl::is_real_v<out_value_type>,
"irfft2: OutViewType must be real");

ifft2(exec_space, in, out, norm, axes, s);
Expand Down Expand Up @@ -775,9 +775,9 @@ void rfftn(const ExecutionSpace& exec_space, const InViewType& in,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(std::is_floating_point<in_value_type>::value,
static_assert(KokkosFFT::Impl::is_real_v<in_value_type>,
"rfftn: InViewType must be real");
static_assert(KokkosFFT::Impl::is_complex<out_value_type>::value,
static_assert(KokkosFFT::Impl::is_complex_v<out_value_type>,
"rfftn: OutViewType must be complex");

fftn(exec_space, in, out, axes, norm, s);
Expand Down Expand Up @@ -826,9 +826,9 @@ void irfftn(const ExecutionSpace& exec_space, const InViewType& in,
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;

static_assert(KokkosFFT::Impl::is_complex<in_value_type>::value,
static_assert(KokkosFFT::Impl::is_complex_v<in_value_type>,
"irfftn: InViewType must be complex");
static_assert(std::is_floating_point<out_value_type>::value,
static_assert(KokkosFFT::Impl::is_real_v<out_value_type>,
"irfftn: OutViewType must be real");

ifftn(exec_space, in, out, axes, norm, s);
Expand Down

0 comments on commit aff3416

Please sign in to comment.