Skip to content

Commit

Permalink
Use global step instead of fractional epochs to cosine the LR
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Nov 2, 2023
1 parent 7d8e906 commit 0e6ccfb
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 58 deletions.
96 changes: 44 additions & 52 deletions helpers/training/custom_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,66 +156,56 @@ class CosineAnnealingWarmRestarts(LRScheduler):
"""

def __init__(
self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False
self,
optimizer,
T_0,
steps_per_epoch,
T_mult=1,
eta_min=0,
last_step=-1,
verbose=False,
):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
raise ValueError(f"Expected positive integer T_0, but got {T_0}")
if T_mult < 1 or not isinstance(T_mult, int):
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")
self.T_0 = T_0
self.steps_per_epoch = steps_per_epoch
self.T_i = T_0
self.T_mult = T_mult
self.eta_min = eta_min
self.T_cur = last_epoch
super().__init__(optimizer, last_epoch, verbose)
self.T_cur = last_step
super().__init__(optimizer, last_step, verbose)

def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.",
UserWarning,
)

lrs = [
self.eta_min
+ (base_lr - self.eta_min)
* (1 + math.cos(math.pi * self.T_cur / self.T_i))
/ 2
for base_lr in self.base_lrs
]

# Debugging print statements
if self.verbose:
print(f"T_cur: {self.T_cur}, T_i: {self.T_i}")
print(f"Learning rates: {lrs}")

return lrs

def step(self, epoch=None):
if epoch is None and self.last_epoch < 0:
epoch = 0
def step(self, step=None):
if step is None and self.last_epoch < 0:
step = 0

if epoch is None:
epoch = self.last_epoch + 1
self.T_cur = self.T_cur + 1
if self.T_cur >= self.T_i:
self.T_cur = self.T_cur - self.T_i
self.T_i = self.T_i * self.T_mult
if step is None:
step = self.last_epoch + 1
self.T_cur = (step // self.steps_per_epoch) + (
step % self.steps_per_epoch
) / self.steps_per_epoch
else:
if epoch < 0:
raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
if epoch >= self.T_0:
if self.T_mult == 1:
self.T_cur = epoch % self.T_0
else:
n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
self.T_i = self.T_0 * self.T_mult ** (n)
else:
self.T_i = self.T_0
self.T_cur = epoch
self.last_epoch = epoch
self.T_cur = (step // self.steps_per_epoch) + (
step % self.steps_per_epoch
) / self.steps_per_epoch

if self.T_cur >= self.T_i:
self.T_cur = self.T_cur - self.T_i
self.T_i = self.T_i * self.T_mult

self.last_epoch = step

class _enable_get_lr_call:
def __init__(self, o):
Expand All @@ -232,20 +222,22 @@ def __exit__(self, type, value, traceback):
with _enable_get_lr_call(self):
for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
param_group, lr = data
param_group['lr'] = math.floor(lr * 1e9) / 1e9
self.print_lr(self.verbose, i, lr, epoch)
param_group["lr"] = math.floor(lr * 1e9) / 1e9
self.print_lr(self.verbose, i, lr, step)

self._last_lr = [group["lr"] for group in self.optimizer.param_groups]

self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

def print_lr(self, is_verbose, group, lr, epoch=None):
"""Display the current learning rate.
"""
"""Display the current learning rate."""
if is_verbose:
if epoch is None:
print('Adjusting learning rate'
' of group {} to {:.8e}.'.format(group, lr))
print(
"Adjusting learning rate"
" of group {} to {:.8e}.".format(group, lr)
)
else:
epoch_str = ("%.2f" if isinstance(epoch, float) else
"%.5d") % epoch
print('Epoch {}: adjusting learning rate'
' of group {} to {:.8e}.'.format(epoch_str, group, lr))
epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch
print(
"Epoch {}: adjusting learning rate"
" of group {} to {:.8e}.".format(epoch_str, group, lr)
)
26 changes: 20 additions & 6 deletions train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,7 @@ def print_bucket_info(bucket_manager):
current_epoch_step = 0
for step, batch in enumerate(train_dataloader):
if args.lr_scheduler == "cosine_annealing_warm_restarts":
scheduler_kwargs["epoch"] = epoch + step / len(bucket_manager)
scheduler_kwargs["step"] = global_step
if accelerator.is_main_process:
progress_bar.set_description(
f"Epoch {current_epoch}/{args.num_train_epochs}, Steps"
Expand Down Expand Up @@ -1050,15 +1050,29 @@ def print_bucket_info(bucket_manager):
# Sample noise that we'll add to the latents - args.noise_offset might need to be set to 0.1 by default.
noise = torch.randn_like(latents)
if args.offset_noise:
if args.noise_offset_probability == 1.0 or random.random() < args.noise_offset_probability:
noise = torch.randn_like(latents) + args.noise_offset * torch.randn(
latents.shape[0], latents.shape[1], 1, 1, device=latents.device
if (
args.noise_offset_probability == 1.0
or random.random() < args.noise_offset_probability
):
noise = torch.randn_like(
latents
) + args.noise_offset * torch.randn(
latents.shape[0],
latents.shape[1],
1,
1,
device=latents.device,
)
else:
noise = torch.randn_like(latents)
if args.input_perturbation:
if args.input_perturbation_probability == 1.0 or random.random() < args.input_perturbation_probability:
noise = noise + args.input_perturbation * torch.randn_like(noise)
if (
args.input_perturbation_probability == 1.0
or random.random() < args.input_perturbation_probability
):
noise = noise + args.input_perturbation * torch.randn_like(
noise
)

bsz = latents.shape[0]
training_logger.debug(f"Working on batch size: {bsz}")
Expand Down

0 comments on commit 0e6ccfb

Please sign in to comment.