Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add structured configs to hydra cli, pass cfg to runners #976

Merged
merged 15 commits into from
Jan 23, 2025
1 change: 1 addition & 0 deletions packages/fairchem-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ adsorbml = ["dscribe","x3dase","scikit-image"]

[project.scripts]
fairchem = "fairchem.core._cli:main"
fairchemv2 = "fairchem.core._cli_hydra:main"

[project.urls]
repository = "https://github.com/FAIR-Chem/fairchem"
Expand Down
6 changes: 0 additions & 6 deletions src/fairchem/core/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ def main(
parser: argparse.ArgumentParser = flags.get_parser()
args, override_args = parser.parse_known_args()

if args.hydra:
from fairchem.core._cli_hydra import main

main(args, override_args)
return

# TODO: rename num_gpus -> num_ranks everywhere
assert (
args.num_gpus > 0
Expand Down
146 changes: 97 additions & 49 deletions src/fairchem/core/_cli_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,92 @@

from __future__ import annotations

import argparse
import logging
import os
import tempfile
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING

import hydra
from omegaconf import OmegaConf

if TYPE_CHECKING:
import argparse

from omegaconf import DictConfig

from fairchem.core.components.runner import Runner


from submitit import AutoExecutor
from submitit.helpers import Checkpointable, DelayedSubmission
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

from fairchem.core.common import distutils
from fairchem.core.common.flags import flags
from fairchem.core.common.utils import get_timestamp_uid, setup_env_vars, setup_imports

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# this effects the cli only since the actual job will be run in subprocesses or remoe
logging.basicConfig(level=logging.INFO)


class SchedulerType(str, Enum):
LOCAL = "local"
SLURM = "slurm"


class DeviceType(str, Enum):
CPU = "cpu"
CUDA = "cuda"


@dataclass
class SchedulerConfig:
mode: SchedulerType = SchedulerType.LOCAL
ranks_per_node: int = 1
num_nodes: int = 1
slurm: dict = field(
default_factory=lambda: {
"mem_gb": 80, # slurm mem in GB
"timeout_hr": 168, # slurm timeout in hours, default to 7 days
"partition": None,
"cpus_per_task": 8,
"qos": None,
"account": None,
}
)


@dataclass
class JobConfig:
run_name: str = field(default_factory=lambda: uuid.uuid4().hex.upper()[0:8])
timestamp_id: str = field(default_factory=lambda: get_timestamp_uid())
run_dir: str = field(default_factory=lambda: tempfile.TemporaryDirectory().name)
device_type: DeviceType = DeviceType.CUDA
debug: bool = False
scheduler: SchedulerConfig = field(default_factory=lambda: SchedulerConfig)
log_dir_name: str = "logs"
checkpoint_dir_name: str = "checkpoint"

@property
def log_dir(self) -> str:
return os.path.join(self.run_dir, self.timestamp_id, self.log_dir_name)

@property
def checkpoint_dir(self) -> str:
return os.path.join(self.run_dir, self.timestamp_id, self.checkpoint_dir_name)


class Submitit(Checkpointable):
def __call__(self, dict_config: DictConfig) -> None:
self.config = dict_config
job_config: JobConfig = OmegaConf.to_object(dict_config.job)
# TODO: setup_imports is not needed if we stop instantiating models with Registry.
setup_imports()
setup_env_vars()
distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args))
distutils.setup(map_job_config_to_dist_config(job_config))
self._init_logger()
runner: Runner = hydra.utils.instantiate(dict_config.runner)
runner.job_config = job_config
runner.load_state()
runner.run()
distutils.cleanup()
Expand All @@ -53,7 +103,7 @@ def _init_logger(self) -> None:
if (
"logger" in self.config
and distutils.is_master()
and not self.config.cli_args.debug
and not self.config.job.debug
):
# get a partial function from the config and instantiate wandb with it
logger_initializer = hydra.utils.instantiate(self.config.logger)
Expand All @@ -76,13 +126,16 @@ def checkpoint(self, *args, **kwargs) -> DelayedSubmission:
return DelayedSubmission(new_runner, self.config, self.cli_args)


def map_cli_args_to_dist_config(cli_args: DictConfig) -> dict:
def map_job_config_to_dist_config(job_cfg: JobConfig) -> dict:
scheduler_config = job_cfg.scheduler
return {
"world_size": cli_args.num_nodes * cli_args.num_gpus,
"distributed_backend": "gloo" if cli_args.cpu else "nccl",
"submit": cli_args.submit,
"world_size": scheduler_config.num_nodes * scheduler_config.ranks_per_node,
"distributed_backend": "gloo"
if job_cfg.device_type == DeviceType.CPU
else "nccl",
"submit": scheduler_config.mode == SchedulerType.SLURM,
"summit": None,
"cpu": cli_args.cpu,
"cpu": job_cfg.device_type == DeviceType.CPU,
"use_cuda_visibile_devices": True,
}

Expand All @@ -94,70 +147,65 @@ def get_hydra_config_from_yaml(
os.environ["HYDRA_FULL_ERROR"] = "1"
config_directory = os.path.dirname(os.path.abspath(config_yml))
config_name = os.path.basename(config_yml)
hydra.initialize_config_dir(config_directory)
hydra.initialize_config_dir(config_directory, version_base="1.1")
return hydra.compose(config_name=config_name, overrides=overrides_args)


def runner_wrapper(config: DictConfig):
Submitit()(config)


# this is meant as a future replacement for the main entrypoint
def main(
args: argparse.Namespace | None = None, override_args: list[str] | None = None
):
if args is None:
parser: argparse.ArgumentParser = flags.get_parser()
parser = argparse.ArgumentParser()
parser.add_argument("--config-yml", type=str, required=True)
args, override_args = parser.parse_known_args()

cfg = get_hydra_config_from_yaml(args.config_yml, override_args)
timestamp_id = get_timestamp_uid()
log_dir = os.path.join(args.run_dir, timestamp_id, "logs")
# override timestamp id and logdir
args.timestamp_id = timestamp_id
args.logdir = log_dir
os.makedirs(log_dir)
OmegaConf.update(cfg, "cli_args", vars(args), force_add=True)
if args.submit: # Run on cluster
# merge default structured config with job
cfg = OmegaConf.merge({"job": OmegaConf.structured(JobConfig)}, cfg)
log_dir = OmegaConf.to_object(cfg.job).log_dir
os.makedirs(cfg.job.run_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)

job_cfg = cfg.job
scheduler_cfg = cfg.job.scheduler

logging.info(f"Running fairchemv2 cli with {cfg}")
if scheduler_cfg.mode == SchedulerType.SLURM: # Run on cluster
executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3)
executor.update_parameters(
name=args.identifier,
mem_gb=args.slurm_mem,
timeout_min=args.slurm_timeout * 60,
slurm_partition=args.slurm_partition,
gpus_per_node=args.num_gpus,
cpus_per_task=8,
tasks_per_node=args.num_gpus,
nodes=args.num_nodes,
slurm_qos=args.slurm_qos,
slurm_account=args.slurm_account,
name=job_cfg.run_name,
mem_gb=scheduler_cfg.slurm.mem_gb,
timeout_min=scheduler_cfg.slurm.timeout_hr * 60,
slurm_partition=scheduler_cfg.slurm.partition,
gpus_per_node=scheduler_cfg.ranks_per_node,
cpus_per_task=scheduler_cfg.slurm.cpus_per_task,
tasks_per_node=scheduler_cfg.ranks_per_node,
nodes=scheduler_cfg.num_nodes,
slurm_qos=scheduler_cfg.slurm.qos,
slurm_account=scheduler_cfg.slurm.account,
)
job = executor.submit(runner_wrapper, cfg)
logger.info(
f"Submitted job id: {timestamp_id}, slurm id: {job.job_id}, logs: {log_dir}"
logging.info(
f"Submitted job id: {job_cfg.timestamp_id}, slurm id: {job.job_id}, logs: {job_cfg.log_dir}"
)
else:
if args.num_gpus > 1:
logging.info(f"Running in local mode with {args.num_gpus} ranks")
# HACK to disable multiprocess dataloading in local mode
# there is an open issue where LMDB's environment cannot be pickled and used
# during torch multiprocessing https://github.com/pytorch/examples/issues/526
# this HACK only works for a training submission where the config is passed in here
if "optim" in cfg and "num_workers" in cfg["optim"]:
cfg["optim"]["num_workers"] = 0
logging.info(
"WARNING: running in local mode, setting dataloading num_workers to 0, see https://github.com/pytorch/examples/issues/526"
)
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

if scheduler_cfg.ranks_per_node > 1:
logging.info(f"Running in local mode with {job_cfg.ranks_per_node} ranks")
launch_config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=args.num_gpus,
nproc_per_node=scheduler_cfg.ranks_per_node,
rdzv_backend="c10d",
max_restarts=0,
)
elastic_launch(launch_config, runner_wrapper)(cfg)
else:
logger.info("Running in local mode without elastic launch")
logging.info("Running in local mode without elastic launch")
distutils.setup_env_local()
runner_wrapper(cfg)
5 changes: 0 additions & 5 deletions src/fairchem/core/common/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,6 @@ def add_core_args(self) -> None:
self.parser.add_argument(
"--cpu", action="store_true", help="Run CPU only training"
)
self.parser.add_argument(
"--hydra",
action="store_true",
help="Use hydra configs instead (in development)",
)
self.parser.add_argument(
"--num-nodes",
default=1,
Expand Down
15 changes: 14 additions & 1 deletion src/fairchem/core/components/runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import Any
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from omegaconf import DictConfig

from fairchem.core._cli_hydra import JobConfig


class Runner(metaclass=ABCMeta):
Expand All @@ -11,6 +16,14 @@ class Runner(metaclass=ABCMeta):
This allows us to decouple away from a monolithic trainer class
"""

@property
def job_config(self) -> JobConfig:
return self._job_config

@job_config.setter
def job_config(self, cfg: DictConfig):
self._job_config = cfg

@abstractmethod
def run(self) -> Any:
raise NotImplementedError
Expand Down
8 changes: 2 additions & 6 deletions tests/core/test_hydra_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import hydra
import pytest

from fairchem.core._cli import main
from fairchem.core._cli_hydra import main
from fairchem.core.common import distutils


def test_hydra_cli():
distutils.cleanup()
hydra.core.global_hydra.GlobalHydra.instance().clear()
sys_args = ["--hydra", "--config-yml", "tests/core/test_hydra_cli.yml", "--cpu"]
sys_args = ["--config-yml", "tests/core/test_hydra_cli.yml"]
sys.argv[1:] = sys_args
main()

Expand All @@ -21,8 +21,6 @@ def test_hydra_cli_throws_error():
distutils.cleanup()
hydra.core.global_hydra.GlobalHydra.instance().clear()
sys_args = [
"--hydra",
"--cpu",
"--config-yml",
"tests/core/test_hydra_cli.yml",
"runner.x=1000",
Expand All @@ -38,8 +36,6 @@ def test_hydra_cli_throws_error_on_invalid_inputs():
distutils.cleanup()
hydra.core.global_hydra.GlobalHydra.instance().clear()
sys_args = [
"--hydra",
"--cpu",
"--config-yml",
"tests/core/test_hydra_cli.yml",
"runner.x=1000",
Expand Down
7 changes: 6 additions & 1 deletion tests/core/test_hydra_cli.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
job:
device_type: CPU
scheduler:
mode: LOCAL

runner:
_target_: fairchem.core.components.runner.MockRunner
x: 10
y: 23
y: 23
Loading