Skip to content

Commit

Permalink
remove multi-dataset work
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Dec 16, 2023
1 parent 8abbb5c commit 6c2f8b0
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions train_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
from helpers.multiaspect.dataset import MultiAspectDataset
from helpers.multiaspect.bucket import BucketManager
from helpers.multiaspect.sampler import MultiAspectSampler

# from helpers.multiaspect.factory import configure_multi_dataset
from helpers.training.state_tracker import StateTracker
from helpers.training.collate import collate_fn
from helpers.training.deepspeed import deepspeed_zero_init_disabled_context_manager
Expand Down Expand Up @@ -262,7 +260,7 @@ def main():
data_backend = LocalDataBackend(accelerator=accelerator)
if not os.path.exists(args.instance_data_dir):
raise FileNotFoundError(
f"Instance {args.instance_data_dir} images root doesn't exist. Cannot continue."
f"Instance {args.instance_data_root} images root doesn't exist. Cannot continue."
)
elif args.data_backend == "aws":
from helpers.data_backend.aws import S3DataBackend
Expand All @@ -278,6 +276,8 @@ def main():
else:
raise ValueError(f"Unsupported data backend: {args.data_backend}")

# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
# Bucket manager. We keep the aspect config in the dataset so that switching datasets is simpler.
bucket_manager = BucketManager(
instance_data_root=args.instance_data_dir,
Expand Down Expand Up @@ -420,10 +420,18 @@ def print_bucket_info(bucket_manager):

# Data loader
train_dataset = MultiAspectDataset(
bucket_manager=bucket_manager,
data_backend=data_backend,
instance_data_root=args.instance_data_dir,
accelerator=accelerator,
size=args.resolution,
size_type=args.resolution_type,
print_names=args.print_filenames or False,
datasets=configure_multi_dataset(
args, accelerator
), # We need to store the list of datasets inside the MAD so that it knows their lengths.
prepend_instance_prompt=args.prepend_instance_prompt or False,
use_captions=not args.only_instance_prompt or False,
use_precomputed_token_ids=True,
debug_dataset_loader=args.debug_dataset_loader,
caption_strategy=args.caption_strategy,
)
logger.info("Creating aspect bucket sampler")
custom_balanced_sampler = MultiAspectSampler(
Expand Down Expand Up @@ -1145,9 +1153,8 @@ def print_bucket_info(bucket_manager):
# This is discussed in Section 4.2 of the same paper.
training_logger.debug(f"Using min-SNR loss")
snr = compute_snr(timesteps, noise_scheduler)
snr_divisor = snr
if noise_scheduler.config.prediction_type == "v_prediction":
snr_divisor = snr + 1
snr = snr + 1

training_logger.debug(
f"Calculating MSE loss weights using SNR as divisor"
Expand All @@ -1157,7 +1164,7 @@ def print_bucket_info(bucket_manager):
[snr, args.snr_gamma * torch.ones_like(timesteps)],
dim=1,
).min(dim=1)[0]
/ snr_divisor
/ snr
)

# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
Expand Down

0 comments on commit 6c2f8b0

Please sign in to comment.