Skip to content

Commit

Permalink
cuda: try simplifying cufinufft signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
janden committed Dec 18, 2023
1 parent 78d03ab commit 4491762
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
16 changes: 13 additions & 3 deletions include/cufinufft.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
// Defines the C++/C user interface to CUFINUFFT library.
#include <cufft.h>

#include <stdint.h>

#ifdef __cplusplus
#define _USE_MATH_DEFINES
#include <complex>
#define CUFINUFFT_CPX_FLT std::complex<float>
#define CUFINUFFT_CPX_DBL std::complex<double>
#else
#include <complex.h>
#define CUFINUFFT_CPX_FLT float complex
#define CUFINUFFT_CPX_DBL double complex
#endif

#include <cufinufft_opts.h>
#include <finufft_errors.h>

Expand All @@ -24,8 +34,8 @@ int cufinufft_setpts(cufinufft_plan d_plan, int M, double *h_kx, double *h_ky, d
int cufinufftf_setpts(cufinufftf_plan d_plan, int M, float *h_kx, float *h_ky, float *h_kz, int N, float *h_s,
float *h_t, float *h_u);

int cufinufft_execute(cufinufft_plan d_plan, cuDoubleComplex *h_c, cuDoubleComplex *h_fk);
int cufinufftf_execute(cufinufftf_plan d_plan, cuFloatComplex *h_c, cuFloatComplex *h_fk);
int cufinufft_execute(cufinufft_plan d_plan, CUFINUFFT_CPX_DBL *h_c, CUFINUFFT_CPX_DBL *h_fk);
int cufinufftf_execute(cufinufftf_plan d_plan, CUFINUFFT_CPX_FLT*h_c, CUFINUFFT_CPX_FLT*h_fk);

int cufinufft_destroy(cufinufft_plan d_plan);
int cufinufftf_destroy(cufinufftf_plan d_plan);
Expand Down
10 changes: 6 additions & 4 deletions src/cuda/cufinufft.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <iostream>
#include <limits>

#include <cuComplex.h>

#include <cufinufft.h>
#include <cufinufft/impl.h>

Expand Down Expand Up @@ -63,12 +65,12 @@ int cufinufft_setpts(cufinufft_plan d_plan, int M, double *d_kx, double *d_ky, d
return cufinufft_setpts_impl(M, d_kx, d_ky, d_kz, N, d_s, d_t, d_u, (cufinufft_plan_t<double> *)d_plan);
}

int cufinufftf_execute(cufinufftf_plan d_plan, cuFloatComplex *d_c, cuFloatComplex *d_fk) {
return cufinufft_execute_impl<float>(d_c, d_fk, (cufinufft_plan_t<float> *)d_plan);
int cufinufftf_execute(cufinufftf_plan d_plan, CUFINUFFT_CPX_FLT *d_c, CUFINUFFT_CPX_FLT *d_fk) {
return cufinufft_execute_impl<float>((cuFloatComplex *) d_c, (cuFloatComplex *) d_fk, (cufinufft_plan_t<float> *)d_plan);
}

int cufinufft_execute(cufinufft_plan d_plan, cuDoubleComplex *d_c, cuda_complex<double> *d_fk) {
return cufinufft_execute_impl<double>(d_c, d_fk, (cufinufft_plan_t<double> *)d_plan);
int cufinufft_execute(cufinufft_plan d_plan, CUFINUFFT_CPX_DBL *d_c, CUFINUFFT_CPX_DBL *d_fk) {
return cufinufft_execute_impl<double>((cuDoubleComplex *) d_c, (cuDoubleComplex *) d_fk, (cufinufft_plan_t<double> *)d_plan);
}

int cufinufftf_destroy(cufinufftf_plan d_plan) {
Expand Down

0 comments on commit 4491762

Please sign in to comment.