diff --git a/batched/dense/impl/KokkosBatched_Getrf_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Getrf_Serial_Internal.hpp index b765aa4af8..bfbe65f9c8 100644 --- a/batched/dense/impl/KokkosBatched_Getrf_Serial_Internal.hpp +++ b/batched/dense/impl/KokkosBatched_Getrf_Serial_Internal.hpp @@ -126,9 +126,8 @@ KOKKOS_INLINE_FUNCTION int SerialGetrfInternalHost::invo // Use recursive code auto n1 = Kokkos::min(m, n) / 2; - // [ A00 ] - // Factor [ --- ] - // [ A10 ] + // Factor A0 = [[A00], + // [A10]] // split A into two submatrices A = [A0, A1] auto A0 = Kokkos::subview(A, Kokkos::ALL, Kokkos::pair(0, n1)); @@ -138,9 +137,8 @@ KOKKOS_INLINE_FUNCTION int SerialGetrfInternalHost::invo if (info == 0 && iinfo > 0) info = iinfo; - // [ A01 ] - // Apply interchanges to [ --- ] - // [ A11 ] + // Apply interchanges to A1 = [[A01], + // [A11]] [[maybe_unused]] auto info_laswp = KokkosBatched::SerialLaswp::invoke(ipiv0, A1); @@ -202,7 +200,7 @@ KOKKOS_INLINE_FUNCTION int SerialGetrfInternalDevice::in if (m <= 0 || n <= 0) return 0; while (!stack.isEmpty()) { - // First of make a subview based on the current state + // Firstly, make a subview based on the current state int current[7]; stack.pop(current); @@ -280,13 +278,11 @@ KOKKOS_INLINE_FUNCTION int SerialGetrfInternalDevice::in } else if (state == 1) { // after first recursive call - // [ A00 ] - // Factor [ --- ] - // [ A10 ] + // Factor A0 = [[A00], + // [A10]] - // [ A01 ] - // Apply interchanges to [ --- ] - // [ A11 ] + // Apply interchanges to A1 = [[A01], + // [A11]] KokkosBatched::SerialLaswp::invoke(ipiv0, A1); // Solve A00 * X = A01 diff --git a/batched/dense/src/KokkosBatched_Getrf.hpp b/batched/dense/src/KokkosBatched_Getrf.hpp index 058364dfbd..1b1bcac903 100644 --- a/batched/dense/src/KokkosBatched_Getrf.hpp +++ b/batched/dense/src/KokkosBatched_Getrf.hpp @@ -30,6 +30,20 @@ namespace KokkosBatched { /// where P is a permutation matrix, L is lower triangular with unit /// diagonal elements (lower trapezoidal if m > n), and U is upper /// triangular (upper trapezoidal if m < n). +/// +/// This is the recusive version of the algorithm. It divides the matrix +/// into four submatrices: +/// A = [[A00, A01], +/// [A10, A11]] +/// where A00 is a square matrix of size n0, A11 is a matrix of size n1 by n1 +/// with n0 = min(m, n) / 2 and n1 = n - n0. +/// +/// This function calls itself to factorize A0 = [[A00], +// [A10]] +/// do the swaps on A1 = [[A01], +/// [A11]] +/// solve A01, update A11, then calls itself to factorize A11 +/// and do the swaps on A10. /// \tparam AViewType: Input type for the matrix, needs to be a 2D view /// \tparam PivViewType: Input type for the pivot indices, needs to be a 1D view ///