Skip to content

Commit

Permalink
Added 64-bit geqrf (#310)
Browse files Browse the repository at this point in the history
* Added 64-bit geqrf APIs
* Testing for 64-bit geqrf APIs
* Updated documentation
* Fixed failing cuda build
  • Loading branch information
tfalders authored Aug 14, 2024
1 parent 1080e91 commit 7449ed7
Show file tree
Hide file tree
Showing 9 changed files with 1,062 additions and 232 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ Full documentation for hipSOLVER is available at the [hipSOLVER Documentation](h
- Added functions
- auxiliary
- hipsolverSetDeterministicMode, hipsolverGetDeterministicMode
- Added compatibility-only functions
- geqrf
- hipsolverDnXgeqrf_bufferSize
- hipsolverDnXgeqrf

### Optimized
### Changed
Expand Down
74 changes: 69 additions & 5 deletions clients/gtest/geqrf_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Arguments geqrf_setup_arguments(geqrf_tuple tup)
return arg;
}

template <testAPI_t API>
template <testAPI_t API, typename I, typename SIZE>
class GEQRF_BASE : public ::TestWithParam<geqrf_tuple>
{
protected:
Expand All @@ -98,18 +98,26 @@ class GEQRF_BASE : public ::TestWithParam<geqrf_tuple>
Arguments arg = geqrf_setup_arguments(GetParam());

if(arg.peek<rocblas_int>("m") == -1 && arg.peek<rocblas_int>("n") == -1)
testing_geqrf_bad_arg<API, BATCHED, STRIDED, T>();
testing_geqrf_bad_arg<API, BATCHED, STRIDED, T, I, SIZE>();

arg.batch_count = 1;
testing_geqrf<API, BATCHED, STRIDED, T>(arg);
testing_geqrf<API, BATCHED, STRIDED, T, I, SIZE>(arg);
}
};

class GEQRF : public GEQRF_BASE<API_NORMAL>
class GEQRF : public GEQRF_BASE<API_NORMAL, int, int>
{
};

class GEQRF_FORTRAN : public GEQRF_BASE<API_FORTRAN>
class GEQRF_FORTRAN : public GEQRF_BASE<API_FORTRAN, int, int>
{
};

class GEQRF_COMPAT : public GEQRF_BASE<API_COMPAT, int64_t, size_t>
{
};

class GEQRF_COMPAT_64 : public GEQRF_BASE<API_COMPAT, int64_t, size_t>
{
};

Expand Down Expand Up @@ -155,6 +163,46 @@ TEST_P(GEQRF_FORTRAN, __double_complex)
run_tests<false, false, rocblas_double_complex>();
}

TEST_P(GEQRF_COMPAT, __float)
{
run_tests<false, false, float>();
}

TEST_P(GEQRF_COMPAT, __double)
{
run_tests<false, false, double>();
}

TEST_P(GEQRF_COMPAT, __float_complex)
{
run_tests<false, false, rocblas_float_complex>();
}

TEST_P(GEQRF_COMPAT, __double_complex)
{
run_tests<false, false, rocblas_double_complex>();
}

TEST_P(GEQRF_COMPAT_64, __float)
{
run_tests<false, false, float>();
}

TEST_P(GEQRF_COMPAT_64, __double)
{
run_tests<false, false, double>();
}

TEST_P(GEQRF_COMPAT_64, __float_complex)
{
run_tests<false, false, rocblas_float_complex>();
}

TEST_P(GEQRF_COMPAT_64, __double_complex)
{
run_tests<false, false, rocblas_double_complex>();
}

// INSTANTIATE_TEST_SUITE_P(daily_lapack,
// GEQRF,
// Combine(ValuesIn(large_matrix_size_range), ValuesIn(large_n_size_range)));
Expand All @@ -170,3 +218,19 @@ INSTANTIATE_TEST_SUITE_P(checkin_lapack,
INSTANTIATE_TEST_SUITE_P(checkin_lapack,
GEQRF_FORTRAN,
Combine(ValuesIn(matrix_size_range), ValuesIn(n_size_range)));

// INSTANTIATE_TEST_SUITE_P(daily_lapack,
// GEQRF_COMPAT,
// Combine(ValuesIn(large_matrix_size_range), ValuesIn(large_n_size_range)));

INSTANTIATE_TEST_SUITE_P(checkin_lapack,
GEQRF_COMPAT,
Combine(ValuesIn(matrix_size_range), ValuesIn(n_size_range)));

// INSTANTIATE_TEST_SUITE_P(daily_lapack,
// GEQRF_COMPAT_64,
// Combine(ValuesIn(large_matrix_size_range), ValuesIn(large_n_size_range)));

INSTANTIATE_TEST_SUITE_P(checkin_lapack,
GEQRF_COMPAT_64,
Combine(ValuesIn(matrix_size_range), ValuesIn(n_size_range)));
Loading

0 comments on commit 7449ed7

Please sign in to comment.