From 1725758cf13087fc15716c9586ff80e7b17f11de Mon Sep 17 00:00:00 2001 From: npriyadarshi Date: Mon, 6 Mar 2017 22:28:24 +0530 Subject: [PATCH] fixed NMatrix#inverse_exact method for MRI (does not apply to JRuby): issues #444, #569, #581, #582 --- ext/nmatrix/math.cpp | 28 ++++++++++----------- lib/nmatrix/math.rb | 58 +++++++++++++++++++++++++++++++++++++++++++- spec/math_spec.rb | 17 +++++++++++++ 3 files changed, 88 insertions(+), 15 deletions(-) diff --git a/ext/nmatrix/math.cpp b/ext/nmatrix/math.cpp index a69c8f52..995ae7ee 100644 --- a/ext/nmatrix/math.cpp +++ b/ext/nmatrix/math.cpp @@ -188,7 +188,7 @@ extern "C" { // Math Functions // //////////////////// -namespace nm { +namespace nm { namespace math { /* @@ -335,18 +335,18 @@ namespace nm { for (int row = k + 1; row < M; ++row) { typename MagnitudeDType::type big; big = magnitude( matrix[M*row + k] ); // element below the temp pivot - + if ( big > akk ) { interchange = row; - akk = big; + akk = big; } - } + } if (interchange != k) { // check if rows need flipping DType temp; for (int col = 0; col < M; ++col) { - NM_SWAP(matrix[interchange*M + col], matrix[k*M + col], temp); + NM_SWAP(matrix[interchange*M + col], matrix[k*M + col], temp); } } @@ -360,7 +360,7 @@ namespace nm { DType pivot = matrix[k * (M + 1)]; matrix[k * (M + 1)] = (DType)(1); // set diagonal as 1 for in-place inversion - for (int col = 0; col < M; ++col) { + for (int col = 0; col < M; ++col) { // divide each element in the kth row with the pivot matrix[k*M + col] = matrix[k*M + col] / pivot; } @@ -369,7 +369,7 @@ namespace nm { if (kk == k) continue; DType dum = matrix[k + M*kk]; - matrix[k + M*kk] = (DType)(0); // prepare for inplace inversion + matrix[k + M*kk] = (DType)(0); // prepare for inplace inversion for (int col = 0; col < M; ++col) { matrix[M*kk + col] = matrix[M*kk + col] - matrix[M*k + col] * dum; } @@ -384,7 +384,7 @@ namespace nm { for (int row = 0; row < M; ++row) { NM_SWAP(matrix[row * M + row_index[k]], matrix[row * M + col_index[k]], - temp); + temp); } } } @@ -410,14 +410,14 @@ namespace nm { DType sum_of_squares, *p_row, *psubdiag, *p_a, scale, innerproduct; int i, k, col; - // For each column use a Householder transformation to zero all entries + // For each column use a Householder transformation to zero all entries // below the subdiagonal. - for (psubdiag = a + nrows, col = 0; col < nrows - 2; psubdiag += nrows + 1, + for (psubdiag = a + nrows, col = 0; col < nrows - 2; psubdiag += nrows + 1, col++) { // Calculate the signed square root of the sum of squares of the // elements below the diagonal. - for (p_a = psubdiag, sum_of_squares = 0.0, i = col + 1; i < nrows; + for (p_a = psubdiag, sum_of_squares = 0.0, i = col + 1; i < nrows; p_a += nrows, i++) { sum_of_squares += *p_a * *p_a; } @@ -447,7 +447,7 @@ namespace nm { *p_a -= u[k] * innerproduct; } } - + // Postmultiply QA by Q for (p_row = a, i = 0; i < nrows; p_row += nrows, i++) { for (innerproduct = 0.0, k = col + 1; k < nrows; k++) { @@ -485,7 +485,7 @@ namespace nm { B[0] = A[lda+1] / det; B[1] = -A[1] / det; B[ldb] = -A[lda] / det; - B[ldb+1] = -A[0] / det; + B[ldb+1] = A[0] / det; } else if (M == 3) { // Calculate the exact determinant. @@ -1313,7 +1313,7 @@ void nm_math_hessenberg(VALUE a) { NULL, NULL, // does not support Complex NULL // no support for Ruby Object }; - + ttable[NM_DTYPE(a)](NM_SHAPE0(a), NM_STORAGE_DENSE(a)->elements); } /* diff --git a/lib/nmatrix/math.rb b/lib/nmatrix/math.rb index eca6e1fc..cf185945 100644 --- a/lib/nmatrix/math.rb +++ b/lib/nmatrix/math.rb @@ -112,6 +112,62 @@ def invert end alias :inverse :invert + # + # call-seq: + # invert_exact! -> NMatrix + # + # Calulates inverse_exact of a matrix of size 2 or 3. + # Only works on dense matrices. + # + # * *Raises* : + # - +StorageTypeError+ -> only implemented on dense matrices. + # - +ShapeError+ -> matrix must be square. + # - +DataTypeError+ -> cannot invert an integer matrix in-place. + # - +NotImplementedError+ -> cannot find exact inverse of matrix with size greater than 3 + # + def invert_exact! + raise(StorageTypeError, "invert only works on dense matrices currently") unless self.dense? + raise(ShapeError, "Cannot invert non-square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1] + raise(DataTypeError, "Cannot invert an integer matrix in-place") if self.integer_dtype? + #No internal implementation of getri, so use this other function + n = self.shape[0] + if n>3 + raise(NotImplementedError, "Cannot find exact inverse of matrix of size greater than 3") + else + clond=self.clone + __inverse_exact__(clond, n, n) + end + end + + # + # call-seq: + # invert_exact -> NMatrix + # + # Make a copy of the matrix, then invert using exact_inverse + # + # * *Returns* : + # - A dense NMatrix. Will be the same type as the input NMatrix, + # except if the input is an integral dtype, in which case it will be a + # :float64 NMatrix. + # + # * *Raises* : + # - +StorageTypeError+ -> only implemented on dense matrices. + # - +ShapeError+ -> matrix must be square. + # - +NotImplementedError+ -> cannot find exact inverse of matrix with size greater than 3 + # + def invert_exact + #write this in terms of invert_exact! so plugins will only have to overwrite + #invert_exact! and not invert_exact + if self.integer_dtype? + cloned = self.cast(dtype: :float64) + cloned.invert_exact! + else + cloned = self.clone + cloned.invert_exact! + end + end + alias :inverse_exact :invert_exact + # # call-seq: # adjugate! -> NMatrix @@ -1393,4 +1449,4 @@ def dtype_for_floor_or_ceil self.__dense_map__ { |l| l.send(op,rhs) } end end -end \ No newline at end of file +end diff --git a/spec/math_spec.rb b/spec/math_spec.rb index 5003e526..0bc59620 100644 --- a/spec/math_spec.rb +++ b/spec/math_spec.rb @@ -488,6 +488,23 @@ expect(a.invert).to be_within(err).of(b) end + + it "should correctly find exact inverse" do + pending("not yet implemented for NMatrix-JRuby") if jruby? + a = NMatrix.new(:dense, 3, [1,2,3,0,1,4,5,6,0], dtype) + b = NMatrix.new(:dense, 3, [-24,18,5,20,-15,-4,-5,4,1], dtype) + + expect(a.invert_exact).to be_within(err).of(b) + end + + it "should correctly find exact inverse" do + pending("not yet implemented for NMatrix-JRuby") if jruby? + a = NMatrix.new(:dense, 2, [1,3,3,8,], dtype) + b = NMatrix.new(:dense, 2, [-8,3,3,-1], dtype) + + expect(a.invert_exact).to be_within(err).of(b) + end + end end