Skip to content

Commit

Permalink
Merge pull request #1214 from bghira/chore/configurator-grad-checkpoi…
Browse files Browse the repository at this point in the history
…nting-interval-model-compat

configurator should avoid asking about checkpointing intervals when the model family does not support it
  • Loading branch information
bghira authored Dec 13, 2024
2 parents 4eb7aee + 1014665 commit 749ae60
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,23 +547,24 @@ def configure_env():
)
)
env_contents["--gradient_checkpointing"] = "true"
gradient_checkpointing_interval = prompt_user(
"Would you like to configure a gradient checkpointing interval? A value larger than 1 will increase VRAM usage but speed up training by skipping checkpoint creation every Nth layer, and a zero will disable this feature.",
0,
)
try:
if int(gradient_checkpointing_interval) > 1:
env_contents["--gradient_checkpointing_interval"] = int(
gradient_checkpointing_interval
)
except:
print("Could not parse gradient checkpointing interval. Not enabling.")
pass
if env_contents["--model_family"] in ["sdxl", "flux", "sd3", "sana"]:
gradient_checkpointing_interval = prompt_user(
"Would you like to configure a gradient checkpointing interval? A value larger than 1 will increase VRAM usage but speed up training by skipping checkpoint creation every Nth layer, and a zero will disable this feature.",
0,
)
try:
if int(gradient_checkpointing_interval) > 1:
env_contents["--gradient_checkpointing_interval"] = int(
gradient_checkpointing_interval
)
except:
print("Could not parse gradient checkpointing interval. Not enabling.")
pass

env_contents["--caption_dropout_probability"] = float(
prompt_user(
"Set the caption dropout rate, or use 0.0 to disable it. Dropout is not recommended for LoRA/LyCORIS training unless you are training for style transfer.",
"0.0" if any([use_lora, use_lycoris]) else "0.1",
"Set the caption dropout rate, or use 0.0 to disable it. Dropout might be a good idea to disable for Flux training, but experimentation is warranted.",
"0.05" if any([use_lora, use_lycoris]) else "0.1",
)
)

Expand Down

0 comments on commit 749ae60

Please sign in to comment.