Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
decade-afk committed Jan 26, 2025
1 parent de31d9b commit c03540e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
2 changes: 2 additions & 0 deletions paddle/phi/backends/dynload/cusolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP);

#if CUDA_VERSION >= 9020
#define CUSOLVER_ROUTINE_EACH_R1(__macro) \
__macro(cusolverDnSgetrs); \
__macro(cusolverDnDgetrs); \
__macro(cusolverDnSpotrfBatched); \
__macro(cusolverDnDpotrfBatched); \
__macro(cusolverDnSpotrsBatched); \
Expand Down
11 changes: 9 additions & 2 deletions paddle/phi/kernels/cpu/lu_solve_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,16 @@ void LuSolveKernel(const Context& dev_ctx,
const auto& x_dims = x.dims();
const int64_t nrhs = x_dims[x_dims.size() - 1]; // Number of columns

// Get number of right-hand sides from x
const auto& x_dims = x.dims();
const int64_t nrhs = x_dims[x_dims.size() - 1]; // Number of columns

// Allocate output tensor
dev_ctx.template Alloc<T>(out);

// Copy RHS data to output (will be overwritten with solution)
std::copy_n(x.data<T>(), x.numel(), out->data<T>());
// std::copy_n(x.data<T>(), x.numel(), out->data<T>());
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);

// Prepare LAPACK parameters
char trans_char = (trans == "N") ? 'N' : ((trans == "T") ? 'T' : 'C');
Expand All @@ -64,9 +69,11 @@ void LuSolveKernel(const Context& dev_ctx,
auto outdims = out->dims();
auto outrank = outdims.size();
auto batchsize = product(common::slice_ddim(outdims, 0, outrank - 2));

auto out_data = out->data<T>();
auto lu_data = lu.data<T>();
auto pivots_data = pivots.data<int>();

for (int i = 0; i < batchsize; i++) {
auto out_data_item = &out_data[i * n_int * n_int];
auto* lu_data_item = &lu_data[i * n_int * n_int];
Expand All @@ -79,7 +86,7 @@ void LuSolveKernel(const Context& dev_ctx,
pivots_data_item,
out_data_item,
ldb,
info);
*info);

PADDLE_ENFORCE_EQ(
info,
Expand Down

0 comments on commit c03540e

Please sign in to comment.