Skip to content

Commit

Permalink
fix: modify sysargv with subcommands
Browse files Browse the repository at this point in the history
  • Loading branch information
JesperDramsch committed Sep 25, 2024
1 parent e808f64 commit 8ee9a28
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions src/anemoi/training/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,40 @@ class Train(Command):
def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
return parser

@staticmethod
def run(args: list[str], unknown_args: list[str] | None = None) -> None:
del args
def run(self, args: argparse.Namespace, unknown_args: list[str] | None = None) -> None:

# Merge the known subcommands with a non-whitespace character for hydra
new_sysargv = self._merge_sysargv(args)

# Add the unknown arguments (belonging to hydra) to sys.argv
if unknown_args is not None:
sys.argv = [sys.argv[0], *unknown_args]
sys.argv = [new_sysargv, *unknown_args]
else:
sys.argv = [sys.argv[0]]
sys.argv = [new_sysargv]

# Import and run the training command
LOGGER.info("Running anemoi training command with overrides: %s", sys.argv[1:])
from anemoi.training.train.train import main as anemoi_train

anemoi_train()

def _merge_sysargv(self, args: argparse.Namespace) -> str:
"""Merge the sys.argv with the known subcommands to pass to hydra.
Parameters
----------
args : argparse.Namespace
args from the command line
Returns
-------
str
Modified sys.argv as string
"""
modified_sysargv = f"{sys.argv[0]} {args.command}"
if hasattr(args, "subcommand"):
modified_sysargv += f" {args.subcommand}"
return modified_sysargv


command = Train

0 comments on commit 8ee9a28

Please sign in to comment.