Skip to content

Commit

Permalink
Merge pull request #1294 from bghira/feature/prodigy-schedulefree-bf16
Browse files Browse the repository at this point in the history
add prodigy optimiser with full bf16 support
  • Loading branch information
bghira authored Jan 22, 2025
2 parents 5aa4b6f + 1c03842 commit b73d4eb
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 20 deletions.
11 changes: 11 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,17 @@ def get_argument_parser():
default=1e-08,
help="Epsilon value for the Adam optimizer",
)
parser.add_argument(
"--prodigy_steps",
type=int,
default=None,
help=(
"When training with Prodigy, this defines how many steps it should be adjusting its learning rate for."
" It seems to be that Diffusion models benefit from a capping off of the adjustments after 25 percent"
" of the training run (dependent on batch size, repeats, and epochs)."
" It this value is not supplied, it will be calculated at 25 percent of your training steps."
),
)
parser.add_argument(
"--max_grad_norm",
default=2.0,
Expand Down
59 changes: 58 additions & 1 deletion helpers/training/optimizer_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@
if "AdEMAMix" in dir(bitsandbytes.optim):
is_ademamix_available = True

is_prodigy_available = False
try:
import prodigyplus

is_prodigy_available = True
except:
if torch.cuda.is_available():
logger.warning(
"Could not load prodigyplus library. Prodigy will not be available."
)


optimizer_choices = {
"adamw_bf16": {
"precision": "bf16",
Expand Down Expand Up @@ -456,6 +468,42 @@
}
)

if is_prodigy_available:
optimizer_choices.update(
{
"prodigy": {
"precision": "any",
"override_lr_scheduler": True,
"can_warmup": False,
"default_settings": {
"lr": 1.0,
"betas": (0.9, 0.99),
"beta3": None,
"weight_decay": 0.0,
"weight_decay_by_lr": True,
"use_bias_correction": False,
"d0": 1e-6,
"d_coef": 1,
"prodigy_steps": 0,
"use_speed": False,
"eps": 1e-8,
"split_groups": True,
"split_groups_mean": True,
"factored": True,
"factored_fp32": True,
"fused_back_pass": False,
"use_stableadamw": True,
"use_muon_pp": False,
"use_cautious": False,
"use_grams": False,
"use_adopt": False,
"stochastic_rounding": True,
},
"class": prodigyplus.prodigy_plus_schedulefree.ProdigyPlusScheduleFree,
}
}
)

args_to_optimizer_mapping = {
"use_adafactor_optimizer": "adafactor",
"use_prodigy_optimizer": "prodigy",
Expand All @@ -465,7 +513,6 @@
}

deprecated_optimizers = {
"prodigy": "Prodigy optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.",
"dadaptation": "D-adaptation optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.",
"adafactor": "Adafactor optimiser has been removed in favour of optimi-stableadamw, which offers improved memory efficiency and convergence.",
"adamw8bit": "AdamW8Bit has been removed in favour of optimi-adamw optimiser, which offers better low-precision support. Please use this or adamw_bf16 instead.",
Expand Down Expand Up @@ -512,6 +559,16 @@ def optimizer_parameters(optimizer, args):
if args.optimizer_release_gradients and "optimi-" in optimizer:
optimizer_params["gradient_release"] = True
optimizer_details["default_settings"] = optimizer_params
if args.optimizer == "prodigy":
prodigy_steps = args.prodigy_steps
if prodigy_steps and prodigy_steps > 0:
optimizer_params["prodigy_steps"] = prodigy_steps
else:
# 25% of the total number of steps
optimizer_params["prodigy_steps"] = int(args.max_train_steps * 0.25)
print(
f"Using Prodigy optimiser with {optimizer_params['prodigy_steps']} steps of learning rate adjustment."
)
return optimizer_class, optimizer_details
else:
raise ValueError(f"Optimizer {optimizer} not found.")
Expand Down
44 changes: 30 additions & 14 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,9 +1232,7 @@ def init_optimizer(self):
def init_lr_scheduler(self):
self.config.is_schedulefree = is_lr_scheduler_disabled(self.config.optimizer)
if self.config.is_schedulefree:
logger.info(
"Using experimental AdamW ScheduleFree optimiser from Facebook. Experimental due to newly added Kahan summation."
)
logger.info("Using experimental ScheduleFree optimiser..")
# we don't use LR schedulers with schedulefree optimisers
lr_scheduler = None
if not self.config.use_deepspeed_scheduler and not self.config.is_schedulefree:
Expand Down Expand Up @@ -2778,12 +2776,14 @@ def train(self):
if param.grad is not None:
param.grad.data = param.grad.data.to(torch.float32)

self.grad_norm = self._max_grad_value()
if (
self.accelerator.sync_gradients
and self.config.optimizer != "optimi-stableadamw"
and self.config.optimizer
not in ["optimi-stableadamw", "prodigy"]
and self.config.max_grad_norm > 0
):
# StableAdamW does not need clipping, similar to Adafactor.
# StableAdamW/Prodigy do not need clipping, similar to Adafactor.
if self.config.grad_clip_method == "norm":
self.grad_norm = self.accelerator.clip_grad_norm_(
self._get_trainable_parameters(),
Expand All @@ -2793,7 +2793,6 @@ def train(self):
# deepspeed can only do norm clipping (internally)
pass
elif self.config.grad_clip_method == "value":
self.grad_norm = self._max_grad_value()
self.accelerator.clip_grad_value_(
self._get_trainable_parameters(),
self.config.max_grad_norm,
Expand Down Expand Up @@ -2824,7 +2823,22 @@ def train(self):
wandb_logs = {}
if self.accelerator.sync_gradients:
try:
if self.config.is_schedulefree:
if "prodigy" in self.config.optimizer:
self.lr = self.optimizer.param_groups[0]["d"]
wandb_logs.update(
{
"prodigy/d": self.optimizer.param_groups[0]["d"],
"prodigy/d_prev": self.optimizer.param_groups[0][
"d_prev"
],
"prodigy/d0": self.optimizer.param_groups[0]["d0"],
"prodigy/d_coef": self.optimizer.param_groups[0][
"d_coef"
],
"prodigy/k": self.optimizer.param_groups[0]["k"],
}
)
elif self.config.is_schedulefree:
# hackjob method of retrieving LR from accelerated optims
self.lr = StateTracker.get_last_lr()
else:
Expand All @@ -2834,12 +2848,14 @@ def train(self):
logger.error(
f"Failed to get the last learning rate from the scheduler. Error: {e}"
)
wandb_logs = {
"train_loss": self.train_loss,
"optimization_loss": loss,
"learning_rate": self.lr,
"epoch": epoch,
}
wandb_logs.update(
{
"train_loss": self.train_loss,
"optimization_loss": loss,
"learning_rate": self.lr,
"epoch": epoch,
}
)
if parent_loss is not None:
wandb_logs["regularisation_loss"] = parent_loss
if self.config.model_family == "flux" and self.guidance_values_list:
Expand All @@ -2850,7 +2866,7 @@ def train(self):
if self.grad_norm is not None:
if self.config.grad_clip_method == "norm":
wandb_logs["grad_norm"] = self.grad_norm
elif self.config.grad_clip_method == "value":
else:
wandb_logs["grad_absmax"] = self.grad_norm
if self.validation is not None and hasattr(
self.validation, "evaluation_result"
Expand Down
18 changes: 16 additions & 2 deletions install/apple/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions install/apple/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ torchao = "^0.7.0"
torchaudio = "^2.5.0"
atomicwrites = "^1.4.1"
beautifulsoup4 = "^4.12.3"
prodigy-plus-schedule-free = "^1.8.51"


[build-system]
Expand Down
18 changes: 16 additions & 2 deletions install/rocm/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions install/rocm/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ bitsandbytes = "^0.44.1"
atomicwrites = "^1.4.1"
torchao = "^0.7.0"
beautifulsoup4 = "^4.12.3"
prodigy-plus-schedule-free = "^1.8.51"


[build-system]
Expand Down
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ nvidia-cudnn-cu12 = "*"
nvidia-nccl-cu12 = "*"
atomicwrites = "^1.4.1"
beautifulsoup4 = "^4.12.3"
prodigy-plus-schedule-free = "^1.8.51"



Expand Down

0 comments on commit b73d4eb

Please sign in to comment.