diff --git a/CHANGELOG.md b/CHANGELOG.md index 4b4a1c32..dc005d28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ Full documentation for hipSOLVER is available at the [hipSOLVER Documentation](h - Added functions - auxiliary - hipsolverSetDeterministicMode, hipsolverGetDeterministicMode +- Added compatibility-only functions + - geqrf + - hipsolverDnXgeqrf_bufferSize + - hipsolverDnXgeqrf ### Optimized ### Changed diff --git a/clients/gtest/geqrf_gtest.cpp b/clients/gtest/geqrf_gtest.cpp index d738d049..0de1e4be 100644 --- a/clients/gtest/geqrf_gtest.cpp +++ b/clients/gtest/geqrf_gtest.cpp @@ -83,7 +83,7 @@ Arguments geqrf_setup_arguments(geqrf_tuple tup) return arg; } -template +template class GEQRF_BASE : public ::TestWithParam { protected: @@ -98,18 +98,26 @@ class GEQRF_BASE : public ::TestWithParam Arguments arg = geqrf_setup_arguments(GetParam()); if(arg.peek("m") == -1 && arg.peek("n") == -1) - testing_geqrf_bad_arg(); + testing_geqrf_bad_arg(); arg.batch_count = 1; - testing_geqrf(arg); + testing_geqrf(arg); } }; -class GEQRF : public GEQRF_BASE +class GEQRF : public GEQRF_BASE { }; -class GEQRF_FORTRAN : public GEQRF_BASE +class GEQRF_FORTRAN : public GEQRF_BASE +{ +}; + +class GEQRF_COMPAT : public GEQRF_BASE +{ +}; + +class GEQRF_COMPAT_64 : public GEQRF_BASE { }; @@ -155,6 +163,46 @@ TEST_P(GEQRF_FORTRAN, __double_complex) run_tests(); } +TEST_P(GEQRF_COMPAT, __float) +{ + run_tests(); +} + +TEST_P(GEQRF_COMPAT, __double) +{ + run_tests(); +} + +TEST_P(GEQRF_COMPAT, __float_complex) +{ + run_tests(); +} + +TEST_P(GEQRF_COMPAT, __double_complex) +{ + run_tests(); +} + +TEST_P(GEQRF_COMPAT_64, __float) +{ + run_tests(); +} + +TEST_P(GEQRF_COMPAT_64, __double) +{ + run_tests(); +} + +TEST_P(GEQRF_COMPAT_64, __float_complex) +{ + run_tests(); +} + +TEST_P(GEQRF_COMPAT_64, __double_complex) +{ + run_tests(); +} + // INSTANTIATE_TEST_SUITE_P(daily_lapack, // GEQRF, // Combine(ValuesIn(large_matrix_size_range), ValuesIn(large_n_size_range))); @@ -170,3 +218,19 @@ INSTANTIATE_TEST_SUITE_P(checkin_lapack, INSTANTIATE_TEST_SUITE_P(checkin_lapack, GEQRF_FORTRAN, Combine(ValuesIn(matrix_size_range), ValuesIn(n_size_range))); + +// INSTANTIATE_TEST_SUITE_P(daily_lapack, +// GEQRF_COMPAT, +// Combine(ValuesIn(large_matrix_size_range), ValuesIn(large_n_size_range))); + +INSTANTIATE_TEST_SUITE_P(checkin_lapack, + GEQRF_COMPAT, + Combine(ValuesIn(matrix_size_range), ValuesIn(n_size_range))); + +// INSTANTIATE_TEST_SUITE_P(daily_lapack, +// GEQRF_COMPAT_64, +// Combine(ValuesIn(large_matrix_size_range), ValuesIn(large_n_size_range))); + +INSTANTIATE_TEST_SUITE_P(checkin_lapack, + GEQRF_COMPAT_64, + Combine(ValuesIn(matrix_size_range), ValuesIn(n_size_range))); diff --git a/clients/include/hipsolver.hpp b/clients/include/hipsolver.hpp index 3bfffb17..934cee5f 100644 --- a/clients/include/hipsolver.hpp +++ b/clients/include/hipsolver.hpp @@ -2031,130 +2031,320 @@ inline hipsolverStatus_t hipsolver_gels(testAPI_t API, /******************** GEQRF ********************/ // normal and strided_batched -inline hipsolverStatus_t hipsolver_geqrf_bufferSize( - testAPI_t API, hipsolverHandle_t handle, int m, int n, float* A, int lda, int* lwork) +inline hipsolverStatus_t hipsolver_geqrf_bufferSize(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int m, + int n, + float* A, + int lda, + float* tau, + int* lworkOnDevice, + int* lworkOnHost) { + *lworkOnHost = 0; switch(API) { case API_NORMAL: - return hipsolverSgeqrf_bufferSize(handle, m, n, A, lda, lwork); + return hipsolverSgeqrf_bufferSize(handle, m, n, A, lda, lworkOnDevice); case API_FORTRAN: - return hipsolverSgeqrf_bufferSizeFortran(handle, m, n, A, lda, lwork); + return hipsolverSgeqrf_bufferSizeFortran(handle, m, n, A, lda, lworkOnDevice); + case API_COMPAT: + return hipsolverDnSgeqrf_bufferSize(handle, m, n, A, lda, lworkOnDevice); default: + *lworkOnDevice = 0; return HIPSOLVER_STATUS_NOT_SUPPORTED; } } -inline hipsolverStatus_t hipsolver_geqrf_bufferSize( - testAPI_t API, hipsolverHandle_t handle, int m, int n, double* A, int lda, int* lwork) +inline hipsolverStatus_t hipsolver_geqrf_bufferSize(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int m, + int n, + double* A, + int lda, + double* tau, + int* lworkOnDevice, + int* lworkOnHost) { + *lworkOnHost = 0; switch(API) { case API_NORMAL: - return hipsolverDgeqrf_bufferSize(handle, m, n, A, lda, lwork); + return hipsolverDgeqrf_bufferSize(handle, m, n, A, lda, lworkOnDevice); case API_FORTRAN: - return hipsolverDgeqrf_bufferSizeFortran(handle, m, n, A, lda, lwork); + return hipsolverDgeqrf_bufferSizeFortran(handle, m, n, A, lda, lworkOnDevice); + case API_COMPAT: + return hipsolverDnDgeqrf_bufferSize(handle, m, n, A, lda, lworkOnDevice); default: + *lworkOnDevice = 0; return HIPSOLVER_STATUS_NOT_SUPPORTED; } } -inline hipsolverStatus_t hipsolver_geqrf_bufferSize( - testAPI_t API, hipsolverHandle_t handle, int m, int n, hipsolverComplex* A, int lda, int* lwork) +inline hipsolverStatus_t hipsolver_geqrf_bufferSize(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int m, + int n, + hipsolverComplex* A, + int lda, + hipsolverComplex* tau, + int* lworkOnDevice, + int* lworkOnHost) { + *lworkOnHost = 0; switch(API) { case API_NORMAL: - return hipsolverCgeqrf_bufferSize(handle, m, n, (hipFloatComplex*)A, lda, lwork); + return hipsolverCgeqrf_bufferSize(handle, m, n, (hipFloatComplex*)A, lda, lworkOnDevice); case API_FORTRAN: - return hipsolverCgeqrf_bufferSizeFortran(handle, m, n, (hipFloatComplex*)A, lda, lwork); + return hipsolverCgeqrf_bufferSizeFortran( + handle, m, n, (hipFloatComplex*)A, lda, lworkOnDevice); + case API_COMPAT: + return hipsolverDnCgeqrf_bufferSize(handle, m, n, (hipFloatComplex*)A, lda, lworkOnDevice); default: + *lworkOnDevice = 0; return HIPSOLVER_STATUS_NOT_SUPPORTED; } } inline hipsolverStatus_t hipsolver_geqrf_bufferSize(testAPI_t API, hipsolverHandle_t handle, + hipsolverDnParams_t params, int m, int n, hipsolverDoubleComplex* A, int lda, - int* lwork) + hipsolverDoubleComplex* tau, + int* lworkOnDevice, + int* lworkOnHost) { + *lworkOnHost = 0; switch(API) { case API_NORMAL: - return hipsolverZgeqrf_bufferSize(handle, m, n, (hipDoubleComplex*)A, lda, lwork); + return hipsolverZgeqrf_bufferSize(handle, m, n, (hipDoubleComplex*)A, lda, lworkOnDevice); case API_FORTRAN: - return hipsolverZgeqrf_bufferSizeFortran(handle, m, n, (hipDoubleComplex*)A, lda, lwork); + return hipsolverZgeqrf_bufferSizeFortran( + handle, m, n, (hipDoubleComplex*)A, lda, lworkOnDevice); + case API_COMPAT: + return hipsolverDnZgeqrf_bufferSize(handle, m, n, (hipDoubleComplex*)A, lda, lworkOnDevice); default: + *lworkOnDevice = 0; return HIPSOLVER_STATUS_NOT_SUPPORTED; } } -inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, - hipsolverHandle_t handle, - int m, - int n, - float* A, - int lda, - int stA, - float* tau, - int stT, - float* work, - int lwork, - int* info, - int bc) +inline hipsolverStatus_t hipsolver_geqrf_bufferSize(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + float* A, + int64_t lda, + float* tau, + size_t* lworkOnDevice, + size_t* lworkOnHost) +{ + switch(API) + { + case API_COMPAT: + return hipsolverDnXgeqrf_bufferSize(handle, + params, + m, + n, + HIP_R_32F, + A, + lda, + HIP_R_32F, + tau, + HIP_R_32F, + lworkOnDevice, + lworkOnHost); + default: + *lworkOnDevice = 0; + *lworkOnHost = 0; + return HIPSOLVER_STATUS_NOT_SUPPORTED; + } +} + +inline hipsolverStatus_t hipsolver_geqrf_bufferSize(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + double* A, + int64_t lda, + double* tau, + size_t* lworkOnDevice, + size_t* lworkOnHost) +{ + switch(API) + { + case API_COMPAT: + return hipsolverDnXgeqrf_bufferSize(handle, + params, + m, + n, + HIP_R_64F, + A, + lda, + HIP_R_64F, + tau, + HIP_R_64F, + lworkOnDevice, + lworkOnHost); + default: + *lworkOnDevice = 0; + *lworkOnHost = 0; + return HIPSOLVER_STATUS_NOT_SUPPORTED; + } +} + +inline hipsolverStatus_t hipsolver_geqrf_bufferSize(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + hipsolverComplex* A, + int64_t lda, + hipsolverComplex* tau, + size_t* lworkOnDevice, + size_t* lworkOnHost) +{ + switch(API) + { + case API_COMPAT: + return hipsolverDnXgeqrf_bufferSize(handle, + params, + m, + n, + HIP_C_32F, + A, + lda, + HIP_C_32F, + tau, + HIP_C_32F, + lworkOnDevice, + lworkOnHost); + default: + *lworkOnDevice = 0; + *lworkOnHost = 0; + return HIPSOLVER_STATUS_NOT_SUPPORTED; + } +} + +inline hipsolverStatus_t hipsolver_geqrf_bufferSize(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + hipsolverDoubleComplex* A, + int64_t lda, + hipsolverDoubleComplex* tau, + size_t* lworkOnDevice, + size_t* lworkOnHost) +{ + switch(API) + { + case API_COMPAT: + return hipsolverDnXgeqrf_bufferSize(handle, + params, + m, + n, + HIP_C_64F, + A, + lda, + HIP_C_64F, + tau, + HIP_C_64F, + lworkOnDevice, + lworkOnHost); + default: + *lworkOnDevice = 0; + *lworkOnHost = 0; + return HIPSOLVER_STATUS_NOT_SUPPORTED; + } +} + +inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int m, + int n, + float* A, + int lda, + int stA, + float* tau, + int stT, + float* workOnDevice, + int lworkOnDevice, + float* workOnHost, + int lworkOnHost, + int* info, + int bc) { switch(API) { case API_NORMAL: - return hipsolverSgeqrf(handle, m, n, A, lda, tau, work, lwork, info); + return hipsolverSgeqrf(handle, m, n, A, lda, tau, workOnDevice, lworkOnDevice, info); case API_FORTRAN: - return hipsolverSgeqrfFortran(handle, m, n, A, lda, tau, work, lwork, info); + return hipsolverSgeqrfFortran(handle, m, n, A, lda, tau, workOnDevice, lworkOnDevice, info); + case API_COMPAT: + return hipsolverDnSgeqrf(handle, m, n, A, lda, tau, workOnDevice, lworkOnDevice, info); default: return HIPSOLVER_STATUS_NOT_SUPPORTED; } } -inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, - hipsolverHandle_t handle, - int m, - int n, - double* A, - int lda, - int stA, - double* tau, - int stT, - double* work, - int lwork, - int* info, - int bc) +inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int m, + int n, + double* A, + int lda, + int stA, + double* tau, + int stT, + double* workOnDevice, + int lworkOnDevice, + double* workOnHost, + int lworkOnHost, + int* info, + int bc) { switch(API) { case API_NORMAL: - return hipsolverDgeqrf(handle, m, n, A, lda, tau, work, lwork, info); + return hipsolverDgeqrf(handle, m, n, A, lda, tau, workOnDevice, lworkOnDevice, info); case API_FORTRAN: - return hipsolverDgeqrfFortran(handle, m, n, A, lda, tau, work, lwork, info); + return hipsolverDgeqrfFortran(handle, m, n, A, lda, tau, workOnDevice, lworkOnDevice, info); + case API_COMPAT: + return hipsolverDnDgeqrf(handle, m, n, A, lda, tau, workOnDevice, lworkOnDevice, info); default: return HIPSOLVER_STATUS_NOT_SUPPORTED; } } -inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, - hipsolverHandle_t handle, - int m, - int n, - hipsolverComplex* A, - int lda, - int stA, - hipsolverComplex* tau, - int stT, - hipsolverComplex* work, - int lwork, - int* info, - int bc) +inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int m, + int n, + hipsolverComplex* A, + int lda, + int stA, + hipsolverComplex* tau, + int stT, + hipsolverComplex* workOnDevice, + int lworkOnDevice, + hipsolverComplex* workOnHost, + int lworkOnHost, + int* info, + int bc) { switch(API) { @@ -2165,8 +2355,8 @@ inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, (hipFloatComplex*)A, lda, (hipFloatComplex*)tau, - (hipFloatComplex*)work, - lwork, + (hipFloatComplex*)workOnDevice, + lworkOnDevice, info); case API_FORTRAN: return hipsolverCgeqrfFortran(handle, @@ -2175,9 +2365,19 @@ inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, (hipFloatComplex*)A, lda, (hipFloatComplex*)tau, - (hipFloatComplex*)work, - lwork, + (hipFloatComplex*)workOnDevice, + lworkOnDevice, info); + case API_COMPAT: + return hipsolverDnCgeqrf(handle, + m, + n, + (hipFloatComplex*)A, + lda, + (hipFloatComplex*)tau, + (hipFloatComplex*)workOnDevice, + lworkOnDevice, + info); default: return HIPSOLVER_STATUS_NOT_SUPPORTED; } @@ -2185,6 +2385,7 @@ inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, hipsolverHandle_t handle, + hipsolverDnParams_t params, int m, int n, hipsolverDoubleComplex* A, @@ -2192,8 +2393,10 @@ inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, int stA, hipsolverDoubleComplex* tau, int stT, - hipsolverDoubleComplex* work, - int lwork, + hipsolverDoubleComplex* workOnDevice, + int lworkOnDevice, + hipsolverDoubleComplex* workOnHost, + int lworkOnHost, int* info, int bc) { @@ -2206,8 +2409,8 @@ inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, (hipDoubleComplex*)A, lda, (hipDoubleComplex*)tau, - (hipDoubleComplex*)work, - lwork, + (hipDoubleComplex*)workOnDevice, + lworkOnDevice, info); case API_FORTRAN: return hipsolverZgeqrfFortran(handle, @@ -2216,9 +2419,179 @@ inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, (hipDoubleComplex*)A, lda, (hipDoubleComplex*)tau, - (hipDoubleComplex*)work, - lwork, + (hipDoubleComplex*)workOnDevice, + lworkOnDevice, info); + case API_COMPAT: + return hipsolverDnZgeqrf(handle, + m, + n, + (hipDoubleComplex*)A, + lda, + (hipDoubleComplex*)tau, + (hipDoubleComplex*)workOnDevice, + lworkOnDevice, + info); + default: + return HIPSOLVER_STATUS_NOT_SUPPORTED; + } +} + +inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + float* A, + int64_t lda, + int64_t stA, + float* tau, + int64_t stT, + float* workOnDevice, + size_t lworkOnDevice, + float* workOnHost, + size_t lworkOnHost, + int* info, + int bc) +{ + switch(API) + { + case API_COMPAT: + return hipsolverDnXgeqrf(handle, + params, + m, + n, + HIP_R_32F, + A, + lda, + HIP_R_32F, + tau, + HIP_R_32F, + workOnDevice, + lworkOnDevice, + workOnHost, + lworkOnHost, + info); + default: + return HIPSOLVER_STATUS_NOT_SUPPORTED; + } +} + +inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + double* A, + int64_t lda, + int64_t stA, + double* tau, + int64_t stT, + double* workOnDevice, + size_t lworkOnDevice, + double* workOnHost, + size_t lworkOnHost, + int* info, + int bc) +{ + switch(API) + { + case API_COMPAT: + return hipsolverDnXgeqrf(handle, + params, + m, + n, + HIP_R_64F, + A, + lda, + HIP_R_64F, + tau, + HIP_R_64F, + workOnDevice, + lworkOnDevice, + workOnHost, + lworkOnHost, + info); + default: + return HIPSOLVER_STATUS_NOT_SUPPORTED; + } +} + +inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + hipsolverComplex* A, + int64_t lda, + int64_t stA, + hipsolverComplex* tau, + int64_t stT, + hipsolverComplex* workOnDevice, + size_t lworkOnDevice, + hipsolverComplex* workOnHost, + size_t lworkOnHost, + int* info, + int bc) +{ + switch(API) + { + case API_COMPAT: + return hipsolverDnXgeqrf(handle, + params, + m, + n, + HIP_C_32F, + A, + lda, + HIP_C_32F, + tau, + HIP_C_32F, + workOnDevice, + lworkOnDevice, + workOnHost, + lworkOnHost, + info); + default: + return HIPSOLVER_STATUS_NOT_SUPPORTED; + } +} + +inline hipsolverStatus_t hipsolver_geqrf(testAPI_t API, + hipsolverHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + hipsolverDoubleComplex* A, + int64_t lda, + int64_t stA, + hipsolverDoubleComplex* tau, + int64_t stT, + hipsolverDoubleComplex* workOnDevice, + size_t lworkOnDevice, + hipsolverDoubleComplex* workOnHost, + size_t lworkOnHost, + int* info, + int bc) +{ + switch(API) + { + case API_COMPAT: + return hipsolverDnXgeqrf(handle, + params, + m, + n, + HIP_C_64F, + A, + lda, + HIP_C_64F, + tau, + HIP_C_64F, + workOnDevice, + lworkOnDevice, + workOnHost, + lworkOnHost, + info); default: return HIPSOLVER_STATUS_NOT_SUPPORTED; } diff --git a/clients/include/hipsolver_dispatcher.hpp b/clients/include/hipsolver_dispatcher.hpp index 5bacbf29..db50f6f6 100644 --- a/clients/include/hipsolver_dispatcher.hpp +++ b/clients/include/hipsolver_dispatcher.hpp @@ -78,14 +78,15 @@ class hipsolver_dispatcher static const func_map map = { {"gebrd", testing_gebrd}, {"gels", testing_gels}, - {"geqrf", testing_geqrf}, + {"geqrf", testing_geqrf}, + {"geqrf_64", testing_geqrf}, {"gesv", testing_gesv}, {"gesvd", testing_gesvd}, {"gesvda_strided_batched", testing_gesvda}, {"gesvdj", testing_gesvdj}, {"gesvdj_batched", testing_gesvdj}, {"getrf", testing_getrf}, - {"getrf_64", testing_getrf}, + {"getrf_64", testing_getrf}, {"getrs", testing_getrs}, {"getrs_64", testing_getrs}, {"potrf", testing_potrf}, diff --git a/clients/include/testing_geqrf.hpp b/clients/include/testing_geqrf.hpp index e21a5875..761511e3 100644 --- a/clients/include/testing_geqrf.hpp +++ b/clients/include/testing_geqrf.hpp @@ -25,53 +25,130 @@ #include "clientcommon.hpp" -template -void geqrf_checkBadArgs(const hipsolverHandle_t handle, - const int m, - const int n, - T dA, - const int lda, - const int stA, - U dIpiv, - const int stP, - U dWork, - const int lwork, - V dInfo, - const int bc) +template +void geqrf_checkBadArgs(const hipsolverHandle_t handle, + const hipsolverDnParams_t params, + const I m, + const I n, + Td dA, + const I lda, + const I stA, + Td dIpiv, + const I stP, + Td dWork, + const SIZE dlwork, + Th hWork, + const SIZE hlwork, + INTd dInfo, + const int bc) { // handle - EXPECT_ROCBLAS_STATUS( - hipsolver_geqrf(API, nullptr, m, n, dA, lda, stA, dIpiv, stP, dWork, lwork, dInfo, bc), - HIPSOLVER_STATUS_NOT_INITIALIZED); + EXPECT_ROCBLAS_STATUS(hipsolver_geqrf(API, + nullptr, + params, + m, + n, + dA, + lda, + stA, + dIpiv, + stP, + dWork, + dlwork, + hWork, + hlwork, + dInfo, + bc), + HIPSOLVER_STATUS_NOT_INITIALIZED); // values // N/A #if defined(__HIP_PLATFORM_HCC__) || defined(__HIP_PLATFORM_AMD__) // pointers - EXPECT_ROCBLAS_STATUS( - hipsolver_geqrf( - API, handle, m, n, (T) nullptr, lda, stA, dIpiv, stP, dWork, lwork, dInfo, bc), - HIPSOLVER_STATUS_INVALID_VALUE); - EXPECT_ROCBLAS_STATUS( - hipsolver_geqrf(API, handle, m, n, dA, lda, stA, (U) nullptr, stP, dWork, lwork, dInfo, bc), - HIPSOLVER_STATUS_INVALID_VALUE); - EXPECT_ROCBLAS_STATUS( - hipsolver_geqrf(API, handle, m, n, dA, lda, stA, dIpiv, stP, dWork, lwork, (V) nullptr, bc), - HIPSOLVER_STATUS_INVALID_VALUE); + if constexpr(!std::is_same::value) + EXPECT_ROCBLAS_STATUS(hipsolver_geqrf(API, + handle, + (hipsolverDnParams_t) nullptr, + m, + n, + dA, + lda, + stA, + dIpiv, + stP, + dWork, + dlwork, + hWork, + hlwork, + dInfo, + bc), + HIPSOLVER_STATUS_INVALID_VALUE); + EXPECT_ROCBLAS_STATUS(hipsolver_geqrf(API, + handle, + params, + m, + n, + (Td) nullptr, + lda, + stA, + dIpiv, + stP, + dWork, + dlwork, + hWork, + hlwork, + dInfo, + bc), + HIPSOLVER_STATUS_INVALID_VALUE); + EXPECT_ROCBLAS_STATUS(hipsolver_geqrf(API, + handle, + params, + m, + n, + dA, + lda, + stA, + (Td) nullptr, + stP, + dWork, + dlwork, + hWork, + hlwork, + dInfo, + bc), + HIPSOLVER_STATUS_INVALID_VALUE); + EXPECT_ROCBLAS_STATUS(hipsolver_geqrf(API, + handle, + params, + m, + n, + dA, + lda, + stA, + dIpiv, + stP, + dWork, + dlwork, + hWork, + hlwork, + (INTd) nullptr, + bc), + HIPSOLVER_STATUS_INVALID_VALUE); #endif } -template +template void testing_geqrf_bad_arg() { // safe arguments hipsolver_local_handle handle; - int m = 1; - int n = 1; - int lda = 1; - int stA = 1; - int stP = 1; + hipsolver_local_params params; + I m = 1; + I n = 1; + I lda = 1; + I stA = 1; + I stP = 1; int bc = 1; if(BATCHED) @@ -84,25 +161,30 @@ void testing_geqrf_bad_arg() // CHECK_HIP_ERROR(dIpiv.memcheck()); // CHECK_HIP_ERROR(dInfo.memcheck()); - // int size_W; - // hipsolver_geqrf_bufferSize(API, handle, m, n, dA.data(), lda, &size_W); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); - // if(size_W) + // SIZE size_dW, size_hW; + // hipsolver_geqrf_bufferSize( + // API, handle, params, m, n, dA.data(), lda, dIpiv.data(), &size_dW, &size_hW); + // host_strided_batch_vector hWork(size_hW, 1, size_hW, 1); + // device_strided_batch_vector dWork(size_dW, 1, size_dW, 1); + // if(size_dW) // CHECK_HIP_ERROR(dWork.memcheck()); // // check bad arguments // geqrf_checkBadArgs(handle, - // m, - // n, - // dA.data(), - // lda, - // stA, - // dIpiv.data(), - // stP, - // dWork.data(), - // size_W, - // dInfo.data(), - // bc); + // params, + // m, + // n, + // dA.data(), + // lda, + // stA, + // dIpiv.data(), + // stP, + // dWork.data(), + // size_dW, + // hWork.data(), + // size_hW, + // dInfo.data(), + // bc); } else { @@ -114,14 +196,17 @@ void testing_geqrf_bad_arg() CHECK_HIP_ERROR(dIpiv.memcheck()); CHECK_HIP_ERROR(dInfo.memcheck()); - int size_W; - hipsolver_geqrf_bufferSize(API, handle, m, n, dA.data(), lda, &size_W); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); - if(size_W) + SIZE size_dW, size_hW; + hipsolver_geqrf_bufferSize( + API, handle, params, m, n, dA.data(), lda, dIpiv.data(), &size_dW, &size_hW); + host_strided_batch_vector hWork(size_hW, 1, size_hW, 1); + device_strided_batch_vector dWork(size_dW, 1, size_dW, 1); + if(size_dW) CHECK_HIP_ERROR(dWork.memcheck()); // check bad arguments geqrf_checkBadArgs(handle, + params, m, n, dA.data(), @@ -130,21 +215,30 @@ void testing_geqrf_bad_arg() dIpiv.data(), stP, dWork.data(), - size_W, + size_dW, + hWork.data(), + size_hW, dInfo.data(), bc); } } -template +template void geqrf_initData(const hipsolverHandle_t handle, - const int m, - const int n, + const I m, + const I n, Td& dA, - const int lda, - const int stA, + const I lda, + const I stA, Ud& dIpiv, - const int stP, + const I stP, const int bc, Th& hA, Uh& hIpiv) @@ -156,9 +250,9 @@ void geqrf_initData(const hipsolverHandle_t handle, // scale A to avoid singularities for(int b = 0; b < bc; ++b) { - for(int i = 0; i < m; i++) + for(I i = 0; i < m; i++) { - for(int j = 0; j < n; j++) + for(I j = 0; j < n; j++) { if(i == j) hA[b][i + j * lda] += 400; @@ -178,30 +272,35 @@ void geqrf_initData(const hipsolverHandle_t handle, template -void geqrf_getError(const hipsolverHandle_t handle, - const int m, - const int n, - Td& dA, - const int lda, - const int stA, - Ud& dIpiv, - const int stP, - Ud& dWork, - const int lwork, - Vd& dInfo, - const int bc, - Th& hA, - Th& hARes, - Uh& hIpiv, - Vh& hInfo, - Vh& hInfoRes, - double* max_err) + typename INTh> +void geqrf_getError(const hipsolverHandle_t handle, + const hipsolverDnParams_t params, + const I m, + const I n, + Td& dA, + const I lda, + const I stA, + Ud& dIpiv, + const I stP, + Ud& dWork, + const SIZE dlwork, + Uh& hWork, + const SIZE hlwork, + INTd& dInfo, + const int bc, + Th& hA, + Th& hARes, + Uh& hIpiv, + INTh& hInfo, + INTh& hInfoRes, + double* max_err) { std::vector hW(n); @@ -212,6 +311,7 @@ void geqrf_getError(const hipsolverHandle_t handle, // GPU lapack CHECK_ROCBLAS_ERROR(hipsolver_geqrf(API, handle, + params, m, n, dA.data(), @@ -220,7 +320,9 @@ void geqrf_getError(const hipsolverHandle_t handle, dIpiv.data(), stP, dWork.data(), - lwork, + dlwork, + hWork.data(), + hlwork, dInfo.data(), bc)); CHECK_HIP_ERROR(hARes.transfer_from(dA)); @@ -255,31 +357,36 @@ void geqrf_getError(const hipsolverHandle_t handle, template -void geqrf_getPerfData(const hipsolverHandle_t handle, - const int m, - const int n, - Td& dA, - const int lda, - const int stA, - Ud& dIpiv, - const int stP, - Ud& dWork, - const int lwork, - Vd& dInfo, - const int bc, - Th& hA, - Uh& hIpiv, - Vh& hInfo, - double* gpu_time_used, - double* cpu_time_used, - const int hot_calls, - const bool perf) + typename INTh> +void geqrf_getPerfData(const hipsolverHandle_t handle, + const hipsolverDnParams_t params, + const I m, + const I n, + Td& dA, + const I lda, + const I stA, + Ud& dIpiv, + const I stP, + Ud& dWork, + const SIZE dlwork, + Uh& hWork, + const SIZE hlwork, + INTd& dInfo, + const int bc, + Th& hA, + Uh& hIpiv, + INTh& hInfo, + double* gpu_time_used, + double* cpu_time_used, + const int hot_calls, + const bool perf) { std::vector hW(n); @@ -303,6 +410,7 @@ void geqrf_getPerfData(const hipsolverHandle_t handle, CHECK_ROCBLAS_ERROR(hipsolver_geqrf(API, handle, + params, m, n, dA.data(), @@ -311,7 +419,9 @@ void geqrf_getPerfData(const hipsolverHandle_t handle, dIpiv.data(), stP, dWork.data(), - lwork, + dlwork, + hWork.data(), + hlwork, dInfo.data(), bc)); } @@ -328,6 +438,7 @@ void geqrf_getPerfData(const hipsolverHandle_t handle, start = get_time_us_sync(stream); hipsolver_geqrf(API, handle, + params, m, n, dA.data(), @@ -336,7 +447,9 @@ void geqrf_getPerfData(const hipsolverHandle_t handle, dIpiv.data(), stP, dWork.data(), - lwork, + dlwork, + hWork.data(), + hlwork, dInfo.data(), bc); *gpu_time_used += get_time_us_sync(stream) - start; @@ -344,21 +457,22 @@ void geqrf_getPerfData(const hipsolverHandle_t handle, *gpu_time_used /= hot_calls; } -template +template void testing_geqrf(Arguments& argus) { // get arguments hipsolver_local_handle handle; - int m = argus.get("m"); - int n = argus.get("n", m); - int lda = argus.get("lda", m); - int stA = argus.get("strideA", lda * n); - int stP = argus.get("strideP", min(m, n)); + hipsolver_local_params params; + I m = argus.get("m"); + I n = argus.get("n", m); + I lda = argus.get("lda", m); + I stA = argus.get("strideA", lda * n); + I stP = argus.get("strideP", min(m, n)); int bc = argus.batch_count; int hot_calls = argus.iters; - int stARes = (argus.unit_check || argus.norm_check) ? stA : 0; + I stARes = (argus.unit_check || argus.norm_check) ? stA : 0; // check non-supported values // N/A @@ -378,6 +492,7 @@ void testing_geqrf(Arguments& argus) { // EXPECT_ROCBLAS_STATUS(hipsolver_geqrf(API, // handle, + // params, // m, // n, // (T* const*)nullptr, @@ -386,7 +501,9 @@ void testing_geqrf(Arguments& argus) // (T*)nullptr, // stP, // (T*)nullptr, - // 0, + // (SIZE)0, + // (T*)nullptr, + // (SIZE)0, // (int*)nullptr, // bc), // HIPSOLVER_STATUS_INVALID_VALUE); @@ -395,6 +512,7 @@ void testing_geqrf(Arguments& argus) { EXPECT_ROCBLAS_STATUS(hipsolver_geqrf(API, handle, + params, m, n, (T*)nullptr, @@ -403,7 +521,9 @@ void testing_geqrf(Arguments& argus) (T*)nullptr, stP, (T*)nullptr, - 0, + (SIZE)0, + (T*)nullptr, + (SIZE)0, (int*)nullptr, bc), HIPSOLVER_STATUS_INVALID_VALUE); @@ -416,12 +536,13 @@ void testing_geqrf(Arguments& argus) } // memory size query is necessary - int size_W; - hipsolver_geqrf_bufferSize(API, handle, m, n, (T*)nullptr, lda, &size_W); + SIZE size_dW, size_hW; + hipsolver_geqrf_bufferSize( + API, handle, params, m, n, (T*)nullptr, lda, (T*)nullptr, &size_dW, &size_hW); if(argus.mem_query) { - rocsolver_bench_inform(inform_mem_query, size_W); + rocsolver_bench_inform(inform_mem_query, size_dW); return; } @@ -433,60 +554,67 @@ void testing_geqrf(Arguments& argus) // host_strided_batch_vector hIpiv(size_P, 1, stP, bc); // host_strided_batch_vector hInfo(1, 1, 1, bc); // host_strided_batch_vector hInfoRes(1, 1, 1, bc); + // host_strided_batch_vector hWork(size_hW, 1, size_hW, 1); // size_hW accounts for bc // device_batch_vector dA(size_A, 1, bc); // device_strided_batch_vector dIpiv(size_P, 1, stP, bc); // device_strided_batch_vector dInfo(1, 1, 1, bc); - // device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + // device_strided_batch_vector dWork(size_dW, 1, size_dW, 1); // size_dW accounts for bc // if(size_A) // CHECK_HIP_ERROR(dA.memcheck()); // if(size_P) // CHECK_HIP_ERROR(dIpiv.memcheck()); // CHECK_HIP_ERROR(dInfo.memcheck()); - // if(size_W) + // if(size_dW) // CHECK_HIP_ERROR(dWork.memcheck()); // // check computations // if(argus.unit_check || argus.norm_check) // geqrf_getError(handle, - // m, - // n, - // dA, - // lda, - // stA, - // dIpiv, - // stP, - // dWork, - // size_W, - // dInfo, - // bc, - // hA, - // hARes, - // hIpiv, - // hInfo, - // hInfoRes, - // &max_error); + // params, + // m, + // n, + // dA, + // lda, + // stA, + // dIpiv, + // stP, + // dWork, + // size_dW, + // hWork, + // size_hW, + // dInfo, + // bc, + // hA, + // hARes, + // hIpiv, + // hInfo, + // hInfoRes, + // &max_error); // // collect performance data // if(argus.timing) // geqrf_getPerfData(handle, - // m, - // n, - // dA, - // lda, - // stA, - // dIpiv, - // stP, - // dWork, - // size_W, - // dInfo, - // bc, - // hA, - // hIpiv, - // hInfo, - // &gpu_time_used, - // &cpu_time_used, - // hot_calls, - // argus.perf); + // params, + // m, + // n, + // dA, + // lda, + // stA, + // dIpiv, + // stP, + // dWork, + // size_dW, + // hWork, + // size_hW, + // dInfo, + // bc, + // hA, + // hIpiv, + // hInfo, + // &gpu_time_used, + // &cpu_time_used, + // hot_calls, + // argus.perf); } else @@ -497,21 +625,23 @@ void testing_geqrf(Arguments& argus) host_strided_batch_vector hIpiv(size_P, 1, stP, bc); host_strided_batch_vector hInfo(1, 1, 1, bc); host_strided_batch_vector hInfoRes(1, 1, 1, bc); + host_strided_batch_vector hWork(size_hW, 1, size_hW, 1); // size_hW accounts for bc device_strided_batch_vector dA(size_A, 1, stA, bc); device_strided_batch_vector dIpiv(size_P, 1, stP, bc); device_strided_batch_vector dInfo(1, 1, 1, bc); - device_strided_batch_vector dWork(size_W, 1, size_W, 1); // size_W accounts for bc + device_strided_batch_vector dWork(size_dW, 1, size_dW, 1); // size_dW accounts for bc if(size_A) CHECK_HIP_ERROR(dA.memcheck()); if(size_P) CHECK_HIP_ERROR(dIpiv.memcheck()); CHECK_HIP_ERROR(dInfo.memcheck()); - if(size_W) + if(size_dW) CHECK_HIP_ERROR(dWork.memcheck()); // check computations if(argus.unit_check || argus.norm_check) geqrf_getError(handle, + params, m, n, dA, @@ -520,7 +650,9 @@ void testing_geqrf(Arguments& argus) dIpiv, stP, dWork, - size_W, + size_dW, + hWork, + size_hW, dInfo, bc, hA, @@ -533,6 +665,7 @@ void testing_geqrf(Arguments& argus) // collect performance data if(argus.timing) geqrf_getPerfData(handle, + params, m, n, dA, @@ -541,7 +674,9 @@ void testing_geqrf(Arguments& argus) dIpiv, stP, dWork, - size_W, + size_dW, + hWork, + size_hW, dInfo, bc, hA, diff --git a/docs/reference/dense-api/lapack.rst b/docs/reference/dense-api/lapack.rst index 795361d9..44592058 100644 --- a/docs/reference/dense-api/lapack.rst +++ b/docs/reference/dense-api/lapack.rst @@ -133,6 +133,8 @@ Orthogonal factorizations hipsolverDngeqrf_bufferSize() --------------------------------------------------- +.. doxygenfunction:: hipsolverDnXgeqrf_bufferSize + :outline: .. doxygenfunction:: hipsolverDnZgeqrf_bufferSize :outline: .. doxygenfunction:: hipsolverDnCgeqrf_bufferSize @@ -145,6 +147,8 @@ hipsolverDngeqrf_bufferSize() hipsolverDngeqrf() --------------------------------------------------- +.. doxygenfunction:: hipsolverDnXgeqrf + :outline: .. doxygenfunction:: hipsolverDnZgeqrf :outline: .. doxygenfunction:: hipsolverDnCgeqrf diff --git a/library/include/internal/hipsolver-dense64.h b/library/include/internal/hipsolver-dense64.h index 1fb26a38..b2792447 100644 --- a/library/include/internal/hipsolver-dense64.h +++ b/library/include/internal/hipsolver-dense64.h @@ -33,6 +33,36 @@ HIPSOLVER_EXPORT hipsolverStatus_t hipsolverDnSetAdvOptions(hipsolverDnParams_t hipsolverDnFunction_t func, hipsolverAlgMode_t alg); +// geqrf +HIPSOLVER_EXPORT hipsolverStatus_t hipsolverDnXgeqrf_bufferSize(hipsolverDnHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + hipDataType dataTypeA, + const void* A, + int64_t lda, + hipDataType dataTypeTau, + const void* tau, + hipDataType computeType, + size_t* lworkOnDevice, + size_t* lworkOnHost); + +HIPSOLVER_EXPORT hipsolverStatus_t hipsolverDnXgeqrf(hipsolverDnHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + hipDataType dataTypeA, + void* A, + int64_t lda, + hipDataType dataTypeTau, + void* tau, + hipDataType computeType, + void* workOnDevice, + size_t lworkOnDevice, + void* workOnHost, + size_t lworkOnHost, + int* devInfo); + // getrf HIPSOLVER_EXPORT hipsolverStatus_t hipsolverDnXgetrf_bufferSize(hipsolverDnHandle_t handle, hipsolverDnParams_t params, diff --git a/library/src/amd_detail/hipsolver_dense64.cpp b/library/src/amd_detail/hipsolver_dense64.cpp index 9490c51c..3d6a4678 100644 --- a/library/src/amd_detail/hipsolver_dense64.cpp +++ b/library/src/amd_detail/hipsolver_dense64.cpp @@ -162,6 +162,144 @@ catch(...) return hipsolver::exception2hip_status(); } +/******************** GEQRF ********************/ +hipsolverStatus_t hipsolverDnXgeqrf_bufferSize(hipsolverDnHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + hipDataType dataTypeA, + const void* A, + int64_t lda, + hipDataType dataTypeTau, + const void* tau, + hipDataType computeType, + size_t* lworkOnDevice, + size_t* lworkOnHost) +try +{ + if(!handle) + return HIPSOLVER_STATUS_NOT_INITIALIZED; + if(!params) + return HIPSOLVER_STATUS_INVALID_VALUE; + if(!lworkOnDevice || !lworkOnHost) + return HIPSOLVER_STATUS_INVALID_VALUE; + + *lworkOnDevice = 0; + *lworkOnHost = 0; + + rocblas_start_device_memory_size_query((rocblas_handle)handle); + hipsolverStatus_t status; + if(dataTypeA == HIP_R_32F && dataTypeTau == HIP_R_32F && computeType == HIP_R_32F) + { + status = hipsolver::rocblas2hip_status( + rocsolver_sgeqrf_64((rocblas_handle)handle, m, n, nullptr, lda, nullptr)); + } + else if(dataTypeA == HIP_R_64F && dataTypeTau == HIP_R_64F && computeType == HIP_R_64F) + { + status = hipsolver::rocblas2hip_status( + rocsolver_dgeqrf_64((rocblas_handle)handle, m, n, nullptr, lda, nullptr)); + } + else if(dataTypeA == HIP_C_32F && dataTypeTau == HIP_C_32F && computeType == HIP_C_32F) + { + status = hipsolver::rocblas2hip_status( + rocsolver_cgeqrf_64((rocblas_handle)handle, m, n, nullptr, lda, nullptr)); + } + else if(dataTypeA == HIP_C_64F && dataTypeTau == HIP_C_64F && computeType == HIP_C_64F) + { + status = hipsolver::rocblas2hip_status( + rocsolver_zgeqrf_64((rocblas_handle)handle, m, n, nullptr, lda, nullptr)); + } + else + return HIPSOLVER_STATUS_INVALID_ENUM; + rocblas_stop_device_memory_size_query((rocblas_handle)handle, lworkOnDevice); + + return status; +} +catch(...) +{ + return hipsolver::exception2hip_status(); +} + +hipsolverStatus_t hipsolverDnXgeqrf(hipsolverDnHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + hipDataType dataTypeA, + void* A, + int64_t lda, + hipDataType dataTypeTau, + void* tau, + hipDataType computeType, + void* workOnDevice, + size_t lworkOnDevice, + void* workOnHost, + size_t lworkOnHost, + int* devInfo) +try +{ + if(!handle) + return HIPSOLVER_STATUS_NOT_INITIALIZED; + if(!params) + return HIPSOLVER_STATUS_INVALID_VALUE; + + if(workOnDevice && lworkOnDevice) + CHECK_ROCBLAS_ERROR( + rocblas_set_workspace((rocblas_handle)handle, workOnDevice, lworkOnDevice)); + else + { + CHECK_HIPSOLVER_ERROR(hipsolverDnXgeqrf_bufferSize((rocblas_handle)handle, + params, + m, + n, + dataTypeA, + A, + lda, + dataTypeTau, + tau, + computeType, + &lworkOnDevice, + &lworkOnHost)); + CHECK_ROCBLAS_ERROR(hipsolverManageWorkspace((rocblas_handle)handle, lworkOnDevice)); + } + + CHECK_ROCBLAS_ERROR(hipsolverZeroInfo((rocblas_handle)handle, devInfo, 1)); + + if(dataTypeA == HIP_R_32F && dataTypeTau == HIP_R_32F && computeType == HIP_R_32F) + { + return hipsolver::rocblas2hip_status( + rocsolver_sgeqrf_64((rocblas_handle)handle, m, n, (float*)A, lda, (float*)tau)); + } + else if(dataTypeA == HIP_R_64F && dataTypeTau == HIP_R_64F && computeType == HIP_R_64F) + { + return hipsolver::rocblas2hip_status( + rocsolver_dgeqrf_64((rocblas_handle)handle, m, n, (double*)A, lda, (double*)tau)); + } + else if(dataTypeA == HIP_C_32F && dataTypeTau == HIP_C_32F && computeType == HIP_C_32F) + { + return hipsolver::rocblas2hip_status(rocsolver_cgeqrf_64((rocblas_handle)handle, + m, + n, + (rocblas_float_complex*)A, + lda, + (rocblas_float_complex*)tau)); + } + else if(dataTypeA == HIP_C_64F && dataTypeTau == HIP_C_64F && computeType == HIP_C_64F) + { + return hipsolver::rocblas2hip_status(rocsolver_zgeqrf_64((rocblas_handle)handle, + m, + n, + (rocblas_double_complex*)A, + lda, + (rocblas_double_complex*)tau)); + } + else + return HIPSOLVER_STATUS_INVALID_ENUM; +} +catch(...) +{ + return hipsolver::exception2hip_status(); +} + /******************** GETRF ********************/ hipsolverStatus_t hipsolverDnXgetrf_bufferSize(hipsolverDnHandle_t handle, hipsolverDnParams_t params, diff --git a/library/src/nvidia_detail/hipsolver_dense64.cpp b/library/src/nvidia_detail/hipsolver_dense64.cpp index 2700b6a1..b0351106 100644 --- a/library/src/nvidia_detail/hipsolver_dense64.cpp +++ b/library/src/nvidia_detail/hipsolver_dense64.cpp @@ -77,6 +77,87 @@ catch(...) return hipsolver::exception2hip_status(); } +/******************** GEQRF ********************/ +hipsolverStatus_t hipsolverDnXgeqrf_bufferSize(hipsolverDnHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + hipDataType dataTypeA, + const void* A, + int64_t lda, + hipDataType dataTypeTau, + const void* tau, + hipDataType computeType, + size_t* lworkOnDevice, + size_t* lworkOnHost) +try +{ + if(!handle) + return HIPSOLVER_STATUS_NOT_INITIALIZED; + if(!params) + return HIPSOLVER_STATUS_INVALID_VALUE; + + return hipsolver::cuda2hip_status(cusolverDnXgeqrf_bufferSize((cusolverDnHandle_t)handle, + (cusolverDnParams_t)params, + m, + n, + dataTypeA, + A, + lda, + dataTypeTau, + tau, + computeType, + lworkOnDevice, + lworkOnHost)); +} +catch(...) +{ + return hipsolver::exception2hip_status(); +} + +hipsolverStatus_t hipsolverDnXgeqrf(hipsolverDnHandle_t handle, + hipsolverDnParams_t params, + int64_t m, + int64_t n, + hipDataType dataTypeA, + void* A, + int64_t lda, + hipDataType dataTypeTau, + void* tau, + hipDataType computeType, + void* workOnDevice, + size_t lworkOnDevice, + void* workOnHost, + size_t lworkOnHost, + int* devInfo) +try +{ + if(!handle) + return HIPSOLVER_STATUS_NOT_INITIALIZED; + if(!params) + return HIPSOLVER_STATUS_INVALID_VALUE; + + return hipsolver::cuda2hip_status(cusolverDnXgeqrf((cusolverDnHandle_t)handle, + (cusolverDnParams_t)params, + m, + n, + dataTypeA, + A, + lda, + dataTypeTau, + tau, + computeType, + workOnDevice, + lworkOnDevice, + workOnHost, + lworkOnHost, + devInfo)); +} +catch(...) +{ + return hipsolver::exception2hip_status(); +} + /******************** GETRF ********************/ hipsolverStatus_t hipsolverDnXgetrf_bufferSize(hipsolverDnHandle_t handle, hipsolverDnParams_t params,