Skip to content

Commit

Permalink
Merge pull request #1276 from rkarhila-amd/fixes_for_amd_gpu
Browse files Browse the repository at this point in the history
Small fixes for running on AMD GPUs
  • Loading branch information
bghira authored Jan 16, 2025
2 parents e4cf9a6 + e4d3309 commit d049c7a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 37 deletions.
2 changes: 2 additions & 0 deletions helpers/training/default_settings/safety_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from diffusers.utils import is_wandb_available
from helpers.training.multi_process import _get_rank as get_rank
from helpers.training.state_tracker import StateTracker
from torch.version import cuda as cuda_version

logger = logging.getLogger(__name__)
if get_rank() == 0:
Expand Down Expand Up @@ -73,6 +74,7 @@ def safety_check(args, accelerator):
accelerator is not None
and accelerator.device.type == "cuda"
and accelerator.is_main_process
and cuda_version is not None
):
import subprocess

Expand Down
85 changes: 48 additions & 37 deletions helpers/training/optimizer_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@
"Could not load bitsandbytes library. BnB-specific optimisers and other functionality will be unavailable."
)

# Some optimizers are not available in multibackend bitsandbytes as of January 2025.
is_ademamix_available = False
if is_bitsandbytes_available:
if 'AdEMAMix' in dir(bitsandbytes.optim):
is_ademamix_available = True

optimizer_choices = {
"adamw_bf16": {
"precision": "bf16",
Expand Down Expand Up @@ -353,6 +359,47 @@
},
"class": bitsandbytes.optim.PagedAdamW8bit,
},
"bnb-lion": {
"precision": "any",
"default_settings": {
"betas": (0.9, 0.99),
"weight_decay": 0.0,
"min_8bit_size": 4096,
},
"class": bitsandbytes.optim.Lion,
},
"bnb-lion8bit": {
"precision": "any",
"default_settings": {
"betas": (0.9, 0.99),
"weight_decay": 0.0,
"min_8bit_size": 4096,
},
"class": bitsandbytes.optim.Lion8bit,
},
"bnb-lion-paged": {
"precision": "any",
"default_settings": {
"betas": (0.9, 0.99),
"weight_decay": 0.0,
"min_8bit_size": 4096,
},
"class": bitsandbytes.optim.PagedLion,
},
"bnb-lion8bit-paged": {
"precision": "any",
"default_settings": {
"betas": (0.9, 0.99),
"weight_decay": 0.0,
"min_8bit_size": 4096,
},
"class": bitsandbytes.optim.PagedLion8bit,
},
})

if is_ademamix_available:
optimizer_choices.update(
{
"bnb-ademamix": {
"precision": "any",
"default_settings": {
Expand Down Expand Up @@ -404,43 +451,7 @@
"min_8bit_size": 4096,
},
"class": bitsandbytes.optim.PagedAdEMAMix8bit,
},
"bnb-lion": {
"precision": "any",
"default_settings": {
"betas": (0.9, 0.99),
"weight_decay": 0.0,
"min_8bit_size": 4096,
},
"class": bitsandbytes.optim.Lion,
},
"bnb-lion8bit": {
"precision": "any",
"default_settings": {
"betas": (0.9, 0.99),
"weight_decay": 0.0,
"min_8bit_size": 4096,
},
"class": bitsandbytes.optim.Lion8bit,
},
"bnb-lion-paged": {
"precision": "any",
"default_settings": {
"betas": (0.9, 0.99),
"weight_decay": 0.0,
"min_8bit_size": 4096,
},
"class": bitsandbytes.optim.PagedLion,
},
"bnb-lion8bit-paged": {
"precision": "any",
"default_settings": {
"betas": (0.9, 0.99),
"weight_decay": 0.0,
"min_8bit_size": 4096,
},
"class": bitsandbytes.optim.PagedLion8bit,
},
}
}
)

Expand Down

0 comments on commit d049c7a

Please sign in to comment.