Skip to content

Commit

Permalink
[fix] Capture Anemoi Training subcommands in MLFlow (#61)
Browse files Browse the repository at this point in the history
* fix: modify sysargv with subcommands

* docs: add changelog
  • Loading branch information
JesperDramsch authored Sep 26, 2024
1 parent e808f64 commit 98b506d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Keep it human-readable, your future self will thank you!

- Fix `TypeError` raised when trying to JSON serialise `datetime.timedelta` object - [#43](https://github.com/ecmwf/anemoi-training/pull/43)
- Bugfixes for CI (#56)
- Show correct subcommand in MLFlow - Addresses [#39](https://github.com/ecmwf/anemoi-training/issues/39) in [#61](https://github.com/ecmwf/anemoi-training/pull/61)

### Changed

Expand Down
31 changes: 26 additions & 5 deletions src/anemoi/training/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,40 @@ class Train(Command):
def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
return parser

@staticmethod
def run(args: list[str], unknown_args: list[str] | None = None) -> None:
del args
def run(self, args: argparse.Namespace, unknown_args: list[str] | None = None) -> None:

# Merge the known subcommands with a non-whitespace character for hydra
new_sysargv = self._merge_sysargv(args)

# Add the unknown arguments (belonging to hydra) to sys.argv
if unknown_args is not None:
sys.argv = [sys.argv[0], *unknown_args]
sys.argv = [new_sysargv, *unknown_args]
else:
sys.argv = [sys.argv[0]]
sys.argv = [new_sysargv]

# Import and run the training command
LOGGER.info("Running anemoi training command with overrides: %s", sys.argv[1:])
from anemoi.training.train.train import main as anemoi_train

anemoi_train()

def _merge_sysargv(self, args: argparse.Namespace) -> str:
"""Merge the sys.argv with the known subcommands to pass to hydra.
Parameters
----------
args : argparse.Namespace
args from the command line
Returns
-------
str
Modified sys.argv as string
"""
modified_sysargv = f"{sys.argv[0]} {args.command}"
if hasattr(args, "subcommand"):
modified_sysargv += f" {args.subcommand}"
return modified_sysargv


command = Train

0 comments on commit 98b506d

Please sign in to comment.