From f8e3a2292f0afeaa987da44b074c5a73108bdcb0 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 15 Nov 2024 14:46:30 -0500 Subject: [PATCH] Add an LLM fine-tuning example (#90) * WIP: Add an LLM finetuning example Signed-off-by: Fabrice Normandin * WIP: add / rename more configs Signed-off-by: Fabrice Normandin * Finetuning example seems to be working Signed-off-by: Fabrice Normandin * Making progress, more self-contained example Signed-off-by: Fabrice Normandin * Works! (need to fix the hash used for path though) Signed-off-by: Fabrice Normandin * Improve hashing, reduce default block size Signed-off-by: Fabrice Normandin * Fix val_loss logging and add docstring Signed-off-by: Fabrice Normandin * Increase the number of dataloader workers Signed-off-by: Fabrice Normandin * Use smaller model for now Signed-off-by: Fabrice Normandin * Use FSDP in the example Signed-off-by: Fabrice Normandin * Fix bug in id generation from config classes Signed-off-by: Fabrice Normandin * Tweak config, try to setup mid-epoch checkpointing Signed-off-by: Fabrice Normandin * Rename `HFExample` -> `TextClassificationExample` Signed-off-by: Fabrice Normandin * Fix broken links in nav Signed-off-by: Fabrice Normandin * Remove "huggingface" datamodule config Signed-off-by: Fabrice Normandin * Fix issues in config/tests for text_classification Signed-off-by: Fabrice Normandin * Add an entry to test the llm_finetuning_example Signed-off-by: Fabrice Normandin * Fix issues in the text classification example Signed-off-by: Fabrice Normandin * Fix weird docstring issues with hydra-zen - https://github.com/mit-ll-responsible-ai/hydra-zen/issues/750 Signed-off-by: Fabrice Normandin * Fix test and config of text_classification_example Signed-off-by: Fabrice Normandin * Move test from main_test.py to example_test.py Signed-off-by: Fabrice Normandin * forward_pass is a method of LearningAlgorithmTests Signed-off-by: Fabrice Normandin * Various type hint fixes and tweaks Signed-off-by: Fabrice Normandin * WIP: Adding some tests for LLM finetuning example Signed-off-by: Fabrice Normandin * Fix issue in `jax.md` Signed-off-by: Fabrice Normandin * Add link to the example page in index.md Signed-off-by: Fabrice Normandin * Fix tests for the llm finetuning example Signed-off-by: Fabrice Normandin * Fix issue with tuples in regression files Signed-off-by: Fabrice Normandin * Fix test for `get_hash_of` Signed-off-by: Fabrice Normandin * Remove unused _field function Signed-off-by: Fabrice Normandin * Fix issue with built-in modules in autoref plugin Signed-off-by: Fabrice Normandin * Add a bit of info in the example doc Signed-off-by: Fabrice Normandin * Add more links in the doc of the module Signed-off-by: Fabrice Normandin * Fix issue with the text classification example Signed-off-by: Fabrice Normandin * Add skipif mark for LLM finetuning test Signed-off-by: Fabrice Normandin * Fix data_dir of text_classification_example Signed-off-by: Fabrice Normandin * Use the "auto" strategy for LLM Finetuning tests Signed-off-by: Fabrice Normandin * Fix error in fork_rng of LLM finetuning example Signed-off-by: Fabrice Normandin * Try a hacky fix for failing test Signed-off-by: Fabrice Normandin * Don't run llm finetuning tests on github Cloud CI Signed-off-by: Fabrice Normandin * Add missing regression files Signed-off-by: Fabrice Normandin * Rename llm_finetuning_example -> llm_finetuning Signed-off-by: Fabrice Normandin * Fix import error Signed-off-by: Fabrice Normandin --------- Signed-off-by: Fabrice Normandin --- .../glue_cola_algorithm_no_op_test.yaml | 35 ++ .../glue_cola_algorithm_no_op_train.yaml | 35 ++ .../glue_cola_algorithm_no_op_validate.yaml | 35 ++ docs/SUMMARY.md | 3 +- docs/examples/index.md | 23 +- docs/examples/llm_finetuning.md | 22 + docs/examples/nlp.md | 42 -- docs/examples/text_classification.md | 41 ++ docs/features/jax.md | 16 +- project/algorithms/__init__.py | 4 +- .../callbacks/samples_per_second.py | 8 + project/algorithms/example_test.py | 17 + project/algorithms/hf_example.py | 127 ---- project/algorithms/llm_finetuning.py | 562 ++++++++++++++++++ project/algorithms/llm_finetuning_test.py | 156 +++++ .../algorithms/testsuites/algorithm_tests.py | 34 +- .../algorithms/text_classification_example.py | 131 ++++ ...py => text_classification_example_test.py} | 18 +- project/configs/algorithm/hf_example.yaml | 10 - .../algorithm/llm_finetuning_example.yaml | 31 + .../algorithm/network/albert-base-v2.yaml | 2 - .../text_classification_example.yaml | 11 + project/configs/datamodule/glue_cola.yaml | 19 + project/configs/datamodule/hf_text.yaml | 7 - .../experiment/llm_finetuning_example.yaml | 37 ++ ....yaml => text_classification_example.yaml} | 6 +- .../trainer/callbacks/model_checkpoint.yaml | 2 +- project/conftest.py | 50 +- project/datamodules/__init__.py | 4 +- project/datamodules/text/__init__.py | 4 +- .../{hf_text.py => text_classification.py} | 116 ++-- ...xt_test.py => text_classification_test.py} | 12 +- project/experiment.py | 4 +- project/main.py | 2 +- project/main_test.py | 32 +- project/utils/autoref_plugin.py | 5 +- project/utils/testutils.py | 1 + pyproject.toml | 4 +- 38 files changed, 1319 insertions(+), 349 deletions(-) create mode 100644 .regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml create mode 100644 .regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml create mode 100644 .regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml create mode 100644 docs/examples/llm_finetuning.md delete mode 100644 docs/examples/nlp.md create mode 100644 docs/examples/text_classification.md delete mode 100644 project/algorithms/hf_example.py create mode 100644 project/algorithms/llm_finetuning.py create mode 100644 project/algorithms/llm_finetuning_test.py create mode 100644 project/algorithms/text_classification_example.py rename project/algorithms/{hf_example_test.py => text_classification_example_test.py} (85%) delete mode 100644 project/configs/algorithm/hf_example.yaml create mode 100644 project/configs/algorithm/llm_finetuning_example.yaml delete mode 100644 project/configs/algorithm/network/albert-base-v2.yaml create mode 100644 project/configs/algorithm/text_classification_example.yaml create mode 100644 project/configs/datamodule/glue_cola.yaml delete mode 100644 project/configs/datamodule/hf_text.yaml create mode 100644 project/configs/experiment/llm_finetuning_example.yaml rename project/configs/experiment/{hf_example.yaml => text_classification_example.yaml} (64%) rename project/datamodules/text/{hf_text.py => text_classification.py} (81%) rename project/datamodules/text/{hf_text_test.py => text_classification_test.py} (88%) diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml b/.regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml new file mode 100644 index 00000000..37d8958b --- /dev/null +++ b/.regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_test.yaml @@ -0,0 +1,35 @@ +attention_mask: + device: cpu + max: 1 + mean: '1.021e-01' + min: 0 + shape: + - 32 + - 128 + sum: 418 +input_ids: + device: cpu + max: 29043 + mean: '1.648e+02' + min: 0 + shape: + - 32 + - 128 + sum: 675172 +labels: + device: cpu + max: -1 + mean: '-1.e+00' + min: -1 + shape: + - 32 + sum: -32 +token_type_ids: + device: cpu + max: 0 + mean: '0.e+00' + min: 0 + shape: + - 32 + - 128 + sum: 0 diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml b/.regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml new file mode 100644 index 00000000..89d6925e --- /dev/null +++ b/.regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_train.yaml @@ -0,0 +1,35 @@ +attention_mask: + device: cpu + max: 1 + mean: '8.374e-02' + min: 0 + shape: + - 32 + - 128 + sum: 343 +input_ids: + device: cpu + max: 26101 + mean: '1.597e+02' + min: 0 + shape: + - 32 + - 128 + sum: 654306 +labels: + device: cpu + max: 1 + mean: '7.188e-01' + min: 0 + shape: + - 32 + sum: 23 +token_type_ids: + device: cpu + max: 0 + mean: '0.e+00' + min: 0 + shape: + - 32 + - 128 + sum: 0 diff --git a/.regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml b/.regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml new file mode 100644 index 00000000..ef5d1104 --- /dev/null +++ b/.regression_files/project/datamodules/datamodules_test/test_first_batch/glue_cola_algorithm_no_op_validate.yaml @@ -0,0 +1,35 @@ +attention_mask: + device: cpu + max: 1 + mean: '9.277e-02' + min: 0 + shape: + - 32 + - 128 + sum: 380 +input_ids: + device: cpu + max: 29043 + mean: '1.362e+02' + min: 0 + shape: + - 32 + - 128 + sum: 557879 +labels: + device: cpu + max: 1 + mean: '7.5e-01' + min: 0 + shape: + - 32 + sum: 24 +token_type_ids: + device: cpu + max: 0 + mean: '0.e+00' + min: 0 + shape: + - 32 + - 128 + sum: 0 diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index e04da3b8..5dba41f0 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -9,7 +9,8 @@ * [Examples πŸ§ͺ](examples/index.md) * [Image Classification (⚑)](examples/torch_sl_example.md) * [Image Classification (jax+⚑)](examples/jax_sl_example.md) - * [NLP (πŸ€—+⚑)](examples/nlp.md) + * [Text Classification (πŸ€—+⚑)](examples/text_classification.md) + * [Fine-tuning an LLM (πŸ€—+⚑)](examples/llm_finetuning.md) * [RL (jax)](examples/jax_rl_example.md) * [Running sweeps](examples/sweeps.md) * [Profiling your codeπŸ“Ž](examples/profiling.md) diff --git a/docs/examples/index.md b/docs/examples/index.md index 0a5824c7..4278e2a4 100644 --- a/docs/examples/index.md +++ b/docs/examples/index.md @@ -1,10 +1,21 @@ +--- +additional_python_references: + - project.algorithms.jax_rl_example + - project.algorithms.example + - project.algorithms.jax_example + - project.algorithms.text_classification_example + - project.algorithms.llm_finetuning + - project.trainers.jax_trainer +--- + # Examples This template includes examples that use either Jax, PyTorch, or both! -| Example link | Research Area | Reference link | Frameworks | -| --------------------------------------- | ------------------------------------------ | ------------------ | --------------- | -| [ExampleAlgorithm](torch_sl_example.md) | Supervised Learning (image classification) | `ExampleAlgorithm` | Torch + ⚑ | -| [JaxExample](jax_sl_example.md) | Supervised Learning (image classification) | `JaxExample` | Torch + Jax + ⚑ | -| [HFExample](nlp.md) | NLP (text classification) | `HFExample` | Torch + πŸ€— + ⚑ | -| [JaxRLExample](jax_rl_example.md) | RL | `JaxRLExample` | Jax | +| Example link | Research Area | Reference link | Frameworks | +| --------------------------------------------------- | ------------------------------------------ | --------------------------- | --------------- | +| [ExampleAlgorithm](torch_sl_example.md) | Supervised Learning (image classification) | `ExampleAlgorithm` | Torch + ⚑ | +| [JaxExample](jax_sl_example.md) | Supervised Learning (image classification) | `JaxExample` | Torch + Jax + ⚑ | +| [TextClassificationExample](text_classification.md) | NLP (text classification) | `TextClassificationExample` | Torch + πŸ€— + ⚑ | +| [JaxRLExample](jax_rl_example.md) | RL | `JaxRLExample` | Jax | +| [LLMFinetuningExample](llm_finetuning.md) | NLP (Causal language modeling) | `LLMFineTuningExample` | Torch + πŸ€— + ⚑ | diff --git a/docs/examples/llm_finetuning.md b/docs/examples/llm_finetuning.md new file mode 100644 index 00000000..0a3d07de --- /dev/null +++ b/docs/examples/llm_finetuning.md @@ -0,0 +1,22 @@ +--- +additional_python_references: + - project.algorithms.llm_finetuning +--- +# Fine-tuning LLMs + +This example is based on [this language modeling example from the HuggingFace transformers documentation](https://huggingface.co/docs/transformers/en/tasks/language_modeling). + +To better understand what's going on in this example, it is a good idea to read through these tutorials first: +* [Causal language modeling simple example - HuggingFace docs](https://huggingface.co/docs/transformers/en/tasks/language_modeling) +* [Fine-tune a language model - Colab Notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb#scrollTo=X6HrpprwIrIz) + +The main difference between this example and the original example from HuggingFace is that the `LLMFinetuningExample` is a `LightningModule`, that is trained by a `lightning.Trainer`. + +This also means that this example doesn't use [`accelerate`](https://huggingface.co/docs/accelerate/en/index) or the HuggingFace Trainer. + + +## Running the example + +```console +python project/main.py experiment=llm_finetuning_example +``` diff --git a/docs/examples/nlp.md b/docs/examples/nlp.md deleted file mode 100644 index 15915af4..00000000 --- a/docs/examples/nlp.md +++ /dev/null @@ -1,42 +0,0 @@ -# NLP (PyTorch) - - -## Overview - -The [HFExample][project.algorithms.hf_example.HFExample] is a [LightningModule][lightning.pytorch.core.module.LightningModule] for a simple auto-regressive text generation task. - -It accepts a [HFDataModule][project.datamodules.text.HFDataModule] as input, along with a network. - -??? note "Click to show the code for HFExample" - {{ inline('project.algorithms.hf_example.HFExample', 4) }} - -## Config files - -### Algorithm config - -??? note "Click to show the Algorithm config" - Source: project/configs/algorithm/hf_example.yaml - - {{ inline('project/configs/algorithm/hf_example.yaml', 4) }} - -### Datamodule config - -??? note "Click to show the Datamodule config" - Source: project/configs/datamodule/hf_text.yaml - - {{ inline('project/configs/datamodule/hf_text.yaml', 4) }} - -## Running the example - -Here is a configuration file that you can use to launch a simple experiment: - -??? note "Click to show the yaml config file" - Source: project/configs/experiment/hf_example.yaml - - {{ inline('project/configs/experiment/hf_example.yaml', 4) }} - -You can use it like so: - -```console -python project/main.py experiment=example -``` diff --git a/docs/examples/text_classification.md b/docs/examples/text_classification.md new file mode 100644 index 00000000..68122bc5 --- /dev/null +++ b/docs/examples/text_classification.md @@ -0,0 +1,41 @@ +# Text Classification ( + πŸ€—) + +## Overview + +The [TextClassificationExample][project.algorithms.text_classification_example.TextClassificationExample] is a [LightningModule][lightning.pytorch.core.module.LightningModule] for a simple text classification task. + +It accepts a [TextClassificationDataModule][project.datamodules.text.TextClassificationDataModule] as input, along with a network. + +??? note "Click to show the code for HFExample" + {{ inline('project.algorithms.text_classification_example.TextClassificationExample', 4) }} + +## Config files + +### Algorithm config + +??? note "Click to show the Algorithm config" + Source: project/configs/algorithm/text_classification_example.yaml + + {{ inline('project/configs/algorithm/text_classification_example.yaml', 4) }} + +### Datamodule config + +??? note "Click to show the Datamodule config" + Source: project/configs/datamodule/glue_cola.yaml + + {{ inline('project/configs/datamodule/glue_cola.yaml', 4) }} + +## Running the example + +Here is a configuration file that you can use to launch a simple experiment: + +??? note "Click to show the yaml config file" + Source: project/configs/experiment/text_classification_example.yaml + + {{ inline('project/configs/experiment/text_classification_example.yaml', 4) }} + +You can use it like so: + +```console +python project/main.py experiment=text_classification_example +``` diff --git a/docs/features/jax.md b/docs/features/jax.md index 82a9022b..e54d4b19 100644 --- a/docs/features/jax.md +++ b/docs/features/jax.md @@ -3,7 +3,7 @@ additional_python_references: - project.algorithms.jax_rl_example - project.algorithms.example - project.algorithms.jax_example - - project.algorithms.hf_example + - project.algorithms.text_classification_example - project.trainers.jax_trainer --- @@ -13,12 +13,14 @@ additional_python_references: This template includes examples that use either Jax, PyTorch, or both! -| Example link | Reference | Framework | Lightning? | -| ------------------------------------------------- | ------------------ | ----------- | ------------ | -| [ExampleAlgorithm](../examples/jax_sl_example.md) | `ExampleAlgorithm` | Torch | yes | -| [JaxExample](../examples/jax_sl_example.md) | `JaxExample` | Torch + Jax | yes | -| [HFExample](../examples/nlp.md) | `HFExample` | Torch + πŸ€— | yes | -| [JaxRLExample](../examples/jax_rl_example.md) | `JaxRLExample` | Jax | no (almost!) | + + +| Example link | Reference | Framework | Lightning? | +| --------------------------------------------------------------- | --------------------------- | ----------- | ------------ | +| [ExampleAlgorithm](../examples/jax_sl_example.md) | `ExampleAlgorithm` | Torch | yes | +| [JaxExample](../examples/jax_sl_example.md) | `JaxExample` | Torch + Jax | yes | +| [TextClassificationExample](../examples/text_classification.md) | `TextClassificationExample` | Torch + πŸ€— | yes | +| [JaxRLExample](../examples/jax_rl_example.md) | `JaxRLExample` | Jax | no (almost!) | In fact, here you can mix and match both Jax and Torch code. For example, you can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning. diff --git a/project/algorithms/__init__.py b/project/algorithms/__init__.py index c92dca8a..de0fcedd 100644 --- a/project/algorithms/__init__.py +++ b/project/algorithms/__init__.py @@ -1,13 +1,13 @@ from .example import ExampleAlgorithm -from .hf_example import HFExample from .jax_example import JaxExample from .jax_rl_example import JaxRLExample from .no_op import NoOp +from .text_classification_example import TextClassificationExample __all__ = [ "ExampleAlgorithm", "JaxExample", "NoOp", - "HFExample", + "TextClassificationExample", "JaxRLExample", ] diff --git a/project/algorithms/callbacks/samples_per_second.py b/project/algorithms/callbacks/samples_per_second.py index 4062b480..d0134cb1 100644 --- a/project/algorithms/callbacks/samples_per_second.py +++ b/project/algorithms/callbacks/samples_per_second.py @@ -1,6 +1,8 @@ import time from typing import Any, Literal +import jax +import torch from lightning import LightningModule, Trainer from torch import Tensor from torch.optim import Optimizer @@ -90,6 +92,12 @@ def log( def get_num_samples(self, batch: BatchType) -> int: if is_sequence_of(batch, Tensor): return batch[0].shape[0] + if isinstance(batch, dict): + return next( + v.shape[0] + for v in jax.tree.leaves(batch) + if isinstance(v, torch.Tensor) and v.ndim > 1 + ) raise NotImplementedError( f"Don't know how many 'samples' there are in batch of type {type(batch)}" ) diff --git a/project/algorithms/example_test.py b/project/algorithms/example_test.py index 0b01a38d..d3e69a9b 100644 --- a/project/algorithms/example_test.py +++ b/project/algorithms/example_test.py @@ -1,9 +1,13 @@ """Example showing how the test suite can be used to add tests for a new algorithm.""" +import pytest import torch from transformers import PreTrainedModel from project.algorithms.testsuites.algorithm_tests import LearningAlgorithmTests +from project.configs import Config +from project.conftest import command_line_overrides +from project.datamodules.image_classification.cifar10 import CIFAR10DataModule from project.datamodules.image_classification.image_classification import ( ImageClassificationDataModule, ) @@ -12,6 +16,19 @@ from .example import ExampleAlgorithm +@pytest.mark.parametrize( + command_line_overrides.__name__, ["algorithm=example datamodule=cifar10"], indirect=True +) +def test_example_experiment_defaults(experiment_config: Config) -> None: + """Test to check that the datamodule is required (even when just an algorithm is set?!).""" + + assert experiment_config.algorithm["_target_"] == ( + ExampleAlgorithm.__module__ + "." + ExampleAlgorithm.__qualname__ + ) + + assert isinstance(experiment_config.datamodule, CIFAR10DataModule) + + @run_for_all_configs_of_type("algorithm", ExampleAlgorithm) @run_for_all_configs_of_type("datamodule", ImageClassificationDataModule) @run_for_all_configs_of_type("algorithm/network", torch.nn.Module, excluding=PreTrainedModel) diff --git a/project/algorithms/hf_example.py b/project/algorithms/hf_example.py deleted file mode 100644 index 9b008e73..00000000 --- a/project/algorithms/hf_example.py +++ /dev/null @@ -1,127 +0,0 @@ -from datetime import datetime -from pathlib import Path - -import torch -from evaluate import load as load_metric -from lightning import LightningModule -from torch.optim.adamw import AdamW -from transformers import ( - AutoConfig, - AutoModelForSequenceClassification, - PreTrainedModel, - get_linear_schedule_with_warmup, -) - -from project.datamodules.text.hf_text import HFDataModule - - -def pretrained_network(model_name_or_path: str | Path, **kwargs) -> PreTrainedModel: - config = AutoConfig.from_pretrained(model_name_or_path, **kwargs) - return AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=config) - - -class HFExample(LightningModule): - """Example of a lightning module used to train a huggingface model.""" - - def __init__( - self, - datamodule: HFDataModule, - network: PreTrainedModel, - hf_metric_name: str, - learning_rate: float = 2e-5, - adam_epsilon: float = 1e-8, - warmup_steps: int = 0, - weight_decay: float = 0.0, - **kwargs, - ): - super().__init__() - - self.save_hyperparameters() - self.num_labels = datamodule.num_labels - self.task_name = datamodule.task_name - self.network = network - self.hf_metric_name = hf_metric_name - self.metric = load_metric( - self.hf_metric_name, - self.task_name, - experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S"), - ) - - # Small fix for the `device` property in LightningModule, which is CPU by default. - self._device = next((p.device for p in self.parameters()), torch.device("cpu")) - - def forward( - self, - input_ids: torch.Tensor, - token_type_ids: torch.Tensor, - attention_mask: torch.Tensor, - labels: torch.Tensor, - ): - return self.network( - input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels - ) - - def model_step(self, batch: dict[str, torch.Tensor]): - input_ids = batch["input_ids"] - token_type_ids = batch["token_type_ids"] - attention_mask = batch["attention_mask"] - labels = batch["labels"] - - outputs = self.forward(input_ids, token_type_ids, attention_mask, labels) - loss = outputs.loss - logits = outputs.logits - - if self.num_labels > 1: - preds = torch.argmax(logits, axis=1) - else: - preds = logits.squeeze() - - return loss, preds, labels - - def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int): - loss, preds, labels = self.model_step(batch) - self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True) - return {"loss": loss, "preds": preds, "labels": labels} - - def validation_step( - self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx: int = 0 - ): - val_loss, preds, labels = self.model_step(batch) - self.log("val/loss", val_loss, on_step=False, on_epoch=True, prog_bar=True) - return {"val/loss": val_loss, "preds": preds, "labels": labels} - - def configure_optimizers(self): - """Prepare optimizer and schedule (linear warmup and decay)""" - model = self.network - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [ - p - for n, p in model.named_parameters() - if not any(nd_param in n for nd_param in no_decay) - ], - "weight_decay": self.hparams.weight_decay, - }, - { - "params": [ - p - for n, p in model.named_parameters() - if any(nd_param in n for nd_param in no_decay) - ], - "weight_decay": 0.0, - }, - ] - optimizer = AdamW( - optimizer_grouped_parameters, - lr=self.hparams.learning_rate, - eps=self.hparams.adam_epsilon, - ) - - scheduler = get_linear_schedule_with_warmup( - optimizer, - num_warmup_steps=self.hparams.warmup_steps, - num_training_steps=self.trainer.estimated_stepping_batches, - ) - scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} - return [optimizer], [scheduler] diff --git a/project/algorithms/llm_finetuning.py b/project/algorithms/llm_finetuning.py new file mode 100644 index 00000000..9330cf0c --- /dev/null +++ b/project/algorithms/llm_finetuning.py @@ -0,0 +1,562 @@ +"""Example: fine-tuning a language model (GPT, GPT-2, CTRL, OPT, etc.) on a text dataset. + +Large chunks of the code here are taken from [this example script in the transformers GitHub repository](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm_no_trainer.py). + +If you haven't already, you should definitely check out [this walkthrough of that script from the HuggingFace docs.](https://huggingface.co/docs/transformers/en/tasks/language_modeling) +""" + +import dataclasses +import hashlib +import itertools +import os +import shutil +from collections.abc import Callable, Mapping +from dataclasses import dataclass +from logging import getLogger +from pathlib import Path +from typing import Concatenate, ParamSpec, TypeVar + +import datasets +import datasets.distributed +import hydra_zen +import torch +import torch.distributed +from datasets import Dataset, load_from_disk +from datasets.dataset_dict import DatasetDict +from lightning import LightningModule +from torch.optim.adamw import AdamW +from torch.utils.data import DataLoader +from transformers import ( + default_data_collator, + get_linear_schedule_with_warmup, +) +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput +from transformers.models.auto import AutoModelForCausalLM +from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers.tokenization_utils_fast import PreTrainedTokenizerBase + +from project.utils.env_vars import SCRATCH, SLURM_TMPDIR +from project.utils.typing_utils import NestedMapping + +logger = getLogger(__name__) + + +def num_cpus_per_task() -> int: + if hasattr(os, "sched_getaffinity"): + return len(os.sched_getaffinity(0)) + return torch.multiprocessing.cpu_count() + + +@hydra_zen.hydrated_dataclass( + target=AutoModelForCausalLM.from_pretrained, + frozen=True, + unsafe_hash=True, + populate_full_signature=True, +) +class NetworkConfig: + """Configuration options related to the choice of network. + + When instantiated by Hydra, this calls the `target` function passed to the decorator. In this + case, this creates pulls the pretrained network weights from the HuggingFace model hub. + """ + + __doc__ = """Configuration options related to the choice of network. + +When instantiated by Hydra, this calls the `target` function passed to the decorator. In this +case, this creates pulls the pretrained network weights from the HuggingFace model hub. +""" + + pretrained_model_name_or_path: str + trust_remote_code: bool = False + torch_dtype: torch.dtype | None = None + attn_implementation: str | None = None + # cache_dir: Path | None = None + # force_download: bool | None = None + # local_files_only: bool | None = None + # proxies: dict[str, str] | None = None + # revision: str | None = None + # subfolder: str | None = None + # use_auth_token: bool | None = None + # token: str | bool | None = None + + +# BUG: Hydra-zen includes the doc of the target, so doctest complains here. +NetworkConfig.__doc__ = """Configuration options related to the choice of network. + +When instantiated by Hydra, this calls the `target` function passed to the decorator. In this +case, this creates pulls the pretrained network weights from the HuggingFace model hub. +""" + + +@hydra_zen.hydrated_dataclass( + target=AutoTokenizer.from_pretrained, + frozen=True, + unsafe_hash=True, + populate_full_signature=True, +) +class TokenizerConfig: + """Configuration options for the tokenizer.""" + + pretrained_model_name_or_path: str + cache_dir: Path | None = None # use standard cache by default. + force_download: bool = False + local_files_only: bool = False + token: str | bool | None = None + revision: str = "main" + use_fast: bool = True + config: PretrainedConfig | None = None + # proxies: dict[str, str] = dataclasses.field(default_factory=dict, hash=False) + subfolder: str = "" + tokenizer_type: str | None = None + trust_remote_code: bool = False + + +# BUG: Hydra-zen includes the doc of the target, so doctest complains here. +TokenizerConfig.__doc__ = """Configuration options for the tokenizer.""" + + +@dataclass(frozen=True, unsafe_hash=True) +class DatasetConfig: + """Configuration options related to the dataset preparation.""" + + dataset_path: str + """Name of the dataset "family"? + + For example, to load "wikitext/wikitext-103-v1", this would be "wikitext". + """ + + dataset_name: str | None = None + """Name of the specific dataset? + + For example, to load "wikitext/wikitext-103-v1", this would be "wikitext-103-v1". + """ + + # Don't include those fields when computign the 'id' of the config, which we use to determine + # if we've already prepared the dataset or not. + per_device_eval_batch_size: int = dataclasses.field( + default=8, metadata={"include_in_id": False} + ) + per_device_train_batch_size: int = dataclasses.field( + default=8, metadata={"include_in_id": False} + ) + + block_size: int = 1024 + + preprocessing_num_workers: int = num_cpus_per_task() + + validation_split_percentage: int = 10 + """Fraction of the train dataset to use for validation if there isn't already a validation + split.""" + + overwrite_cache: bool = False + + +def load_raw_datasets(config: DatasetConfig): + raw_datasets = datasets.load_dataset(config.dataset_path, config.dataset_name) + assert isinstance(raw_datasets, DatasetDict) + if "validation" not in raw_datasets.keys() and config.validation_split_percentage > 0: + raw_datasets["validation"] = datasets.load_dataset( + config.dataset_path, + config.dataset_name, + split=f"train[:{config.validation_split_percentage}%]", + ) + raw_datasets["train"] = datasets.load_dataset( + config.dataset_path, + config.dataset_name, + split=f"train[{config.validation_split_percentage}%:]", + ) + return raw_datasets + + +def prepare_datasets( + dataset_config: DatasetConfig, tokenizer_config: TokenizerConfig +) -> DatasetDict: + # todo: an improvement could be to cache each portion, so that if we just change the block + # size, we don't have to re-tokenize the dataset for example. + raw_datasets = load_raw_datasets(dataset_config) + tokenizer = load_tokenizer(tokenizer_config) + tokenized_datasets = tokenize_datasets(raw_datasets, tokenizer, dataset_config) + lm_datasets = group_text_into_blocks(tokenized_datasets, tokenizer, dataset_config) + return lm_datasets + + +def load_tokenizer(config: TokenizerConfig) -> PreTrainedTokenizerBase: + return hydra_zen.instantiate(config) + + +def tokenize_datasets( + raw_datasets: DatasetDict, + tokenizer: PreTrainedTokenizerBase, + config: DatasetConfig, +) -> DatasetDict: + return raw_datasets.map( + lambda b: tokenizer(b["text"]), + batched=True, + remove_columns=raw_datasets["train"].column_names, + load_from_cache_file=not config.overwrite_cache, + desc="Tokenizing the dataset", + ) + + +def group_text_into_blocks( + tokenized_datasets: DatasetDict, + tokenizer: PreTrainedTokenizerBase, + config: DatasetConfig, +) -> DatasetDict: + block_size = config.block_size + if block_size > tokenizer.model_max_length: + logger.warning( + f"The block_size passed ({block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) + block_size = tokenizer.model_max_length + + return tokenized_datasets.map( + group_texts, + fn_kwargs={"block_size": block_size}, + batched=True, + load_from_cache_file=True, + num_proc=config.preprocessing_num_workers, + desc=f"Grouping tokens into chunks of size {block_size}", + ) + + +# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. +def group_texts(examples: dict, block_size: int): + # Concatenate all texts. + concatenated_examples = {k: list(itertools.chain(*examples[k])) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + +class LLMFinetuningExample(LightningModule): + """Example of a lightning module used to fine-tune a huggingface model.""" + + def __init__( + self, + network_config: NetworkConfig, + tokenizer_config: TokenizerConfig, + dataset_config: DatasetConfig, + learning_rate: float = 2e-5, + adam_epsilon: float = 1e-8, + warmup_steps: int = 0, + weight_decay: float = 0.0, + init_seed: int = 42, + ): + super().__init__() + self.network_config = network_config + self.tokenizer_config = tokenizer_config + self.dataset_config = dataset_config + self.learning_rate = learning_rate + self.adam_epsilon = adam_epsilon + self.warmup_steps = warmup_steps + self.weight_decay = weight_decay + self.init_seed = init_seed + + # NOTE: have to do this because Lightning doesn't do it automatically for dataclasses... + self.save_hyperparameters( + dict( + network_config=dataclasses.asdict(network_config), + tokenizer_config=dataclasses.asdict(tokenizer_config), + dataset_config=dataclasses.asdict(dataset_config), + learning_rate=learning_rate, + adam_epsilon=adam_epsilon, + warmup_steps=warmup_steps, + weight_decay=weight_decay, + init_seed=init_seed, + ) + ) + + # We will prepare the dataset only on the first task of the first node node for multi-node + # jobs. + # TODO: there is probably a way to do distributed preprocessing (tokenization/grouping/...) + # perhaps we could look into this: + # https://huggingface.co/docs/datasets/v3.1.0/en/use_with_pytorch#distributed + self.prepare_data_per_node = True # Execute `prepare_data` on each node. + self.data_configs_id = ( + f"{get_hash_of(self.dataset_config)[:8]}_{get_hash_of(self.tokenizer_config)[:8]}" + ) + logger.info(f"Unique id for our dataset / tokenizer configs: {self.data_configs_id}") + + self.scratch_prepared_dataset_dir: Path | None = None + if SCRATCH is not None: + # TODO: Should we base ourselves on the HF environment variables instead of hard-coding + # $SCRATCH/data/...? + self.scratch_prepared_dataset_dir = ( + SCRATCH / "data" / "prepared_dataset" / self.data_configs_id + ) + self.scratch_prepared_dataset_dir.parent.mkdir(parents=True, exist_ok=True) + + fast_data_dir = (SLURM_TMPDIR or Path.cwd()) / "data" / "prepared_dataset" + self.fast_prepared_dataset_dir = fast_data_dir / self.data_configs_id + self.fast_prepared_dataset_dir.parent.mkdir(parents=True, exist_ok=True) + + self.tokenizer: PreTrainedTokenizerBase | None = None + self.train_dataset: Dataset | None = None + self.valid_dataset: Dataset | None = None + self.network: AutoModelForCausalLM | None = None + + def configure_model(self) -> None: + # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#speed-up-model-initialization + # Initialize the weights on the GPU if we have one, so we don't + # request lots of RAM just to load up the model weights and then not use it. + if self.network is not None: + return + logger.info(f"Rank {self.local_rank}: {self.device=}") + with torch.random.fork_rng(devices=[self.device] if self.device.type == "cuda" else []): + # deterministic weight initialization + torch.manual_seed(self.init_seed) + self.network = hydra_zen.instantiate(self.network_config) + + def prepare_data(self): + # This gets called on every node in a distrituted training setup. + # See the Lightning docs for this method for more information. + # + # If we've already prepared the dataset on this node, we can just load it. + # If we're on a SLURM cluster and we've already prepared it in $SCRATCH, then copy it to + # the local fast directory. + # Otherwise do the tokenization and grouping and save it to the local fast directory, then + # copy it to the $SCRATCH directory for future use. + if _try_to_load_prepared_dataset_from(self.fast_prepared_dataset_dir): + logger.info( + f"Dataset is already prepared on this node at {self.fast_prepared_dataset_dir}" + ) + return + logger.debug("Dataset hasn't been prepared on this node yet.") + + if not self.scratch_prepared_dataset_dir: + # Let's assume that you're using SLURM for multi-node jobs for now. + # SCRATCH isn't set --> not on a SLURM cluster. + assert self.trainer.num_nodes == 1 + logger.info(f"Preparing dataset at {self.fast_prepared_dataset_dir}.") + lm_datasets = prepare_datasets(self.dataset_config, self.tokenizer_config) + lm_datasets.save_to_disk(self.fast_prepared_dataset_dir) + return + + if _try_to_load_prepared_dataset_from(self.scratch_prepared_dataset_dir): + logger.info( + f"Dataset is already prepared on the shared filesystem at " + f"{self.scratch_prepared_dataset_dir}" + ) + copy_dataset_files(self.scratch_prepared_dataset_dir, self.fast_prepared_dataset_dir) + return + + logger.debug("Dataset has not yet been prepared with this config yet.") + + if self.trainer.num_nodes == 1: + logger.debug("Single-node training. Preparing the dataset.") + lm_datasets = prepare_datasets(self.dataset_config, self.tokenizer_config) + lm_datasets.save_to_disk(self.fast_prepared_dataset_dir) + logger.info(f"Saved processed dataset to {self.fast_prepared_dataset_dir}") + copy_dataset_files(self.fast_prepared_dataset_dir, self.scratch_prepared_dataset_dir) + return + + # NOTE: There might be a way to distribute the preprocessing across nodes, I'm not sure. + # todo: Would be even better to add an `srun` step before this with `ntasks_per_node=1` to + # speed up the preprocessing! + _barrier_name = "prepare_data" + if self.global_rank == 0: + logger.info( + f"Task {self.global_rank}: Preparing the dataset in $SLURM_TMPDIR and copying it to $SCRATCH." + ) + # TODO: This might cause some timeouts if the dataset preprocessing takes a while to do, no? + # TODO: + lm_datasets = prepare_datasets(self.dataset_config, self.tokenizer_config) + lm_datasets.save_to_disk(self.fast_prepared_dataset_dir) + logger.info(f"Saved processed dataset to {self.fast_prepared_dataset_dir}") + copy_dataset_files(self.fast_prepared_dataset_dir, self.scratch_prepared_dataset_dir) + logger.info(f"Task {self.global_rank}: Done preparing the dataset.") + # wait (i.e. join the other tasks that are already waiting) + self.trainer.strategy.barrier(_barrier_name) + else: + logger.info( + f"Task {self.global_rank}: Waiting for the first task on the first node to finish preparing the dataset." + ) + # Wait for the first task to get to the barrier (i.e. wait for the first task to finish + # preprocessing the dataset). + self.trainer.strategy.barrier(_barrier_name) + + assert self.scratch_prepared_dataset_dir.exists() + logger.info( + f"Copying the dataset files prepared by the first node at {self.scratch_prepared_dataset_dir}" + ) + copy_dataset_files(self.scratch_prepared_dataset_dir, self.fast_prepared_dataset_dir) + + logger.info(f"Done preparing the datasets at {self.fast_prepared_dataset_dir}.") + + def setup(self, stage: str): + """Hook from Lightning that is called at the start of training, validation and testing. + + TODO: Later perhaps we could do the preprocessing in a distributed manner like this: + https://discuss.huggingface.co/t/how-to-save-datasets-as-distributed-with-save-to-disk/25674/2 + """ + # https://huggingface.co/docs/datasets/v3.1.0/en/use_with_pytorch#distributed + # Load the tokenizer (again). + self.tokenizer = load_tokenizer(self.tokenizer_config) + lm_datasets = datasets.load_from_disk(self.fast_prepared_dataset_dir) + + # This is done here again because in distributed training jobs, `prepare_data` is only + # called in the first task on each node, while `setup` is called in every task. + logger.info(f"Loading processed dataset from {self.fast_prepared_dataset_dir}") + assert isinstance(lm_datasets, DatasetDict) + self.train_dataset = lm_datasets["train"] + self.valid_dataset = lm_datasets["validation"] + + # todo: Should we be using `datasets.distributed.split_dataset_by_node` here? Or do we let + # PyTorch-Lightning setup the distributed sampler for us? + # self.train_dataset = datasets.distributed.split_dataset_by_node( + # self.train_dataset, rank=self.global_rank, world_size=self.trainer.world_size + # ) + # self.valid_dataset = datasets.distributed.split_dataset_by_node( + # self.valid_dataset, rank=self.global_rank, world_size=self.trainer.world_size + # ) + + def train_dataloader(self): + assert self.train_dataset is not None + return DataLoader( + self.train_dataset, + shuffle=True, + collate_fn=default_data_collator, + num_workers=self.dataset_config.preprocessing_num_workers, + batch_size=self.dataset_config.per_device_train_batch_size, + ) + + def val_dataloader(self): + assert self.valid_dataset is not None + + return DataLoader( + self.valid_dataset, + collate_fn=default_data_collator, + num_workers=self.dataset_config.preprocessing_num_workers, + batch_size=self.dataset_config.per_device_eval_batch_size, + ) + + def forward(self, **inputs: torch.Tensor) -> BaseModelOutput: + assert self.network is not None + return self.network(**inputs) + + def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int): + outputs: CausalLMOutput | SequenceClassifierOutput = self(**batch) + loss = outputs.loss + assert loss is not None + # todo: log more stuff! + self.log("train/loss", loss, on_step=True, on_epoch=False, prog_bar=True) + return loss + + def validation_step( + self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx: int = 0 + ): + outputs: CausalLMOutput | SequenceClassifierOutput = self(**batch) + loss = outputs.loss + assert loss is not None + # todo: log more stuff! + self.log("val/loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + return loss + + def configure_optimizers(self): + """Prepare optimizer and schedule (linear warmup and decay)""" + # Not sure if necessary, but trying to follow this recommendation for when using FSDP: + # https://github.com/ashleve/lightning-hydra-template/pull/604 + model = self.trainer.model or self + assert model is not None + + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if not any(nd_param in n for nd_param in no_decay) + ], + "weight_decay": self.weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if any(nd_param in n for nd_param in no_decay) + ], + "weight_decay": 0.0, + }, + ] + optimizer = AdamW( + optimizer_grouped_parameters, + lr=self.learning_rate, + eps=self.adam_epsilon, + ) + + scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=self.warmup_steps, + num_training_steps=self.trainer.estimated_stepping_batches, + ) + scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + +def copy_dataset_files(src: Path, dest: Path): + logger.info(f"Copying dataset from {src} --> {dest}") + shutil.copytree(src, dest) + + +def get_hash_of(config_dataclass) -> str: + # IDEA: don't include fields if they have `hash=False` in the "hash". + vals = dataclasses.asdict(config_dataclass) + for field in dataclasses.fields(config_dataclass): + if not _include_field_in_id(field): + logger.debug(f"Ignoring field {field.name} when computing the ID.") + vals.pop(field.name) + + flattened_vals = dict(sorted(flatten_dict(vals).items())) + vals_string = ",".join(f"{k}:{v}" for k, v in flattened_vals.items()) + return hashlib.md5(vals_string.encode()).hexdigest() + + +V = TypeVar("V") + + +def flatten_dict(d: NestedMapping[str, V]) -> dict[str, V]: + result = {} + for k, v in d.items(): + if isinstance(v, Mapping): + result.update({f"{k}.{subk}": subv for subk, subv in flatten_dict(v).items()}) + else: + result[k] = v + return result + + +P = ParamSpec("P") + + +def _try_to_load_prepared_dataset_from( + dataset_path: Path, + _load_from_disk_fn: Callable[Concatenate[Path, P], Dataset | DatasetDict] = load_from_disk, + *_load_from_disk_args: P.args, + **_load_from_disk_kwargs: P.kwargs, +) -> DatasetDict | None: + try: + datasets = _load_from_disk_fn( + dataset_path, *_load_from_disk_args, **_load_from_disk_kwargs + ) + except FileNotFoundError as exc: + logger.debug(f"Unable to load the prepared dataset from {dataset_path}: {exc}") + return None + else: + logger.debug(f"Dataset is already prepared at {dataset_path}") + assert isinstance(datasets, DatasetDict) + return datasets + + +def _include_field_in_id(field: dataclasses.Field) -> bool: + return field.metadata.get("include_in_id", True) diff --git a/project/algorithms/llm_finetuning_test.py b/project/algorithms/llm_finetuning_test.py new file mode 100644 index 00000000..99f0edf1 --- /dev/null +++ b/project/algorithms/llm_finetuning_test.py @@ -0,0 +1,156 @@ +"""Unit tests for the llm finetuning example.""" + +import copy +import operator + +import jax +import lightning +import numpy as np +import pytest +import torch +from tensor_regression import TensorRegressionFixture +from tensor_regression.stats import get_simple_attributes +from tensor_regression.to_array import to_ndarray +from torch.utils.data import DataLoader + +from project.algorithms.llm_finetuning import ( + DatasetConfig, + LLMFinetuningExample, + TokenizerConfig, + get_hash_of, +) +from project.algorithms.testsuites.algorithm_tests import LearningAlgorithmTests +from project.configs.config import Config +from project.conftest import command_line_overrides +from project.utils.env_vars import SLURM_JOB_ID +from project.utils.testutils import IN_GITHUB_COULD_CI, run_for_all_configs_of_type +from project.utils.typing_utils import PyTree +from project.utils.typing_utils.protocols import DataModule + + +@pytest.mark.parametrize( + ("c1", "c2"), + [ + ( + DatasetConfig(dataset_path="wikitext", dataset_name="wikitext-2-v1"), + DatasetConfig(dataset_path="wikitext", dataset_name="wikitext-103-v1"), + ), + ( + TokenizerConfig(pretrained_model_name_or_path="gpt2"), + TokenizerConfig(pretrained_model_name_or_path="bert-base-uncased"), + ), + ], +) +def test_get_hash_of(c1, c2): + assert get_hash_of(c1) == get_hash_of(c1) + assert get_hash_of(c2) == get_hash_of(c2) + assert get_hash_of(c1) != get_hash_of(c2) + assert get_hash_of(c1) == get_hash_of(copy.deepcopy(c1)) + assert get_hash_of(c2) == get_hash_of(copy.deepcopy(c2)) + + +@get_simple_attributes.register(tuple) +def _get_tuple_attributes(value: tuple, precision: int | None): + # This is called to get some simple stats to store in regression files during tests, in + # particular for tuples (since there isn't already a handler for it in the tensor_regression + # package.) + # Note: This information about this output is not very descriptive. + # not this is called only for the `out.past_key_values` entry in the `CausalLMOutputWithPast` + # that is returned from the forward pass output. + num_items_to_include = 5 # only show the stats of some of the items. + return { + "length": len(value), + **{ + f"{i}": get_simple_attributes(item, precision=precision) + for i, item in enumerate(value[:num_items_to_include]) + }, + } + + +@to_ndarray.register(tuple) +def _tuple_to_ndarray(v: tuple) -> np.ndarray: + """Convert a tuple of values to a numpy array to be stored in a regression file.""" + # This could get a bit tricky because the items might not have the same shape and so on. + # However it seems like the ndarrays_regression fixture (which is what tensor_regression uses + # under the hood) is not complaining about us returning a list here, so we'll leave it at that + # for now. + return [to_ndarray(v_i) for v_i in v] # type: ignore + + +@pytest.mark.skipif( + IN_GITHUB_COULD_CI, reason="This test is too resource-intensive to run on the GitHub CI." +) +@pytest.mark.parametrize( + command_line_overrides.__name__, + ["trainer.strategy=auto" if SLURM_JOB_ID is None else ""], + indirect=True, +) +@run_for_all_configs_of_type("algorithm", LLMFinetuningExample) +class TestLLMFinetuningExample(LearningAlgorithmTests[LLMFinetuningExample]): + @pytest.fixture(scope="function") + def train_dataloader( + self, + algorithm: LLMFinetuningExample, + request: pytest.FixtureRequest, + trainer: lightning.Trainer, + ) -> DataLoader: + """Fixture that creates and returns the training dataloader. + + NOTE: Here we're purpusefully redefining the `project.conftest.train_dataloader` fixture + because it assumes that the algorithm uses a datamodule. + Here we change the fixture scope. + """ + # a bit hacky: Set the trainer on the lightningmodule. + algorithm._trainer = trainer + algorithm.prepare_data() + algorithm.setup("fit") + + train_dataloader = algorithm.train_dataloader() + assert isinstance(train_dataloader, DataLoader) + return train_dataloader + + @pytest.fixture(scope="function") + def training_batch( + self, train_dataloader: DataLoader, device: torch.device + ) -> dict[str, torch.Tensor]: + # Get a batch of data from the dataloader. + + # The batch of data will always be the same because the dataloaders are passed a Generator + # object in their constructor. + assert isinstance(train_dataloader, DataLoader) + dataloader_iterator = iter(train_dataloader) + + with torch.random.fork_rng(list(range(torch.cuda.device_count()))): + # TODO: This ugliness is because torchvision transforms use the global pytorch RNG! + torch.random.manual_seed(42) + batch = next(dataloader_iterator) + + return jax.tree.map(operator.methodcaller("to", device=device), batch) + + @pytest.fixture(scope="function") + def forward_pass_input(self, training_batch: PyTree[torch.Tensor], device: torch.device): + """Extracts the model input from a batch of data coming from the dataloader. + + Overwrite this if your batches are not tuples of tensors (i.e. if your algorithm isn't a + simple supervised learning algorithm like the example). + """ + assert isinstance(training_batch, dict) + return training_batch + + # Checking all the weights against the 900mb reference .npz file is a bit slow. + @pytest.mark.slow + def test_initialization_is_reproducible( + self, + experiment_config: Config, + datamodule: DataModule, + seed: int, + tensor_regression: TensorRegressionFixture, + trainer: lightning.Trainer, + ): + super().test_initialization_is_reproducible( + experiment_config=experiment_config, + datamodule=datamodule, + seed=seed, + tensor_regression=tensor_regression, + trainer=trainer, + ) diff --git a/project/algorithms/testsuites/algorithm_tests.py b/project/algorithms/testsuites/algorithm_tests.py index 6c352c47..13287173 100644 --- a/project/algorithms/testsuites/algorithm_tests.py +++ b/project/algorithms/testsuites/algorithm_tests.py @@ -29,14 +29,6 @@ AlgorithmType = TypeVar("AlgorithmType", bound=LightningModule) -def forward_pass(algorithm: LightningModule, input: PyTree[torch.Tensor]): - """Performs the forward pass with the lightningmodule, unpacking the inputs if necessary.""" - if len(inspect.signature(algorithm.forward).parameters) == 1: - return algorithm(input) - assert isinstance(input, dict) - return algorithm(**input) - - @pytest.mark.incremental class LearningAlgorithmTests(Generic[AlgorithmType], ABC): """Suite of unit tests for an "Algorithm" (LightningModule). @@ -49,6 +41,18 @@ class LearningAlgorithmTests(Generic[AlgorithmType], ABC): # algorithm_config: ParametrizedFixture[str] + def forward_pass(self, algorithm: LightningModule, input: PyTree[torch.Tensor]): + """Performs the forward pass with the lightningmodule, unpacking the inputs if necessary. + + Overwrite this if your algorithm's forward method is more complicated. + """ + signature = inspect.signature(algorithm.forward) + if any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in signature.parameters.values()): + return algorithm(*input) + if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in signature.parameters.values()): + return algorithm(**input) + return algorithm(input) + def test_initialization_is_deterministic( self, experiment_config: Config, @@ -75,10 +79,10 @@ def test_forward_pass_is_deterministic( with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))): torch.random.manual_seed(seed) - out1 = forward_pass(algorithm, forward_pass_input) + out1 = self.forward_pass(algorithm, forward_pass_input) with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))): torch.random.manual_seed(seed) - out2 = forward_pass(algorithm, forward_pass_input) + out2 = self.forward_pass(algorithm, forward_pass_input) torch.testing.assert_close(out1, out2) @@ -139,11 +143,19 @@ def test_initialization_is_reproducible( datamodule: DataModule, seed: int, tensor_regression: TensorRegressionFixture, + trainer: lightning.Trainer, ): """Check that the network initialization is reproducible given the same random seed.""" with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))): torch.random.manual_seed(seed) algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule) + + if isinstance(algorithm, LightningModule): + # Using `init_module` so the weights are on the right device and with the right + # precision. + with trainer.init_module(): + algorithm.configure_model() + tensor_regression.check( algorithm.state_dict(), # Save the regression files on a different subfolder for each device (cpu / cuda) @@ -161,7 +173,7 @@ def test_forward_pass_is_reproducible( """Check that the forward pass is reproducible given the same input and random seed.""" with torch.random.fork_rng(devices=list(range(torch.cuda.device_count()))): torch.random.manual_seed(seed) - out = forward_pass(algorithm, forward_pass_input) + out = self.forward_pass(algorithm, forward_pass_input) tensor_regression.check( {"input": forward_pass_input, "out": out}, diff --git a/project/algorithms/text_classification_example.py b/project/algorithms/text_classification_example.py new file mode 100644 index 00000000..25b7f6d0 --- /dev/null +++ b/project/algorithms/text_classification_example.py @@ -0,0 +1,131 @@ +from datetime import datetime +from typing import TypeVar + +import evaluate +import hydra_zen +import torch +from hydra_zen.typing import Builds +from lightning import LightningModule +from torch.optim.adamw import AdamW +from transformers import ( + PreTrainedModel, + get_linear_schedule_with_warmup, +) +from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput + +from project.datamodules.text.text_classification import TextClassificationDataModule + +T = TypeVar("T") +# Config that returns the object of type T when instantiated. +ConfigFor = Builds[type[T]] + + +class TextClassificationExample(LightningModule): + """Example of a lightning module used to train a huggingface model for text classification.""" + + def __init__( + self, + datamodule: TextClassificationDataModule, + network: ConfigFor[PreTrainedModel], + hf_metric_name: str, + learning_rate: float = 2e-5, + adam_epsilon: float = 1e-8, + warmup_steps: int = 0, + weight_decay: float = 0.0, + init_seed: int = 42, + ): + super().__init__() + self.network_config = network + self.num_labels = getattr(datamodule, "num_classes", None) + self.task_name = datamodule.task_name + self.init_seed = init_seed + self.hf_metric_name = hf_metric_name + self.learning_rate = learning_rate + self.adam_epsilon = adam_epsilon + self.warmup_steps = warmup_steps + self.weight_decay = weight_decay + + self.metric = evaluate.load( + self.hf_metric_name, + self.task_name, + # todo: replace with hydra job id perhaps? + experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S"), + ) + + self.save_hyperparameters(ignore=["network", "datamodule"]) + + def configure_model(self) -> None: + with torch.random.fork_rng(devices=[self.device]): + # deterministic weight initialization + torch.manual_seed(self.init_seed) + self.network = hydra_zen.instantiate(self.network_config) + + return super().configure_model() + + def forward(self, inputs: dict[str, torch.Tensor]) -> BaseModelOutput: + return self.network(**inputs) + + def shared_step(self, batch: dict[str, torch.Tensor], batch_idx: int, stage: str): + outputs: CausalLMOutput | SequenceClassifierOutput = self(batch) + loss = outputs.loss + assert isinstance(loss, torch.Tensor), loss + # todo: log the output of the metric. + self.log(f"{stage}/loss", loss, prog_bar=True) + if isinstance(outputs, SequenceClassifierOutput): + metric_value = self.metric.compute( + # logits=outputs.logits, + predictions=outputs.logits.argmax(-1), + references=batch["labels"], + ) + assert isinstance(metric_value, dict) + for k, v in metric_value.items(): + self.log( + f"{stage}/{k}", + v, + prog_bar=True, + ) + return loss + + def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int): + return self.shared_step(batch, batch_idx, "train") + + def validation_step( + self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx: int = 0 + ): + return self.shared_step(batch, batch_idx, "val") + + def configure_optimizers(self): + """Prepare optimizer and schedule (linear warmup and decay)""" + model = self.network + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if not any(nd_param in n for nd_param in no_decay) + ], + "weight_decay": self.weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if any(nd_param in n for nd_param in no_decay) + ], + "weight_decay": 0.0, + }, + ] + optimizer = AdamW( + optimizer_grouped_parameters, + lr=self.learning_rate, + eps=self.adam_epsilon, + ) + + scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=self.warmup_steps, + num_training_steps=self.trainer.estimated_stepping_batches, + ) + scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} + return [optimizer], [scheduler] diff --git a/project/algorithms/hf_example_test.py b/project/algorithms/text_classification_example_test.py similarity index 85% rename from project/algorithms/hf_example_test.py rename to project/algorithms/text_classification_example_test.py index 07fb2494..280c4763 100644 --- a/project/algorithms/hf_example_test.py +++ b/project/algorithms/text_classification_example_test.py @@ -11,8 +11,8 @@ from transformers import PreTrainedModel from typing_extensions import override -from project.algorithms.hf_example import HFExample -from project.datamodules.text.hf_text import HFDataModule +from project.algorithms.text_classification_example import TextClassificationExample +from project.datamodules.text.text_classification import TextClassificationDataModule from project.utils.env_vars import SLURM_JOB_ID from project.utils.testutils import run_for_all_configs_of_type @@ -47,10 +47,10 @@ def total_vram_gb() -> float: @pytest.mark.skipif(total_vram_gb() < 16, reason="Not enough VRAM to run this test.") -@run_for_all_configs_of_type("algorithm", HFExample) -@run_for_all_configs_of_type("datamodule", HFDataModule) +@run_for_all_configs_of_type("algorithm", TextClassificationExample) +@run_for_all_configs_of_type("datamodule", TextClassificationDataModule) @run_for_all_configs_of_type("algorithm/network", PreTrainedModel) -class TestHFExample(LearningAlgorithmTests[HFExample]): +class TestTextClassificationExample(LearningAlgorithmTests[TextClassificationExample]): """Tests for the HF example.""" @pytest.mark.xfail( @@ -60,8 +60,8 @@ class TestHFExample(LearningAlgorithmTests[HFExample]): ) def test_backward_pass_is_reproducible( # type: ignore self, - datamodule: HFDataModule, - algorithm: HFExample, + datamodule: TextClassificationDataModule, + algorithm: TextClassificationExample, seed: int, accelerator: str, devices: int | list[int], @@ -82,8 +82,8 @@ def test_backward_pass_is_reproducible( # type: ignore @pytest.mark.slow def test_overfit_batch( self, - algorithm: HFExample, - datamodule: HFDataModule, + algorithm: TextClassificationExample, + datamodule: TextClassificationDataModule, tmp_path: Path, num_steps: int = 3, ): diff --git a/project/configs/algorithm/hf_example.yaml b/project/configs/algorithm/hf_example.yaml deleted file mode 100644 index b8a2ed16..00000000 --- a/project/configs/algorithm/hf_example.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# Config for the JaxExample algorithm -defaults: - - network: albert-base-v2.yaml - # - /datamodule@_global_.datamodule: hf_text.yaml - -_target_: project.algorithms.hf_example.HFExample -# NOTE: Why _partial_? Because the config doesn't create the algo directly, it creates a function -# that will accept the datamodule and network and return the algo. -_partial_: true -hf_metric_name: glue diff --git a/project/configs/algorithm/llm_finetuning_example.yaml b/project/configs/algorithm/llm_finetuning_example.yaml new file mode 100644 index 00000000..d516b969 --- /dev/null +++ b/project/configs/algorithm/llm_finetuning_example.yaml @@ -0,0 +1,31 @@ +_target_: project.algorithms.llm_finetuning.LLMFinetuningExample +network_config: + _target_: project.algorithms.llm_finetuning.NetworkConfig + _recursive_: false + _convert_: object + pretrained_model_name_or_path: facebook/opt-350m + # Uncomment to use fp16 for training. Beware of nans! + # torch_dtype: + # _target_: hydra.utils.get_object + # path: torch.float16 + # attn_implementation: "flash_attention_2" +tokenizer_config: + _target_: project.algorithms.llm_finetuning.TokenizerConfig + _recursive_: false + _convert_: object + # Use the same key as in the network config. Avoids having to duplicate the value. + pretrained_model_name_or_path: ${..network_config.pretrained_model_name_or_path} + use_fast: true + trust_remote_code: true +dataset_config: + _target_: project.algorithms.llm_finetuning.DatasetConfig + dataset_path: wikitext + dataset_name: wikitext-2-v1 # Small dataset for this demo. `wikitext-103-v1` is larger. + per_device_train_batch_size: 8 + per_device_eval_batch_size: 8 + block_size: 256 +learning_rate: 2e-5 +adam_epsilon: 1e-8 +warmup_steps: 0 +weight_decay: 0 +init_seed: 42 diff --git a/project/configs/algorithm/network/albert-base-v2.yaml b/project/configs/algorithm/network/albert-base-v2.yaml deleted file mode 100644 index 52b0fd63..00000000 --- a/project/configs/algorithm/network/albert-base-v2.yaml +++ /dev/null @@ -1,2 +0,0 @@ -_target_: project.algorithms.hf_example.pretrained_network -model_name_or_path: albert-base-v2 diff --git a/project/configs/algorithm/text_classification_example.yaml b/project/configs/algorithm/text_classification_example.yaml new file mode 100644 index 00000000..2540a5fe --- /dev/null +++ b/project/configs/algorithm/text_classification_example.yaml @@ -0,0 +1,11 @@ +# Config for the Text classification example algorithm +_target_: project.algorithms.text_classification_example.TextClassificationExample +_recursive_: false +network: + _target_: transformers.models.auto.modeling_auto.AutoModelForSequenceClassification.from_pretrained + pretrained_model_name_or_path: albert-base-v2 + +# NOTE: Why _partial_? Because the config doesn't create the algo directly, it creates a function +# that will accept the datamodule and network and return the algo. +_partial_: true +hf_metric_name: glue diff --git a/project/configs/datamodule/glue_cola.yaml b/project/configs/datamodule/glue_cola.yaml new file mode 100644 index 00000000..078a153d --- /dev/null +++ b/project/configs/datamodule/glue_cola.yaml @@ -0,0 +1,19 @@ +_target_: project.datamodules.text.TextClassificationDataModule +data_dir: ${oc.env:SCRATCH,.}/data +hf_dataset_path: glue +task_name: cola +text_fields: + - "sentence" +tokenizer: + _target_: transformers.models.auto.tokenization_auto.AutoTokenizer.from_pretrained + use_fast: true + # Note: We could interpolate this value with `${/algorithm/network/pretrained_model_name_or_path}` + # to avoid duplicating a value, but this also makes it harder to use this by itself or with + # another algorithm. + pretrained_model_name_or_path: albert-base-v2 + cache_dir: ${..data_dir} + trust_remote_code: true +num_classes: 2 +max_seq_length: 128 +train_batch_size: 32 +eval_batch_size: 32 diff --git a/project/configs/datamodule/hf_text.yaml b/project/configs/datamodule/hf_text.yaml deleted file mode 100644 index ba0e6315..00000000 --- a/project/configs/datamodule/hf_text.yaml +++ /dev/null @@ -1,7 +0,0 @@ -_target_: project.datamodules.HFDataModule -tokenizer: albert-base-v2 -hf_dataset_path: glue -task_name: cola -max_seq_length: 128 -train_batch_size: 32 -eval_batch_size: 32 diff --git a/project/configs/experiment/llm_finetuning_example.yaml b/project/configs/experiment/llm_finetuning_example.yaml new file mode 100644 index 00000000..30ae5e6a --- /dev/null +++ b/project/configs/experiment/llm_finetuning_example.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +defaults: + - override /algorithm: llm_finetuning_example + - override /trainer/callbacks: default + +algorithm: + dataset_config: + per_device_eval_batch_size: 4 + per_device_train_batch_size: 4 + block_size: 256 + validation_split_percentage: 10 + overwrite_cache: false + +trainer: + max_epochs: 10 + devices: auto + strategy: + _target_: lightning.pytorch.strategies.FSDPStrategy + # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#optimize-the-sharding-strategy + sharding_strategy: "FULL_SHARD" + limit_val_batches: 1 + num_sanity_val_steps: 0 + val_check_interval: 50 + enable_checkpointing: true + detect_anomaly: false # recommended to turn this on when debugging nans with fp16 training. + callbacks: + model_checkpoint: + verbose: true + # every_n_train_steps: 1000 # todo: restarting from a within-epoch checkpoint doesn't seem to work! + +hydra: + run: + # output directory, generated dynamically on each run + dir: logs/${name} +name: llm_finetuning_example +ckpt_path: last diff --git a/project/configs/experiment/hf_example.yaml b/project/configs/experiment/text_classification_example.yaml similarity index 64% rename from project/configs/experiment/hf_example.yaml rename to project/configs/experiment/text_classification_example.yaml index 99037beb..5f81445f 100644 --- a/project/configs/experiment/hf_example.yaml +++ b/project/configs/experiment/text_classification_example.yaml @@ -1,9 +1,7 @@ # @package _global_ - defaults: - - override /datamodule: hf_text - - override /algorithm: hf_example - - override /algorithm/network: albert-base-v2 + - override /algorithm: text_classification_example + - override /datamodule: glue_cola - override /trainer/callbacks: none trainer: diff --git a/project/configs/trainer/callbacks/model_checkpoint.yaml b/project/configs/trainer/callbacks/model_checkpoint.yaml index f8ddf50c..65b594b4 100644 --- a/project/configs/trainer/callbacks/model_checkpoint.yaml +++ b/project/configs/trainer/callbacks/model_checkpoint.yaml @@ -8,7 +8,7 @@ model_checkpoint: filename: null # checkpoint filename monitor: null # name of the logged metric which determines when model is improving verbose: False # verbosity mode - save_last: null # additionally always save an exact copy of the last checkpoint to a file last.ckpt + save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt save_top_k: 1 # save k best models (determined by above metric) mode: "min" # "max" means higher metric value is better, can be also "min" auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name diff --git a/project/conftest.py b/project/conftest.py index b240ff92..bc7ddc2a 100644 --- a/project/conftest.py +++ b/project/conftest.py @@ -92,6 +92,7 @@ setup_logging, ) from project.main import PROJECT_NAME +from project.trainers.jax_trainer import JaxTrainer from project.utils.env_vars import REPO_ROOTDIR from project.utils.hydra_utils import resolve_dictconfig from project.utils.seeding import seeded_rng @@ -251,7 +252,7 @@ def experiment_dictconfig( return dict_config -@pytest.fixture() +@pytest.fixture(scope="function") def experiment_config( experiment_dictconfig: DictConfig, ) -> Config: @@ -281,23 +282,36 @@ def algorithm( ): """Fixture that creates the "algorithm" (a [LightningModule][lightning.pytorch.core.module.LightningModule]).""" + # todo: Use the `with device` block only for `configure_model` to replicate the same conditions + # as when we're using the PyTorch-Lightning Trainer. with device: - return instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule) + algorithm = instantiate_algorithm(experiment_config.algorithm, datamodule=datamodule) + if isinstance(algorithm, lightning.LightningModule): + algorithm.configure_model() + return algorithm @pytest.fixture(scope="function") def trainer( experiment_config: Config, -) -> pl.Trainer: +) -> pl.Trainer | JaxTrainer: setup_logging(log_level=experiment_config.log_level) lightning.seed_everything(experiment_config.seed, workers=True) return instantiate_trainer(experiment_config) @pytest.fixture(scope="session") -def train_dataloader(datamodule: DataModule) -> DataLoader: +def train_dataloader( + datamodule: lightning.LightningDataModule | None, request: pytest.FixtureRequest +) -> DataLoader: if isinstance(datamodule, VisionDataModule) or hasattr(datamodule, "num_workers"): datamodule.num_workers = 0 # type: ignore + if datamodule is None: + raise NotImplementedError( + "This test is trying to use `train_dataloader` directly or indirectly but the " + "algorithm that is being tested does not use a datamodule (or the test was not " + "configured properly)! Consider overwriting this fixture in your test class." + ) datamodule.prepare_data() datamodule.setup("fit") train_dataloader = datamodule.train_dataloader() @@ -324,7 +338,7 @@ def training_batch( return jax.tree.map(operator.methodcaller("to", device=device), batch) -@pytest.fixture(autouse=True) +@pytest.fixture(autouse=True, scope="function") def seed(request: pytest.FixtureRequest, make_torch_deterministic: None): """Fixture that seeds everything for reproducibility and yields the random seed used.""" random_seed = getattr(request, "param", DEFAULT_SEED) @@ -635,25 +649,9 @@ def pytest_configure(config: pytest.Config): ) -# import numpy as np -# def fixed_hash_fn(v: jax.Array | np.ndarray | torch.Tensor) -> int: -# if isinstance(v, torch.Tensor): -# return hash(tuple(v.detach().cpu().contiguous().numpy().flatten().tolist())) -# if isinstance(v, jax.Array | np.ndarray): -# return hash(tuple(v.flatten().tolist())) -# raise NotImplementedError(f"Don't know how to hash value {v} of type {type(v)}.") - -# tensor_regression.stats._hash = fixed_hash_fn - - -def _patched_simple_attributes(v, precision: int | None): - stats = tensor_regression.stats.get_simple_attributes(v, precision=precision) - stats.pop("hash", None) - return stats - - +# TODO: remove these, add this fix to the tensor_regression package instead. @pytest.fixture(autouse=True) -def dont_use_tensor_hashes_in_regression_files(monkeypatch: pytest.MonkeyPatch): +def _dont_use_tensor_hashes_in_regression_files(monkeypatch: pytest.MonkeyPatch): """Temporarily remove the hash of tensors from the regression files.""" monkeypatch.setattr( @@ -661,3 +659,9 @@ def dont_use_tensor_hashes_in_regression_files(monkeypatch: pytest.MonkeyPatch): tensor_regression.fixture.get_simple_attributes.__name__, # type: ignore _patched_simple_attributes, ) + + +def _patched_simple_attributes(v, precision: int | None): + stats = tensor_regression.stats.get_simple_attributes(v, precision=precision) + stats.pop("hash", None) + return stats diff --git a/project/datamodules/__init__.py b/project/datamodules/__init__.py index 8bc4c5e2..a9905dfc 100644 --- a/project/datamodules/__init__.py +++ b/project/datamodules/__init__.py @@ -10,7 +10,7 @@ from .image_classification.imagenet32 import ImageNet32DataModule, imagenet32_normalization from .image_classification.inaturalist import INaturalistDataModule from .image_classification.mnist import MNISTDataModule -from .text.hf_text import HFDataModule +from .text.text_classification import TextClassificationDataModule from .vision import VisionDataModule __all__ = [ @@ -24,5 +24,5 @@ "ImageNetDataModule", "MNISTDataModule", "VisionDataModule", - "HFDataModule", + "TextClassificationDataModule", ] diff --git a/project/datamodules/text/__init__.py b/project/datamodules/text/__init__.py index f949b53d..05c14809 100644 --- a/project/datamodules/text/__init__.py +++ b/project/datamodules/text/__init__.py @@ -1,3 +1,3 @@ -from .hf_text import HFDataModule +from .text_classification import TextClassificationDataModule -__all__ = ["HFDataModule"] +__all__ = ["TextClassificationDataModule"] diff --git a/project/datamodules/text/hf_text.py b/project/datamodules/text/text_classification.py similarity index 81% rename from project/datamodules/text/hf_text.py rename to project/datamodules/text/text_classification.py index f7a7edb7..ffbb9aa8 100644 --- a/project/datamodules/text/hf_text.py +++ b/project/datamodules/text/text_classification.py @@ -1,3 +1,10 @@ +"""Example algorithm that can train a huggingface model. + +Also check out this link for more detailed example script: + +https://github.com/lebrice/mila-docs/blob/llm_training/docs/examples/distributed/LLM_training/main.py +""" + from __future__ import annotations import shutil @@ -10,7 +17,7 @@ from datasets import DatasetDict, load_dataset from lightning import LightningDataModule from torch.utils.data import DataLoader -from transformers import AutoTokenizer +from transformers import PreTrainedTokenizerBase from project.utils.env_vars import REPO_ROOTDIR, SCRATCH, SLURM_TMPDIR @@ -23,53 +30,47 @@ SupportedTask = Literal["cola", "sst2", "mrpc", "qqp", "stsb", "mnli", "qnli", "rte", "wnli", "ax"] -def get_task_info(task_name: SupportedTask): - task_field_map = { - "cola": ["sentence"], - "sst2": ["sentence"], - "mrpc": ["sentence1", "sentence2"], - "qqp": ["question1", "question2"], - "stsb": ["sentence1", "sentence2"], - "mnli": ["premise", "hypothesis"], - "qnli": ["question", "sentence"], - "rte": ["sentence1", "sentence2"], - "wnli": ["sentence1", "sentence2"], - "ax": ["premise", "hypothesis"], - } - - num_labels = { - "cola": 2, - "sst2": 2, - "mrpc": 2, - "qqp": 2, - "stsb": 1, - "mnli": 3, - "qnli": 2, - "rte": 2, - "wnli": 2, - "ax": 3, - } - - task_map = task_field_map.get(task_name, None) - num_labels = num_labels.get(task_name, None) - - if task_map is None: - raise ValueError(f"Task {task_name} task fields currently not supported.") - - if num_labels is None: - raise ValueError(f"Task {task_name} labels currently not supported.") - - return task_map, num_labels - - -class HFDataModule(LightningDataModule): ## to be homogenized with the base text class - """Lightning data module for HF datasets.""" +task_field_map = { + "cola": ["sentence"], + "sst2": ["sentence"], + "mrpc": ["sentence1", "sentence2"], + "qqp": ["question1", "question2"], + "stsb": ["sentence1", "sentence2"], + "mnli": ["premise", "hypothesis"], + "qnli": ["question", "sentence"], + "rte": ["sentence1", "sentence2"], + "wnli": ["sentence1", "sentence2"], + "ax": ["premise", "hypothesis"], +} + +num_labels = { + "cola": 2, + "sst2": 2, + "mrpc": 2, + "qqp": 2, + "stsb": 1, + "mnli": 3, + "qnli": 2, + "rte": 2, + "wnli": 2, + "ax": 3, +} + + +class TextClassificationDataModule(LightningDataModule): + """Lightning data module for HF text classification datasets. + + This is based on this tutorial: + https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/text-transformers.html + """ def __init__( self, hf_dataset_path: str, - tokenizer: str, - task_name: SupportedTask, + tokenizer: PreTrainedTokenizerBase, + task_name: str, + text_fields: list[str] | None = None, + num_classes: int | None = None, data_dir: str | Path = SCRATCH or REPO_ROOTDIR / "data", loader_columns: list = [ "datasets_idx", @@ -108,21 +109,21 @@ def __init__( self.processed_dataset_path = ( self.data_dir / f"{self.hf_dataset_path}_{self.task_name}_dataset" ) + + if text_fields is None: + text_fields = task_field_map.get(task_name) + self.text_fields = text_fields or ["text"] + + if num_classes is None: + num_classes = num_labels.get(task_name) + self.num_classes = num_classes + if SLURM_TMPDIR: self.working_path = SLURM_TMPDIR / self.processed_dataset_path.name else: self.working_path = self.processed_dataset_path - # self.dataset_path = self.working_path = self.data_dir / f"{self.task_name}_dataset" - - self.text_fields, self.num_labels = get_task_info(task_name) - - ## potential inconsistency in text_fields and task_map ## todo: verify authentication method setup. Is trust_remote_code the right play here? - self.tokenizer = AutoTokenizer.from_pretrained( - self.tokenizer, use_fast=True, cache_dir=self.data_dir, trust_remote_code=True - ) - _rng = torch.Generator(device="cpu").manual_seed(self.seed) self.train_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) self.val_dl_rng_seed = int(torch.randint(0, int(1e6), (1,), generator=_rng).item()) @@ -135,12 +136,13 @@ def prepare_data(self): self.hf_dataset_path, self.task_name, cache_dir=str(self.data_dir / ".cache/huggingface/datasets"), + save_infos=True, ) # Tokenize and save to $SCRATCH tokenized_dataset = dataset.map( self.convert_to_features, batched=True, - remove_columns=["label"], + remove_columns=(["label"] if "label" in dataset.column_names else []), load_from_cache_file=True, ) logger.debug(f"Saving (overwriting) tokenized dataset at {self.processed_dataset_path}") @@ -242,9 +244,9 @@ def convert_to_features(self, example_batch, indices=None): pad_to_max_length=True, truncation=True, ) - - # Rename label to labels to make it easier to pass to model forward - features["labels"] = example_batch["label"] + if "label" in example_batch and "labels" not in example_batch: + # Rename label to labels to make it easier to pass to model forward + features["labels"] = example_batch["label"] return features diff --git a/project/datamodules/text/hf_text_test.py b/project/datamodules/text/text_classification_test.py similarity index 88% rename from project/datamodules/text/hf_text_test.py rename to project/datamodules/text/text_classification_test.py index 191958a1..c4878fc3 100644 --- a/project/datamodules/text/hf_text_test.py +++ b/project/datamodules/text/text_classification_test.py @@ -3,7 +3,7 @@ import pytest from omegaconf import DictConfig -from project.datamodules.text.hf_text import HFDataModule +from project.datamodules.text.text_classification import TextClassificationDataModule from project.experiment import ( instantiate_datamodule, ) @@ -38,7 +38,7 @@ def datamodule( @pytest.fixture() def prepared_datamodule( - datamodule: HFDataModule, + datamodule: TextClassificationDataModule, tmp_path_factory: pytest.TempPathFactory, ): tmp_path = tmp_path_factory.mktemp("data") @@ -64,9 +64,9 @@ def prepared_datamodule( datamodule.working_path = _slurm_tmpdir_before -@run_for_all_configs_of_type("datamodule", HFDataModule) +@run_for_all_configs_of_type("datamodule", TextClassificationDataModule) def test_dataset_location( - prepared_datamodule: HFDataModule, + prepared_datamodule: TextClassificationDataModule, ): """Test that the dataset is downloaded to the correct location.""" datamodule = prepared_datamodule @@ -81,10 +81,10 @@ def test_dataset_location( assert file_path.exists(), f"Expected file: {file_name} not found at {file_path}." -@run_for_all_configs_of_type("datamodule", HFDataModule) +@run_for_all_configs_of_type("datamodule", TextClassificationDataModule) @pytest.mark.skip(reason="Not implemented") def test_pretrained_weight_location( - prepared_datamodule: HFDataModule, + prepared_datamodule: TextClassificationDataModule, ): """Test that the pretrained weights are downloaded to the correct location.""" # datamodule = prepared_datamodule diff --git a/project/experiment.py b/project/experiment.py index b6998926..940537f6 100644 --- a/project/experiment.py +++ b/project/experiment.py @@ -27,7 +27,7 @@ from lightning import Callback, LightningDataModule, LightningModule, Trainer from project.configs.config import Config -from project.trainers.jax_trainer import JaxModule +from project.trainers.jax_trainer import JaxModule, JaxTrainer from project.utils.typing_utils.protocols import DataModule from project.utils.utils import validate_datamodule @@ -66,7 +66,7 @@ def setup_logging(log_level: str, global_log_level: str = "WARNING") -> None: project_logger.setLevel(log_level.upper()) -def instantiate_trainer(experiment_config: Config) -> Trainer: +def instantiate_trainer(experiment_config: Config) -> Trainer | JaxTrainer: # NOTE: Need to do a bit of sneaky type tricks to convince the outside world that these # fields have the right type. diff --git a/project/main.py b/project/main.py index da6b7dd1..3d970cf2 100644 --- a/project/main.py +++ b/project/main.py @@ -53,7 +53,7 @@ schemas_dir=REPO_ROOTDIR / ".schemas", regen_schemas=False, stop_on_error=False, - quiet=True, + quiet=False, verbose=False, add_headers=False, # don't fallback to adding headers if we can't use vscode settings file. ) diff --git a/project/main_test.py b/project/main_test.py index a8ae8d15..08e0fba2 100644 --- a/project/main_test.py +++ b/project/main_test.py @@ -6,7 +6,6 @@ import uuid from unittest.mock import Mock -import hydra_zen import omegaconf.errors import pytest import torch @@ -15,10 +14,7 @@ from omegaconf import DictConfig import project.main -from project.algorithms.example import ExampleAlgorithm -from project.configs.config import Config from project.conftest import command_line_overrides -from project.datamodules.image_classification.cifar10 import CIFAR10DataModule from project.utils.env_vars import REPO_ROOTDIR, SLURM_JOB_ID from project.utils.hydra_utils import resolve_dictconfig from project.utils.testutils import IN_GITHUB_CI @@ -81,7 +77,7 @@ def mock_evaluate_jax_module(monkeypatch: pytest.MonkeyPatch): experiment_commands_to_test = [ "experiment=example trainer.fast_dev_run=True", - "experiment=hf_example trainer.fast_dev_run=True", + "experiment=text_classification_example trainer.fast_dev_run=True", # "experiment=jax_example trainer.fast_dev_run=True", "experiment=jax_rl_example trainer.max_epochs=1", pytest.param( @@ -118,13 +114,18 @@ def mock_evaluate_jax_module(monkeypatch: pytest.MonkeyPatch): "trainer.fast_dev_run=True ", # make each job quicker to run marks=pytest.mark.slow, ), - pytest.param( - "experiment=profiling " - "algorithm=no_op " + ( + "experiment=profiling algorithm=no_op " "datamodule=cifar10 " # Run a small dataset instead of ImageNet (would take ~6min to process on a compute node..) "trainer/logger=tensorboard " # Use Tensorboard logger because DeviceStatsMonitor requires a logger being used. "trainer.fast_dev_run=True " # make each job quicker to run ), + pytest.param( + "experiment=llm_finetuning_example trainer.fast_dev_run=True trainer/logger=[]", + marks=pytest.mark.skipif( + SLURM_JOB_ID is None, reason="Can only be run on a slurm cluster." + ), + ), ] @@ -208,21 +209,6 @@ def test_setting_just_algorithm_isnt_enough(experiment_dictconfig: DictConfig) - _ = resolve_dictconfig(experiment_dictconfig) -@pytest.mark.parametrize( - command_line_overrides.__name__, ["algorithm=example datamodule=cifar10"], indirect=True -) -def test_example_experiment_defaults(experiment_config: Config) -> None: - """Test to check that the datamodule is required (even when just an algorithm is set?!).""" - - assert experiment_config.algorithm["_target_"] == ( - ExampleAlgorithm.__module__ + "." + ExampleAlgorithm.__qualname__ - ) - assert ( - isinstance(experiment_config.datamodule, CIFAR10DataModule) - or hydra_zen.get_target(experiment_config.datamodule) is CIFAR10DataModule - ) - - @pytest.mark.skipif( IN_GITHUB_CI and sys.platform == "darwin", reason="TODO: Getting a 'MPS backend out of memory' error on the Github CI. ", diff --git a/project/utils/autoref_plugin.py b/project/utils/autoref_plugin.py index d0dd108a..ded6f6ac 100644 --- a/project/utils/autoref_plugin.py +++ b/project/utils/autoref_plugin.py @@ -155,7 +155,10 @@ def _expand(obj: types.ModuleType | object) -> list[object]: return [ v for v in objects_in_global_scope - if not (inspect.ismodule(v) and inspect.getsourcefile(v) != source_file) + if not ( + (inspect.ismodule(v) and getattr(v, "__file__", None) is None) # built-in module. + or (inspect.ismodule(v) and inspect.getsourcefile(v) != source_file) + ) ] diff --git a/project/utils/testutils.py b/project/utils/testutils.py index 3d31606b..d4e9b546 100644 --- a/project/utils/testutils.py +++ b/project/utils/testutils.py @@ -23,6 +23,7 @@ IN_GITHUB_CI = "GITHUB_ACTIONS" in os.environ IN_SELF_HOSTED_GITHUB_CI = IN_GITHUB_CI and "self-hosted" in os.environ.get("RUNNER_LABELS", "") +IN_GITHUB_COULD_CI = IN_GITHUB_CI and not IN_SELF_HOSTED_GITHUB_CI PARAM_WHEN_USED_MARK_NAME = "parametrize_when_used" diff --git a/pyproject.toml b/pyproject.toml index 61d2c97d..5927c145 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,9 +60,7 @@ docs = [ "mkdocs-section-index>=0.3.9", "mkdocs-macros-plugin>=1.0.5", ] -gpu = [ - "jax[cuda12]>=0.4.31" -] +gpu = ["jax[cuda12]>=0.4.31"] [build-system] requires = ["hatchling"]