Skip to content

Commit

Permalink
Thyra DefaultMultipliedLinearOp: Caching of intermediate vectors
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Glusa <[email protected]>
  • Loading branch information
cgcgcg committed Jan 21, 2025
1 parent bc6496e commit fb5e877
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ class DefaultMultipliedLinearOp : virtual public MultipliedLinearOpBase<Scalar>
*/
bool opSupportedImpl(EOpTransp M_trans) const;

void allocateVecs(const Ordinal dim) const;

/** \brief . */
void applyImpl(
const EOpTransp M_trans,
Expand All @@ -228,6 +230,7 @@ class DefaultMultipliedLinearOp : virtual public MultipliedLinearOpBase<Scalar>
private:

Array<Teuchos::ConstNonconstObjectContainer<LinearOpBase<Scalar> > > Ops_;
mutable std::vector<Teuchos::RCP<MultiVectorBase<Scalar> > > T_k_;

inline void assertInitialized() const;
inline std::string getClassName() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,23 @@ bool DefaultMultipliedLinearOp<Scalar>::opSupportedImpl(EOpTransp M_trans) const
// ToDo: Cache these?
}

template<class Scalar>
void DefaultMultipliedLinearOp<Scalar>::allocateVecs(const Ordinal dim) const {
const int nOps = Ops_.size();
if ((T_k_.size() != nOps+1) || ((nOps > 0) && (T_k_[0]->domain()->dim() != dim))) {
// op[0]->range
// op[0]->domain == op[1]->range
// ...
// op[nOps-2]->domain == op[nOps-1]->range
// op[nOps-1]->domain
T_k_.resize(0);
for( int k = 0; k < nOps; ++k ) {
T_k_.push_back(createMembers(getOp(k)->range(), dim));
}
T_k_.push_back(createMembers(getOp(nOps-1)->domain(), dim));
}
}


template<class Scalar>
void DefaultMultipliedLinearOp<Scalar>::applyImpl(
Expand All @@ -255,6 +272,7 @@ void DefaultMultipliedLinearOp<Scalar>::applyImpl(
#endif // TEUCHOS_DEBUG
const int nOps = Ops_.size();
const Ordinal m = X.domain()->dim();
allocateVecs(m);
if( real_trans(M_trans)==NOTRANS ) {
//
// Y = alpha * M * X + beta*Y
Expand All @@ -265,7 +283,7 @@ void DefaultMultipliedLinearOp<Scalar>::applyImpl(
for( int k = nOps-1; k >= 0; --k ) {
RCP<MultiVectorBase<Scalar> > Y_k;
RCP<const MultiVectorBase<Scalar> > X_k;
if(k==0) Y_k = rcpFromPtr(Y); else Y_k = T_k = createMembers(getOp(k)->range(), m);
if(k==0) Y_k = rcpFromPtr(Y); else Y_k = T_k = T_k_[k];
if(k==nOps-1) X_k = rcpFromRef(X); else X_k = T_kp1;
if( k > 0 )
Thyra::apply(*getOp(k), M_trans, *X_k, Y_k.ptr());
Expand All @@ -284,7 +302,7 @@ void DefaultMultipliedLinearOp<Scalar>::applyImpl(
for( int k = 0; k <= nOps-1; ++k ) {
RCP<MultiVectorBase<Scalar> > Y_k;
RCP<const MultiVectorBase<Scalar> > X_k;
if(k==nOps-1) Y_k = rcpFromPtr(Y); else Y_k = T_k = createMembers(getOp(k)->domain(), m);
if(k==nOps-1) Y_k = rcpFromPtr(Y); else Y_k = T_k = T_k_[k+1];
if(k==0) X_k = rcpFromRef(X); else X_k = T_km1;
if( k < nOps-1 )
Thyra::apply(*getOp(k), M_trans, *X_k, Y_k.ptr());
Expand Down

0 comments on commit fb5e877

Please sign in to comment.