Skip to content

Commit

Permalink
[LinearAlgebra] constexpr if statement when possible (#4352)
Browse files Browse the repository at this point in the history
* [LinearAlgebra] constexpr if statement when possible

* Use structured binding
  • Loading branch information
alxbilger authored Dec 14, 2023
1 parent 6ed9f8f commit 8774d29
Showing 1 changed file with 88 additions and 64 deletions.
152 changes: 88 additions & 64 deletions Sofa/framework/LinearAlgebra/src/sofa/linearalgebra/BaseMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
namespace sofa::linearalgebra
{

BaseMatrix::BaseMatrix() {}
BaseMatrix::BaseMatrix() = default;

BaseMatrix::~BaseMatrix()
{}
Expand Down Expand Up @@ -64,20 +64,27 @@ struct BaseMatrixLinearOpMV_BlockDiagonal
const Index colSize = mat->colSize();
BlockData buffer;

if (!add)
opVresize(result, (transpose ? colSize : rowSize));
for (std::pair<RowBlockConstIterator, RowBlockConstIterator> rowRange = mat->bRowsRange();
rowRange.first != rowRange.second;
++rowRange.first)
if constexpr (!add)
{
if constexpr (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
for (auto [rowIt, rowEnd] = mat->bRowsRange(); rowIt != rowEnd; ++rowIt)
{
std::pair<ColBlockConstIterator,ColBlockConstIterator> colRange = rowRange.first.range();
if (colRange.first != colRange.second) // diagonal block exists
auto [colBegin, colEnd] = rowIt.range();
if (colBegin != colEnd) // diagonal block exists
{
BlockConstAccessor block = colRange.first.bloc();
BlockConstAccessor block = colBegin.bloc();
const BlockData& bdata = *(const BlockData*)block.elements(buffer.ptr());
const Index i = block.getRow() * NL;
const Index j = block.getCol() * NC;
if (!transpose)
if constexpr (!transpose)
{
type::VecNoInit<NC,Real> vj;
for (int bj = 0; bj < NC; ++bj)
Expand Down Expand Up @@ -162,8 +169,17 @@ struct BaseMatrixLinearOpMV_BlockDiagonal<Real, 1, 1, add, transpose, M, V1, V2>
{
const Index rowSize = mat->rowSize();
const Index colSize = mat->colSize();
if (!add)
opVresize(result, (transpose ? colSize : rowSize));
if constexpr (!add)
{
if constexpr (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
const Index size = (rowSize < colSize) ? rowSize : colSize;
for (Index i=0; i<size; ++i)
{
Expand All @@ -188,14 +204,14 @@ struct BaseMatrixLinearOpMV_BlockSparse
BlockData buffer;
type::Vec<NC,Real> vtmpj;
type::Vec<NL,Real> vtmpi;
if (!add)
if constexpr (!add)
{
opVresize(result, (transpose ? colSize : rowSize));
for (std::pair<RowBlockConstIterator, RowBlockConstIterator> rowRange = mat->bRowsRange();
rowRange.first != rowRange.second;
++rowRange.first)
}
for (auto [rowIt, rowEnd] = mat->bRowsRange(); rowIt != rowEnd; ++rowIt)
{
const Index i = rowRange.first.row() * NL;
if (!transpose)
const Index i = rowIt.row() * NL;
if constexpr (!transpose)
{
for (int bi = 0; bi < NL; ++bi)
vtmpi[bi] = (Real)0;
Expand All @@ -205,14 +221,12 @@ struct BaseMatrixLinearOpMV_BlockSparse
for (int bi = 0; bi < NL; ++bi)
vtmpi[bi] = (Real)opVget(v, i+bi);
}
for (std::pair<ColBlockConstIterator,ColBlockConstIterator> colRange = rowRange.first.range();
colRange.first != colRange.second;
++colRange.first)
for (auto [colIt, colEnd] = rowIt.range(); colIt != colEnd; ++colIt)
{
BlockConstAccessor block = colRange.first.bloc();
BlockConstAccessor block = colIt.bloc();
const BlockData& bdata = *(const BlockData*)block.elements(buffer.ptr());
const Index j = block.getCol() * NC;
if (!transpose)
if constexpr (!transpose)
{
for (int bj = 0; bj < NC; ++bj)
vtmpj[bj] = (Real)opVget(v, j+bj);
Expand All @@ -231,14 +245,11 @@ struct BaseMatrixLinearOpMV_BlockSparse
opVadd(result, j+bj, vtmpj[bj]);
}
}
if (!transpose)
if constexpr (!transpose)
{
for (int bi = 0; bi < NL; ++bi)
opVadd(result, i+bi, vtmpi[bi]);
}
else
{
}
}
}
};
Expand All @@ -253,9 +264,18 @@ class BaseMatrixLinearOpMV
{
const Index rowSize = mat->rowSize();
const Index colSize = mat->colSize();
if (!add)
opVresize(result, (transpose ? colSize : rowSize));
if (!transpose)
if constexpr (!add)
{
if constexpr (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
if constexpr (!transpose)
{
for (Index i=0; i<rowSize; ++i)
{
Expand Down Expand Up @@ -285,8 +305,17 @@ class BaseMatrixLinearOpMV
{
const Index rowSize = mat->rowSize();
const Index colSize = mat->colSize();
if (!add)
opVresize(result, (transpose ? colSize : rowSize));
if constexpr (!add)
{
if constexpr (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
const Index size = (rowSize < colSize) ? rowSize : colSize;
for (Index i=0; i<size; ++i)
{
Expand All @@ -300,8 +329,17 @@ class BaseMatrixLinearOpMV
{
const Index rowSize = mat->rowSize();
const Index colSize = mat->colSize();
if (!add)
opVresize(result, (transpose ? colSize : rowSize));
if constexpr (!add)
{
if (transpose)
{
opVresize(result, colSize);
}
else
{
opVresize(result, rowSize);
}
}
const Index size = (rowSize < colSize) ? rowSize : colSize;
for (Index i=0; i<size; ++i)
{
Expand Down Expand Up @@ -498,22 +536,18 @@ struct BaseMatrixLinearOpAM_BlockSparse
{
BlockData buffer;

for (std::pair<RowBlockConstIterator, RowBlockConstIterator> rowRange = m1->bRowsRange();
rowRange.first != rowRange.second;
++rowRange.first)
for (auto [rowIt, rowEnd] = m1->bRowsRange(); rowIt != rowEnd; ++rowIt)
{
const Index i = rowRange.first.row() * NL;
const Index i = rowIt.row() * NL;

for (std::pair<ColBlockConstIterator,ColBlockConstIterator> colRange = rowRange.first.range();
colRange.first != colRange.second;
++colRange.first)
for (auto [colIt, colEnd] = rowIt.range(); colIt != colEnd; ++colIt)
{

BlockConstAccessor block = colRange.first.bloc();
BlockConstAccessor block = colIt.bloc();
const BlockData& bdata = *(const BlockData*)block.elements(buffer.ptr());
const Index j = block.getCol() * NC;

if (!transpose)
if constexpr (!transpose)
{
for (int bi = 0; bi < NL; ++bi)
for (int bj = 0; bj < NC; ++bj)
Expand Down Expand Up @@ -544,21 +578,16 @@ struct BaseMatrixLinearOpAMS_BlockSparse
{
BlockData buffer;

for (std::pair<RowBlockConstIterator, RowBlockConstIterator> rowRange = m1->bRowsRange();
rowRange.first != rowRange.second;
++rowRange.first)
for (auto [rowIt, rowEnd] = m1->bRowsRange(); rowIt != rowEnd; ++rowIt)
{
const Index i = rowRange.first.row() * NL;
const Index i = rowIt.row() * NL;

for (std::pair<ColBlockConstIterator,ColBlockConstIterator> colRange = rowRange.first.range();
colRange.first != colRange.second;
++colRange.first)
for (auto [colIt, colEnd] = rowIt.range(); colIt != colEnd; ++colIt)
{

BlockConstAccessor block = colRange.first.bloc();
BlockConstAccessor block = colIt.bloc();
const BlockData& bdata = *(const BlockData*)block.elements(buffer.ptr());
const Index j = block.getCol() * NC;
if (!transpose)
if constexpr (!transpose)
{
for (int bi = 0; bi < NL; ++bi)
for (int bj = 0; bj < NC; ++bj)
Expand Down Expand Up @@ -588,22 +617,17 @@ struct BaseMatrixLinearOpAM1_BlockSparse
{
BlockData buffer;

for (std::pair<RowBlockConstIterator, RowBlockConstIterator> rowRange = m1->bRowsRange();
rowRange.first != rowRange.second;
++rowRange.first)
for (auto [rowIt, rowEnd] = m1->bRowsRange(); rowIt != rowEnd; ++rowIt)
{
const Index i = rowRange.first.row();
const Index i = rowIt.row();

for (std::pair<ColBlockConstIterator,ColBlockConstIterator> colRange = rowRange.first.range();
colRange.first != colRange.second;
++colRange.first)
for (auto [colIt, colEnd] = rowIt.range(); colIt != colEnd; ++colIt)
{

BlockConstAccessor block = colRange.first.bloc();
BlockConstAccessor block = colIt.bloc();
const BlockData& bdata = *(const BlockData*)block.elements(&buffer);
const Index j = block.getCol();

if (!transpose)
if constexpr (!transpose)
{
m2->add(i,j,bdata * fact);
}
Expand All @@ -626,7 +650,7 @@ class BaseMatrixLinearOpAM
{
const Index rowSize = m1->rowSize();
const Index colSize = m2->colSize();
if (!transpose)
if constexpr (!transpose)
{
for (Index j=0; j<rowSize; ++j)
{
Expand Down

0 comments on commit 8774d29

Please sign in to comment.