Skip to content

Commit

Permalink
Merge pull request #807 from Ainesh06102004/add_optimizers_monai
Browse files Browse the repository at this point in the history
Add schedulers monai
  • Loading branch information
sarthakpati authored Mar 13, 2024
2 parents 64962a1 + 3ca6792 commit a934324
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 54 deletions.
3 changes: 3 additions & 0 deletions GANDLF/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
cosineannealing,
)

from .wrap_monai import warmupcosineschedule

# defining dict for schedulers - key is the string and the value is the transform object
global_schedulers_dict = {
Expand All @@ -24,6 +25,8 @@
"plateau": reduce_on_plateau,
"reduceonplateau": reduce_on_plateau,
"cosineannealing": cosineannealing,
"warmupcosineschedule": warmupcosineschedule,
"wcs": warmupcosineschedule,
}


Expand Down
10 changes: 10 additions & 0 deletions GANDLF/schedulers/wrap_monai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from monai.optimizers import WarmupCosineSchedule as WCS

def warmupcosineschedule(parameters):
parameters["scheduler"]["warmup_steps"] = parameters["scheduler"].get("warmup_steps",0.1*parameters["num_epochs"])

return WCS(
parameters["optimizer_object"],
t_total = parameters["num_epochs"],
warmup_steps = parameters["scheduler"]["warmup_steps"]
)
83 changes: 29 additions & 54 deletions GANDLF/schedulers/wrap_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@ def base_triangle(parameters):
"""

# pick defaults
if not ("min_lr" in parameters["scheduler"]):
parameters["scheduler"]["min_lr"] = 10**-3
if not ("max_lr" in parameters["scheduler"]):
parameters["scheduler"]["max_lr"] = 1
parameters["scheduler"]["min_lr"] = parameters["scheduler"].get("min_lr",10**-3)
parameters["scheduler"]["max_lr"] = parameters["scheduler"].get("max_lr",1)

clr = cyclical_lr(
parameters["scheduler"]["step_size"],
Expand All @@ -65,12 +63,9 @@ def base_triangle(parameters):

def triangle_modified(parameters):
# pick defaults
if not ("min_lr" in parameters["scheduler"]):
parameters["scheduler"]["min_lr"] = 0.000001
if not ("max_lr" in parameters["scheduler"]):
parameters["scheduler"]["max_lr"] = 0.001
if not ("max_lr_multiplier" in parameters["scheduler"]):
parameters["scheduler"]["max_lr_multiplier"] = 1.0
parameters["scheduler"]["min_lr"] = parameters["scheduler"].get("min_lr",0.000001)
parameters["scheduler"]["max_lr"] = parameters["scheduler"].get("max_lr",0.001)
parameters["scheduler"]["max_lr_multiplier"] = parameters["scheduler"].get("max_lr_multiplier",1.0)

clr = cyclical_lr_modified(
parameters["scheduler"]["step_size"],
Expand All @@ -83,29 +78,22 @@ def triangle_modified(parameters):

def cyclic_lr_base(parameters, mode="triangular"):
# pick defaults for "min_lr", "max_lr", "max_lr_multiplier" if not present in parameters
if not ("min_lr" in parameters["scheduler"]):
parameters["scheduler"]["min_lr"] = parameters["learning_rate"] * 0.001
if not ("max_lr" in parameters["scheduler"]):
parameters["scheduler"]["max_lr"] = parameters["learning_rate"]
if not ("gamma" in parameters["scheduler"]):
parameters["scheduler"]["gamma"] = 0.1
if not ("scale_mode" in parameters["scheduler"]):
parameters["scheduler"]["scale_mode"] = "cycle"
if not ("cycle_momentum" in parameters["scheduler"]):
parameters["scheduler"]["cycle_momentum"] = False
if not ("base_momentum" in parameters["scheduler"]):
parameters["scheduler"]["base_momentum"] = 0.8
if not ("max_momentum" in parameters["scheduler"]):
parameters["scheduler"]["max_momentum"] = 0.9
parameters["scheduler"]["min_lr"] = parameters["scheduler"].get("min_lr",parameters["learning_rate"] * 0.001)
parameters["scheduler"]["max_lr"] = parameters["scheduler"].get("max_lr",parameters["learning_rate"])
parameters["scheduler"]["gamma"] = parameters["scheduler"].get("gamma",0.1)
parameters["scheduler"]["scale_mode"] = parameters["scheduler"].get("scale_mode","cycle")
parameters["scheduler"]["cycle_momentum"] = parameters["scheduler"].get("cycle_momentum",False)
parameters["scheduler"]["base_momentum"] = parameters["scheduler"].get("base_momentum",0.8)
parameters["scheduler"]["max_momentum"] = parameters["scheduler"].get("max_momentum",0.9)

return CyclicLR(
parameters["optimizer_object"],
parameters["learning_rate"] * 0.001,
parameters["learning_rate"],
parameters["learning_rate"] * 0.001, #min lr
parameters["learning_rate"], #mar_lr
step_size_up=parameters["scheduler"]["step_size"],
step_size_down=None,
mode=mode,
gamma=1.0,
gamma=parameters["scheduler"]["gamma"],
scale_fn=None,
scale_mode=parameters["scheduler"]["scale_mode"],
cycle_momentum=parameters["scheduler"]["cycle_momentum"],
Expand All @@ -124,16 +112,14 @@ def cyclic_lr_exp_range(parameters):


def exp(parameters):
if not ("gamma" in parameters["scheduler"]):
parameters["scheduler"]["gamma"] = 0.1
parameters["scheduler"]["gamma"] = parameters["scheduler"].get("gamma",0.1)
return ExponentialLR(
parameters["optimizer_object"], parameters["scheduler"]["gamma"]
)


def step(parameters):
if not ("gamma" in parameters["scheduler"]):
parameters["scheduler"]["gamma"] = 0.1
parameters["scheduler"]["gamma"] = parameters["scheduler"].get("gamma",0.1)
return StepLR(
parameters["optimizer_object"],
parameters["scheduler"]["step_size"],
Expand All @@ -142,26 +128,18 @@ def step(parameters):


def reduce_on_plateau(parameters):
if not ("min_lr" in parameters["scheduler"]):
parameters["scheduler"]["min_lr"] = parameters["learning_rate"] * 0.001
if not ("gamma" in parameters["scheduler"]):
parameters["scheduler"]["gamma"] = 0.1
if not ("mode" in parameters["scheduler"]):
parameters["scheduler"]["mde"] = "min"
if not ("threshold_mode" in parameters["scheduler"]):
parameters["scheduler"]["threshold_mode"] = "rel"
if not ("factor" in parameters["scheduler"]):
parameters["scheduler"]["factor"] = 0.1
if not ("patience" in parameters["scheduler"]):
parameters["scheduler"]["patience"] = 10
if not ("threshold" in parameters["scheduler"]):
parameters["scheduler"]["threshold"] = 0.0001
if not ("cooldown" in parameters["scheduler"]):
parameters["scheduler"]["cooldown"] = 0
parameters["scheduler"]["min_lr"] = parameters["scheduler"].get("min_lr",parameters["learning_rate"] * 0.001)
parameters["scheduler"]["gamma"] = parameters["scheduler"].get("gamma",0.1)
parameters["scheduler"]["mode"] = parameters["scheduler"].get("mde","min")
parameters["scheduler"]["threshold_mode"] = parameters["scheduler"].get("threshold_mode","rel")
parameters["scheduler"]["factor"] = parameters["scheduler"].get("factor",0.1)
parameters["scheduler"]["patience"] = parameters["scheduler"].get("patience",10)
parameters["scheduler"]["threshold"] = parameters["scheduler"].get("threshold",0.0001)
parameters["scheduler"]["cooldown"] = parameters["scheduler"].get("cooldown",0)

return ReduceLROnPlateau(
parameters["optimizer_object"],
mode=parameters["scheduler"]["mde"],
mode=parameters["scheduler"]["mode"],
factor=parameters["scheduler"]["factor"],
patience=parameters["scheduler"]["patience"],
threshold=parameters["scheduler"]["threshold"],
Expand All @@ -172,12 +150,9 @@ def reduce_on_plateau(parameters):


def cosineannealing(parameters):
if not ("T_0" in parameters["scheduler"]):
parameters["scheduler"]["T_0"] = 5
if not ("T_mult" in parameters["scheduler"]):
parameters["scheduler"]["T_mult"] = 1
if not ("min_lr" in parameters["scheduler"]):
parameters["scheduler"]["min_lr"] = parameters["learning_rate"] * 0.001
parameters["scheduler"]["T_0"] = parameters["scheduler"].get("T_0",5)
parameters["scheduler"]["T_mult"] = parameters["scheduler"].get("T_mult",1)
parameters["scheduler"]["min_lr"] = parameters["scheduler"].get("min_lr",0.001)

return CosineAnnealingWarmRestarts(
parameters["optimizer_object"],
Expand Down

0 comments on commit a934324

Please sign in to comment.