Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of loggers and dist strategy #180

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 159 additions & 30 deletions src/itwinai/torch/inference.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file has changed on main and there are some conflicts to solve here

Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,28 @@
#
# Credit:
# - Matteo Bunino <[email protected]> - CERN
# - Rakesh Sarma <[email protected]> - FZJ
# --------------------------------------------------------------------------------------

import abc
import os
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

from ..components import Predictor, monitor_exec
from ..loggers import Logger
from ..serialization import ModelLoader
from ..utils import clear_key, dynamically_import_class
from .config import TrainingConfiguration
from .distributed import (
DeepSpeedStrategy,
HorovodStrategy,
NonDistributedStrategy,
TorchDDPStrategy,
TorchDistributedStrategy,
distributed_resources_available,
)
from .type import Batch


Expand All @@ -42,10 +51,11 @@ def __call__(self) -> nn.Module:
"""
if os.path.exists(self.model_uri):
# Model is on local filesystem.
model = torch.load(self.model_uri)
return model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is to be done distributed, would you not need a more sophisticated way of getting the specific device? I might be wrong here

model = torch.load(self.model_uri, map_location=device)
return model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return the model in inference mode as it was before as model.eval()


if self.model_uri.startswith("mlflow+"):
if self.model_uri.startswith('mlflow+'):
# Model is on an MLFLow server
# Form is 'mlflow+MLFLOW_TRACKING_URI+RUN_ID+ARTIFACT_PATH'
import mlflow
Expand All @@ -68,7 +78,7 @@ def __call__(self) -> nn.Module:
tracking_uri=mlflow.get_tracking_uri(),
)
model = torch.load(ckpt_path)
return model.eval()
return model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that the model.eval() was removed here as well. Is there any reason?


raise ValueError(
"Unrecognized model URI: model may not be there! "
Expand All @@ -79,43 +89,115 @@ def __call__(self) -> nn.Module:
class TorchPredictor(Predictor):
"""Applies a pre-trained torch model to unseen data."""

_strategy: TorchDistributedStrategy = None
#: PyTorch ``DataLoader`` for inference dataset.
inference_dataloader: DataLoader = None
#: Pre-trained PyTorch model used to make predictions.
model: nn.Module = None
#: ``Dataset`` on which to make predictions (ML inference).
test_dataset: Dataset
#: ``DataLoader`` for test dataset.
test_dataloader: DataLoader = None
#: PyTorch random number generator (PRNG).
torch_rng: torch.Generator = None
#: itwinai ``itwinai.Logger``
logger: Logger = None

def __init__(
self,
config: Union[Dict, TrainingConfiguration],
model: Union[nn.Module, ModelLoader],
test_dataloader_class: str = "torch.utils.data.DataLoader",
test_dataloader_kwargs: Optional[Dict] = None,
name: str = None,
strategy: Literal["ddp", "deepspeed", "horovod"] = 'ddp',
logger: Optional[Logger] = None,
checkpoints_location: str = "checkpoints",
name: str = None
) -> None:
super().__init__(model=model, name=name)
self.save_parameters(**self.locals2params(locals()))
self.model = self.model.eval()

# Train and validation dataloaders
self.test_dataloader_class = dynamically_import_class(test_dataloader_class)
test_dataloader_kwargs = (
test_dataloader_kwargs if test_dataloader_kwargs is not None else {}
if isinstance(config, dict):
self.config = TrainingConfiguration(**config)
else:
self.config = config
self.strategy = strategy
self.logger = logger
self.checkpoints_location = checkpoints_location

@property
def strategy(self) -> TorchDistributedStrategy:
"""Strategy currently in use."""
return self._strategy

@strategy.setter
def strategy(self, strategy: Union[str, TorchDistributedStrategy]) -> None:
if isinstance(strategy, TorchDistributedStrategy):
self._strategy = strategy
else:
self._strategy = self._detect_strategy(strategy)

@property
def device(self) -> str:
"""Current device from distributed strategy."""
return self.strategy.device()

def _detect_strategy(self, strategy: str) -> TorchDistributedStrategy:
if not distributed_resources_available():
print("WARNING: falling back to non-distributed strategy.")
dist_str = NonDistributedStrategy()
elif strategy == 'ddp':
dist_str = TorchDDPStrategy(backend='nccl')
elif strategy == 'horovod':
dist_str = HorovodStrategy()
elif strategy == 'deepspeed':
dist_str = DeepSpeedStrategy(backend='nccl')
else:
raise NotImplementedError(
f"Strategy '{strategy}' is not recognized/implemented.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to add a newline before the closing parenthesis here

return dist_str

def _init_distributed_strategy(self) -> None:
if not self.strategy.is_initialized:
self.strategy.init()
Comment on lines +121 to +155
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There methods could be inherited from the TorchTrainer to avoid code duplication. Sometime ago @annaelisalappe, @jarlsondre, and I were brainstorming ways to move all these methods in a common parent class, from which both trainer(s) and predictor can inherit. It is an option, but not critical. The benefit is that we would reduce code duplication, resulting in easier code maintenance. On the other hand, we add an element in the hierarchy

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option could be adding a predict method to the TorchTrainer and avoid all the above. This is how lightning does it.


def distribute_model(self) -> None:
"""
Distribute the torch model with the chosen strategy.
Comment on lines +158 to +159
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid new line after """

"""
if self.model is None:
raise ValueError(
"self.model is None! Mandatory constructor argument "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either move this check in the constructor, or remove "constructor" from the error message. It may be a bit misleading othwerwise

)
distribute_kwargs = {}
# Distributed model
self.model, _, _ = self.strategy.distributed(
self.model, None, None, **distribute_kwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case you would be better of using keyword arguments. Just reading this, I don't know which parameters are being set to "None", etc.

)
self.test_dataloader_kwargs = clear_key(
test_dataloader_kwargs, "train_dataloader_kwargs", "dataset"

def create_dataloaders(
self,
inference_dataset: Dataset
) -> None:
"""
Create inference dataloader.
Comment on lines +175 to +176
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid new line after """


Args:
inference_dataset (Dataset): inference dataset object.
"""

self.inference_dataloader = self.strategy.create_dataloader(
dataset=inference_dataset,
batch_size=self.config.batch_size,
num_workers=self.config.num_workers_dataloader,
pin_memory=self.config.pin_gpu_memory,
generator=self.torch_rng,
shuffle=self.config.shuffle_test
)

@monitor_exec
def execute(
self,
test_dataset: Dataset,
inference_dataset: Dataset,
model: nn.Module = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is defaulted to None, you have to include that in the type annotation using "nn.Module | None"

) -> Dict[str, Any]:
"""Applies a torch model to a dataset for inference.

Args:
test_dataset (Dataset[str, Any]): each item in this dataset is a
inference_dataset (Dataset[str, Any]): each item in this dataset is a
couple (item_unique_id, item)
model (nn.Module, optional): torch model. Overrides the existing
model, if given. Defaults to None.
Expand All @@ -124,18 +206,27 @@ def execute(
Dict[str, Any]: maps each item ID to the corresponding predicted
value(s).
"""
self._init_distributed_strategy()
if model is not None:
# Overrides existing "internal" model
self.model = model

test_dataloader = self.test_dataloader_class(
test_dataset, **self.test_dataloader_kwargs
self.create_dataloaders(
inference_dataset=inference_dataset
)
Comment on lines +214 to 216
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the newline here? Surely this fits on a single line?


self.distribute_model()

if self.logger:
self.logger.create_logger_context(rank=self.strategy.global_rank())
hparams = self.config.model_dump()
hparams['distributed_strategy'] = self.strategy.__class__.__name__
self.logger.save_hyperparameters(hparams)

all_predictions = dict()
for samples_ids, samples in test_dataloader:
for ids, (samples_ids, samples) in enumerate(self.inference_dataloader):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is ids and enumerate used for?

with torch.no_grad():
pred = self.model(samples)
pred = self.model(samples.to(self.device))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can only leave a comment on this line as GH is not letting me to review lines that were not changed.
I would suggest to move the block from line 215 to line 226 to a separate method called predict, which is the equivalent of the train method in the trainer

pred = self.transform_predictions(pred)
for idx, pre in zip(samples_ids, pred):
# For each item in the batch
Expand All @@ -144,11 +235,49 @@ def execute(
else:
pre = pre.to_dense().tolist()
all_predictions[idx] = pre

if self.logger:
self.logger.destroy_logger_context()

self.strategy.clean_up()

return all_predictions

@abc.abstractmethod
def log(
self,
item: Union[Any, List[Any]],
identifier: Union[str, List[str]],
kind: str = 'metric',
step: Optional[int] = None,
batch_idx: Optional[int] = None,
Comment on lines +248 to +252
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When doing type hints, we have decided to use "|" instead of "Union" and "| None" instead of "Optional" in codebase.

**kwargs
) -> None:
"""Log ``item`` with ``identifier`` name of ``kind`` type at ``step``
time step.

Args:
item (Union[Any, List[Any]]): element to be logged (e.g., metric).
identifier (Union[str, List[str]]): unique identifier for the
element to log(e.g., name of a metric).
kind (str, optional): type of the item to be logged. Must be one
among the list of self.supported_types. Defaults to 'metric'.
step (Optional[int], optional): logging step. Defaults to None.
batch_idx (Optional[int], optional): DataLoader batch counter
(i.e., batch idx), if available. Defaults to None.
"""
if self.logger:
self.logger.log(
item=item,
identifier=identifier,
kind=kind,
step=step,
batch_idx=batch_idx,
**kwargs
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I cannot comment on unchanged lines, but it would be nice to remove the newline after """ in the docstring below. It is something that we are trying to gradually cleanup

def transform_predictions(self, batch: Batch) -> Batch:
"""Post-process the predictions of the torch model (e.g., apply
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the newline here

Post-process the predictions of the torch model (e.g., apply
threshold in case of multi-label classifier).
"""

Expand Down
3 changes: 2 additions & 1 deletion tests/use-cases/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def test_mnist_inference_torch(torch_env, install_requirements):
)
with tempfile.TemporaryDirectory() as temp_dir:
# Create fake inference dataset and checkpoint
generate_model_cmd = f"{torch_env}/bin/python {exec} " f"--root {temp_dir}"
generate_model_cmd = (f"{torch_env}/bin/python {exec} "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, when we have to use multiple lines for the contents of some parenthesis (for "(", "[" and "{"), we use newlines after the opening parenthesis and before the closing parenthesis. In this case, the code would thus be

generate_model_cmd = (
 ...
)

f"--root {temp_dir}")
subprocess.run(generate_model_cmd.split(), check=True, cwd=temp_dir)

# Running inference
Expand Down
10 changes: 7 additions & 3 deletions use-cases/mnist/torch/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,15 @@ inference_pipeline:
class_path: itwinai.torch.inference.TorchModelLoader
init_args:
model_uri: ${inference_model_mlflow_uri}
test_dataloader_kwargs:
batch_size: ${batch_size}
config:
batch_size: 32
num_workers_dataloader: 4
pin_gpu_memory: true
shuffle_test: false
strategy: ${strategy}

- class_path: saver.TorchMNISTLabelSaver
init_args:
save_dir: ${predictions_dir}
predictions_file: ${predictions_file}
class_labels: ${class_labels}
class_labels: ${class_labels}
59 changes: 54 additions & 5 deletions use-cases/xtclim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,63 @@ The file `train.py` trains the network. Caution: It will overwrite the weights o

The `anomaly.py` file evaluates the network on the available datasets - train, test, and projection.

## How to launch pipeline
## Installation

The config file `pipeline.yaml` contains all the steps to execute the workflow. You can launch it from the root of the repository with:
Please follow the documentation to install the itwinai environment.
After that, install the required libraries within the itwinai environment with:

```bash
python train.py -p pipeline.yaml
pip install -r Requirements.txt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be consistent with the other use cases (and more in general with pip standards), it would be good to rename Requirements.txt to requirements.txt

```

## How to launch pipeline locally

The config file `pipeline.yaml` contains all the steps to execute the workflow.
This file also contains all the seasons, and a separate run is launched for each season.
You can launch the pipeline through `train.py` from the root of the repository with:

```bash
python train.py

```

## TODOs
Integration of post-processing step + distributed strategies
## How to launch pipeline on an HPC system

The `startscript` job script can be used to launch a pipeline with SLURM on an HPC system.
These steps should be followed to export the environment variables required by the script.

```bash
# Distributed training with torch DistributedDataParallel
PYTHON_VENV="../../envAI_hdfml"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep in mind that we have changed the default name of the venv to .venv, as this generalizes better across HPC systems. This happened when we migrated to the uv package manager. You are free to use different names locally, but then we suggest that you symlink it to .venv. Check out our document on uv. I believe it is called uv-tutorial.md.

DIST_MODE="ddp"
RUN_NAME="ddp-cerfacs"
sbatch --export=ALL,DIST_MODE="$DIST_MODE",RUN_NAME="$RUN_NAME",PYTHON_VENV="$PYTHON_VENV" \
startscript
```

The results and/or errors are available in `job.out` and `job.err` log files.
Training and inference steps are defined in the pipeline, where distributed resources
are exploited in both the steps.

With MLFLow logger, the logs can be visualized in the MLFlow UI:

```bash
mlflow ui --backend-store-uri mllogs/mlflow

# In background
mlflow ui --backend-store-uri mllogs/mlflow > /dev/null 2>&1 &
```

### Hyperparameter Optimization (HPO)

The repository also provides functionality to perform HPO with Ray. With HPO,
multiple trials with different hyperparameter configurations are run in a distributed
infrastructure, typically in an HPC environment. This allows finding the optimal
configurations which provides the minimal/maximal loss for the investigated network.
The `hpo.py` file contains the implementation, which launches the `pipeline.yaml` pipeline.
To launch an HPO experiment, simply run:
```bash
sbatch slurm_ray.sh
```
The parsing arguments to the `hpo.py` file can be changed to customize the required parameters
that need to be considered in the HPO process.
Loading
Loading