From c2e4cc89a1958c63695d2bf53a4c3ce09796badb Mon Sep 17 00:00:00 2001 From: Gopalji Gaur Date: Sun, 8 Dec 2024 18:43:26 +0100 Subject: [PATCH 1/4] Added support for PyTorch Lightning in the DDP backend. --- neps/runtime.py | 50 +++++++++++++++++++++++++++++++++++++++- neps/state/filebased.py | 9 ++++++++ neps/state/neps_state.py | 6 +++++ neps/state/protocols.py | 4 ++++ 4 files changed, 68 insertions(+), 1 deletion(-) diff --git a/neps/runtime.py b/neps/runtime.py index 71b72f878..81fc70a5c 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -26,7 +26,10 @@ WorkerRaiseError, ) from neps.state._eval import evaluate_trial -from neps.state.filebased import create_or_load_filebased_neps_state +from neps.state.filebased import ( + create_or_load_filebased_neps_state, + load_filebased_neps_state, +) from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings from neps.state.trial import Trial @@ -43,6 +46,24 @@ def _default_worker_name() -> str: return f"{os.getpid()}-{isoformat}" +def _is_ddp_and_not_rank_zero() -> bool: + import torch.distributed as dist + + # Check for environment variables typically set by DDP + ddp_env_vars = ["WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"] + rank_env_vars = ["RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK"] + + # Check if PyTorch distributed is initialized + if (dist.is_available() and dist.is_initialized()) or all( + var in os.environ for var in ddp_env_vars + ): + for var in rank_env_vars: + rank = os.environ.get(var) + if rank is not None: + return int(rank) != 0 + return False + + N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR = 0 N_FAILED_TO_SET_TRIAL_STATE = 10 @@ -488,6 +509,26 @@ def run(self) -> None: # noqa: C901, PLR0915 ) +def _launch_ddp_runtime( + *, + evaluation_fn: Callable[..., float | Mapping[str, Any]], + optimization_dir: Path, +) -> None: + neps_state = load_filebased_neps_state(directory=optimization_dir) + + # TODO: This is a bit of a hack to get the current trial to evaluate. Sometimes + # the previous trial gets sampled when we don't want it to. This is a bit of a + # hack to get around that. + prev_trial = None + while True: + current_trial = neps_state.get_current_evaluating_trial() + if current_trial is not None and ( + prev_trial is None or current_trial.id != prev_trial.id # type: ignore[unreachable] + ): + evaluation_fn(**current_trial.config) + prev_trial = current_trial + + # TODO: This should be done directly in `api.run` at some point to make it clearer at an # entryy point how the woerer is set up to run if someone reads the entry point code. def _launch_runtime( # noqa: PLR0913 @@ -506,6 +547,13 @@ def _launch_runtime( # noqa: PLR0913 max_evaluations_for_worker: int | None, pre_load_hooks: Iterable[Callable[[BaseOptimizer], BaseOptimizer]] | None, ) -> None: + if _is_ddp_and_not_rank_zero(): + # Do not launch a new worker if we are in a DDP setup and not rank 0 + _launch_ddp_runtime( + evaluation_fn=evaluation_fn, optimization_dir=optimization_dir + ) + return + if overwrite_optimization_dir and optimization_dir.exists(): logger.info( f"Overwriting optimization directory '{optimization_dir}' as" diff --git a/neps/state/filebased.py b/neps/state/filebased.py index cf53c6225..ed21e81fe 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -209,6 +209,15 @@ def pending(self) -> Iterable[tuple[str, Synced[Trial, Path]]]: ] return iter((_id, t) for _id, t, _ in sorted(pending, key=lambda x: x[2])) + @override + def evaluating(self) -> Iterable[tuple[str, Synced[Trial, Path]]]: + evaluating = [ + (_id, t, trial.metadata.time_sampled) + for (_id, t) in self.all().items() + if (trial := t.synced()).state == Trial.State.EVALUATING + ] + return iter((_id, t) for _id, t, _ in sorted(evaluating, key=lambda x: x[2])) + @dataclass class ReaderWriterTrial(ReaderWriter[Trial, Path]): diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 1ed3f67b0..22fda4c6d 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -214,6 +214,12 @@ def get_next_pending_trial(self, n: int | None = None) -> Trial | list[Trial] | return take(n, _pending_itr) return next(_pending_itr, None) + def get_current_evaluating_trial(self) -> Trial | None: + """Get the current evaluating trial.""" + for _, shared_trial in self._trials.evaluating(): + return shared_trial.synced() + return None + def all_trial_ids(self) -> set[str]: """Get all the trial ids that are known about.""" return self._trials.all_trial_ids() diff --git a/neps/state/protocols.py b/neps/state/protocols.py index 51bff7d36..e7b31302e 100644 --- a/neps/state/protocols.py +++ b/neps/state/protocols.py @@ -138,6 +138,10 @@ def pending(self) -> Iterable[tuple[str, Synced[Trial, K]]]: """ ... + def evaluating(self) -> Iterable[tuple[str, Synced[Trial, K]]]: + """Get all evaluating trials in the repo.""" + ... + @dataclass class VersionedResource(Generic[T, K]): From 3af8969d9eb6a5155fb0b30c7f415d42ebf658c3 Mon Sep 17 00:00:00 2001 From: Gopalji Gaur Date: Wed, 18 Dec 2024 01:52:01 +0100 Subject: [PATCH 2/4] hacky fix for multiple DDP launches --- neps/runtime.py | 30 ++++++++++++++++++++++++------ neps/state/neps_state.py | 11 ++++++----- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 81fc70a5c..10823e011 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -521,12 +521,30 @@ def _launch_ddp_runtime( # hack to get around that. prev_trial = None while True: - current_trial = neps_state.get_current_evaluating_trial() - if current_trial is not None and ( - prev_trial is None or current_trial.id != prev_trial.id # type: ignore[unreachable] - ): - evaluation_fn(**current_trial.config) - prev_trial = current_trial + current_eval_trials = neps_state.get_current_evaluating_trials() + # If the worker id on previous trial is the same as the current one, only then + # evaluate it. + + if len(current_eval_trials) > 0: + current_trial = None + if prev_trial is None: + # TODO: This is wrong. we evaluate the first trial in the list + # Instead, we need to check and evaluate the trial that is being + # evaluated by the parent process. + # Currently this only works if the DDP trainings are launched after some + # trials evaluation has begun. + current_trial = current_eval_trials[0] + else: + for trial in current_eval_trials: # type: ignore[unreachable] + if ( + trial.metadata.evaluating_worker_id + == prev_trial.metadata.evaluating_worker_id + ) and (trial.id != prev_trial.id): + current_trial = trial + break + if current_trial: + evaluation_fn(**current_trial.config) + prev_trial = current_trial # TODO: This should be done directly in `api.run` at some point to make it clearer at an diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 22fda4c6d..4bb9b5107 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -214,11 +214,12 @@ def get_next_pending_trial(self, n: int | None = None) -> Trial | list[Trial] | return take(n, _pending_itr) return next(_pending_itr, None) - def get_current_evaluating_trial(self) -> Trial | None: - """Get the current evaluating trial.""" - for _, shared_trial in self._trials.evaluating(): - return shared_trial.synced() - return None + def get_current_evaluating_trials(self) -> list[Trial]: + """Get all the current evaluating trials.""" + _eval_itr = ( + shared_trial.synced() for _, shared_trial in self._trials.evaluating() + ) + return list(_eval_itr) def all_trial_ids(self) -> set[str]: """Get all the trial ids that are known about.""" From 7e7810cbf0575f136f8258c2f8d07ced81cb2f58 Mon Sep 17 00:00:00 2001 From: Gopalji Gaur Date: Sun, 22 Dec 2024 23:27:23 +0100 Subject: [PATCH 3/4] use env var to share config with higher rank workers --- neps/runtime.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 10823e011..48738e5f8 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -46,6 +46,9 @@ def _default_worker_name() -> str: return f"{os.getpid()}-{isoformat}" +_DDP_ENV_VAR_NAME = "NEPS_DDP_TRIAL_ID" + + def _is_ddp_and_not_rank_zero() -> bool: import torch.distributed as dist @@ -64,6 +67,11 @@ def _is_ddp_and_not_rank_zero() -> bool: return False +def _set_ddp_env_var(trial_id: str) -> None: + """Sets an environment variable with current trial_id in a DDP setup.""" + os.environ[_DDP_ENV_VAR_NAME] = trial_id + + N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR = 0 N_FAILED_TO_SET_TRIAL_STATE = 10 @@ -131,6 +139,7 @@ def _set_global_trial(trial: Trial) -> Iterator[None]: "\n\nThis is most likely a bug and should be reported to NePS!" ) _CURRENTLY_RUNNING_TRIAL_IN_PROCESS = trial + _set_ddp_env_var(trial.id) yield for _key, callback in _TRIAL_END_CALLBACKS.items(): callback(trial) @@ -515,25 +524,25 @@ def _launch_ddp_runtime( optimization_dir: Path, ) -> None: neps_state = load_filebased_neps_state(directory=optimization_dir) - - # TODO: This is a bit of a hack to get the current trial to evaluate. Sometimes - # the previous trial gets sampled when we don't want it to. This is a bit of a - # hack to get around that. prev_trial = None while True: current_eval_trials = neps_state.get_current_evaluating_trials() # If the worker id on previous trial is the same as the current one, only then # evaluate it. - if len(current_eval_trials) > 0: current_trial = None if prev_trial is None: - # TODO: This is wrong. we evaluate the first trial in the list - # Instead, we need to check and evaluate the trial that is being - # evaluated by the parent process. - # Currently this only works if the DDP trainings are launched after some - # trials evaluation has begun. - current_trial = current_eval_trials[0] + # In the beginning, we simply read the current trial from the + # environment variable + if _DDP_ENV_VAR_NAME in os.environ: + current_id = os.getenv(_DDP_ENV_VAR_NAME) + if current_id is None: + raise RuntimeError( + "In a pytorch-lightning DDP setup, the environment variable" + f" '{_DDP_ENV_VAR_NAME}' was not set. This is probably a bug in" + " NePS and should be reported." + ) + current_trial = neps_state.get_trial_by_id(current_id) else: for trial in current_eval_trials: # type: ignore[unreachable] if ( From f9b74259b503b15f2fb76f5e9c3cc79c23cd90e6 Mon Sep 17 00:00:00 2001 From: Gopalji Gaur Date: Sat, 18 Jan 2025 00:00:34 +0100 Subject: [PATCH 4/4] docs: added example for DDP with PyTorch Lightning --- neps_examples/__init__.py | 1 + .../efficiency/pytorch_lightning_ddp.py | 101 ++++++++++++++++++ 2 files changed, 102 insertions(+) create mode 100644 neps_examples/efficiency/pytorch_lightning_ddp.py diff --git a/neps_examples/__init__.py b/neps_examples/__init__.py index dc7468928..be18683ec 100644 --- a/neps_examples/__init__.py +++ b/neps_examples/__init__.py @@ -17,6 +17,7 @@ "expert_priors_for_hyperparameters", "multi_fidelity", "multi_fidelity_and_expert_priors", + "pytorch_lightning_ddp", ], } diff --git a/neps_examples/efficiency/pytorch_lightning_ddp.py b/neps_examples/efficiency/pytorch_lightning_ddp.py new file mode 100644 index 000000000..4b387ed43 --- /dev/null +++ b/neps_examples/efficiency/pytorch_lightning_ddp.py @@ -0,0 +1,101 @@ +import lightning as L +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, random_split +import neps +import logging + +NUM_GPU = 8 # Number of GPUs to use for DDP + +class ToyModel(nn.Module): + """ Taken from https://pytorch.org/tutorials/intermediate/ddp_tutorial.html """ + def __init__(self): + super(ToyModel, self).__init__() + self.net1 = nn.Linear(10, 10) + self.relu = nn.ReLU() + self.net2 = nn.Linear(10, 5) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + +class LightningModel(L.LightningModule): + def __init__(self, lr): + super().__init__() + self.lr = lr + self.model = ToyModel() + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.mse_loss(y_hat, y) + self.log("train_loss", loss, prog_bar=True, sync_dist=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.mse_loss(y_hat, y) + self.log("val_loss", loss, prog_bar=True, sync_dist=True) + return loss + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self.model(x) + loss = F.mse_loss(y_hat, y) + self.log("test_loss", loss, prog_bar=True, sync_dist=True) + return loss + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=self.lr) + +def evaluate_pipeline(lr=0.1, epoch=20): + L.seed_everything(42) + # Model + model = LightningModel(lr=lr) + + # Generate random tensors for data and labels + data = torch.rand((1000, 10)) + labels = torch.rand((1000, 5)) + + dataset = list(zip(data, labels)) + + train_dataset, val_dataset, test_dataset = random_split(dataset, [600, 200, 200]) + + # Define simple data loaders using tensors and slicing + train_dataloader = DataLoader(train_dataset, batch_size=20, shuffle=True) + val_dataloader = DataLoader(val_dataset, batch_size=20, shuffle=False) + test_dataloader = DataLoader(test_dataset, batch_size=20, shuffle=False) + + # Trainer with DDP Strategy + trainer = L.Trainer(gradient_clip_val=0.25, + max_epochs=epoch, + fast_dev_run=False, + strategy='ddp', + devices=NUM_GPU + ) + trainer.fit(model, train_dataloader, val_dataloader) + trainer.validate(model, test_dataloader) + return trainer.logged_metrics["val_loss"] + +pipeline_space = dict( + lr=neps.Float( + lower=0.001, + upper=0.1, + log=True, + prior=0.01 + ), + epoch=neps.Integer( + lower=1, + upper=3, + is_fidelity=True + ) + ) + +logging.basicConfig(level=logging.INFO) +neps.run( + evaluate_pipeline=evaluate_pipeline, + pipeline_space=pipeline_space, + root_directory="results/pytorch_lightning_ddp", + max_evaluations_total=5 + ) \ No newline at end of file