diff --git a/mrmustard/training/optimizer.py b/mrmustard/training/optimizer.py index c5a9bffb9..dec84e797 100644 --- a/mrmustard/training/optimizer.py +++ b/mrmustard/training/optimizer.py @@ -18,7 +18,7 @@ from itertools import chain, groupby from typing import List, Callable, Sequence, Union, Mapping, Dict -from mrmustard import math +from mrmustard import math, settings from mrmustard.math.parameters import Constant, Variable from mrmustard.training.callbacks import Callback from mrmustard.training.progress_bar import ProgressBar @@ -98,35 +98,58 @@ def minimize( def _minimize(self, cost_fn, by_optimizing, max_steps, callbacks): # finding out which parameters are trainable from the ops trainable_params = self._get_trainable_params(by_optimizing) + if settings.PROGRESSBAR: + bar = ProgressBar(max_steps) + with bar: + self._optimization_loop(cost_fn, trainable_params, max_steps, callbacks, bar) + else: + self._optimization_loop(cost_fn, trainable_params, max_steps, callbacks) + + def _optimization_loop( + self, cost_fn, trainable_params, max_steps, callbacks, progress_bar=None + ): + """Internal method that performs the main optimization loop. + + Args: + cost_fn (Callable): The cost function to minimize + trainable_params (dict): Dictionary of trainable parameters + max_steps (int): Maximum number of optimization steps + callbacks (dict): Dictionary of callback functions to execute during optimization + progress_bar (ProgressBar, optional): Progress bar instance for displaying optimization progress. + If None, no progress will be displayed. Defaults to None. + + Note: + This method maintains internal state in self.opt_history and self.callback_history, + tracking the optimization progress and callback results respectively. + """ cost_fn_modified = False orig_cost_fn = cost_fn - bar = ProgressBar(max_steps) - with bar: - while not self.should_stop(max_steps): - cost, grads = self.compute_loss_and_gradients(cost_fn, trainable_params.values()) + while not self.should_stop(max_steps): + cost, grads = self.compute_loss_and_gradients(cost_fn, trainable_params.values()) - trainables = {tag: (x, dx) for (tag, x), dx in zip(trainable_params.items(), grads)} + trainables = {tag: (x, dx) for (tag, x), dx in zip(trainable_params.items(), grads)} - if cost_fn_modified: - self.callback_history["orig_cost"].append(orig_cost_fn()) + if cost_fn_modified: + self.callback_history["orig_cost"].append(orig_cost_fn()) - new_cost_fn, new_grads = self._run_callbacks( - callbacks=callbacks, - cost_fn=cost_fn, - cost=cost, - trainables=trainables, - ) + new_cost_fn, new_grads = self._run_callbacks( + callbacks=callbacks, + cost_fn=cost_fn, + cost=cost, + trainables=trainables, + ) - self.apply_gradients(trainable_params.values(), new_grads or grads) - self.opt_history.append(cost) - bar.step(math.asnumpy(cost)) + self.apply_gradients(trainable_params.values(), new_grads or grads) + self.opt_history.append(cost) + if progress_bar is not None: + progress_bar.step(math.asnumpy(cost)) - if callable(new_cost_fn): - cost_fn = new_cost_fn - if not cost_fn_modified: - cost_fn_modified = True - self.callback_history["orig_cost"] = self.opt_history.copy() + if callable(new_cost_fn): + cost_fn = new_cost_fn + if not cost_fn_modified: + cost_fn_modified = True + self.callback_history["orig_cost"] = self.opt_history.copy() def apply_gradients(self, trainable_params, grads): """Apply gradients to variables.