Skip to content

Commit

Permalink
[LinearAlgebra] Use class template argument deduction with MatrixExpr (
Browse files Browse the repository at this point in the history
  • Loading branch information
alxbilger authored Dec 14, 2023
1 parent 4a3db69 commit 6ed9f8f
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1311,7 +1311,7 @@ class CompressedRowSparseMatrixMechanical final // final is used to allow the co
template<class TBlock2, class TPolicy2>
void operator-=(const CompressedRowSparseMatrixMechanical<TBlock2, TPolicy2>& m)
{
equal(MatrixExpr< MatrixNegative< CompressedRowSparseMatrixMechanical<TBlock2, TPolicy2> > >(MatrixNegative< CompressedRowSparseMatrixMechanical<TBlock2, TPolicy2> >(m)), true);
equal(MatrixExpr { MatrixNegative< CompressedRowSparseMatrixMechanical<TBlock2, TPolicy2> >(m) }, true);
}

template<class Expr2>
Expand All @@ -1329,23 +1329,23 @@ class CompressedRowSparseMatrixMechanical final // final is used to allow the co
template<class Expr2>
void operator-=(const MatrixExpr< Expr2 >& m)
{
addEqual(MatrixExpr< MatrixNegative< Expr2 > >(MatrixNegative< Expr2 >(m)));
addEqual(MatrixExpr{ MatrixNegative< Expr2 >(m) } );
}

MatrixExpr< MatrixTranspose< Matrix > > t() const
{
return MatrixExpr< MatrixTranspose< Matrix > >(MatrixTranspose< Matrix >(*this));
return MatrixExpr{ MatrixTranspose< Matrix >{*this} };
}


MatrixExpr< MatrixNegative< Matrix > > operator-() const
{
return MatrixExpr< MatrixNegative< Matrix > >(MatrixNegative< Matrix >(*this));
return MatrixExpr{ MatrixNegative< Matrix >(*this) };
}

MatrixExpr< MatrixScale< Matrix, double > > operator*(const double& r) const
{
return MatrixExpr< MatrixScale< Matrix, double > >(MatrixScale< Matrix, double >(*this, r));
return MatrixExpr{ MatrixScale< Matrix, double >(*this, r) };
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,27 +283,27 @@ class DiagonalMatrix : public linearalgebra::BaseMatrix
template<class Expr2>
void operator-=(const MatrixExpr< Expr2 >& m)
{
addEqual(MatrixExpr< MatrixNegative< Expr2 > >(MatrixNegative< Expr2 >(m)));
addEqual(MatrixExpr { MatrixNegative< Expr2 >(m) } );
}

MatrixExpr< MatrixTranspose< DiagonalMatrix<T> > > t() const
{
return MatrixExpr< MatrixTranspose< DiagonalMatrix<T> > >(MatrixTranspose< DiagonalMatrix<T> >(*this));
return MatrixExpr { MatrixTranspose< DiagonalMatrix<T> >(*this) };
}

MatrixExpr< MatrixInverse< DiagonalMatrix<T> > > i() const
{
return MatrixExpr< MatrixInverse< DiagonalMatrix<T> > >(MatrixInverse< DiagonalMatrix<T> >(*this));
return MatrixExpr { MatrixInverse< DiagonalMatrix<T> >(*this) };
}

MatrixExpr< MatrixNegative< DiagonalMatrix<T> > > operator-() const
{
return MatrixExpr< MatrixNegative< DiagonalMatrix<T> > >(MatrixNegative< DiagonalMatrix<T> >(*this));
return MatrixExpr { MatrixNegative< DiagonalMatrix<T> >(*this) };
}

MatrixExpr< MatrixScale< DiagonalMatrix<T>, double > > operator*(const double& r) const
{
return MatrixExpr< MatrixScale< DiagonalMatrix<T>, double > >(MatrixScale< DiagonalMatrix<T>, double >(*this, r));
return MatrixExpr { MatrixScale< DiagonalMatrix<T>, double >(*this, r) };
}

friend std::ostream& operator << (std::ostream& out, const DiagonalMatrix<T>& v )
Expand Down
33 changes: 19 additions & 14 deletions Sofa/framework/LinearAlgebra/src/sofa/linearalgebra/MatrixExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,60 +45,62 @@ class MatrixNegative;
template<class M1, class R2>
class MatrixScale;

/// Data structure representing an operation on matrices. Used in the context of
/// the expression templates pattern.
template<class T>
class MatrixExpr : public T
{
public:
typedef T Expr;

MatrixExpr(const Expr& e) : Expr(e) {}
explicit MatrixExpr(const Expr& e) : Expr(e) {}

template<class M2>
MatrixExpr< MatrixProduct< Expr, typename M2::Expr > > operator*(const M2& m) const
{
return MatrixExpr< MatrixProduct< Expr, typename M2::Expr > >(MatrixProduct< Expr, typename M2::Expr >(*this, m));
return MatrixExpr { MatrixProduct< Expr, typename M2::Expr >(*this, m) };
}
template<class M2>
MatrixExpr< MatrixAddition< Expr, typename M2::Expr > > operator+(const M2& m) const
{
return MatrixExpr< MatrixAddition< Expr, typename M2::Expr > >(MatrixAddition< Expr, typename M2::Expr >(*this, m));
return MatrixExpr { MatrixAddition< Expr, typename M2::Expr >(*this, m) };
}
template<class M2>
MatrixExpr< MatrixSubstraction< Expr, typename M2::Expr > > operator-(const M2& m) const
{
return MatrixExpr< MatrixSubstraction< Expr, typename M2::Expr > >(MatrixSubstraction< Expr, typename M2::Expr >(*this, m));
return MatrixExpr { MatrixSubstraction< Expr, typename M2::Expr >(*this, m) };
}
MatrixExpr< MatrixNegative< Expr > > operator-() const
{
return MatrixExpr< MatrixNegative< Expr > >(MatrixNegative< Expr >(*this));
return MatrixExpr { MatrixNegative< Expr >(*this) };
}
MatrixExpr< MatrixTranspose< Expr > > t() const
{
return MatrixExpr< MatrixTranspose< Expr > >(MatrixTranspose< Expr >(*this));
return MatrixExpr { MatrixTranspose< Expr >(*this) };
}

MatrixExpr< MatrixScale< Expr, double > > operator*(double d) const
{
return MatrixExpr< MatrixScale< Expr, double > >(MatrixScale< Expr, double >(*this, d));
return MatrixExpr { MatrixScale< Expr, double >(*this, d) };
}
friend MatrixExpr< MatrixScale< Expr, double > > operator*(double d, const MatrixExpr<Expr>& m)
{
return MatrixExpr< MatrixScale< Expr, double > >(MatrixScale< Expr, double >(m, d));
return MatrixExpr { MatrixScale< Expr, double >(m, d) };
}
template<class M1>
friend MatrixExpr< MatrixProduct< typename M1::Expr, Expr > > operator*(const M1& m1, const MatrixExpr<Expr>& m2)
{
return MatrixExpr< MatrixProduct< typename M1::Expr, Expr > >(MatrixProduct< typename M1::Expr, Expr >(m1,m2));
return MatrixExpr { MatrixProduct< typename M1::Expr, Expr >(m1,m2) };
}
template<class M1>
friend MatrixExpr< MatrixAddition< typename M1::Expr, Expr > > operator+(const M1& m1, const MatrixExpr<Expr>& m2)
{
return MatrixExpr< MatrixAddition< typename M1::Expr, Expr > >(MatrixAddition< typename M1::Expr, Expr >(m1,m2));
return MatrixExpr { MatrixAddition< typename M1::Expr, Expr >(m1,m2) };
}
template<class M1>
friend MatrixExpr< MatrixSubstraction< typename M1::Expr, Expr > > operator-(const M1& m1, const MatrixExpr<Expr>& m2)
{
return MatrixExpr< MatrixSubstraction< typename M1::Expr, Expr > >(MatrixSubstraction< typename M1::Expr, Expr >(m1,m2));
return MatrixExpr { MatrixSubstraction< typename M1::Expr, Expr >(m1,m2) };
}
};

Expand Down Expand Up @@ -216,7 +218,8 @@ class MatrixNegative
typedef typename M1::matrix_type matrix_type;

const M1& m1;
MatrixNegative(const M1& m1) : m1(m1)

explicit MatrixNegative(const M1& m1) : m1(m1)
{}

bool valid() const
Expand Down Expand Up @@ -274,7 +277,8 @@ class MatrixTranspose
typedef typename M1::matrix_type matrix_type;

const M1& m1;
MatrixTranspose(const M1& m1) : m1(m1)

explicit MatrixTranspose(const M1& m1) : m1(m1)
{}

bool valid() const
Expand Down Expand Up @@ -556,7 +560,8 @@ class MatrixInverse
enum { operand = 0 };

const M1& m1;
MatrixInverse(const M1& m1) : m1(m1)

explicit MatrixInverse(const M1& m1) : m1(m1)
{}

bool valid() const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,35 +331,35 @@ class SparseMatrix : public linearalgebra::BaseMatrix

MatrixExpr< MatrixTranspose< SparseMatrix<T> > > t() const
{
return MatrixExpr< MatrixTranspose< SparseMatrix<T> > >(MatrixTranspose< SparseMatrix<T> >(*this));
return MatrixExpr { MatrixTranspose< SparseMatrix<T> >(*this) };
}

MatrixExpr< MatrixNegative< SparseMatrix<T> > > operator-() const
{
return MatrixExpr< MatrixNegative< SparseMatrix<T> > >(MatrixNegative< SparseMatrix<T> >(*this));
return MatrixExpr { MatrixNegative< SparseMatrix<T> >(*this) };
}

template<class Real2>
MatrixExpr< MatrixProduct< SparseMatrix<T>, SparseMatrix<Real2> > > operator*(const SparseMatrix<Real2>& m) const
{
return MatrixExpr< MatrixProduct< SparseMatrix<T>, SparseMatrix<Real2> > >(MatrixProduct< SparseMatrix<T>, SparseMatrix<Real2> >(*this, m));
return MatrixExpr { MatrixProduct< SparseMatrix<T>, SparseMatrix<Real2> >(*this, m) };
}

MatrixExpr< MatrixScale< SparseMatrix<T>, double > > operator*(const double& r) const
{
return MatrixExpr< MatrixScale< SparseMatrix<T>, double > >(MatrixScale< SparseMatrix<T>, double >(*this, r));
return MatrixExpr { MatrixScale< SparseMatrix<T>, double >(*this, r) };
}

template<class Real2>
MatrixExpr< MatrixAddition< SparseMatrix<T>, SparseMatrix<Real2> > > operator+(const SparseMatrix<Real2>& m) const
{
return MatrixExpr< MatrixAddition< SparseMatrix<T>, SparseMatrix<Real2> > >(MatrixAddition< SparseMatrix<T>, SparseMatrix<Real2> >(*this, m));
return MatrixExpr { MatrixAddition< SparseMatrix<T>, SparseMatrix<Real2> >(*this, m) };
}

template<class Real2>
MatrixExpr< MatrixAddition< SparseMatrix<T>, SparseMatrix<Real2> > > operator-(const SparseMatrix<Real2>& m) const
{
return MatrixExpr< MatrixAddition< SparseMatrix<T>, SparseMatrix<Real2> > >(MatrixAddition< SparseMatrix<T>, SparseMatrix<Real2> >(*this, m));
return MatrixExpr { MatrixAddition< SparseMatrix<T>, SparseMatrix<Real2> >(*this, m) };
}

void swap(SparseMatrix<T>& m)
Expand Down

0 comments on commit 6ed9f8f

Please sign in to comment.