From 50ef7e2faf5cefb04388c210a6499e0c16ed5687 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 12 Dec 2024 13:18:12 +0000 Subject: [PATCH 01/10] Use capital Phi to match other syntax --- cpp/sopt/gradient_utils.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/sopt/gradient_utils.h b/cpp/sopt/gradient_utils.h index 5c4edeb9..4cbdf3e3 100644 --- a/cpp/sopt/gradient_utils.h +++ b/cpp/sopt/gradient_utils.h @@ -15,19 +15,19 @@ class IterationState { IterationState() = delete; IterationState(const T& target, - std::shared_ptr> phi) + std::shared_ptr> Phi) : _target(target) { - _phi = phi; + _Phi = Phi; } const T& target() const { return _target; } - const sopt::LinearTransform& phi() const { return *_phi; } + const sopt::LinearTransform& Phi() const { return *_Phi; } private: const T _target; - std::shared_ptr> _phi; + std::shared_ptr> _Phi; }; } // namespace sopt From b26014e29d969bf32aa094e9c7c19aeec4ea4fc1 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 12 Dec 2024 13:18:40 +0000 Subject: [PATCH 02/10] Avoid target copying --- cpp/sopt/forward_backward.h | 14 ++++++-------- cpp/sopt/imaging_forward_backward.h | 27 +++++++++++++++++++-------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/cpp/sopt/forward_backward.h b/cpp/sopt/forward_backward.h index c1c02986..2b9ffd9e 100644 --- a/cpp/sopt/forward_backward.h +++ b/cpp/sopt/forward_backward.h @@ -65,9 +65,8 @@ class ForwardBackward { //! Setups ForwardBackward //! \param[in] f_function: the differentiable function \f$f\f$ with a gradient //! \param[in] g_function: the non-differentiable function \f$g\f$ with a proximal operator - template ForwardBackward(t_Gradient const &f_gradient, t_Proximal const &g_proximal, - Eigen::MatrixBase const &target) + t_Vector const &target) : itermax_(std::numeric_limits::max()), regulariser_strength_(1e-8), step_size_(1), @@ -77,7 +76,7 @@ class ForwardBackward { Phi_(linear_transform_identity()), f_gradient_(f_gradient), g_proximal_(g_proximal), - target_(target) {} + target_(&target) {} virtual ~ForwardBackward() {} // Macro helps define properties that can be initialized as in @@ -127,11 +126,10 @@ class ForwardBackward { } //! Vector of target measurements - t_Vector const &target() const { return target_; } + t_Vector const &target() const { return *target_; } //! Sets the vector of target measurements - template - ForwardBackward &target(Eigen::MatrixBase const &target) { - target_ = target; + ForwardBackward &target(t_Vector const &target) { + target_ = ⌖ return *this; } @@ -234,7 +232,7 @@ class ForwardBackward { Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) const; //! Vector of measurements - t_Vector target_; + const t_Vector *target_; }; /** diff --git a/cpp/sopt/imaging_forward_backward.h b/cpp/sopt/imaging_forward_backward.h index 4e2ff847..1f75df27 100644 --- a/cpp/sopt/imaging_forward_backward.h +++ b/cpp/sopt/imaging_forward_backward.h @@ -17,6 +17,9 @@ #include "sopt/non_differentiable_func.h" #include "sopt/differentiable_func.h" +#include +#include "sopt/gradient_utils.h" + #ifdef SOPT_MPI #include "sopt/mpi/communicator.h" #include "sopt/mpi/utilities.h" @@ -38,6 +41,7 @@ class ImagingForwardBackward { using t_Gradient = typename FB::t_Gradient; using t_l2Gradient = typename std::function; using t_IsConverged = typename FB::t_IsConverged; + using t_randomUpdater = std::function()>; //! Values indicating how the algorithm ran struct Diagnostic : public FB::Diagnostic { @@ -61,10 +65,10 @@ class ImagingForwardBackward { // \f$f\f$ is differentiable with a supplied gradient and \f$g\f$ is non-differentiable with a supplied proximal operator. // Throughout this class we will use `f` and `g` in variables to refer to these two parts of the objective function. //! \param[in] target: Vector of target measurements - template - ImagingForwardBackward(Eigen::MatrixBase const &target) + ImagingForwardBackward(t_Vector const &target) : g_function_(nullptr), f_function_(nullptr), + random_updater_(nullptr), tight_frame_(false), residual_tolerance_(0.), relative_variation_(1e-4), @@ -78,7 +82,7 @@ class ImagingForwardBackward { fista_(true), is_converged_(), Phi_(linear_transform_identity()), - target_(target) {} + target_(&target) {} virtual ~ImagingForwardBackward() {} // Macro helps define properties that can be initialized as in @@ -152,6 +156,13 @@ class ImagingForwardBackward { return *this; } + // Getter and setter for the random updater object + t_randomUpdater &random_updater() { return random_updater_; } + ImagingForwardBackward& random_updater( t_randomUpdater &f_function) { + random_updater_ = random_updater_; // may change this to a move if we don't need to keep it + return *this; + } + t_LinearTransform const &Psi() const { return (g_function_) ? g_function_->Psi() : linear_transform_identity(); @@ -167,14 +178,13 @@ class ImagingForwardBackward { //} //! Vector of target measurements - t_Vector const &target() const { return target_; } + t_Vector const &target() const { return *target_; } //! Minimum of objective_function Real objmin() const { return objmin_; } //! Sets the vector of target measurements - template - ImagingForwardBackward &target(Eigen::MatrixBase const &target) { - target_ = target; + ImagingForwardBackward &target(t_Vector const &target) { + target_ = ⌖ return *this; } @@ -251,9 +261,10 @@ class ImagingForwardBackward { // These should point to an instance of a derived class (e.g. L1GProximal) once set up std::shared_ptr> g_function_; std::shared_ptr> f_function_; + t_randomUpdater random_updater_; //! Vector of measurements - t_Vector target_; + const t_Vector *target_; //! Mininum of objective function mutable Real objmin_; From d3c1fbd287ae10cc8f404187720b63104f237c25 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 12 Dec 2024 13:47:08 +0000 Subject: [PATCH 03/10] Add constructor with random update --- cpp/sopt/imaging_forward_backward.h | 33 ++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/cpp/sopt/imaging_forward_backward.h b/cpp/sopt/imaging_forward_backward.h index 1f75df27..657a8cc5 100644 --- a/cpp/sopt/imaging_forward_backward.h +++ b/cpp/sopt/imaging_forward_backward.h @@ -19,6 +19,7 @@ #include #include "sopt/gradient_utils.h" +#include #ifdef SOPT_MPI #include "sopt/mpi/communicator.h" @@ -83,8 +84,37 @@ class ImagingForwardBackward { is_converged_(), Phi_(linear_transform_identity()), target_(&target) {} - virtual ~ImagingForwardBackward() {} + ImagingForwardBackward(t_randomUpdater &updater) + : g_function_(nullptr), + f_function_(nullptr), + random_updater_(updater), + tight_frame_(false), + residual_tolerance_(0.), + relative_variation_(1e-4), + residual_convergence_(nullptr), + objective_convergence_(nullptr), + itermax_(std::numeric_limits::max()), + regulariser_strength_(1e-8), + step_size_(1), + sigma_(1), + sq_op_norm_(1), + fista_(true), + is_converged_() + { + if(random_updater) + { + // target and Phi are not known ahead of time for random data sets so need to be initialised + target_state_ = std::make_unique>(random_updater_); + target_ = &target_state_.target(); + } + else + { + throw std::runtime_error("Attempted to construct ImagingForwardBackward class with a null random updater. To run without random updates supply a target vector instead."); + } + } + + virtual ~ImagingForwardBackward() {} // Macro helps define properties that can be initialized as in // auto padmm = ImagingForwardBackward().prop0(value).prop1(value); #define SOPT_MACRO(NAME, TYPE) \ @@ -262,6 +292,7 @@ class ImagingForwardBackward { std::shared_ptr> g_function_; std::shared_ptr> f_function_; t_randomUpdater random_updater_; + std::shared_ptr> target_state_; //! Vector of measurements const t_Vector *target_; From ac1e6a052f81dbde20f3d40532f5e2a9aa2bcb9d Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 12 Dec 2024 15:24:55 +0000 Subject: [PATCH 04/10] Refactor Phi and target into general problem state --- cpp/sopt/gradient_utils.h | 5 +++++ cpp/sopt/imaging_forward_backward.h | 33 +++++++++++++++++------------ 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/cpp/sopt/gradient_utils.h b/cpp/sopt/gradient_utils.h index 4cbdf3e3..3c7476dc 100644 --- a/cpp/sopt/gradient_utils.h +++ b/cpp/sopt/gradient_utils.h @@ -24,6 +24,11 @@ class IterationState { const sopt::LinearTransform& Phi() const { return *_Phi; } + void Phi(const sopt::LinearTransform &new_phi) + { + _Phi = std::make_shared>(sopt::LinearTransform(new_phi)); + } + private: const T _target; diff --git a/cpp/sopt/imaging_forward_backward.h b/cpp/sopt/imaging_forward_backward.h index 657a8cc5..095c667c 100644 --- a/cpp/sopt/imaging_forward_backward.h +++ b/cpp/sopt/imaging_forward_backward.h @@ -42,7 +42,7 @@ class ImagingForwardBackward { using t_Gradient = typename FB::t_Gradient; using t_l2Gradient = typename std::function; using t_IsConverged = typename FB::t_IsConverged; - using t_randomUpdater = std::function()>; + using t_randomUpdater = std::function>()>; //! Values indicating how the algorithm ran struct Diagnostic : public FB::Diagnostic { @@ -81,9 +81,11 @@ class ImagingForwardBackward { sigma_(1), sq_op_norm_(1), fista_(true), - is_converged_(), - Phi_(linear_transform_identity()), - target_(&target) {} + is_converged_() + { + std::shared_ptr Id = std::make_shared(linear_transform_identity()); + problem_state = std::make_shared>(target, Id); + } ImagingForwardBackward(t_randomUpdater &updater) : g_function_(nullptr), @@ -102,11 +104,10 @@ class ImagingForwardBackward { fista_(true), is_converged_() { - if(random_updater) + if(random_updater_) { // target and Phi are not known ahead of time for random data sets so need to be initialised - target_state_ = std::make_unique>(random_updater_); - target_ = &target_state_.target(); + problem_state = random_updater_(); } else { @@ -157,8 +158,13 @@ class ImagingForwardBackward { SOPT_MACRO(fista, bool); //! A function verifying convergence SOPT_MACRO(is_converged, t_IsConverged); + //! Measurement operator - SOPT_MACRO(Phi, t_LinearTransform); + t_LinearTransform const &Phi() const { return problem_state->Phi(); } + ImagingForwardBackward &Phi(t_LinearTransform const &(Phi)) { + problem_state->Phi(Phi); + return *this; + } #ifdef SOPT_MPI //! Communicator for summing objective_function @@ -208,13 +214,13 @@ class ImagingForwardBackward { //} //! Vector of target measurements - t_Vector const &target() const { return *target_; } + t_Vector const &target() const { return problem_state->target(); } //! Minimum of objective_function Real objmin() const { return objmin_; } //! Sets the vector of target measurements ImagingForwardBackward &target(t_Vector const &target) { - target_ = ⌖ + problem_state->target(target); return *this; } @@ -268,7 +274,7 @@ class ImagingForwardBackward { template typename std::enable_if= 1, ImagingForwardBackward &>::type Phi( ARGS &&... args) { - Phi_ = linear_transform(std::forward(args)...); + problem_state->Phi(linear_transform(std::forward(args)...)); return *this; } @@ -292,10 +298,9 @@ class ImagingForwardBackward { std::shared_ptr> g_function_; std::shared_ptr> f_function_; t_randomUpdater random_updater_; - std::shared_ptr> target_state_; + //! Problem state represents Phi and y s.t. the problem to solve is y = Phi x + std::shared_ptr> problem_state; - //! Vector of measurements - const t_Vector *target_; //! Mininum of objective function mutable Real objmin_; From ba0c6df93e64ef11ddfaf5ffe5c5e21563c071fe Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 12 Dec 2024 17:01:34 +0000 Subject: [PATCH 05/10] Flexibility for iteration state --- cpp/sopt/gradient_utils.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/sopt/gradient_utils.h b/cpp/sopt/gradient_utils.h index 3c7476dc..5fbd1206 100644 --- a/cpp/sopt/gradient_utils.h +++ b/cpp/sopt/gradient_utils.h @@ -12,7 +12,10 @@ namespace sopt { template class IterationState { public: - IterationState() = delete; + IterationState(const T& target) + { + _Phi = std::make_shared>(linear_transform_identity()); + } IterationState(const T& target, std::shared_ptr> Phi) @@ -26,7 +29,7 @@ class IterationState { void Phi(const sopt::LinearTransform &new_phi) { - _Phi = std::make_shared>(sopt::LinearTransform(new_phi)); + _Phi = std::make_shared>(new_phi); } private: From b8a647934f4954eb86f6770ef2b8076f5637a560 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 12 Dec 2024 17:01:53 +0000 Subject: [PATCH 06/10] Add random updater / problem state into FB --- cpp/sopt/forward_backward.h | 38 +++++++++++++++++++++-------- cpp/sopt/imaging_forward_backward.h | 2 +- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/cpp/sopt/forward_backward.h b/cpp/sopt/forward_backward.h index 2b9ffd9e..1c50a999 100644 --- a/cpp/sopt/forward_backward.h +++ b/cpp/sopt/forward_backward.h @@ -11,6 +11,8 @@ #include "sopt/logging.h" #include "sopt/types.h" +#include "sopt/gradient_utils.h" + namespace sopt::algorithm { /*! \brief Forward Backward Splitting @@ -41,6 +43,7 @@ class ForwardBackward { //! Type of the gradient // The first argument is the output vector, the others are inputs using t_Gradient = std::function; + using t_randomUpdater = std::function>()>; //! Values indicating how the algorithm ran struct Diagnostic { @@ -73,10 +76,12 @@ class ForwardBackward { sq_op_norm_(1), is_converged_(), fista_(true), - Phi_(linear_transform_identity()), f_gradient_(f_gradient), - g_proximal_(g_proximal), - target_(&target) {} + g_proximal_(g_proximal) + { + std::shared_ptr Id = std::make_shared(linear_transform_identity()); + problem_state = std::make_shared>(target, Id); + } virtual ~ForwardBackward() {} // Macro helps define properties that can be initialized as in @@ -106,12 +111,18 @@ class ForwardBackward { //! \brief A function verifying convergence //! \details It takes as input two arguments: the current solution x and the current residual. SOPT_MACRO(is_converged, t_IsConverged); - //! Measurement operator - SOPT_MACRO(Phi, t_LinearTransform); //! First proximal SOPT_MACRO(f_gradient, t_Gradient); //! Second proximal SOPT_MACRO(g_proximal, t_Proximal); + + //! Measurement operator + t_LinearTransform const &Phi() const { return problem_state->Phi(); } + ForwardBackward &Phi(t_LinearTransform const &(Phi)) { + problem_state->Phi(Phi); + return *this; + } + #undef SOPT_MACRO //! \brief Simplifies calling the gradient function void f_gradient(t_Vector &out, t_Vector const &x, t_Vector const &res, t_LinearTransform const &Phi) const { f_gradient()(out, x, res, Phi); } @@ -126,10 +137,10 @@ class ForwardBackward { } //! Vector of target measurements - t_Vector const &target() const { return *target_; } + t_Vector const &target() const { return problem_state->target(); } //! Sets the vector of target measurements ForwardBackward &target(t_Vector const &target) { - target_ = ⌖ + problem_state->target(target); return *this; } @@ -183,7 +194,7 @@ class ForwardBackward { //! Set Φ and Φ^† using arguments that sopt::linear_transform understands template typename std::enable_if= 1, ForwardBackward &>::type Phi(ARGS &&... args) { - Phi_ = linear_transform(std::forward(args)...); + problem_state->Phi(linear_transform(std::forward(args)...)); return *this; } @@ -231,8 +242,9 @@ class ForwardBackward { //! \param[in] residuals: initial residuals Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) const; - //! Vector of measurements - const t_Vector *target_; + //! problem state (shared with Imaging Forward Backward) + std::shared_ptr> problem_state; + t_randomUpdater random_updates; }; /** @@ -258,6 +270,12 @@ void ForwardBackward::iteration_step(t_Vector &image, t_Vector &residual const Real weight = regulariser_strength() * step_size(); g_proximal(image, weight, auxilliary_with_step); // apply proximal operator to new image auxilliary_image = image + FISTA_step * (image - prev_image); // update auxilliary vector with FISTA acceleration step + + // set up next iteration + if(random_updates) + { + problem_state = random_updates(); + } residual = (Phi() * auxilliary_image) - target(); // updates the residual for the NEXT iteration (new image). } diff --git a/cpp/sopt/imaging_forward_backward.h b/cpp/sopt/imaging_forward_backward.h index 095c667c..30000172 100644 --- a/cpp/sopt/imaging_forward_backward.h +++ b/cpp/sopt/imaging_forward_backward.h @@ -42,7 +42,7 @@ class ImagingForwardBackward { using t_Gradient = typename FB::t_Gradient; using t_l2Gradient = typename std::function; using t_IsConverged = typename FB::t_IsConverged; - using t_randomUpdater = std::function>()>; + using t_randomUpdater = typename FB::t_randomUpdater; //! Values indicating how the algorithm ran struct Diagnostic : public FB::Diagnostic { From 94b696bf1666de4ad01a15c574147308ff458e8e Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 12 Dec 2024 17:24:25 +0000 Subject: [PATCH 07/10] Stuff isn't const any more --- cpp/sopt/forward_backward.h | 44 ++++++++++++++++++----------- cpp/sopt/imaging_forward_backward.h | 8 ++++-- cpp/sopt/l2_forward_backward.h | 2 +- cpp/tests/forward_backward.cc | 2 +- 4 files changed, 35 insertions(+), 21 deletions(-) diff --git a/cpp/sopt/forward_backward.h b/cpp/sopt/forward_backward.h index 1c50a999..94d3c2a8 100644 --- a/cpp/sopt/forward_backward.h +++ b/cpp/sopt/forward_backward.h @@ -118,8 +118,20 @@ class ForwardBackward { //! Measurement operator t_LinearTransform const &Phi() const { return problem_state->Phi(); } - ForwardBackward &Phi(t_LinearTransform const &(Phi)) { - problem_state->Phi(Phi); + ForwardBackward &Phi(t_LinearTransform const &new_phi) { + problem_state->Phi(new_phi); + return *this; + } + + ForwardBackward &random_updater(t_randomUpdater &rU) + { + random_updater_ = rU; + return *this; + } + + ForwardBackward &set_problem_state(std::shared_ptr> pS) + { + problem_state = pS; return *this; } @@ -151,42 +163,42 @@ class ForwardBackward { //! \brief Calls Forward Backward //! \param[out] out: Output vector x - Diagnostic operator()(t_Vector &out) const { return operator()(out, initial_guess()); } + Diagnostic operator()(t_Vector &out) { return operator()(out, initial_guess()); } //! \brief Calls Forward Backward //! \param[out] out: Output vector x //! \param[in] guess: initial guess - Diagnostic operator()(t_Vector &out, std::tuple const &guess) const { + Diagnostic operator()(t_Vector &out, std::tuple const &guess) { return operator()(out, std::get<0>(guess), std::get<1>(guess)); } //! \brief Calls Forward Backward //! \param[out] out: Output vector x //! \param[in] guess: initial guess Diagnostic operator()(t_Vector &out, - std::tuple const &guess) const { + std::tuple const &guess) { return operator()(out, std::get<0>(guess), std::get<1>(guess)); } //! \brief Calls Forward Backward //! \param[in] guess: initial guess - DiagnosticAndResult operator()(std::tuple const &guess) const { + DiagnosticAndResult operator()(std::tuple const &guess) { return operator()(std::tie(std::get<0>(guess), std::get<1>(guess))); } //! \brief Calls Forward Backward //! \param[in] guess: initial guess DiagnosticAndResult operator()( - std::tuple const &guess) const { + std::tuple const &guess) { DiagnosticAndResult result; static_cast(result) = operator()(result.x, guess); return result; } //! \brief Calls Forward Backward //! \param[in] guess: initial guess - DiagnosticAndResult operator()() const { + DiagnosticAndResult operator()() { DiagnosticAndResult result; static_cast(result) = operator()(result.x, initial_guess()); return result; } //! Makes it simple to chain different calls to FB - DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) const { + DiagnosticAndResult operator()(DiagnosticAndResult const &warmstart) { DiagnosticAndResult result = warmstart; static_cast(result) = operator()(result.x, warmstart.x, warmstart.residual); return result; @@ -222,7 +234,7 @@ class ForwardBackward { protected: void iteration_step(t_Vector &out, t_Vector &residual, t_Vector &p, t_Vector &z, - const t_real lambda) const; + const t_real lambda); //! Checks input makes sense void sanity_check(t_Vector const &x_guess, t_Vector const &res_guess) const { @@ -240,11 +252,11 @@ class ForwardBackward { //! \param[out] out: Output vector x //! \param[in] guess: initial guess //! \param[in] residuals: initial residuals - Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res) const; + Diagnostic operator()(t_Vector &out, t_Vector const &guess, t_Vector const &res); //! problem state (shared with Imaging Forward Backward) std::shared_ptr> problem_state; - t_randomUpdater random_updates; + t_randomUpdater random_updater_; }; /** @@ -263,7 +275,7 @@ class ForwardBackward { */ template void ForwardBackward::iteration_step(t_Vector &image, t_Vector &residual, t_Vector &auxilliary_image, - t_Vector &gradient_current, const t_real FISTA_step) const { + t_Vector &gradient_current, const t_real FISTA_step) { t_Vector prev_image = image; f_gradient(gradient_current, auxilliary_image, residual, Phi()); // assigns gradient_current (non normalised) t_Vector auxilliary_with_step = auxilliary_image - step_size() / sq_op_norm() * gradient_current; // step to new image using gradient @@ -272,16 +284,16 @@ void ForwardBackward::iteration_step(t_Vector &image, t_Vector &residual auxilliary_image = image + FISTA_step * (image - prev_image); // update auxilliary vector with FISTA acceleration step // set up next iteration - if(random_updates) + if(random_updater_) { - problem_state = random_updates(); + problem_state = random_updater_(); } residual = (Phi() * auxilliary_image) - target(); // updates the residual for the NEXT iteration (new image). } template typename ForwardBackward::Diagnostic ForwardBackward::operator()( - t_Vector &out, t_Vector const &x_guess, t_Vector const &res_guess) const { + t_Vector &out, t_Vector const &x_guess, t_Vector const &res_guess) { SOPT_HIGH_LOG("Performing Forward Backward Splitting"); if (fista()) { SOPT_HIGH_LOG("Using FISTA algorithm"); diff --git a/cpp/sopt/imaging_forward_backward.h b/cpp/sopt/imaging_forward_backward.h index 30000172..2916d1b7 100644 --- a/cpp/sopt/imaging_forward_backward.h +++ b/cpp/sopt/imaging_forward_backward.h @@ -226,7 +226,7 @@ class ImagingForwardBackward { //! \brief Calls Forward Backward //! \param[out] out: Output vector x - Diagnostic operator()(t_Vector &out) const { + Diagnostic operator()(t_Vector &out) { return operator()(out, ForwardBackward::initial_guess(target(), Phi(), sq_op_norm())); } //! \brief Calls Forward Backward @@ -360,14 +360,16 @@ typename ImagingForwardBackward::Diagnostic ImagingForwardBackwardobjmin_ = std::real(scalvar.previous()); return result; }; - auto const fb = ForwardBackward(f_gradient, g_proximal, target()) + auto fb = ForwardBackward(f_gradient, g_proximal, target()) .itermax(itermax()) .step_size(gradient_step_size) .regulariser_strength(regulariser_strength()) .sq_op_norm(sq_op_norm()) .fista(fista()) .Phi(Phi()) - .is_converged(convergence); + .is_converged(convergence) + .random_updater(random_updater_) + .set_problem_state(problem_state); static_cast::Diagnostic &>(result) = fb(out, std::tie(guess, res)); return result; diff --git a/cpp/sopt/l2_forward_backward.h b/cpp/sopt/l2_forward_backward.h index 9bc666b0..1754b83f 100644 --- a/cpp/sopt/l2_forward_backward.h +++ b/cpp/sopt/l2_forward_backward.h @@ -276,7 +276,7 @@ typename L2ForwardBackward::Diagnostic L2ForwardBackward::operat this->objmin_ = std::real(scalvar.previous()); return result; }; - auto const fb = ForwardBackward(f_gradient, g_proximal, target()) + auto fb = ForwardBackward(f_gradient, g_proximal, target()) .itermax(itermax()) .step_size(step_size()) .regulariser_strength(regulariser_strength()) diff --git a/cpp/tests/forward_backward.cc b/cpp/tests/forward_backward.cc index 26a5cbb4..356108fc 100644 --- a/cpp/tests/forward_backward.cc +++ b/cpp/tests/forward_backward.cc @@ -47,7 +47,7 @@ TEST_CASE("Forward Backward with ||x - x0||_2^2 function", "[fb]") { CAPTURE(target0); CAPTURE(x_guess); CAPTURE(res); - auto const fb = algorithm::ForwardBackward(grad, g0, target0) + auto fb = algorithm::ForwardBackward(grad, g0, target0) .itermax(itermax) .regulariser_strength(regulariser_strength) .step_size(beta) From 43cd53213e8a4b51750df94534482f66d4e96785 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 19 Dec 2024 00:30:22 +0000 Subject: [PATCH 08/10] Add stochastic update test --- cpp/tests/CMakeLists.txt | 1 + cpp/tests/stochastic_update.cc | 103 +++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 cpp/tests/stochastic_update.cc diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index af8c93a9..87168661 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -31,6 +31,7 @@ add_catch_test(credible_region LIBRARIES sopt SEED ${RAND_SEED}) add_catch_test(forward_backward LIBRARIES sopt tools_for_tests SEED ${RAND_SEED}) add_catch_test(gradient_operator LIBRARIES sopt tools_for_tests SEED ${RAND_SEED}) add_catch_test(inpainting LIBRARIES sopt tools_for_tests SEED ${RAND_SEED}) +add_catch_test(stochastic_update LIBRARIES sopt tools_for_tests SEED ${RAND_SEED}) add_catch_test(pd_inpainting LIBRARIES sopt tools_for_tests SEED ${RAND_SEED}) add_catch_test(linear_transform LIBRARIES sopt SEED ${RAND_SEED}) add_catch_test(maths LIBRARIES sopt SEED ${RAND_SEED}) diff --git a/cpp/tests/stochastic_update.cc b/cpp/tests/stochastic_update.cc new file mode 100644 index 00000000..9ca209bc --- /dev/null +++ b/cpp/tests/stochastic_update.cc @@ -0,0 +1,103 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sopt/imaging_forward_backward.h" +#include "sopt/l1_non_diff_function.h" +#include "sopt/logging.h" +#include "sopt/maths.h" +#include "sopt/relative_variation.h" +#include "sopt/sampling.h" +#include "sopt/types.h" +#include "sopt/utilities.h" +#include "sopt/wavelets.h" +#include "sopt/gradient_utils.h" + +// This header is not part of the installed sopt interface +// It is only present in tests +#include "tools_for_tests/directories.h" +#include "tools_for_tests/tiffwrappers.h" + +// \min_{x} ||\Psi^Tx||_1 \quad \mbox{s.t.} \quad ||y - Ax||_2 < \epsilon and x \geq 0 + +using Scalar = double; +using Vector = sopt::Vector; +using Matrix = sopt::Matrix; +using Image = sopt::Image; + +TEST_CASE("Inpainting"){ + extern std::unique_ptr mersenne; + std::string const input = "cameraman256"; + + Image const image = sopt::tools::read_standard_tiff(input); + + auto const wavelet = sopt::wavelets::factory("DB8", 4); + + auto const psi = sopt::linear_transform(wavelet, image.rows(), image.cols()); + size_t nmeasure = static_cast(image.size() * 0.5); + + double constexpr snr = 30.0; + std::shared_ptr> Phi = + std::make_shared>( + sopt::linear_transform(sopt::Sampling(image.size(), nmeasure, *mersenne))); + Vector y = (*Phi) * Vector::Map(image.data(), image.size()); + + auto sigma = y.stableNorm() / std::sqrt(y.size()) * std::pow(10.0, -(snr / 20.0)); + sopt::t_real constexpr regulariser_strength = 18; + sopt::t_real const beta = sigma*sigma*0.5; + + // Define a stochostic target/operator updater! + std::unique_ptr *m = &mersenne; + std::function>()> random_updater = [&image, m, sigma, nmeasure](){ + double constexpr snr = 30.0; + std::shared_ptr> Phi = + std::make_shared>(sopt::linear_transform(sopt::Sampling(image.size(), nmeasure, **m))); + Vector y = (*Phi) * Vector::Map(image.data(), image.size()); + + std::normal_distribution<> gaussian_dist(0, sigma); + for (sopt::t_int i = 0; i < y.size(); i++) y(i) = y(i) + gaussian_dist(*mersenne); + + return std::make_shared>(y, Phi); + }; + + auto fb = sopt::algorithm::ImagingForwardBackward(random_updater); + fb.itermax(1000) + .step_size(beta) // stepsize + .sigma(sigma) // sigma + .regulariser_strength(regulariser_strength) // regularisation paramater + .relative_variation(1e-3) + .residual_tolerance(0) + .tight_frame(true); + + // Create a shared pointer to an instance of the L1GProximal class + // and set its properties + auto gp = std::make_shared>(false); + gp->l1_proximal_tolerance(1e-4) + .l1_proximal_nu(1) + .l1_proximal_itermax(50) + .l1_proximal_positivity_constraint(true) + .l1_proximal_real_constraint(true) + .Psi(psi); + + // Once the properties are set, inject it into the ImagingForwardBackward object + fb.g_function(gp); + + auto const diagnostic = fb(); + + CHECK(diagnostic.good); + CHECK(diagnostic.niters < 500); + + // compare input image to cleaned output image + // calculate mean squared error sum_i ( ( x_true(i) - x_est(i) ) **2 ) + // check this is less than the number of pixels * 0.01 + + Eigen::Map flat_image(image.data(), image.size()); + auto mse = (flat_image - diagnostic.x).array().square().sum() / image.size(); + CAPTURE(mse); + CHECK(mse < 0.01); +} From 8a6fdbcb99c1b4bfecc4a832bf25f6618e1d10f2 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 19 Dec 2024 00:30:46 +0000 Subject: [PATCH 09/10] Fix random update setter --- cpp/sopt/imaging_forward_backward.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/sopt/imaging_forward_backward.h b/cpp/sopt/imaging_forward_backward.h index 2916d1b7..d3f4cbfe 100644 --- a/cpp/sopt/imaging_forward_backward.h +++ b/cpp/sopt/imaging_forward_backward.h @@ -194,8 +194,8 @@ class ImagingForwardBackward { // Getter and setter for the random updater object t_randomUpdater &random_updater() { return random_updater_; } - ImagingForwardBackward& random_updater( t_randomUpdater &f_function) { - random_updater_ = random_updater_; // may change this to a move if we don't need to keep it + ImagingForwardBackward& random_updater( t_randomUpdater &new_updater) { + random_updater_ = new_updater; // may change this to a move if we don't need to keep it return *this; } From cd4abee7abc11436cb10f88c4de6fd709b68efd7 Mon Sep 17 00:00:00 2001 From: Michael McLeod Date: Thu, 19 Dec 2024 00:34:38 +0000 Subject: [PATCH 10/10] If f(x) available use in obj convergence o/w l2 norm --- cpp/sopt/imaging_forward_backward.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/sopt/imaging_forward_backward.h b/cpp/sopt/imaging_forward_backward.h index d3f4cbfe..cb16b364 100644 --- a/cpp/sopt/imaging_forward_backward.h +++ b/cpp/sopt/imaging_forward_backward.h @@ -392,7 +392,8 @@ bool ImagingForwardBackward::objective_convergence(ScalarRelativeVariati if (static_cast(objective_convergence())) return objective_convergence()(x, residual); if (scalvar.relative_tolerance() <= 0e0) return true; auto const current = ((regulariser_strength() > 0) ? g_function_->function(x) - * regulariser_strength() : 0) + std::pow(sopt::l2_norm(residual), 2) / (2 * sigma() * sigma()); + * regulariser_strength() : 0) + \ + ((f_function_) ? f_function_->function(x, target(), Phi()) : std::pow(sopt::l2_norm(residual), 2) / (2 * sigma() * sigma())); return scalvar(current); }