From 11bf9d9a322355df53f66bfe2d4d943b0365cc12 Mon Sep 17 00:00:00 2001 From: Siyuan Chen Date: Fri, 11 Oct 2024 12:54:09 -0700 Subject: [PATCH] poc --- .../_plugins/spcs/image_registry/manager.py | 9 +- .../_plugins/spcs/image_registry/registry.py | 127 ++++++++++++++++ .../_plugins/spcs/image_repository/manager.py | 2 +- .../cli/_plugins/spcs/services/commands.py | 32 +++- .../cli/_plugins/spcs/services/manager.py | 9 +- .../_plugins/spcs/services/project_model.py | 122 +++++++++++++++ .../_plugins/spcs/services/spcs_processor.py | 140 ++++++++++++++++++ .../api/project/schemas/project_definition.py | 5 + .../api/project/schemas/v1/spcs/__init__.py | 0 .../api/project/schemas/v1/spcs/service.py | 95 ++++++++++++ src/snowflake/cli/api/project/util.py | 11 ++ 11 files changed, 545 insertions(+), 7 deletions(-) create mode 100644 src/snowflake/cli/_plugins/spcs/image_registry/registry.py create mode 100644 src/snowflake/cli/_plugins/spcs/services/project_model.py create mode 100644 src/snowflake/cli/_plugins/spcs/services/spcs_processor.py create mode 100644 src/snowflake/cli/api/project/schemas/v1/spcs/__init__.py create mode 100644 src/snowflake/cli/api/project/schemas/v1/spcs/service.py diff --git a/src/snowflake/cli/_plugins/spcs/image_registry/manager.py b/src/snowflake/cli/_plugins/spcs/image_registry/manager.py index 493b434d9d..d3294dcd4a 100644 --- a/src/snowflake/cli/_plugins/spcs/image_registry/manager.py +++ b/src/snowflake/cli/_plugins/spcs/image_registry/manager.py @@ -80,9 +80,12 @@ def get_registry_url(self) -> str: if len(results) == 0: raise NoImageRepositoriesFoundError() sample_repository_url = results[0]["repository_url"] - if not self._has_url_scheme(sample_repository_url): - sample_repository_url = f"//{sample_repository_url}" - return urlparse(sample_repository_url).netloc + return self.get_registry_url_from_repo(sample_repository_url) + + def get_registry_url_from_repo(self, repo_url) -> str: + if not self._has_url_scheme(repo_url): + repo_url = f"//{repo_url}" + return urlparse(repo_url).netloc def docker_registry_login(self) -> str: registry_url = self.get_registry_url() diff --git a/src/snowflake/cli/_plugins/spcs/image_registry/registry.py b/src/snowflake/cli/_plugins/spcs/image_registry/registry.py new file mode 100644 index 0000000000..02322c4f31 --- /dev/null +++ b/src/snowflake/cli/_plugins/spcs/image_registry/registry.py @@ -0,0 +1,127 @@ +import datetime +import json +import os +import subprocess +import uuid + +import docker + +TARGET_ARCH = "amd64" +DOCKER_BUILDER = "snowflake-cli-builder" + + +class Registry: + def __init__(self, registry_url, logger) -> None: + self._registry_url = registry_url + self._logger = logger + self._docker_client = docker.from_env(timeout=300) + self._is_arm = self._is_arch_arm() + if self._is_arm: + if os.system(f"docker buildx use {DOCKER_BUILDER}") != 0: + os.system(f"docker buildx create --name {DOCKER_BUILDER} --use") + + def _is_arch_arm(self): + result = subprocess.run(["uname", "-m"], stdout=subprocess.PIPE) + arch = result.stdout.strip().decode("UTF-8") + self._logger.info("Detected machine architecture: %s", arch) + return arch == "arm64" or arch == "aarch64" + + def _raise_error_from_output(self, output: str): + for line in output.splitlines(): + try: + jsline = json.loads(line) + if "error" in jsline: + raise docker.errors.APIError(jsline["error"]) + except json.JSONDecodeError: + pass # not a json, don't parse, assume no error + + def _gen_image_tag(self) -> str: + ts = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + uid = str(uuid.uuid4()).split("-")[0] + return f"{ts}-{uid}" + + def push(self, image_name): + self._logger.info("Uploading image %s", image_name) + output = self._docker_client.images.push(image_name) + self._raise_error_from_output(output) + return output + + def pull(self, image_name): + self._logger.info("Pulling image %s", image_name) + + n = image_name.rindex(":") + if n >= 0 and "/" not in image_name[n + 1 :]: + # if ':' is present in the last part in image name (separated by '/') + image = image_name[0:n] + tag = image_name[n + 1 :] + else: + image = image_name + tag = None + return self._docker_client.images.pull(image, tag, platform=TARGET_ARCH) + + def build_and_push_image( + self, + image_source_local_path: str, + image_path: str, + tag: str = "", + generate_tag: bool = False, + ): + """ + builds an image and push it to sf image registry + """ + + docker_file_path = os.path.join(image_source_local_path, "Dockerfile") + + if not tag and generate_tag: + tag = self._gen_image_tag() + + # build and upload image to registry if running remotely + self._logger.info("registry: %s", self._registry_url) + tagged = self._registry_url + image_path + if tag is not None: + tagged = f"{tagged}:{tag}" + + if self._is_arm: + self._logger.info("Using docker buildx for building image %s", tagged) + + # emulate intel environment on arm - see https://github.com/docker/buildx/issues/464 + # os.system( + # "docker run -it --rm --privileged tonistiigi/binfmt --install all" + # ) + + docker_build_cmd = f""" + docker buildx build --tag {tagged} + --load + --platform linux/amd64 + {image_source_local_path} + -f {docker_file_path} + --builder {DOCKER_BUILDER} + --rm + """ + + parts = list( + filter( + lambda part: part != "", + [part.strip() for part in docker_build_cmd.split("\n")], + ) + ) + docker_cmd = " ".join(parts) + self._logger.info("Executing: %s", docker_cmd) + if 0 != os.system(docker_cmd): + assert False, f"failed : unable to build image {tagged} with buildx" + + push_output = self.push(tagged) + self._logger.info(push_output) + else: + # build and upload image to registry if running remotely + self._logger.info("Building image %s with docker python sdk", tagged) + _, output = self._docker_client.images.build( + path=image_source_local_path, + dockerfile=docker_file_path, + rm=True, + tag=tagged, + ) + for o in output: + self._logger.info(o) + push_output = self.push(tagged) + self._logger.info(push_output) diff --git a/src/snowflake/cli/_plugins/spcs/image_repository/manager.py b/src/snowflake/cli/_plugins/spcs/image_repository/manager.py index 865740b788..4e41b97906 100644 --- a/src/snowflake/cli/_plugins/spcs/image_repository/manager.py +++ b/src/snowflake/cli/_plugins/spcs/image_repository/manager.py @@ -15,6 +15,7 @@ from urllib.parse import urlparse from snowflake.cli._plugins.spcs.common import handle_object_already_exists +from snowflake.cli._plugins.spcs.image_registry.registry import Registry from snowflake.cli.api.constants import ObjectType from snowflake.cli.api.identifiers import FQN from snowflake.cli.api.sql_execution import SqlExecutionMixin @@ -32,7 +33,6 @@ def get_role(self): return self._conn.role def get_repository_url(self, repo_name: str, with_scheme: bool = True): - repo_row = self.show_specific_object( "image repositories", repo_name, check_schema=True ) diff --git a/src/snowflake/cli/_plugins/spcs/services/commands.py b/src/snowflake/cli/_plugins/spcs/services/commands.py index 058d40dad5..4ec1455d7b 100644 --- a/src/snowflake/cli/_plugins/spcs/services/commands.py +++ b/src/snowflake/cli/_plugins/spcs/services/commands.py @@ -20,6 +20,7 @@ import typer from click import ClickException + from snowflake.cli._plugins.object.command_aliases import ( add_object_command_aliases, scope_option, @@ -30,6 +31,9 @@ validate_and_set_instances, ) from snowflake.cli._plugins.spcs.services.manager import ServiceManager +from snowflake.cli._plugins.spcs.services.spcs_processor import SpcsProcessor +from snowflake.cli.api.cli_global_context import get_cli_context +from snowflake.cli.api.commands.decorators import with_project_definition from snowflake.cli.api.commands.flags import ( IfNotExistsOption, OverrideableOption, @@ -41,6 +45,7 @@ from snowflake.cli.api.identifiers import FQN from snowflake.cli.api.output.types import ( CommandResult, + MessageResult, QueryJsonValueResult, QueryResult, SingleQueryResult, @@ -200,7 +205,16 @@ def status(name: FQN = ServiceNameArgument, **options) -> CommandResult: Retrieves the status of a service. """ cursor = ServiceManager().status(service_name=name.identifier) - return QueryJsonValueResult(cursor) + return SingleQueryResult(cursor) + + +@app.command(requires_connection=True) +def container_status(name: FQN = ServiceNameArgument, **options) -> CommandResult: + """ + Retrieves the container status of a service. + """ + cursor = ServiceManager().container_status(service_name=name.identifier) + return QueryResult(cursor) @app.command(requires_connection=True) @@ -343,3 +357,19 @@ def unset_property( comment=comment, ) return SingleQueryResult(cursor) + + +@app.command("deploy", requires_connection=True) +@with_project_definition() +def service_deploy( + **options, +) -> CommandResult: + """ + Deploys the service in the current schema or creates a new service if it does not exist. + """ + cli_context = get_cli_context() + processor = SpcsProcessor( + project_definition=cli_context.project_definition.spcs, + project_root=cli_context.project_root, + ) + return SingleQueryResult(processor.deploy()) diff --git a/src/snowflake/cli/_plugins/spcs/services/manager.py b/src/snowflake/cli/_plugins/spcs/services/manager.py index 3f83ad21db..3c6eb7c984 100644 --- a/src/snowflake/cli/_plugins/spcs/services/manager.py +++ b/src/snowflake/cli/_plugins/spcs/services/manager.py @@ -31,6 +31,8 @@ from snowflake.connector.cursor import SnowflakeCursor from snowflake.connector.errors import ProgrammingError +from snowflake.cli.api.project.schemas.v1.spcs.service import Service + class ServiceManager(SqlExecutionMixin): def create( @@ -74,7 +76,7 @@ def create( query.append(f"QUERY_WAREHOUSE = {query_warehouse}") if comment: - query.append(f"COMMENT = {comment}") + query.append(f"COMMENT = $${comment}$$") if tags: tag_list = ",".join(f"{t.name}={t.value_string_literal()}" for t in tags) @@ -130,7 +132,10 @@ def _read_yaml(self, path: Path) -> str: return json.dumps(data) def status(self, service_name: str) -> SnowflakeCursor: - return self._execute_query(f"CALL SYSTEM$GET_SERVICE_STATUS('{service_name}')") + return self._execute_query(f"DESC SERVICE {service_name}") + + def container_status(self, service_name: str) -> SnowflakeCursor: + return self._execute_query(f"SHOW SERVICE CONTAINERS IN SERVICE {service_name}") def logs( self, service_name: str, instance_id: str, container_name: str, num_lines: int diff --git a/src/snowflake/cli/_plugins/spcs/services/project_model.py b/src/snowflake/cli/_plugins/spcs/services/project_model.py new file mode 100644 index 0000000000..5656aeb2fc --- /dev/null +++ b/src/snowflake/cli/_plugins/spcs/services/project_model.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from functools import cached_property +from pathlib import Path +from typing import List + +from snowflake.cli._plugins.nativeapp.artifacts import resolve_without_follow +from snowflake.cli.api.project.schemas.v1.native_app.path_mapping import PathMapping +from snowflake.cli.api.project.schemas.v1.spcs.service import Service +from snowflake.cli.api.project.util import ( + to_identifier, +) + + +class ServiceProjectModel: + def __init__( + self, + project_definition: Service, + project_root: Path, + ): + self._project_definition = project_definition + self._project_root = resolve_without_follow(project_root) + + @property + def project_root(self) -> Path: + return self._project_root + + @property + def definition(self) -> Service: + return self._project_definition + + @cached_property + def service_name(self) -> str: + return self._project_definition.name + + @cached_property + def spec(self) -> PathMapping: + return self.definition.spec + + @cached_property + def images(self) -> List[PathMapping]: + return self.definition.images + + @cached_property + def image_sources(self) -> List[PathMapping]: + source_image_paths = [] + for image in self.images: + source_image_paths.append(PathMapping(src=image.src)) + return source_image_paths + + @cached_property + def source_repo_path(self) -> str: + return self.definition.source_repo + + @cached_property + def source_repo_fqn(self) -> str: + repo_path = self.definition.source_repo + return repo_path.strip("/").replace("/", ".") + + @cached_property + def source_stage_fqn(self) -> str: + return self.definition.source_stage + + @cached_property + def bundle_root(self) -> Path: + return self.project_root / self.definition.bundle_root + + @cached_property + def deploy_root(self) -> Path: + return self.project_root / self.definition.deploy_root + + @cached_property + def generated_root(self) -> Path: + return self.deploy_root / self.definition.generated_root + + @cached_property + def project_identifier(self) -> str: + return to_identifier(self.definition.name) + + @cached_property + def query_warehouse(self) -> str: + return to_identifier(self.definition.query_warehouse) + + @cached_property + def compute_pool(self) -> str: + return to_identifier(self.definition.compute_pool) + + @cached_property + def min_instances(self) -> int: + return self.definition.min_instances + + @cached_property + def max_instances(self) -> int: + return self.definition.max_instances + + @cached_property + def comment(self) -> str: + return self.definition.comment + + # def get_bundle_context(self) -> BundleContext: + # return BundleContext( + # package_name=self.package_name, + # artifacts=self.artifacts, + # project_root=self.project_root, + # bundle_root=self.bundle_root, + # deploy_root=self.deploy_root, + # generated_root=self.generated_root, + # ) diff --git a/src/snowflake/cli/_plugins/spcs/services/spcs_processor.py b/src/snowflake/cli/_plugins/spcs/services/spcs_processor.py new file mode 100644 index 0000000000..894cb1a539 --- /dev/null +++ b/src/snowflake/cli/_plugins/spcs/services/spcs_processor.py @@ -0,0 +1,140 @@ +import logging +import os +from pathlib import Path +import time +from typing import List + +from snowflake.cli._plugins.nativeapp.artifacts import build_bundle +from snowflake.cli._plugins.spcs.image_registry.manager import RegistryManager +from snowflake.cli._plugins.spcs.image_registry.registry import Registry +from snowflake.cli._plugins.spcs.image_repository.manager import ImageRepositoryManager +from snowflake.cli._plugins.spcs.services.manager import ServiceManager +from snowflake.cli._plugins.spcs.services.project_model import ServiceProjectModel +from snowflake.cli._plugins.stage.diff import DiffResult +from snowflake.cli.api.console import cli_console as cc +from snowflake.cli.api.entities.utils import sync_deploy_root_with_stage +from snowflake.cli.api.identifiers import FQN +from snowflake.cli.api.project.schemas.v1.spcs.service import Service +from snowflake.cli.api.project.util import extract_database, extract_schema + +logger = logging.getLogger(__name__) + + +class SpcsProcessor(ServiceManager): + def __init__(self, project_definition: Service, project_root: Path): + self._project_definition = ServiceProjectModel(project_definition, project_root) + self._registry_manager = RegistryManager() + self._repo_manager = ImageRepositoryManager() + self._service_manager = ServiceManager() + + def deploy( + self, + prune: bool = True, + recursive: bool = True, + paths: List[Path] = None, + print_diff: bool = True, + ): + # 1. Upload service spec and source code from deploy root local folder to stage + artifacts = [self._project_definition.spec] + artifacts.extend(self._project_definition.image_sources) + + stage_fqn = self._project_definition.source_stage_fqn + stage_schema = extract_schema(stage_fqn) + stage_database = extract_database(stage_fqn) + + diff = sync_deploy_root_with_stage( + console=cc, + deploy_root=self._project_definition.deploy_root, + package_name=stage_database, + stage_schema=stage_schema, + bundle_map=build_bundle( + self._project_definition.project_root, + self._project_definition.deploy_root, + artifacts=artifacts, + ), + role="SYSADMIN", # TODO: Use the correct role + prune=prune, + recursive=recursive, + stage_fqn=stage_fqn, + local_paths_to_sync=paths, # sync all + print_diff=print_diff, + ) + + # 2. Rebuild images if source code has changed + self._sync_images() + + # 3. Deploy service + spec_path = self._project_definition.deploy_root.joinpath( + self._project_definition.spec.src + ) + print("Deploying service with spec " + str(spec_path)) + res = self._service_manager.create( + service_name=self._project_definition.service_name, + compute_pool=self._project_definition.compute_pool, + spec_path=spec_path, + min_instances=self._project_definition.min_instances, + max_instances=self._project_definition.max_instances, + query_warehouse=self._project_definition.query_warehouse, + comment=self._project_definition.comment, + auto_resume=True, + external_access_integrations=None, + tags=None, + if_not_exists=True, + ).fetchone() + print(str(res[0])) + + if diff and diff.has_changes(): + print("Source change detected. Upgrading service.") + self._service_manager.upgrade_spec( + self._project_definition.service_name, spec_path + ) + + status = self._service_manager.status( + self._project_definition.service_name + ).fetchone()[1] + retry = 10 + while status != "RUNNING" and retry > 0: + print("Waiting for service to be ready... Status: " + status) + time.sleep(2) + retry -= 1 + status = self._service_manager.status( + self._project_definition.service_name + ).fetchone()[1] + return self._service_manager.status(self._project_definition.service_name) + + def _sync_images(self): + repo_fqn = FQN.from_string(self._project_definition.source_repo_fqn) + + self._repo_manager.create( + name=repo_fqn.identifier, if_not_exists=True, replace=False + ) + repo_url = self._repo_manager.get_repository_url( + repo_name=repo_fqn.identifier, with_scheme=False + ) + + registry_url = self._registry_manager.get_registry_url_from_repo(repo_url) + self._registry_manager.login_to_registry("https://" + registry_url) + registry = Registry(registry_url, logger) + + print("Syncing images in repository " + repo_url) + for image in self._project_definition.images: + # image_src = self._project_definition.deploy_root.joinpath( + # image.src.strip("*").strip("/") + # ) + image_src = image.src.strip("*").strip("/") + print("image_src: " + str(image_src)) + + image_src_folder_name = os.path.basename(image_src) + image_path = os.path.join( + self._project_definition.source_repo_path, image_src_folder_name + ) + if image.dest: + image_path = os.path.join( + self._project_definition.source_repo_path, + image.dest, + image_src_folder_name, + ) + + registry.build_and_push_image( + image_src, image_path, "latest", False + ) # TODO: fix tag diff --git a/src/snowflake/cli/api/project/schemas/project_definition.py b/src/snowflake/cli/api/project/schemas/project_definition.py index 26faaf9a05..f596ada00b 100644 --- a/src/snowflake/cli/api/project/schemas/project_definition.py +++ b/src/snowflake/cli/api/project/schemas/project_definition.py @@ -38,6 +38,8 @@ from snowflake.cli.api.utils.types import Context from typing_extensions import Annotated +from snowflake.cli.api.project.schemas.v1.spcs.service import Service + AnnotatedEntity = Annotated[EntityModel, Field(discriminator="type")] scalar = str | int | float | bool @@ -102,6 +104,9 @@ class DefinitionV10(_ProjectDefinitionBase): streamlit: Optional[Streamlit] = Field( title="Streamlit definitions for the project", default=None ) + spcs: Optional[Service] = Field( + title="SPCS definitions for the project", default=None + ) class DefinitionV11(DefinitionV10): diff --git a/src/snowflake/cli/api/project/schemas/v1/spcs/__init__.py b/src/snowflake/cli/api/project/schemas/v1/spcs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/snowflake/cli/api/project/schemas/v1/spcs/service.py b/src/snowflake/cli/api/project/schemas/v1/spcs/service.py new file mode 100644 index 0000000000..1587108fa6 --- /dev/null +++ b/src/snowflake/cli/api/project/schemas/v1/spcs/service.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import re +from typing import List, Optional, Union + +from pydantic import Field, field_validator +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel +from snowflake.cli.api.project.schemas.v1.native_app.path_mapping import PathMapping +from snowflake.cli.api.project.util import ( + DB_SCHEMA_AND_NAME, +) + + +class Service(UpdatableModel): + name: str = Field( + title="Project identifier", + ) + source_stage: str = Field( + title="Identifier of the stage that stores service source code.", + ) + spec: str = Field( + title="Service spec file path", + ) + source_repo: str = Field( + title="Identifier of the image repo that stores image source code.", + ) + images: List[Union[PathMapping, str]] = Field( + title="List of image source and destination pairs to add to the deploy root", + ) + compute_pool: str = Field( + title="Compute pool where the service will be deployed.", + ) + min_instances: int = Field( + title="Service min instances", + ) + max_instances: int = Field( + title="Service max instances", + ) + query_warehouse: Optional[str] = Field( + title="Default warehouse to run queries in the service.", + ) + comment: Optional[str] = Field( + title="Comment", + ) + bundle_root: Optional[str] = Field( + title="Folder at the root of your project where artifacts necessary to perform the bundle step are stored.", + default="output/bundle/", + ) + deploy_root: Optional[str] = Field( + title="Folder at the root of your project where the bundle step copies the artifacts.", + default="output/deploy/", + ) + generated_root: Optional[str] = Field( + title="Subdirectory of the deploy root where files generated by the Snowflake CLI will be written.", + default="__generated/", + ) + scratch_stage: Optional[str] = Field( + title="Identifier of the stage that stores temporary scratch data used by the Snowflake CLI.", + default="app_src.stage_snowflake_cli_scratch", + ) + + @field_validator("source_stage") + @classmethod + def validate_source_stage(cls, input_value: str): + if not re.match(DB_SCHEMA_AND_NAME, input_value): + raise ValueError("Incorrect value for source_stage value") + return input_value + + @field_validator("spec") + @classmethod + def transform_artifacts( + cls, orig_artifacts: Union[PathMapping, str] + ) -> PathMapping: + return ( + PathMapping(src=orig_artifacts) + if orig_artifacts and isinstance(orig_artifacts, str) + else orig_artifacts + ) + + @field_validator("images") + @classmethod + def transform_images( + cls, orig_artifacts: List[Union[PathMapping, str]] + ) -> List[PathMapping]: + transformed_artifacts = [] + if orig_artifacts is None: + return transformed_artifacts + + for artifact in orig_artifacts: + if isinstance(artifact, PathMapping): + transformed_artifacts.append(artifact) + else: + transformed_artifacts.append(PathMapping(src=artifact)) + + return transformed_artifacts diff --git a/src/snowflake/cli/api/project/util.py b/src/snowflake/cli/api/project/util.py index 564ebbd867..15a30d471c 100644 --- a/src/snowflake/cli/api/project/util.py +++ b/src/snowflake/cli/api/project/util.py @@ -194,6 +194,17 @@ def extract_schema(qualified_name: str): return None +def extract_database(qualified_name: str): + """ + Extracts the database from either a two-part or three-part qualified name + (i.e. schema.object or database.schema.object). If qualified_name is not + qualified with a database, returns None. + """ + if match := re.fullmatch(DB_SCHEMA_AND_NAME, qualified_name): + return match.group(1) + return None + + def first_set_env(*keys: str): for k in keys: v = os.getenv(k)