diff --git a/src/anemoi/training/commands/train.py b/src/anemoi/training/commands/train.py index 88cee39a..46052b90 100644 --- a/src/anemoi/training/commands/train.py +++ b/src/anemoi/training/commands/train.py @@ -29,19 +29,40 @@ class Train(Command): def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: return parser - @staticmethod - def run(args: list[str], unknown_args: list[str] | None = None) -> None: - del args + def run(self, args: argparse.Namespace, unknown_args: list[str] | None = None) -> None: + + # Merge the known subcommands with a non-whitespace character for hydra + new_sysargv = self._merge_sysargv(args) + # Add the unknown arguments (belonging to hydra) to sys.argv if unknown_args is not None: - sys.argv = [sys.argv[0], *unknown_args] + sys.argv = [new_sysargv, *unknown_args] else: - sys.argv = [sys.argv[0]] + sys.argv = [new_sysargv] + # Import and run the training command LOGGER.info("Running anemoi training command with overrides: %s", sys.argv[1:]) from anemoi.training.train.train import main as anemoi_train anemoi_train() + def _merge_sysargv(self, args: argparse.Namespace) -> str: + """Merge the sys.argv with the known subcommands to pass to hydra. + + Parameters + ---------- + args : argparse.Namespace + args from the command line + + Returns + ------- + str + Modified sys.argv as string + """ + modified_sysargv = f"{sys.argv[0]} {args.command}" + if hasattr(args, "subcommand"): + modified_sysargv += f" {args.subcommand}" + return modified_sysargv + command = Train