Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
decade-afk committed Jan 26, 2025
1 parent f69f370 commit 8c9771c
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 11 deletions.
19 changes: 9 additions & 10 deletions paddle/phi/kernels/cpu/lu_solve_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,15 @@ void LuSolveKernel(const Context& dev_ctx,
auto out_data_item = &out_data[i * n_int * n_int];
auto lu_data_item = &lu_data[i * n_int * n_int];
auto pivots_data_item = &pivots_data[i * n_int];
phi::dynload::sgetrs_(
trans_char,
n_int, // Order of matrix A
nrhs_int, // Number of right hand sides
lu_data_item, // LU factorization
lda, // Leading dimension of A
pivots_data, // Pivot indices
out_data_item, // RHS/solution matrix
ldb, // Leading dimension of B
info); // Status indicator
phi::funcs::lapackLuSolve<T>(trans_char,
n_int,
nrhs_int,
lu_data_item,
lda,
pivots_data_item,
out_data_item,
ldb,
info);
}
} else if (std::is_same<T, double>::value) {
phi::dynload::dgetrs_(
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/kernels/funcs/lapack/lapack_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,33 @@ void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
dynload::sgetrf_(&m, &n, a, &lda, ipiv, info);
}

// lu_solve
template <>
void lapackLuSolve<float>(char trans,
int n,
int nrhs,
float *a,
int lda,
int *ipiv,
float *b,
int ldb,
int *info) {
dynload::sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
}

template <>
void lapackLuSolve<double>(char trans,
int n,
int nrhs,
double *a,
int lda,
int *ipiv,
double *b,
int ldb,
int *info) {
dynload::sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
}

// eigh
template <>
void lapackEigh<float>(char jobz,
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/kernels/funcs/lapack/lapack_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ namespace funcs {
template <typename T>
void lapackLu(int m, int n, T *a, int lda, int *ipiv, int *info);

// Lu_solve
template <typename T>
void lapackLuSolve(char trans,
int n,
int nrhs,
T *a,
int lda,
int *ipiv,
T *b,
int ldb,
int *info);

// Eigh
template <typename T, typename ValueType = T>
void lapackEigh(char jobz,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3156,7 +3156,7 @@
func : lu_solve
data_type : x
backward : lu_solve_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface
# interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : lu_unpack
args : (Tensor x, Tensor y, bool unpack_ludata = true, bool unpack_pivots = true)
Expand Down

0 comments on commit 8c9771c

Please sign in to comment.