Skip to content

Commit

Permalink
fixed NMatrix#inverse_exact method for MRI (does not apply to JRuby):…
Browse files Browse the repository at this point in the history
… issues #444, #569, #581, #582
  • Loading branch information
npriyadarshi authored and translunar committed Mar 11, 2017
1 parent cfadf50 commit 1725758
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 15 deletions.
28 changes: 14 additions & 14 deletions ext/nmatrix/math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ extern "C" {
// Math Functions //
////////////////////

namespace nm {
namespace nm {
namespace math {

/*
Expand Down Expand Up @@ -335,18 +335,18 @@ namespace nm {
for (int row = k + 1; row < M; ++row) {
typename MagnitudeDType<DType>::type big;
big = magnitude( matrix[M*row + k] ); // element below the temp pivot

if ( big > akk ) {
interchange = row;
akk = big;
akk = big;
}
}
}

if (interchange != k) { // check if rows need flipping
DType temp;

for (int col = 0; col < M; ++col) {
NM_SWAP(matrix[interchange*M + col], matrix[k*M + col], temp);
NM_SWAP(matrix[interchange*M + col], matrix[k*M + col], temp);
}
}

Expand All @@ -360,7 +360,7 @@ namespace nm {
DType pivot = matrix[k * (M + 1)];
matrix[k * (M + 1)] = (DType)(1); // set diagonal as 1 for in-place inversion

for (int col = 0; col < M; ++col) {
for (int col = 0; col < M; ++col) {
// divide each element in the kth row with the pivot
matrix[k*M + col] = matrix[k*M + col] / pivot;
}
Expand All @@ -369,7 +369,7 @@ namespace nm {
if (kk == k) continue;

DType dum = matrix[k + M*kk];
matrix[k + M*kk] = (DType)(0); // prepare for inplace inversion
matrix[k + M*kk] = (DType)(0); // prepare for inplace inversion
for (int col = 0; col < M; ++col) {
matrix[M*kk + col] = matrix[M*kk + col] - matrix[M*k + col] * dum;
}
Expand All @@ -384,7 +384,7 @@ namespace nm {

for (int row = 0; row < M; ++row) {
NM_SWAP(matrix[row * M + row_index[k]], matrix[row * M + col_index[k]],
temp);
temp);
}
}
}
Expand All @@ -410,14 +410,14 @@ namespace nm {
DType sum_of_squares, *p_row, *psubdiag, *p_a, scale, innerproduct;
int i, k, col;

// For each column use a Householder transformation to zero all entries
// For each column use a Householder transformation to zero all entries
// below the subdiagonal.
for (psubdiag = a + nrows, col = 0; col < nrows - 2; psubdiag += nrows + 1,
for (psubdiag = a + nrows, col = 0; col < nrows - 2; psubdiag += nrows + 1,
col++) {
// Calculate the signed square root of the sum of squares of the
// elements below the diagonal.

for (p_a = psubdiag, sum_of_squares = 0.0, i = col + 1; i < nrows;
for (p_a = psubdiag, sum_of_squares = 0.0, i = col + 1; i < nrows;
p_a += nrows, i++) {
sum_of_squares += *p_a * *p_a;
}
Expand Down Expand Up @@ -447,7 +447,7 @@ namespace nm {
*p_a -= u[k] * innerproduct;
}
}

// Postmultiply QA by Q
for (p_row = a, i = 0; i < nrows; p_row += nrows, i++) {
for (innerproduct = 0.0, k = col + 1; k < nrows; k++) {
Expand Down Expand Up @@ -485,7 +485,7 @@ namespace nm {
B[0] = A[lda+1] / det;
B[1] = -A[1] / det;
B[ldb] = -A[lda] / det;
B[ldb+1] = -A[0] / det;
B[ldb+1] = A[0] / det;

} else if (M == 3) {
// Calculate the exact determinant.
Expand Down Expand Up @@ -1313,7 +1313,7 @@ void nm_math_hessenberg(VALUE a) {
NULL, NULL, // does not support Complex
NULL // no support for Ruby Object
};

ttable[NM_DTYPE(a)](NM_SHAPE0(a), NM_STORAGE_DENSE(a)->elements);
}
/*
Expand Down
58 changes: 57 additions & 1 deletion lib/nmatrix/math.rb
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,62 @@ def invert
end
alias :inverse :invert

#
# call-seq:
# invert_exact! -> NMatrix
#
# Calulates inverse_exact of a matrix of size 2 or 3.
# Only works on dense matrices.
#
# * *Raises* :
# - +StorageTypeError+ -> only implemented on dense matrices.
# - +ShapeError+ -> matrix must be square.
# - +DataTypeError+ -> cannot invert an integer matrix in-place.
# - +NotImplementedError+ -> cannot find exact inverse of matrix with size greater than 3
#
def invert_exact!
raise(StorageTypeError, "invert only works on dense matrices currently") unless self.dense?
raise(ShapeError, "Cannot invert non-square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1]
raise(DataTypeError, "Cannot invert an integer matrix in-place") if self.integer_dtype?
#No internal implementation of getri, so use this other function
n = self.shape[0]
if n>3
raise(NotImplementedError, "Cannot find exact inverse of matrix of size greater than 3")
else
clond=self.clone
__inverse_exact__(clond, n, n)
end
end

#
# call-seq:
# invert_exact -> NMatrix
#
# Make a copy of the matrix, then invert using exact_inverse
#
# * *Returns* :
# - A dense NMatrix. Will be the same type as the input NMatrix,
# except if the input is an integral dtype, in which case it will be a
# :float64 NMatrix.
#
# * *Raises* :
# - +StorageTypeError+ -> only implemented on dense matrices.
# - +ShapeError+ -> matrix must be square.
# - +NotImplementedError+ -> cannot find exact inverse of matrix with size greater than 3
#
def invert_exact
#write this in terms of invert_exact! so plugins will only have to overwrite
#invert_exact! and not invert_exact
if self.integer_dtype?
cloned = self.cast(dtype: :float64)
cloned.invert_exact!
else
cloned = self.clone
cloned.invert_exact!
end
end
alias :inverse_exact :invert_exact

#
# call-seq:
# adjugate! -> NMatrix
Expand Down Expand Up @@ -1393,4 +1449,4 @@ def dtype_for_floor_or_ceil
self.__dense_map__ { |l| l.send(op,rhs) }
end
end
end
end
17 changes: 17 additions & 0 deletions spec/math_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,23 @@

expect(a.invert).to be_within(err).of(b)
end

it "should correctly find exact inverse" do
pending("not yet implemented for NMatrix-JRuby") if jruby?
a = NMatrix.new(:dense, 3, [1,2,3,0,1,4,5,6,0], dtype)
b = NMatrix.new(:dense, 3, [-24,18,5,20,-15,-4,-5,4,1], dtype)

expect(a.invert_exact).to be_within(err).of(b)
end

it "should correctly find exact inverse" do
pending("not yet implemented for NMatrix-JRuby") if jruby?
a = NMatrix.new(:dense, 2, [1,3,3,8,], dtype)
b = NMatrix.new(:dense, 2, [-8,3,3,-1], dtype)

expect(a.invert_exact).to be_within(err).of(b)
end

end
end

Expand Down

0 comments on commit 1725758

Please sign in to comment.