Skip to content

Commit

Permalink
Merge pull request #117 from mila-iqia/add_comet
Browse files Browse the repository at this point in the history
Add comet, updates pytorch lightning
  • Loading branch information
mirkobronzi authored Apr 8, 2024
2 parents 6327a28 + 8dba668 commit 6e2d81f
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 39 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,21 @@ For both these cases, there is the possibility to run with or without Orion.
(Orion is a hyper-parameter search tool - see https://github.com/Epistimio/orion -
that is already configured in this project)

In any case, the run script will take multiple config files as arguments (`--configs`).
This is because the config files will be merged together thanks to OmegaConf (the latter
takes precedence).
Note the param `--cli-config-params` can also be used, at CLI time, to modify/add more parameters.

### Loggers
Currently, Tensorboard, Comet and Aims are supported.
For Comet, you will have to specify the key and the project.
This can be done in several way (see the Comet-ML docs); a quick way is to set the env variables:
```
COMET_WORKSPACE=...
COMET_PROJECT_NAME=...
COMET_API_KEY=...
```

#### Run locally

For example, to run on your local machine without Orion:
Expand Down
1 change: 1 addition & 0 deletions amlrt_project/data/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
TENSORBOARD = 'tensorboard'
AIM = 'aim'
COMET = 'comet'
LOG_FOLDER = 'log_folder'
EXP_LOGGERS = 'experiment_loggers'
18 changes: 7 additions & 11 deletions amlrt_project/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import sys

import comet_ml # noqa
import pytorch_lightning as pl

from amlrt_project.data.data_loader import FashionMnistDM
Expand All @@ -25,9 +26,8 @@ def main():
parser.add_argument('--log', help='log to this file (in addition to stdout/err)')
parser.add_argument('--ckpt-path', help='Path to best model')
parser.add_argument('--data', help='path to data', required=True)
parser.add_argument('--gpus', default=None,
help='list of GPUs to use. If not specified, runs on CPU.'
'Example of GPU usage: 1 means run on GPU 1, 0 on GPU 0.')
parser.add_argument('--accelerator', default='auto',
help='The accelerator to use - default is "auto".')
add_config_file_params_to_argparser(parser)
args = parser.parse_args()

Expand All @@ -44,7 +44,7 @@ def main():
root.setLevel(logging.INFO)
root.addHandler(handler)

hyper_params = load_configs(args.configs, args.cli_config_params)
hyper_params = load_configs(args.config, args.cli_config_params)

evaluate(args, data_dir, hyper_params)

Expand All @@ -58,27 +58,23 @@ def evaluate(args, data_dir, hyper_params):
output_dir (str): path to output folder
hyper_params (dict): hyper parameters from the config file
"""
# __TODO__ change the hparam that are used from the training algorithm
# (and NOT the model - these will be specified in the model itself)
logger.info('List of hyper-parameters:')
check_and_log_hp(
['architecture', 'batch_size', 'exp_name', 'early_stopping'],
hyper_params)

trainer = pl.Trainer(
gpus=args.gpus,
accelerator=args.accelerator,
)

datamodule = FashionMnistDM(data_dir, hyper_params)
datamodule.setup()

model = load_model(hyper_params)
model = model.load_from_checkpoint(args.ckpt_path)

val_metrics = trainer.validate(model, datamodule=datamodule)
test_metrics = trainer.test(model, datamodule=datamodule)
val_metrics = trainer.validate(model, datamodule=datamodule, ckpt_path=args.ckpt_path)
test_metrics = trainer.test(model, datamodule=datamodule, ckpt_path=args.ckpt_path)

# We can have many val/test sets, so iterate throught their results.
logger.info(f"Validation Metrics: {val_metrics}")
logger.info(f"Test Metrics: {test_metrics}")

Expand Down
23 changes: 10 additions & 13 deletions amlrt_project/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
import sys

import comet_ml # noqa
import orion
import pytorch_lightning as pl
from orion.client import report_results
Expand Down Expand Up @@ -48,9 +49,8 @@ def main():
help='will disable the progressbar while going over the mini-batch')
parser.add_argument('--start-from-scratch', action='store_true',
help='will delete the output folder before starting the experiment')
parser.add_argument('--gpus', default=None,
help='list of GPUs to use. If not specified, runs on CPU.'
'Example of GPU usage: 1 means run on GPU 1, 0 on GPU 0.')
parser.add_argument('--accelerator', default='auto',
help='The accelerator to use - default is "auto".')
parser.add_argument('--debug', action='store_true')
add_config_file_params_to_argparser(parser)
args = parser.parse_args()
Expand Down Expand Up @@ -86,7 +86,7 @@ def main():
root.setLevel(logging.INFO)
root.addHandler(handler)

hyper_params = load_configs(args.configs, args.cli_config_params)
hyper_params = load_configs(args.config, args.cli_config_params)
save_hparams(hyper_params, os.path.join(args.output, CONFIG_FILE_NAME))

run(args, data_dir, output_dir, hyper_params)
Expand Down Expand Up @@ -121,7 +121,7 @@ def run(args, data_dir, output_dir, hyper_params):
model = load_model(hyper_params)

train(model=model, datamodule=datamodule, output=output_dir, hyper_params=hyper_params,
use_progress_bar=not args.disable_progressbar, gpus=args.gpus)
use_progress_bar=not args.disable_progressbar, accelerator=args.accelerator)


def train(**kwargs): # pragma: no cover
Expand All @@ -144,7 +144,8 @@ def train(**kwargs): # pragma: no cover
value=-float(best_dev_metric))])


def train_impl(model, datamodule, output, hyper_params, use_progress_bar, gpus): # pragma: no cover
def train_impl(model, datamodule, output, hyper_params, use_progress_bar,
accelerator): # pragma: no cover
"""Main training loop implementation.
Args:
Expand All @@ -153,7 +154,7 @@ def train_impl(model, datamodule, output, hyper_params, use_progress_bar, gpus):
output (str): Output directory.
hyper_params (dict): Dict containing hyper-parameters.
use_progress_bar (bool): Use tqdm progress bar (can be disabled when logging).
gpus: number of GPUs to use.
accelerator: the device where to run.
"""
check_and_log_hp(['max_epoch'], hyper_params)

Expand Down Expand Up @@ -192,18 +193,14 @@ def train_impl(model, datamodule, output, hyper_params, use_progress_bar, gpus):
trainer = pl.Trainer(
callbacks=[early_stopping, best_checkpoint_callback, last_checkpoint_callback],
max_epochs=hyper_params['max_epoch'],
resume_from_checkpoint=resume_from_checkpoint,
gpus=gpus,
accelerator=accelerator,
logger=name2loggers.values()
)

trainer.fit(model, datamodule=datamodule)
trainer.fit(model, datamodule=datamodule, ckpt_path=resume_from_checkpoint)

# Log the best result and associated hyper parameters
best_dev_result = float(early_stopping.best_score.cpu().numpy())
# logging hyper-parameters again - this time also passing the final result
log_hyper_parameters(name2loggers, hyper_params, best_dev_result)
# logging to file
with open(os.path.join(output, 'results.txt'), 'w') as stream_out:
stream_out.write(f'final best_dev_metric: {best_dev_result}\n')

Expand Down
4 changes: 2 additions & 2 deletions amlrt_project/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def save_hparams(hparams: dict, output_file: str):

def add_config_file_params_to_argparser(parser):
"""Add the parser options to deal with multiple config files and CLI config."""
parser.add_argument('--configs', nargs='*', default=[],
parser.add_argument('--config', nargs='*', default=[],
help='config files with generic hyper-parameters, such as optimizer, '
'batch_size, ... - in yaml format. Can be zero, one or more than '
'one file. If multiple configs are passed, the latter files will '
Expand All @@ -49,7 +49,7 @@ def main():
'config file here', required=True)
add_config_file_params_to_argparser(parser)
args = parser.parse_args()
hyper_params = load_configs(args.configs, args.cli_config_params)
hyper_params = load_configs(args.config, args.cli_config_params)
save_hparams(hyper_params, args.merged_config_file)


Expand Down
14 changes: 7 additions & 7 deletions amlrt_project/utils/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import pytorch_lightning as pl
from git import InvalidGitRepositoryError, Repo
from pip._internal.operations import freeze
from pytorch_lightning.loggers import CometLogger

from amlrt_project.data.constants import AIM, EXP_LOGGERS, TENSORBOARD
from amlrt_project.data.constants import AIM, COMET, EXP_LOGGERS, TENSORBOARD
from amlrt_project.utils.aim_logger_utils import prepare_aim_logger

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,16 +100,15 @@ def load_experiment_loggers(
continue
aim_logger = prepare_aim_logger(hyper_params, options, output)
name2loggers[AIM] = aim_logger
elif logger_name == COMET:
comet_logger = CometLogger()
name2loggers[COMET] = comet_logger
else:
raise NotImplementedError(f"logger {logger_name} is not supported")
return name2loggers


def log_hyper_parameters(name2loggers, hyper_params, best_dev_result=None):
def log_hyper_parameters(name2loggers, hyper_params):
"""Log the experiment hyper-parameters to all the loggers."""
for name, logger in name2loggers.items():
if name == AIM:
logger.log_hyperparams(hyper_params)
elif name == TENSORBOARD:
if best_dev_result is not None:
logger.log_hyperparams(hyper_params, metrics={'best_dev_metric': best_dev_result})
logger.log_hyperparams(hyper_params)
3 changes: 2 additions & 1 deletion examples/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ experiment_loggers:
aim:
# change this to an absolute path to always use the same aim db file
log_folder: ./
# comet: null << uncomment to use comet - see README for more info on setup

# architecture
hidden_dim: 256
Expand All @@ -25,4 +26,4 @@ architecture: simple_mlp
early_stopping:
metric: val_loss
mode: min
patience: 3
patience: 3
4 changes: 2 additions & 2 deletions examples/local_orion/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ set -e
export ORION_DB_ADDRESS='orion_db.pkl'
export ORION_DB_TYPE='pickleddb'

merge-configs --configs ../config.yaml config.yaml --merged-config-file merged_config.yaml
orion -v hunt --config orion_config.yaml amlrt-train --data ../data \
merge-configs --config ../config.yaml config.yaml --merged-config-file merged_config.yaml
orion -vvv -v hunt --config orion_config.yaml amlrt-train --data ../data \
--config merged_config.yaml --disable-progressbar \
--output '{exp.working_dir}/{trial.id}/' \
--log '{exp.working_dir}/{trial.id}/exp.log'
2 changes: 1 addition & 1 deletion examples/slurm/to_submit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@

export MLFLOW_TRACKING_URI='mlruns'

amlrt-train --data ../data --output output --configs ../config.yaml config.yaml --tmp-folder ${SLURM_TMPDIR} --disable-progressbar
amlrt-train --data ../data --output output --config ../config.yaml config.yaml --tmp-folder ${SLURM_TMPDIR} --disable-progressbar
2 changes: 1 addition & 1 deletion examples/slurm_orion/to_submit.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
export ORION_DB_ADDRESS='orion_db.pkl'
export ORION_DB_TYPE='pickleddb'

merge-configs --configs ../config.yaml config.yaml --merged-config-file merged_config.yaml
merge-configs --config ../config.yaml config.yaml --merged-config-file merged_config.yaml
orion -v hunt --config orion_config.yaml \
amlrt-train --data ../data --config merged_config.yaml --disable-progressbar \
--output '{exp.working_dir}/{trial.id}/' \
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
python_requires='>=3.9',
install_requires=[
'aim==3.18.1; os_name!="nt"',
'comet-ml==3.39.3',
'flake8==4.0.1',
'flake8-docstrings==1.6.0',
'gitpython==3.1.27',
Expand All @@ -18,7 +19,7 @@
'pyyaml==6.0',
'pytest==7.1.2',
'pytest-cov==3.0.0',
'pytorch_lightning==1.8.3',
'pytorch_lightning==2.2.1',
'pytype==2024.2.27',
'sphinx==7.2.6',
'sphinx-autoapi==3.0.0',
Expand Down

0 comments on commit 6e2d81f

Please sign in to comment.