diff --git a/shpc/client/__init__.py b/shpc/client/__init__.py index 4309ae47e..6844ea3dc 100644 --- a/shpc/client/__init__.py +++ b/shpc/client/__init__.py @@ -130,6 +130,41 @@ def get_parser(): action="store_true", ) + # Reinstall already installed recipes + reinstall = subparsers.add_parser( + "reinstall", + description="reinstall a recipe.", + formatter_class=argparse.RawTextHelpFormatter, + ) + reinstall.add_argument( + "reinstall_recipe", + nargs="?", + help="recipe to reinstall\nshpc reinstall python\nshpc reinstall python:3.9.5-alpine", + default=None, + ) + reinstall.add_argument( + "--all", + dest="all", + help="reinstall all currently installed modules.", + action="store_true", + ) + reinstall.add_argument( + "--ignore-missing", + "-i", + dest="ignore_missing", + help="Ignore and leave intact the versions that don't exist in the registry anymore.", + default=False, + action="store_true", + ) + reinstall.add_argument( + "--uninstall-missing", + "-u", + dest="uninstall_missing", + help="Uninstall the versions that don't exist in the registry anymore.", + default=False, + action="store_true", + ) + # List installed modules listing = subparsers.add_parser("list", description="list installed modules.") listing.add_argument("pattern", help="filter to a pattern", nargs="?") @@ -372,6 +407,7 @@ def get_parser(): inspect, install, listing, + reinstall, shell, test, uninstall, @@ -491,6 +527,8 @@ def help(return_code=0): from .get import main elif args.command == "install": from .install import main + elif args.command == "reinstall": + from .reinstall import main elif args.command == "inspect": from .inspect import main elif args.command == "list": diff --git a/shpc/client/install.py b/shpc/client/install.py index 57aca791f..fd0e65d6d 100644 --- a/shpc/client/install.py +++ b/shpc/client/install.py @@ -27,7 +27,6 @@ def main(args, parser, extra, subparser): # And do the install cli.install( args.install_recipe, - force=args.force, container_image=args.container_image, keep_path=args.keep_path, ) @@ -36,5 +35,4 @@ def main(args, parser, extra, subparser): cli.settings.default_view, args.install_recipe, force=args.force, - container_image=args.container_image, ) diff --git a/shpc/client/reinstall.py b/shpc/client/reinstall.py new file mode 100644 index 000000000..d0187ecfa --- /dev/null +++ b/shpc/client/reinstall.py @@ -0,0 +1,46 @@ +__author__ = "Vanessa Sochat" +__copyright__ = "Copyright 2021-2022, Vanessa Sochat" +__license__ = "MPL 2.0" + +import shpc.utils +from shpc.logger import logger + + +def main(args, parser, extra, subparser): + + from shpc.main import get_client + + shpc.utils.ensure_no_extra(extra) + + cli = get_client( + quiet=args.quiet, + settings_file=args.settings_file, + module_sys=args.module_sys, + container_tech=args.container_tech, + ) + + # Update config settings on the fly + cli.settings.update_params(args.config_params) + + # It doesn't make sense to give a module name and --all + if args.reinstall_recipe and args.all: + logger.exit("Conflicting arguments reinstall_recipe and --all, choose one.") + # One option must be present + if not args.reinstall_recipe and not args.all: + logger.exit("Missing arguments: provide reinstall_recipe or --all.") + if args.ignore_missing and args.uninstall_missing: + logger.exit( + "Conflicting arguments --ignore-missing and --uninstall-missing, choose one." + ) + + # And do the reinstall + cli.reinstall( + args.reinstall_recipe, + when_missing=( + "ignore" + if args.ignore_missing + else "uninstall" + if args.uninstall_missing + else None + ), + ) diff --git a/shpc/main/modules/base.py b/shpc/main/modules/base.py index 30b9b0b39..a468799a8 100644 --- a/shpc/main/modules/base.py +++ b/shpc/main/modules/base.py @@ -375,7 +375,12 @@ def get_module(self, name, container_image=None, keep_path=False): return module def install( - self, name, force=False, container_image=None, keep_path=False, **kwargs + self, + name, + allow_reinstall=False, + container_image=None, + keep_path=False, + **kwargs ): """ Given a unique resource identifier, install a recipe. @@ -383,7 +388,6 @@ def install( For lmod, this means creating a subfolder in modules, pulling the container to it, and writing a module file there. We've already grabbed the name from docker (which is currently the only supported). - "force" is currently not used. """ # Create a new module module = self.get_module( @@ -393,6 +397,23 @@ def install( # We always load overrides for an install module.load_override_file() + # Check previous installations of this module + if os.path.exists(module.module_dir): + if not allow_reinstall: + logger.exit( + "%s is already installed. Do `shpc reinstall` to proceed with a reinstallation." + % module.tagged_name + ) + logger.info("%s is already installed. Reinstalling." % module.tagged_name) + # Don't explicitly remove the container, since we still need it, + # though it may still happen if shpc is configured to store + # containers and modules in the same directory + self._uninstall( + module.module_dir, + self.settings.module_base, + "$module_base/%s" % module.name, + ) + # Create the module and container directory utils.mkdirp([module.module_dir, module.container_dir]) @@ -421,12 +442,12 @@ def install( logger.info("Module %s was created." % module.tagged_name) return module.container_path - def view_install(self, view_name, name, force=False, container_image=None): + def view_install(self, view_name, name, force=False): """ Install a module in a view. The module must already be installed. Set "force" to True to allow overwriting existing symlinks. """ - module = self.get_module(name, container_image=container_image) + module = self.get_module(name) # A view is a symlink under views_base/$view/$module if view_name not in self.views: @@ -440,3 +461,58 @@ def view_install(self, view_name, name, force=False, container_image=None): # Don't continue if it exists, unless force is True view.confirm_install(module.module_dir, force=force) view.install(module.module_dir) + + def reinstall(self, name, when_missing=None): + """ + Reinstall the module, or all modules + """ + if name: + module_name, _, version = name.partition(":") + # Find all the versions currently installed + installed_modules = self._get_module_lookup( + self.settings.module_base, self.modulefile, module_name + ) + if (module_name not in installed_modules) or ( + version and version not in installed_modules[module_name] + ): + logger.exit("%s is not installed. Nothing to reinstall." % name) + versions = [version] if version else installed_modules[module_name] + # Reinstall the required version(s) one by one + for version in versions: + self._reinstall(module_name, version, when_missing) + else: + # Reinstall everything that is currently installed + installed_modules = self._get_module_lookup( + self.settings.module_base, self.modulefile + ) + for module_name, versions in installed_modules.items(): + for version in versions: + self._reinstall(module_name, version, when_missing) + + def _reinstall(self, module_name, version, when_missing): + """ + Reinstall (and possibly upgrade) all the current modules, possibly filtered by pattern. + """ + result = self.registry.find(module_name) + if result: + config = container.ContainerConfig(result) + if version in config.tags: + return self.install(module_name + ":" + version, allow_reinstall=True) + else: + missing = module_name + ":" + version + else: + missing = module_name + + if when_missing: + if when_missing == "ignore": + logger.info( + "%s is not in the Registry any more. Ignoring as instructed." + % missing + ) + elif when_missing == "uninstall": + self.uninstall(module_name + ":" + version, force=True) + else: + logger.exit( + "%s is not in the Registry any more. Add --uninstall-missing or --ignore-missing." + % missing + ) diff --git a/shpc/tests/test_client.py b/shpc/tests/test_client.py index c1b3e31e8..27a1f93dd 100644 --- a/shpc/tests/test_client.py +++ b/shpc/tests/test_client.py @@ -51,8 +51,6 @@ def test_install_get(tmp_path, module_sys, module_file, container_tech, remote): assert client.get("python:3.9.2-alpine") - client.install("python:3.9.2-alpine") - @pytest.mark.parametrize( "module_sys,module_file,remote", @@ -377,3 +375,41 @@ def test_add(tmp_path, module_sys, remote): client.get("dinosaur/salad:latest") client.install("dinosaur/salad:latest") assert client.get("dinosaur/salad:latest") + + +@pytest.mark.parametrize( + "module_sys,module_file,container_tech,remote", + [ + ("lmod", "module.lua", "singularity", False), + ("lmod", "module.lua", "podman", False), + ("tcl", "module.tcl", "singularity", False), + ("tcl", "module.tcl", "podman", False), + ("lmod", "module.lua", "singularity", True), + ("lmod", "module.lua", "podman", True), + ("tcl", "module.tcl", "singularity", True), + ("tcl", "module.tcl", "podman", True), + ], +) +def test_reinstall(tmp_path, module_sys, module_file, container_tech, remote): + """ + Test install and reinstall + """ + client = init_client(str(tmp_path), module_sys, container_tech, remote=remote) + + # Install known tag + client.install("python:3.9.2-alpine") + module_dir = os.path.join(client.settings.module_base, "python", "3.9.2-alpine") + env_file = os.path.join(module_dir, client.settings.environment_file) + dummy = os.path.join(module_dir, "dummy.sh") + # Ensure the content is initially as expected + assert os.path.exists(env_file) + assert not os.path.exists(dummy) + # Modify it + os.unlink(env_file) + shpc.utils.write_file(module_file, "") + assert not os.path.exists(env_file) + assert os.path.exists(dummy) + # The reinstallation should restore everything + client.install("python:3.9.2-alpine", force=True) + assert os.path.exists(env_file) + assert not os.path.exists(dummy)