Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stochastic update functionality to FB algorithm #445

Merged
merged 10 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 53 additions & 25 deletions cpp/sopt/forward_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include "sopt/logging.h"
#include "sopt/types.h"

#include "sopt/gradient_utils.h"

namespace sopt::algorithm {

/*! \brief Forward Backward Splitting
Expand Down Expand Up @@ -41,6 +43,7 @@ class ForwardBackward {
//! Type of the gradient
// The first argument is the output vector, the others are inputs
using t_Gradient = std::function<void(t_Vector &gradient, const t_Vector &image, const t_Vector &residual, const t_LinearTransform& Phi)>;
using t_randomUpdater = std::function<std::shared_ptr<IterationState<t_Vector>>()>;

//! Values indicating how the algorithm ran
struct Diagnostic {
Expand All @@ -65,19 +68,20 @@ class ForwardBackward {
//! Setups ForwardBackward
//! \param[in] f_function: the differentiable function \f$f\f$ with a gradient
//! \param[in] g_function: the non-differentiable function \f$g\f$ with a proximal operator
template <typename DERIVED>
ForwardBackward(t_Gradient const &f_gradient, t_Proximal const &g_proximal,
Eigen::MatrixBase<DERIVED> const &target)
t_Vector const &target)
: itermax_(std::numeric_limits<t_uint>::max()),
regulariser_strength_(1e-8),
step_size_(1),
sq_op_norm_(1),
is_converged_(),
fista_(true),
Phi_(linear_transform_identity<Scalar>()),
f_gradient_(f_gradient),
g_proximal_(g_proximal),
target_(target) {}
g_proximal_(g_proximal)
{
std::shared_ptr<t_LinearTransform> Id = std::make_shared<t_LinearTransform>(linear_transform_identity<Scalar>());
problem_state = std::make_shared<IterationState<t_Vector>>(target, Id);
}
virtual ~ForwardBackward() {}

// Macro helps define properties that can be initialized as in
Expand Down Expand Up @@ -107,12 +111,30 @@ class ForwardBackward {
//! \brief A function verifying convergence
//! \details It takes as input two arguments: the current solution x and the current residual.
SOPT_MACRO(is_converged, t_IsConverged);
//! Measurement operator
SOPT_MACRO(Phi, t_LinearTransform);
//! First proximal
SOPT_MACRO(f_gradient, t_Gradient);
//! Second proximal
SOPT_MACRO(g_proximal, t_Proximal);

//! Measurement operator
t_LinearTransform const &Phi() const { return problem_state->Phi(); }
ForwardBackward<SCALAR> &Phi(t_LinearTransform const &new_phi) {
problem_state->Phi(new_phi);
return *this;
}

ForwardBackward<SCALAR> &random_updater(t_randomUpdater &rU)
{
random_updater_ = rU;
return *this;
}

ForwardBackward<SCALAR> &set_problem_state(std::shared_ptr<IterationState<t_Vector>> pS)
{
problem_state = pS;
return *this;
}

#undef SOPT_MACRO
//! \brief Simplifies calling the gradient function
void f_gradient(t_Vector &out, t_Vector const &x, t_Vector const &res, t_LinearTransform const &Phi) const { f_gradient()(out, x, res, Phi); }
Expand All @@ -127,11 +149,10 @@ class ForwardBackward {
}

//! Vector of target measurements
t_Vector const &target() const { return target_; }
t_Vector const &target() const { return problem_state->target(); }
//! Sets the vector of target measurements
template <typename DERIVED>
ForwardBackward<Scalar> &target(Eigen::MatrixBase<DERIVED> const &target) {
target_ = target;
ForwardBackward<Scalar> &target(t_Vector const &target) {
problem_state->target(target);
return *this;
}

Expand All @@ -142,50 +163,50 @@ class ForwardBackward {

//! \brief Calls Forward Backward
//! \param[out] out: Output vector x
Diagnostic operator()(t_Vector &out) const { return operator()(out, initial_guess()); }
Diagnostic operator()(t_Vector &out) { return operator()(out, initial_guess()); }
//! \brief Calls Forward Backward
//! \param[out] out: Output vector x
//! \param[in] guess: initial guess
Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess) const {
Diagnostic operator()(t_Vector &out, std::tuple<t_Vector, t_Vector> const &guess) {
return operator()(out, std::get<0>(guess), std::get<1>(guess));
}
//! \brief Calls Forward Backward
//! \param[out] out: Output vector x
//! \param[in] guess: initial guess
Diagnostic operator()(t_Vector &out,
std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
std::tuple<t_Vector const &, t_Vector const &> const &guess) {
return operator()(out, std::get<0>(guess), std::get<1>(guess));
}
//! \brief Calls Forward Backward
//! \param[in] guess: initial guess
DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess) const {
DiagnosticAndResult operator()(std::tuple<t_Vector, t_Vector> const &guess) {
return operator()(std::tie(std::get<0>(guess), std::get<1>(guess)));
}
//! \brief Calls Forward Backward
//! \param[in] guess: initial guess
DiagnosticAndResult operator()(
std::tuple<t_Vector const &, t_Vector const &> const &guess) const {
std::tuple<t_Vector const &, t_Vector const &> const &guess) {
DiagnosticAndResult result;
static_cast<Diagnostic &>(result) = operator()(result.x, guess);
return result;
}
//! \brief Calls Forward Backward
//! \param[in] guess: initial guess
DiagnosticAndResult operator()() const {
DiagnosticAndResult operator()() {
DiagnosticAndResult result;
static_cast<Diagnostic &>(result) = operator()(result.x, initial_guess());
return result;
}
//! Makes it simple to chain different calls to FB
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const {
DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) {
DiagnosticAndResult result = warmstart;
static_cast<Diagnostic &>(result) = operator()(result.x, warmstart.x, warmstart.residual);
return result;
}
//! Set Φ and Φ^† using arguments that sopt::linear_transform understands
template <typename... ARGS>
typename std::enable_if<sizeof...(ARGS) >= 1, ForwardBackward &>::type Phi(ARGS &&... args) {
Phi_ = linear_transform(std::forward<ARGS>(args)...);
problem_state->Phi(linear_transform(std::forward<ARGS>(args)...));
return *this;
}

Expand Down Expand Up @@ -213,7 +234,7 @@ class ForwardBackward {

protected:
void iteration_step(t_Vector &out, t_Vector &residual, t_Vector &p, t_Vector &z,
const t_real lambda) const;
const t_real lambda);

//! Checks input makes sense
void sanity_check(t_Vector const &x_guess, t_Vector const &res_guess) const {
Expand All @@ -231,10 +252,11 @@ class ForwardBackward {
//! \param[out] out: Output vector x
//! \param[in] guess: initial guess
//! \param[in] residuals: initial residuals
Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) const;
Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res);

//! Vector of measurements
t_Vector target_;
//! problem state (shared with Imaging Forward Backward)
std::shared_ptr<IterationState<t_Vector>> problem_state;
t_randomUpdater random_updater_;
};

/**
Expand All @@ -253,19 +275,25 @@ class ForwardBackward {
*/
template <typename SCALAR>
void ForwardBackward<SCALAR>::iteration_step(t_Vector &image, t_Vector &residual, t_Vector &auxilliary_image,
t_Vector &gradient_current, const t_real FISTA_step) const {
t_Vector &gradient_current, const t_real FISTA_step) {
t_Vector prev_image = image;
f_gradient(gradient_current, auxilliary_image, residual, Phi()); // assigns gradient_current (non normalised)
t_Vector auxilliary_with_step = auxilliary_image - step_size() / sq_op_norm() * gradient_current; // step to new image using gradient
const Real weight = regulariser_strength() * step_size();
g_proximal(image, weight, auxilliary_with_step); // apply proximal operator to new image
auxilliary_image = image + FISTA_step * (image - prev_image); // update auxilliary vector with FISTA acceleration step

// set up next iteration
if(random_updater_)
{
problem_state = random_updater_();
}
residual = (Phi() * auxilliary_image) - target(); // updates the residual for the NEXT iteration (new image).
}

template <typename SCALAR>
typename ForwardBackward<SCALAR>::Diagnostic ForwardBackward<SCALAR>::operator()(
t_Vector &out, t_Vector const &x_guess, t_Vector const &res_guess) const {
t_Vector &out, t_Vector const &x_guess, t_Vector const &res_guess) {
SOPT_HIGH_LOG("Performing Forward Backward Splitting");
if (fista()) {
SOPT_HIGH_LOG("Using FISTA algorithm");
Expand Down
18 changes: 13 additions & 5 deletions cpp/sopt/gradient_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,30 @@ namespace sopt {
template <typename T>
class IterationState {
public:
IterationState() = delete;
IterationState(const T& target)
{
_Phi = std::make_shared<sopt::LinearTransform<T>>(linear_transform_identity<T>());
}

IterationState(const T& target,
std::shared_ptr<sopt::LinearTransform<T>> phi)
std::shared_ptr<sopt::LinearTransform<T>> Phi)
: _target(target) {
_phi = phi;
_Phi = Phi;
}

const T& target() const { return _target; }

const sopt::LinearTransform<T>& phi() const { return *_phi; }
const sopt::LinearTransform<T>& Phi() const { return *_Phi; }

void Phi(const sopt::LinearTransform<T> &new_phi)
{
_Phi = std::make_shared<sopt::LinearTransform<T>>(new_phi);
}

private:
const T _target;

std::shared_ptr<sopt::LinearTransform<T>> _phi;
std::shared_ptr<sopt::LinearTransform<T>> _Phi;
};

} // namespace sopt
Expand Down
Loading
Loading