diff --git a/paddle/phi/kernels/cpu/lu_solve_kernel.cc b/paddle/phi/kernels/cpu/lu_solve_kernel.cc index 030e2a5d428261..4ea604089ee19e 100644 --- a/paddle/phi/kernels/cpu/lu_solve_kernel.cc +++ b/paddle/phi/kernels/cpu/lu_solve_kernel.cc @@ -73,16 +73,15 @@ void LuSolveKernel(const Context& dev_ctx, auto out_data_item = &out_data[i * n_int * n_int]; auto lu_data_item = &lu_data[i * n_int * n_int]; auto pivots_data_item = &pivots_data[i * n_int]; - phi::dynload::sgetrs_( - trans_char, - n_int, // Order of matrix A - nrhs_int, // Number of right hand sides - lu_data_item, // LU factorization - lda, // Leading dimension of A - pivots_data, // Pivot indices - out_data_item, // RHS/solution matrix - ldb, // Leading dimension of B - info); // Status indicator + phi::funcs::lapackLuSolve(trans_char, + n_int, + nrhs_int, + lu_data_item, + lda, + pivots_data_item, + out_data_item, + ldb, + info); } } else if (std::is_same::value) { phi::dynload::dgetrs_( diff --git a/paddle/phi/kernels/funcs/lapack/lapack_function.cc b/paddle/phi/kernels/funcs/lapack/lapack_function.cc index ebfd53291c36fa..8759accc565a26 100644 --- a/paddle/phi/kernels/funcs/lapack/lapack_function.cc +++ b/paddle/phi/kernels/funcs/lapack/lapack_function.cc @@ -30,6 +30,33 @@ void lapackLu(int m, int n, float *a, int lda, int *ipiv, int *info) { dynload::sgetrf_(&m, &n, a, &lda, ipiv, info); } +// lu_solve +template <> +void lapackLuSolve(char trans, + int n, + int nrhs, + float *a, + int lda, + int *ipiv, + float *b, + int ldb, + int *info) { + dynload::sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info); +} + +template <> +void lapackLuSolve(char trans, + int n, + int nrhs, + double *a, + int lda, + int *ipiv, + double *b, + int ldb, + int *info) { + dynload::sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info); +} + // eigh template <> void lapackEigh(char jobz, diff --git a/paddle/phi/kernels/funcs/lapack/lapack_function.h b/paddle/phi/kernels/funcs/lapack/lapack_function.h index d251095bb79f06..e54792e1c5bb27 100644 --- a/paddle/phi/kernels/funcs/lapack/lapack_function.h +++ b/paddle/phi/kernels/funcs/lapack/lapack_function.h @@ -21,6 +21,18 @@ namespace funcs { template void lapackLu(int m, int n, T *a, int lda, int *ipiv, int *info); +// Lu_solve +template +void lapackLuSolve(char trans, + int n, + int nrhs, + T *a, + int lda, + int *ipiv, + T *b, + int ldb, + int *info); + // Eigh template void lapackEigh(char jobz, diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 527b01e5bc7b79..74b2082b9394cf 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3156,7 +3156,7 @@ func : lu_solve data_type : x backward : lu_solve_grad - interfaces : paddle::dialect::InferSymbolicShapeInterface + # interfaces : paddle::dialect::InferSymbolicShapeInterface - op : lu_unpack args : (Tensor x, Tensor y, bool unpack_ludata = true, bool unpack_pivots = true)