From 97e666025e2b754dcbc1878e0676ce6fe456bd9c Mon Sep 17 00:00:00 2001 From: Robin Lobel Date: Sat, 6 Apr 2024 21:20:28 +0200 Subject: [PATCH] 0.9.18 --- torchstudio/analyzers/multiclass.py | 2 +- torchstudio/analyzers/multilabel.py | 2 +- torchstudio/analyzers/valuesdistribution.py | 2 +- torchstudio/metricsplot.py | 2 +- torchstudio/modelbuild.py | 4 +- torchstudio/modeltrain.py | 6 +- torchstudio/optim/adamw_schedulefree.py | 162 ++++++++++++++++++++ torchstudio/parametersplot.py | 2 +- torchstudio/pythonparse.py | 2 + torchstudio/renderers/bitmap.py | 2 +- torchstudio/renderers/boundingbox.py | 2 +- torchstudio/renderers/labels.py | 2 +- torchstudio/renderers/signal.py | 2 +- torchstudio/renderers/spectrogram.py | 2 +- torchstudio/renderers/volume.py | 2 +- torchstudio/schedulers/multistep.py | 15 +- torchstudio/schedulers/noschedule.py | 12 ++ torchstudio/schedulers/onecycle.py | 21 ++- torchstudio/schedulers/step.py | 12 +- torchstudio/tcpcodec.py | 2 +- 20 files changed, 232 insertions(+), 26 deletions(-) create mode 100644 torchstudio/optim/adamw_schedulefree.py create mode 100644 torchstudio/schedulers/noschedule.py diff --git a/torchstudio/analyzers/multiclass.py b/torchstudio/analyzers/multiclass.py index 6a3e9b4..6f070c9 100644 --- a/torchstudio/analyzers/multiclass.py +++ b/torchstudio/analyzers/multiclass.py @@ -100,7 +100,7 @@ def generate_report(self, size, dpi): canvas = plt.get_current_fig_manager().canvas canvas.draw() - img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) + img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba()) plt.close() return img diff --git a/torchstudio/analyzers/multilabel.py b/torchstudio/analyzers/multilabel.py index f39fe71..b0d47c4 100644 --- a/torchstudio/analyzers/multilabel.py +++ b/torchstudio/analyzers/multilabel.py @@ -87,7 +87,7 @@ def generate_report(self, size, dpi): canvas = plt.get_current_fig_manager().canvas canvas.draw() - img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) + img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba()) plt.close() return img diff --git a/torchstudio/analyzers/valuesdistribution.py b/torchstudio/analyzers/valuesdistribution.py index e9c7e5e..9bad976 100644 --- a/torchstudio/analyzers/valuesdistribution.py +++ b/torchstudio/analyzers/valuesdistribution.py @@ -105,7 +105,7 @@ def generate_report(self, size, dpi): canvas = plt.get_current_fig_manager().canvas canvas.draw() - img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) + img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba()) plt.close() return img diff --git a/torchstudio/metricsplot.py b/torchstudio/metricsplot.py index 7f4eb01..f403ecf 100644 --- a/torchstudio/metricsplot.py +++ b/torchstudio/metricsplot.py @@ -130,7 +130,7 @@ def inverse(x): canvas = plt.get_current_fig_manager().canvas canvas.draw() - img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) + img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba()) plt.close() return img diff --git a/torchstudio/modelbuild.py b/torchstudio/modelbuild.py index 2dc7950..20e6e7a 100644 --- a/torchstudio/modelbuild.py +++ b/torchstudio/modelbuild.py @@ -238,8 +238,8 @@ def level_trace(root): for tensor in output_tensors: metric.append("Accuracy") - tc.send_msg(app_socket, 'SetHyperParametersValues', tc.encode_ints([64,0,100,20])) - tc.send_msg(app_socket, 'SetHyperParametersNames', tc.encode_strings(loss+metric+['Adam','Step'])) + tc.send_msg(app_socket, 'SetHyperParametersValues', tc.encode_ints([64,0,100,30])) + tc.send_msg(app_socket, 'SetHyperParametersNames', tc.encode_strings(loss+metric+['AdamWScheduleFree','NoSchedule'])) if msg_type == 'Exit': break diff --git a/torchstudio/modeltrain.py b/torchstudio/modeltrain.py index 9697e4a..3df19a1 100644 --- a/torchstudio/modeltrain.py +++ b/torchstudio/modeltrain.py @@ -16,7 +16,7 @@ from tqdm.auto import tqdm from collections.abc import Iterable import threading - +import math class CachedDataset(Dataset): def __init__(self, train=True, hash=None): @@ -294,6 +294,8 @@ def deepcopy_cpu(value): #training model.train() + if hasattr(optimizer, 'train'): + optimizer.train() train_loss = 0 train_metrics = [] for metric in metrics: @@ -334,6 +336,8 @@ def deepcopy_cpu(value): #validation model.eval() + if hasattr(optimizer, 'eval'): + optimizer.eval() valid_loss = 0 valid_metrics = [] for metric in metrics: diff --git a/torchstudio/optim/adamw_schedulefree.py b/torchstudio/optim/adamw_schedulefree.py new file mode 100644 index 0000000..4fd5361 --- /dev/null +++ b/torchstudio/optim/adamw_schedulefree.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import torch +import torch.optim +import math + +class AdamWScheduleFree(torch.optim.Optimizer): + r""" + Schedule-Free AdamW + As the name suggests, no scheduler is needed with this optimizer. + To add warmup, rather than using a learning rate schedule you can just + set the warmup_steps parameter. + + This optimizer requires that .train() and .val() be called before the + beginning of training and evaluation respectively. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining + parameter groups. + lr (float): + Learning rate parameter (default 0.0025) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)). + eps (float): + Term added to the denominator outside of the root operation to + improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + warmup_steps (int): Enables a linear learning rate warmup (default 0). + r (float): Use polynomial weighting in the average + with power r (default 0). + weight_lr_power (float): During warmup, the weights in the average will + be equal to lr raised to this power. Set to 0 for no weighting + (default 2.0). + """ + def __init__(self, + params, + lr=0.0025, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + warmup_steps=0, + r=0.0, + weight_lr_power=2.0, + ): + + defaults = dict(lr=lr, + betas=betas, + eps=eps, + r=r, + k=0, + warmup_steps=warmup_steps, + train_mode = True, + weight_sum=0.0, + lr_max=-1.0, + weight_lr_power=weight_lr_power, + weight_decay=weight_decay) + super().__init__(params, defaults) + + def eval(self): + for group in self.param_groups: + train_mode = group['train_mode'] + beta1, _ = group['betas'] + if train_mode: + for p in group['params']: + state = self.state[p] + if 'z' in state: + # Set p.data to x + p.data.lerp_(end=state['z'], weight=1-1/beta1) + group['train_mode'] = False + + def train(self): + for group in self.param_groups: + train_mode = group['train_mode'] + beta1, _ = group['betas'] + if not train_mode: + for p in group['params']: + state = self.state[p] + if 'z' in state: + # Set p.data to y + p.data.lerp_(end=state['z'], weight=1-beta1) + group['train_mode'] = True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + eps = group['eps'] + beta1, beta2 = group['betas'] + decay = group['weight_decay'] + k = group['k'] + r = group['r'] + warmup_steps = group['warmup_steps'] + weight_lr_power = group['weight_lr_power'] + + if k < warmup_steps: + sched = (k+1) / warmup_steps + else: + sched = 1.0 + + bias_correction2 = 1 - beta2 ** (k+1) + lr = group['lr']*sched*math.sqrt(bias_correction2) + + lr_max = group['lr_max'] = max(lr, group['lr_max']) + + weight = ((k+1)**r) * (lr_max**weight_lr_power) + weight_sum = group['weight_sum'] = group['weight_sum'] + weight + + ckp1 = weight/weight_sum + + if not group['train_mode']: + raise Exception("Not in train mode!") + + for p in group['params']: + if p.grad is None: + continue + + y = p.data # Notation to match theory + grad = p.grad.data + + state = self.state[p] + + if 'z' not in state: + state['z'] = torch.clone(y) + state['exp_avg_sq'] = torch.zeros_like(p.data) + + z = state['z'] + exp_avg_sq = state['exp_avg_sq'] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) + denom = exp_avg_sq.sqrt().add_(eps) + + # Reuse grad buffer for memory efficiency + grad_normalized = grad.div_(denom) + + # Weight decay calculated at y + if decay != 0: + grad_normalized.add_(y, alpha=decay) + + # These operations update y in-place, + # without computing x explicitly. + y.lerp_(end=z, weight=ckp1) + y.add_(grad_normalized, alpha=lr*(beta1*(1-ckp1)-1)) + + # z step + z.sub_(grad_normalized, alpha=lr) + + group['k'] = k+1 + return loss diff --git a/torchstudio/parametersplot.py b/torchstudio/parametersplot.py index be27d22..80b154c 100644 --- a/torchstudio/parametersplot.py +++ b/torchstudio/parametersplot.py @@ -134,7 +134,7 @@ def plot_parameters(size, dpi, canvas = plt.get_current_fig_manager().canvas canvas.draw() - img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) + img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba()) plt.close() return img diff --git a/torchstudio/pythonparse.py b/torchstudio/pythonparse.py index 00b175d..1ff31ce 100644 --- a/torchstudio/pythonparse.py +++ b/torchstudio/pythonparse.py @@ -157,6 +157,7 @@ def filter_parent_objects(objects:List[Dict]) -> List[Dict]: generated_class="""\ import typing +from typing import Any, Callable, List, Tuple, Union, Sequence, Optional import pathlib import torch import torch.nn as nn @@ -172,6 +173,7 @@ def __init__({4}): generated_function="""\ import typing +from typing import Any, Callable, List, Tuple, Union, Sequence, Optional import pathlib import torch import torch.nn as nn diff --git a/torchstudio/renderers/bitmap.py b/torchstudio/renderers/bitmap.py index abd5fe3..f820350 100644 --- a/torchstudio/renderers/bitmap.py +++ b/torchstudio/renderers/bitmap.py @@ -118,6 +118,6 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp canvas = plt.get_current_fig_manager().canvas canvas.draw() - img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) + img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba()) plt.close() return img diff --git a/torchstudio/renderers/boundingbox.py b/torchstudio/renderers/boundingbox.py index 38d4126..bcd9ce5 100644 --- a/torchstudio/renderers/boundingbox.py +++ b/torchstudio/renderers/boundingbox.py @@ -122,7 +122,7 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp canvas = plt.get_current_fig_manager().canvas canvas.draw() - img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) + img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba()) plt.close() return img diff --git a/torchstudio/renderers/labels.py b/torchstudio/renderers/labels.py index 1e2d339..9f48671 100644 --- a/torchstudio/renderers/labels.py +++ b/torchstudio/renderers/labels.py @@ -129,7 +129,7 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp canvas = plt.get_current_fig_manager().canvas canvas.draw() - img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) + img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba()) plt.close() return img diff --git a/torchstudio/renderers/signal.py b/torchstudio/renderers/signal.py index 843d92e..18ce33c 100644 --- a/torchstudio/renderers/signal.py +++ b/torchstudio/renderers/signal.py @@ -92,7 +92,7 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp canvas = plt.get_current_fig_manager().canvas canvas.draw() - img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) + img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba()) plt.close() return img diff --git a/torchstudio/renderers/spectrogram.py b/torchstudio/renderers/spectrogram.py index fd842f5..cd5e6d5 100644 --- a/torchstudio/renderers/spectrogram.py +++ b/torchstudio/renderers/spectrogram.py @@ -124,7 +124,7 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp canvas = plt.get_current_fig_manager().canvas canvas.draw() - img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) + img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba()) plt.close() return img diff --git a/torchstudio/renderers/volume.py b/torchstudio/renderers/volume.py index eba7bc7..4f1ae5f 100644 --- a/torchstudio/renderers/volume.py +++ b/torchstudio/renderers/volume.py @@ -122,7 +122,7 @@ def render(self, title, tensor, size, dpi, shift=(0,0,0,0), scale=(1,1,1,1), inp canvas = plt.get_current_fig_manager().canvas canvas.draw() - img = PIL.Image.frombytes('RGB',canvas.get_width_height(),canvas.tostring_rgb()) + img = PIL.Image.frombytes('RGBA',canvas.get_width_height(),canvas.buffer_rgba()) plt.close() return img diff --git a/torchstudio/schedulers/multistep.py b/torchstudio/schedulers/multistep.py index 2b939fd..2bfee20 100644 --- a/torchstudio/schedulers/multistep.py +++ b/torchstudio/schedulers/multistep.py @@ -12,8 +12,15 @@ class MultiStep(lr_scheduler.MultiStepLR): gamma (float): Multiplicative factor of learning rate decay. Default: 0.1. last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. Example: + >>> # xdoctest: +SKIP >>> # Assuming optimizer uses lr = 0.05 for all groups >>> # lr = 0.05 if epoch < 30 >>> # lr = 0.005 if 30 <= epoch < 80 @@ -22,6 +29,8 @@ class MultiStep(lr_scheduler.MultiStepLR): >>> for epoch in range(100): >>> train(...) >>> validate(...) - >>> scheduler.step()""" - def __init__(self, optimizer, milestones=[75, 100, 125], gamma=0.1, last_epoch=-1): - super().__init__(optimizer, milestones, gamma, last_epoch, verbose=False) + >>> scheduler.step() + """ + + def __init__(self, optimizer, milestones=[75, 100, 125], gamma=0.1, last_epoch=-1, verbose="deprecated"): + super().__init__(optimizer, milestones, gamma, last_epoch, verbose) diff --git a/torchstudio/schedulers/noschedule.py b/torchstudio/schedulers/noschedule.py new file mode 100644 index 0000000..f258774 --- /dev/null +++ b/torchstudio/schedulers/noschedule.py @@ -0,0 +1,12 @@ +class NoSchedule(): + """No Schedule + """ + def __init__(self, + optimizer, + last_epoch=-1): + self.last_epoch=0 if last_epoch<0 else last_epoch + + def step(self): + self.last_epoch+=1 + + diff --git a/torchstudio/schedulers/onecycle.py b/torchstudio/schedulers/onecycle.py index 38d8353..0c48698 100644 --- a/torchstudio/schedulers/onecycle.py +++ b/torchstudio/schedulers/onecycle.py @@ -1,7 +1,7 @@ import torch.optim.lr_scheduler as lr_scheduler class OneCycle(lr_scheduler.OneCycleLR): - """Sets the learning rate of each parameter group according to the + r"""Sets the learning rate of each parameter group according to the 1cycle learning rate policy. The 1cycle policy anneals the learning rate from an initial learning rate to some maximum learning rate and then from that maximum learning rate to some minimum learning rate much lower @@ -84,24 +84,32 @@ class OneCycle(lr_scheduler.OneCycleLR): number of *batches* computed, not the total number of epochs computed. When last_epoch=-1, the schedule is started from the beginning. Default: -1 + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. Example: + >>> # xdoctest: +SKIP >>> data_loader = torch.utils.data.DataLoader(...) >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) >>> for epoch in range(10): >>> for batch in data_loader: >>> train_batch(...) + >>> optimizer.step() >>> scheduler.step() - .. _Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates: + .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: https://arxiv.org/abs/1708.07120 """ def __init__(self, optimizer, - max_lr=1, - total_steps=200, + max_lr, + total_steps=None, epochs=None, steps_per_epoch=None, pct_start=0.3, @@ -112,5 +120,6 @@ def __init__(self, div_factor=25., final_div_factor=1e4, three_phase=False, - last_epoch=-1): - super().__init__(optimizer, max_lr, total_steps, epochs, steps_per_epoch, pct_start, anneal_strategy, cycle_momentum, max_momentum, div_factor, final_div_factor, three_phase, last_epoch, verbose=False) + last_epoch=-1, + verbose="deprecated"): + super().__init__(optimizer, max_lr, total_steps, epochs, steps_per_epoch, pct_start, anneal_strategy, cycle_momentum, max_momentum, div_factor, final_div_factor, three_phase, last_epoch, verbose) diff --git a/torchstudio/schedulers/step.py b/torchstudio/schedulers/step.py index 6cdaa43..99508cb 100644 --- a/torchstudio/schedulers/step.py +++ b/torchstudio/schedulers/step.py @@ -12,8 +12,15 @@ class Step(lr_scheduler.StepLR): gamma (float): Multiplicative factor of learning rate decay. Default: 0.1. last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. Example: + >>> # xdoctest: +SKIP >>> # Assuming optimizer uses lr = 0.05 for all groups >>> # lr = 0.05 if epoch < 30 >>> # lr = 0.005 if 30 <= epoch < 60 @@ -25,5 +32,6 @@ class Step(lr_scheduler.StepLR): >>> validate(...) >>> scheduler.step() """ - def __init__(self, optimizer, step_size=75, gamma=0.1, last_epoch=-1): - super().__init__(optimizer, step_size, gamma, last_epoch, verbose=False) + + def __init__(self, optimizer, step_size=75, gamma=0.1, last_epoch=-1, verbose="deprecated"): + super().__init__(optimizer, step_size, gamma, last_epoch, verbose) diff --git a/torchstudio/tcpcodec.py b/torchstudio/tcpcodec.py index 2625d4a..3ce9619 100644 --- a/torchstudio/tcpcodec.py +++ b/torchstudio/tcpcodec.py @@ -138,7 +138,7 @@ def encode_numpy_tensors(tensors): buffer = bytes() for tensor in tensors: size=len(tensor.shape) - buffer+=bytes(tensor.dtype.byteorder.replace('=','<' if sys.byteorder == 'little' else '>')+tensor.dtype.kind,'utf-8')+tensor.dtype.itemsize.to_bytes(1,byteorder='little')+struct.pack(f'')+tensor.dtype.kind,'utf-8')+tensor.dtype.itemsize.to_bytes(1,byteorder='little')+struct.pack(f'