diff --git a/src/snowflake/cli/_app/__main__.py b/src/snowflake/cli/_app/__main__.py index d0777941e3..80fcb16465 100644 --- a/src/snowflake/cli/_app/__main__.py +++ b/src/snowflake/cli/_app/__main__.py @@ -16,11 +16,11 @@ import sys -from snowflake.cli._app.cli_app import app_factory +from snowflake.cli._app.cli_app import CliAppFactory def main(*args): - app = app_factory() + app = CliAppFactory().create_or_get_app() app(*args) diff --git a/src/snowflake/cli/_app/cli_app.py b/src/snowflake/cli/_app/cli_app.py index e85521cff9..71fee6097f 100644 --- a/src/snowflake/cli/_app/cli_app.py +++ b/src/snowflake/cli/_app/cli_app.py @@ -18,13 +18,12 @@ import os import platform import sys -from dataclasses import dataclass from pathlib import Path from typing import Optional import click import typer -from click import Context +from click import Context as ClickContext from snowflake.cli import __about__ from snowflake.cli._app.api_impl.plugin.plugin_config_provider_impl import ( PluginConfigProviderImpl, @@ -44,7 +43,7 @@ show_new_version_banner_callback, ) from snowflake.cli.api import Api, api_provider -from snowflake.cli.api.config import config_init +from snowflake.cli.api.config import config_init, get_feature_flags_section from snowflake.cli.api.output.formats import OutputFormat from snowflake.cli.api.output.types import CollectionResult from snowflake.cli.api.secure_path import SecurePath @@ -52,25 +51,6 @@ log = logging.getLogger(__name__) -_api = Api(plugin_config_provider=PluginConfigProviderImpl()) -api_provider.register_api(_api) - -_commands_registration = CommandsRegistrationWithCallbacks(_api.plugin_config_provider) - - -@dataclass -class AppContextHolder: - # needed to access the context from tests - app_context: Optional[Context] = None - - -app_context_holder = AppContextHolder() - - -def _exit_with_cleanup(): - _commands_registration.reset_running_instance_registration_state() - raise typer.Exit() - def _do_not_execute_on_completion(callback): def enriched_callback(value): @@ -81,197 +61,227 @@ def enriched_callback(value): return enriched_callback -def _commands_registration_callback(value: bool): - if value: - _commands_registration.register_commands_if_ready_and_not_registered_yet() - # required to make the tests working - # because a single test can execute multiple commands using always the same "app" instance - _commands_registration.reset_running_instance_registration_state() - app_context_holder.app_context = click.get_current_context() - - -@_commands_registration.before -def _config_init_callback(configuration_file: Optional[Path]): - config_init(configuration_file) - - -@_commands_registration.before -def _disable_external_command_plugins_callback(value: bool): - if value: - _commands_registration.disable_external_command_plugins() - - -@_do_not_execute_on_completion -@_commands_registration.after -def _docs_callback(value: bool): - if value: - ctx = click.get_current_context() - generate_docs(SecurePath("gen_docs"), ctx.command) - _exit_with_cleanup() - - -@_do_not_execute_on_completion -@_commands_registration.after -def _help_callback(value: bool): - if value: - ctx = click.get_current_context() - typer.echo(ctx.get_help()) - _exit_with_cleanup() - - -@_do_not_execute_on_completion -@_commands_registration.after -def _commands_structure_callback(value: bool): - if value: - ctx = click.get_current_context() - generate_commands_structure(ctx.command).print_node() - _exit_with_cleanup() - - -@_do_not_execute_on_completion -def _version_callback(value: bool): - if value: - print_result(MessageResult(f"Snowflake CLI version: {__about__.VERSION}")) - _exit_with_cleanup() - - -from snowflake.cli.api.config import get_feature_flags_section - - -@_do_not_execute_on_completion -def _info_callback(value: bool): - if value: - result = CollectionResult( - [ - {"key": "version", "value": __about__.VERSION}, - { - "key": "default_config_file_path", - "value": str(CONFIG_MANAGER.file_path), - }, - {"key": "python_version", "value": sys.version}, - {"key": "system_info", "value": platform.platform()}, - {"key": "feature_flags", "value": get_feature_flags_section()}, - {"key": "SNOWFLAKE_HOME", "value": os.getenv("SNOWFLAKE_HOME")}, - ], +class CliAppFactory: + def __init__(self): + api = Api(plugin_config_provider=PluginConfigProviderImpl()) + self._api = api + self._commands_registration = CommandsRegistrationWithCallbacks( + api.plugin_config_provider ) - print_result(result, output_format=OutputFormat.JSON) - _exit_with_cleanup() - - -def app_factory() -> SnowCliMainTyper: - app = SnowCliMainTyper() - new_version_msg = get_new_version_msg() - - @app.callback( - invoke_without_command=True, - epilog=new_version_msg, - result_callback=show_new_version_banner_callback(new_version_msg), - add_help_option=False, # custom_help option added below - help=f"Snowflake CLI tool for developers [v{__about__.VERSION}]", - ) - def default( - ctx: typer.Context, - # We need a custom help option with _help_callback called after command registration - # to have all commands visible in the help. - # This is required since click 8.1.8, when the default help option - # has started to being executed before our eager options, including command registration. - custom_help: bool = typer.Option( - None, - "--help", - "-h", - help="Show this message and exit.", - callback=_help_callback, - is_eager=True, - ), - version: bool = typer.Option( - None, - "--version", - help="Shows version of the Snowflake CLI", - callback=_version_callback, - is_eager=True, - ), - docs: bool = typer.Option( - None, - "--docs", - hidden=True, - help="Generates Snowflake CLI documentation", - callback=_docs_callback, - is_eager=True, - ), - structure: bool = typer.Option( - None, - "--structure", - hidden=True, - help="Prints Snowflake CLI structure of commands", - callback=_commands_structure_callback, - is_eager=True, - ), - info: bool = typer.Option( - None, - "--info", - help="Shows information about the Snowflake CLI", - callback=_info_callback, - ), - configuration_file: Path = typer.Option( - None, - "--config-file", - help="Specifies Snowflake CLI configuration file that should be used", - exists=True, - dir_okay=False, - is_eager=True, - callback=_config_init_callback, - ), - pycharm_debug_library_path: str = typer.Option( - None, - "--pycharm-debug-library-path", - hidden=True, - ), - pycharm_debug_server_host: str = typer.Option( - "localhost", - "--pycharm-debug-server-host", - hidden=True, - ), - pycharm_debug_server_port: int = typer.Option( - 12345, - "--pycharm-debug-server-port", - hidden=True, - ), - disable_external_command_plugins: bool = typer.Option( - None, - "--disable-external-command-plugins", - help="Disable external command plugins", - callback=_disable_external_command_plugins_callback, - is_eager=True, - hidden=True, - ), - # THIS OPTION SHOULD BE THE LAST OPTION IN THE LIST! - # --- - # This is a hidden artificial option used only to guarantee execution of commands registration - # and make this guaranty not dependent on other callbacks. - # Commands registration is invoked as soon as all callbacks - # decorated with "_commands_registration.before" are executed - # but if there are no such callbacks (at the result of possible future changes) - # then we need to invoke commands registration manually. - # - # This option is also responsible for resetting registration state for test purposes. - commands_registration: bool = typer.Option( - True, - "--commands-registration", - help="Commands registration", - hidden=True, - is_eager=True, - callback=_commands_registration_callback, - ), - ) -> None: - """ - Snowflake CLI tool for developers. - """ - if not ctx.invoked_subcommand: - typer.echo(ctx.get_help()) - setup_pycharm_remote_debugger_if_provided( - pycharm_debug_library_path=pycharm_debug_library_path, - pycharm_debug_server_host=pycharm_debug_server_host, - pycharm_debug_server_port=pycharm_debug_server_port, + api_provider.register_api(api) + self._app: Optional[SnowCliMainTyper] = None + self._click_context: Optional[ClickContext] = None + + def _exit_with_cleanup(self): + self._commands_registration.reset_running_instance_registration_state() + raise typer.Exit() + + def _commands_registration_callback(self): + def callback(value: bool): + self._click_context = click.get_current_context() + if value: + self._commands_registration.register_commands_from_plugins() + # required to make the tests working + # because a single test can execute multiple commands using always the same "app" instance + self._commands_registration.reset_running_instance_registration_state() + + return callback + + @staticmethod + def _config_init_callback(): + def callback(configuration_file: Optional[Path]): + config_init(configuration_file) + + return callback + + def _disable_external_command_plugins_callback(self): + def callback(value: bool): + if value: + self._commands_registration.disable_external_command_plugins() + + return callback + + def _docs_callback(self): + @_do_not_execute_on_completion + @self._commands_registration.after + def callback(value: bool): + if value: + ctx = click.get_current_context() + generate_docs(SecurePath("gen_docs"), ctx.command) + self._exit_with_cleanup() + + return callback + + def _help_callback(self): + @_do_not_execute_on_completion + @self._commands_registration.after + def callback(value: bool): + if value: + ctx = click.get_current_context() + typer.echo(ctx.get_help()) + self._exit_with_cleanup() + + return callback + + def _commands_structure_callback(self): + @_do_not_execute_on_completion + @self._commands_registration.after + def callback(value: bool): + if value: + ctx = click.get_current_context() + generate_commands_structure(ctx.command).print_node() + self._exit_with_cleanup() + + return callback + + def _version_callback(self): + @_do_not_execute_on_completion + def callback(value: bool): + if value: + print_result( + MessageResult(f"Snowflake CLI version: {__about__.VERSION}") + ) + self._exit_with_cleanup() + + return callback + + def _info_callback(self): + @_do_not_execute_on_completion + def callback(value: bool): + if value: + result = CollectionResult( + [ + {"key": "version", "value": __about__.VERSION}, + { + "key": "default_config_file_path", + "value": str(CONFIG_MANAGER.file_path), + }, + {"key": "python_version", "value": sys.version}, + {"key": "system_info", "value": platform.platform()}, + {"key": "feature_flags", "value": get_feature_flags_section()}, + {"key": "SNOWFLAKE_HOME", "value": os.getenv("SNOWFLAKE_HOME")}, + ], + ) + print_result(result, output_format=OutputFormat.JSON) + self._exit_with_cleanup() + + return callback + + def create_or_get_app(self) -> SnowCliMainTyper: + if self._app: + return self._app + + app = SnowCliMainTyper() + new_version_msg = get_new_version_msg() + + @app.callback( + invoke_without_command=True, + epilog=new_version_msg, + result_callback=show_new_version_banner_callback(new_version_msg), + add_help_option=False, # custom_help option added below + help=f"Snowflake CLI tool for developers [v{__about__.VERSION}]", ) - - return app + def default( + ctx: typer.Context, + # We need a custom help option with _help_callback called after command registration + # to have all commands visible in the help. + # This is required since click 8.1.8, when the default help option + # has started to being executed before our eager options, including command registration. + custom_help: bool = typer.Option( + None, + "--help", + "-h", + help="Show this message and exit.", + callback=self._help_callback, + is_eager=True, + ), + version: bool = typer.Option( + None, + "--version", + help="Shows version of the Snowflake CLI", + callback=self._version_callback(), + is_eager=True, + ), + docs: bool = typer.Option( + None, + "--docs", + hidden=True, + help="Generates Snowflake CLI documentation", + callback=self._docs_callback(), + is_eager=True, + ), + structure: bool = typer.Option( + None, + "--structure", + hidden=True, + help="Prints Snowflake CLI structure of commands", + callback=self._commands_structure_callback(), + is_eager=True, + ), + info: bool = typer.Option( + None, + "--info", + help="Shows information about the Snowflake CLI", + callback=self._info_callback(), + ), + configuration_file: Path = typer.Option( + None, + "--config-file", + help="Specifies Snowflake CLI configuration file that should be used", + exists=True, + dir_okay=False, + is_eager=True, + callback=self._config_init_callback(), + ), + pycharm_debug_library_path: str = typer.Option( + None, + "--pycharm-debug-library-path", + hidden=True, + ), + pycharm_debug_server_host: str = typer.Option( + "localhost", + "--pycharm-debug-server-host", + hidden=True, + ), + pycharm_debug_server_port: int = typer.Option( + 12345, + "--pycharm-debug-server-port", + hidden=True, + ), + disable_external_command_plugins: bool = typer.Option( + None, + "--disable-external-command-plugins", + help="Disable external command plugins", + callback=self._disable_external_command_plugins_callback(), + is_eager=True, + hidden=True, + ), + # THIS OPTION SHOULD BE THE LAST OPTION IN THE LIST! + # --- + # This is a hidden artificial option used only to guarantee execution of commands registration. + # This option is also responsible for resetting registration state for test purposes. + commands_registration: bool = typer.Option( + True, + "--commands-registration", + help="Commands registration", + hidden=True, + is_eager=True, + callback=self._commands_registration_callback(), + ), + ) -> None: + """ + Snowflake CLI tool for developers. + """ + if not ctx.invoked_subcommand: + typer.echo(ctx.get_help()) + setup_pycharm_remote_debugger_if_provided( + pycharm_debug_library_path=pycharm_debug_library_path, + pycharm_debug_server_host=pycharm_debug_server_host, + pycharm_debug_server_port=pycharm_debug_server_port, + ) + + self._app = app + return app + + def get_click_context(self): + return self._click_context diff --git a/src/snowflake/cli/_app/commands_registration/commands_registration_with_callbacks.py b/src/snowflake/cli/_app/commands_registration/commands_registration_with_callbacks.py index e344ae90a4..43c4ea85f5 100644 --- a/src/snowflake/cli/_app/commands_registration/commands_registration_with_callbacks.py +++ b/src/snowflake/cli/_app/commands_registration/commands_registration_with_callbacks.py @@ -21,7 +21,6 @@ load_builtin_and_external_command_plugins, load_only_builtin_command_plugins, ) -from snowflake.cli._app.commands_registration.threadsafe import ThreadsafeCounter from snowflake.cli._app.commands_registration.typer_registration import ( register_commands_from_plugins, ) @@ -36,27 +35,13 @@ class CommandRegistrationConfig: class CommandsRegistrationWithCallbacks: def __init__(self, plugin_config_provider: PluginConfigProvider): self._plugin_config_provider = plugin_config_provider - self._counter_of_callbacks_required_before_registration: ThreadsafeCounter = ( - ThreadsafeCounter(0) - ) - self._counter_of_callbacks_invoked_before_registration: ThreadsafeCounter = ( - ThreadsafeCounter(0) - ) self._callbacks_after_registration: List[Callable[[], None]] = [] self._commands_registration_config: CommandRegistrationConfig = ( CommandRegistrationConfig(enable_external_command_plugins=True) ) self._commands_already_registered: bool = False - def register_commands_if_ready_and_not_registered_yet(self): - all_required_callbacks_executed = ( - self._counter_of_callbacks_required_before_registration.value - == self._counter_of_callbacks_invoked_before_registration.value - ) - if all_required_callbacks_executed and not self._commands_already_registered: - self._register_commands_from_plugins() - - def _register_commands_from_plugins(self) -> None: + def register_commands_from_plugins(self) -> None: if self._commands_registration_config.enable_external_command_plugins: self._register_builtin_and_enabled_external_plugin_commands() else: @@ -83,15 +68,6 @@ def _register_builtin_and_enabled_external_plugin_commands(self): def disable_external_command_plugins(self): self._commands_registration_config.enable_external_command_plugins = False - def before(self, callback): - def enriched_callback(value): - self._counter_of_callbacks_invoked_before_registration.increment() - callback(value) - self.register_commands_if_ready_and_not_registered_yet() - - self._counter_of_callbacks_required_before_registration.increment() - return enriched_callback - def after(self, callback): def delayed_callback(value): self._callbacks_after_registration.append(lambda: callback(value)) @@ -99,7 +75,5 @@ def delayed_callback(value): return delayed_callback def reset_running_instance_registration_state(self): - self._commands_already_registered = False - self._counter_of_callbacks_invoked_before_registration.set(0) self._callbacks_after_registration.clear() self._commands_registration_config.enable_external_command_plugins = True diff --git a/tests/test_connection.py b/tests/test_connection.py index 32cf54a2ee..f4d3531cb9 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1159,6 +1159,7 @@ def test_new_connection_is_added_to_connections_toml( @mock.patch( "snowflake.cli._plugins.connection.commands.connector.auth.get_token_from_private_key" ) +@mock.patch.dict(os.environ, {}, clear=True) def test_generate_jwt_without_passphrase( mocked_get_token, runner, named_temporary_file ): @@ -1190,6 +1191,7 @@ def test_generate_jwt_without_passphrase( @mock.patch( "snowflake.cli._plugins.connection.commands.connector.auth.get_token_from_private_key" ) +@mock.patch.dict(os.environ, {}, clear=True) def test_generate_jwt_with_passphrase( mocked_get_token, runner, named_temporary_file, passphrase ): @@ -1269,6 +1271,7 @@ def test_generate_jwt_with_pass_phrase_from_env( @mock.patch( "snowflake.cli._plugins.connection.commands.connector.auth.get_token_from_private_key" ) +@mock.patch.dict(os.environ, {}, clear=True) def test_generate_jwt_uses_config(mocked_get_token, runner, named_temporary_file): mocked_get_token.return_value = "funny token" diff --git a/tests/test_docs_generation_output.py b/tests/test_docs_generation_output.py index e0edf7a5a0..0f7449d8ed 100644 --- a/tests/test_docs_generation_output.py +++ b/tests/test_docs_generation_output.py @@ -18,7 +18,6 @@ from click import Command from pydantic.json_schema import GenerateJsonSchema, model_json_schema -from snowflake.cli._app.cli_app import app_context_holder from snowflake.cli.api.project.schemas.project_definition import DefinitionV11 @@ -128,14 +127,9 @@ def test_files_generated_for_each_optional_project_definition_property( assert len(errors) == 0, "\n".join(errors) -def test_all_commands_have_generated_files(runner, temp_dir): +def test_all_commands_have_generated_files(runner, temp_dir, get_click_context): runner.invoke(["--docs"]) - # invoke help command to populate app context (plugins registration) - runner.invoke([""]) - - ctx = app_context_holder.app_context - commands_path = Path(temp_dir) / "gen_docs" / "commands" errors = [] @@ -159,7 +153,11 @@ def _check(command: Command, directory_path: Path, command_path=None): f"Command `{' '.join(command_path)}` documentation was not properly generated" ) - _check(ctx.command, commands_path) + app = get_click_context().command + assert ( + len(app.commands) >= 1 + ) # confirm that test is actually checking some commands + _check(get_click_context().command, commands_path) assert len(errors) == 0, "\n".join(errors) diff --git a/tests/test_main.py b/tests/test_main.py index 2a2428fc15..b51ca56da0 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -26,7 +26,6 @@ import pytest from click import Command -from snowflake.cli._app.cli_app import app_context_holder from snowflake.connector.config_manager import CONFIG_MANAGER from typer.core import TyperArgument, TyperOption @@ -89,11 +88,10 @@ def test_docs_callback(runner): assert result.exit_code == 0, result.output -def test_all_commands_have_proper_documentation(runner): +def test_all_commands_have_proper_documentation(runner, get_click_context): # invoke any command to populate app context (plugins registration) runner.invoke("--help") - ctx = app_context_holder.app_context errors = [] def _check(command: Command, path: t.Optional[t.List] = None): @@ -128,15 +126,14 @@ def _check(command: Command, path: t.Optional[t.List] = None): f"Command `snow {' '.join(path)}` is missing help for `{param.name}` option" ) - _check(ctx.command) + _check(get_click_context().command) assert len(errors) == 0, "\n".join(errors) -def test_if_there_are_no_option_duplicates(runner): +def test_if_there_are_no_option_duplicates(runner, get_click_context): runner.invoke("--help") - ctx = app_context_holder.app_context duplicates = {} def _check(command: Command, path: t.Optional[t.List] = None): @@ -153,7 +150,7 @@ def check_options_for_duplicates(params: t.List[TyperOption]) -> t.Set[str]: flags = [flag for param in params for flag in param.opts] return set([flag for flag in flags if (flags.count(flag) > 1)]) - _check(ctx.command) + _check(get_click_context().command) assert duplicates == {}, "\n".join(duplicates) diff --git a/tests/testing_utils/fixtures.py b/tests/testing_utils/fixtures.py index 192776fcc9..0e369bbb09 100644 --- a/tests/testing_utils/fixtures.py +++ b/tests/testing_utils/fixtures.py @@ -28,7 +28,7 @@ import pytest import yaml -from snowflake.cli._app.cli_app import app_factory +from snowflake.cli._app.cli_app import CliAppFactory from snowflake.cli._plugins.nativeapp.codegen.snowpark.models import ( NativeAppExtensionFunction, ) @@ -233,15 +233,24 @@ def package_file(): @pytest.fixture(scope="function") -def runner(test_snowcli_config): - app = app_factory() - yield SnowCLIRunner(app, test_snowcli_config) +def app_factory(): + yield CliAppFactory() @pytest.fixture(scope="function") -def build_runner(test_snowcli_config): +def get_click_context(app_factory): + yield lambda: app_factory.get_click_context() + + +@pytest.fixture(scope="function") +def runner(build_runner): + yield build_runner() + + +@pytest.fixture(scope="function") +def build_runner(app_factory, test_snowcli_config): def func(): - app = app_factory() + app = app_factory.create_or_get_app() return SnowCLIRunner(app, test_snowcli_config) return func diff --git a/tests_integration/conftest.py b/tests_integration/conftest.py index 1ac9bda801..21571c5668 100644 --- a/tests_integration/conftest.py +++ b/tests_integration/conftest.py @@ -32,7 +32,7 @@ from typer import Typer from typer.testing import CliRunner -from snowflake.cli._app.cli_app import app_factory +from snowflake.cli._app.cli_app import CliAppFactory from snowflake.cli.api.cli_global_context import ( fork_cli_context, get_cli_context_manager, @@ -189,7 +189,7 @@ def invoke_with_connection( @pytest.fixture def runner(test_snowcli_config_provider, default_username, resource_suffix): - app = app_factory() + app = CliAppFactory().create_or_get_app() yield SnowCLIRunner( app, test_snowcli_config_provider, diff --git a/tests_integration/test_external_plugins.py b/tests_integration/test_external_plugins.py index 9fc88a9ad4..9980cafc55 100644 --- a/tests_integration/test_external_plugins.py +++ b/tests_integration/test_external_plugins.py @@ -31,21 +31,9 @@ def install_plugins(): subprocess.check_call(["pip", "install", path / "snowpark_hello_single_command"]) -@pytest.fixture() -def reset_command_registration_state(): - def _reset_command_registration_state(): - from snowflake.cli._app.cli_app import _commands_registration - - _commands_registration.reset_running_instance_registration_state() - - yield _reset_command_registration_state - - _reset_command_registration_state() - - @pytest.mark.integration def test_loading_of_installed_plugins_if_all_plugins_enabled( - runner, install_plugins, caplog, reset_command_registration_state + runner, install_plugins, caplog ): runner.use_config("config_with_enabled_all_external_plugins.toml") @@ -110,7 +98,6 @@ def test_loading_of_installed_plugins_if_only_one_plugin_is_enabled( runner, install_plugins, caplog, - reset_command_registration_state, ): runner.use_config("config_with_enabled_only_one_external_plugin.toml") @@ -141,7 +128,6 @@ def test_enabled_value_must_be_boolean( config_value, runner, snowflake_home, - reset_command_registration_state, ): def _use_config_with_value(value): config = Path(snowflake_home) / "config.toml" @@ -163,8 +149,6 @@ def _use_config_with_value(value): ), second assert "boolean" in third, third - reset_command_registration_state() - def _assert_that_no_error_logs(caplog): error_logs = [