From 10782d25ca5f4ef9cb68eb5dd4927ba5eabb657c Mon Sep 17 00:00:00 2001 From: DJ Sterling Date: Mon, 11 Nov 2024 10:54:51 -0700 Subject: [PATCH] MVP support for Kaggle Packages Support tracking accessed datasources and writing / reading them to a requirements.yaml file. Support a "Package Scope" using that file which applies those datasource versions at runtime if the user requests a datasource without an explicit version. Support importing a Package. This uses a handle equivalent to a Notebook handle, and like Notebooks is currently limited to the latest version, with version support coming soon. Support Package Asset files whose path honors the current Package Scope. --- integration_tests/test_package_import.py | 47 +++++ pyproject.toml | 1 + src/kagglehub/__init__.py | 3 +- src/kagglehub/clients.py | 2 +- src/kagglehub/colab_cache_resolver.py | 61 ++++-- src/kagglehub/competition.py | 3 +- src/kagglehub/datasets.py | 3 +- src/kagglehub/handle.py | 31 ++- src/kagglehub/http_resolver.py | 53 +++-- src/kagglehub/kaggle_cache_resolver.py | 51 +++-- src/kagglehub/models.py | 3 +- src/kagglehub/notebooks.py | 7 +- src/kagglehub/packages.py | 195 ++++++++++++++++++ src/kagglehub/registry.py | 4 +- src/kagglehub/requirements.py | 128 ++++++++++++ src/kagglehub/resolver.py | 32 ++- src/kagglehub/utility_scripts.py | 2 +- tests/data/package-v1.zip | Bin 0 -> 722 bytes tests/data/package-v2.zip | Bin 0 -> 722 bytes tests/server_stubs/jwt_stub.py | 101 +++++++-- .../notebook_output_download_stub.py | 8 +- tests/test_http_package_import.py | 59 ++++++ tests/test_http_requirements.py | 54 +++++ tests/test_kaggle_cache_package_import.py | 78 +++++++ tests/test_kaggle_cache_requirements.py | 70 +++++++ tests/test_registry.py | 14 +- 26 files changed, 911 insertions(+), 99 deletions(-) create mode 100644 integration_tests/test_package_import.py create mode 100644 src/kagglehub/packages.py create mode 100644 src/kagglehub/requirements.py create mode 100644 tests/data/package-v1.zip create mode 100644 tests/data/package-v2.zip create mode 100644 tests/test_http_package_import.py create mode 100644 tests/test_http_requirements.py create mode 100644 tests/test_kaggle_cache_package_import.py create mode 100644 tests/test_kaggle_cache_requirements.py diff --git a/integration_tests/test_package_import.py b/integration_tests/test_package_import.py new file mode 100644 index 0000000..6250a01 --- /dev/null +++ b/integration_tests/test_package_import.py @@ -0,0 +1,47 @@ +import sys +import unittest + +from kagglehub import package_import + +from .utils import create_test_cache, unauthenticated + +UNVERSIONED_HANDLE = "dster/package-test" +VERSIONED_HANDLE = "dster/package-test/versions/1" + + +class TestPackageImport(unittest.TestCase): + + def tearDown(self) -> None: + # Clear any imported packages from sys.modules. + for name in list(sys.modules.keys()): + if name.startswith("kagglehub_package"): + del sys.modules[name] + + def test_package_versioned_succeeds(self) -> None: + with create_test_cache(): + package = package_import(VERSIONED_HANDLE) + + self.assertIn("foo", dir(package)) + self.assertEqual("bar", package.foo()) + + def test_package_unversioned_succeeds(self) -> None: + with create_test_cache(): + package = package_import(UNVERSIONED_HANDLE) + + self.assertIn("foo", dir(package)) + self.assertEqual("baz", package.foo()) + + def test_download_private_package_succeeds(self) -> None: + with create_test_cache(): + package = package_import("integrationtester/kagglehub-test-private-package") + + self.assertIn("foo", dir(package)) + self.assertEqual("bar", package.foo()) + + def test_public_package_with_unauthenticated_succeeds(self) -> None: + with create_test_cache(): + with unauthenticated(): + package = package_import(UNVERSIONED_HANDLE) + + self.assertIn("foo", dir(package)) + self.assertEqual("baz", package.foo()) diff --git a/pyproject.toml b/pyproject.toml index f9194c5..6821a0d 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "tqdm", "packaging", "model_signing", + "pyyaml", ] [project.urls] diff --git a/src/kagglehub/__init__.py b/src/kagglehub/__init__.py index 44b2681..dbfdfa8 100644 --- a/src/kagglehub/__init__.py +++ b/src/kagglehub/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.3.6" +__version__ = "0.3.7" import kagglehub.logger # configures the library logger. from kagglehub import colab_cache_resolver, http_resolver, kaggle_cache_resolver, registry @@ -7,6 +7,7 @@ from kagglehub.datasets import KaggleDatasetAdapter, dataset_download, dataset_upload, load_dataset from kagglehub.models import model_download, model_upload from kagglehub.notebooks import notebook_output_download +from kagglehub.packages import get_package_asset_path, package_import from kagglehub.utility_scripts import utility_script_install registry.model_resolver.add_implementation(http_resolver.ModelHttpResolver()) diff --git a/src/kagglehub/clients.py b/src/kagglehub/clients.py index 3f7510f..cb5f5fa 100644 --- a/src/kagglehub/clients.py +++ b/src/kagglehub/clients.py @@ -380,7 +380,7 @@ class ColabClient: MOUNT_PATH = "/kagglehub/models/mount" MODEL_MOUNT_PATH = "/kagglehub/models/mount" DATASET_MOUNT_PATH = "/kagglehub/datasets/mount" - # TBE_RUNTIME_ADDR serves requests made from `is_supported` and `__call__` + # TBE_RUNTIME_ADDR serves requests made from `is_supported` and `_resolve` # of ModelColabCacheResolver. TBE_RUNTIME_ADDR_ENV_VAR_NAME = "TBE_RUNTIME_ADDR" diff --git a/src/kagglehub/colab_cache_resolver.py b/src/kagglehub/colab_cache_resolver.py index c43e791..f3bc5a9 100644 --- a/src/kagglehub/colab_cache_resolver.py +++ b/src/kagglehub/colab_cache_resolver.py @@ -8,6 +8,7 @@ from kagglehub.exceptions import BackendError, NotFoundError from kagglehub.handle import DatasetHandle, ModelHandle from kagglehub.logger import EXTRA_CONSOLE_BLOCK +from kagglehub.packages import PackageScope from kagglehub.resolver import Resolver COLAB_CACHE_MOUNT_FOLDER_ENV_VAR_NAME = "COLAB_CACHE_MOUNT_FOLDER" @@ -29,9 +30,10 @@ def is_supported(self, handle: ModelHandle, *_, **__) -> bool: # noqa: ANN002, "variation": handle.variation, } - if handle.is_versioned(): + version = _get_model_version(handle) + if version: # Colab treats version as int in the request - data["version"] = handle.version # type: ignore + data["version"] = version # type: ignore try: api_client.post(data, ColabClient.IS_MODEL_SUPPORTED_PATH, handle) @@ -39,7 +41,9 @@ def is_supported(self, handle: ModelHandle, *_, **__) -> bool: # noqa: ANN002, return False return True - def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: + def _resolve( + self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + ) -> tuple[str, Optional[int]]: if force_download: logger.info( "Ignoring `force_download` argument when running inside the Colab notebook environment.", @@ -53,9 +57,11 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download "framework": h.framework, "variation": h.variation, } - if h.is_versioned(): + + version = _get_model_version(h) + if version: # Colab treats version as int in the request - data["version"] = h.version # type: ignore + data["version"] = version # type: ignore response = api_client.post(data, ColabClient.MODEL_MOUNT_PATH, h) @@ -85,8 +91,8 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download f"You can access the other files of the attached model at '{cached_path}'" ) raise ValueError(msg) - return cached_filepath - return cached_path + return cached_filepath, version + return cached_path, version class DatasetColabCacheResolver(Resolver[DatasetHandle]): @@ -100,9 +106,10 @@ def is_supported(self, handle: DatasetHandle, *_, **__) -> bool: # noqa: ANN002 "dataset": handle.dataset, } - if handle.is_versioned(): + version = _get_dataset_version(handle) + if version: # Colab treats version as int in the request - data["version"] = handle.version # type: ignore + data["version"] = version # type: ignore try: api_client.post(data, ColabClient.IS_DATASET_SUPPORTED_PATH, handle) @@ -110,7 +117,9 @@ def is_supported(self, handle: DatasetHandle, *_, **__) -> bool: # noqa: ANN002 return False return True - def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: + def _resolve( + self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + ) -> tuple[str, Optional[int]]: if force_download: logger.info( "Ignoring `force_download` argument when running inside the Colab notebook environment.", @@ -122,9 +131,11 @@ def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_downlo "owner": h.owner, "dataset": h.dataset, } - if h.is_versioned(): + + version = _get_dataset_version(h) + if version: # Colab treats version as int in the request - data["version"] = h.version # type: ignore + data["version"] = version # type: ignore response = api_client.post(data, ColabClient.DATASET_MOUNT_PATH, h) @@ -154,5 +165,27 @@ def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_downlo f"You can access the other files of the attached dataset at '{cached_path}'" ) raise ValueError(msg) - return cached_filepath - return cached_path + return cached_filepath, version + return cached_path, version + + +def _get_model_version(h: ModelHandle) -> Optional[int]: + if h.is_versioned(): + return h.version + + version_from_package_scope = PackageScope.get_version(h) + if version_from_package_scope is not None: + return version_from_package_scope + + return None + + +def _get_dataset_version(h: DatasetHandle) -> Optional[int]: + if h.is_versioned(): + return h.version + + version_from_package_scope = PackageScope.get_version(h) + if version_from_package_scope is not None: + return version_from_package_scope + + return None diff --git a/src/kagglehub/competition.py b/src/kagglehub/competition.py index 7265c41..3050576 100644 --- a/src/kagglehub/competition.py +++ b/src/kagglehub/competition.py @@ -20,4 +20,5 @@ def competition_download(handle: str, path: Optional[str] = None, *, force_downl h = parse_competition_handle(handle) logger.info(f"Downloading competition: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK}) - return registry.competition_resolver(h, path, force_download=force_download) + path, _ = registry.competition_resolver(h, path, force_download=force_download) + return path diff --git a/src/kagglehub/datasets.py b/src/kagglehub/datasets.py index 8e630ba..3330dc9 100755 --- a/src/kagglehub/datasets.py +++ b/src/kagglehub/datasets.py @@ -37,7 +37,8 @@ def dataset_download(handle: str, path: Optional[str] = None, *, force_download: h = parse_dataset_handle(handle) logger.info(f"Downloading Dataset: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK}) - return registry.dataset_resolver(h, path, force_download=force_download) + path, _ = registry.dataset_resolver(h, path, force_download=force_download) + return path def dataset_upload( diff --git a/src/kagglehub/handle.py b/src/kagglehub/handle.py index d930f19..6fa7e49 100644 --- a/src/kagglehub/handle.py +++ b/src/kagglehub/handle.py @@ -12,11 +12,12 @@ NUM_VERSIONED_MODEL_PARTS = 5 # e.g.: //// NUM_UNVERSIONED_MODEL_PARTS = 4 # e.g.: /// +NUM_VERSIONED_NOTEBOOK_PARTS = 4 # e.g.: //versions/ NUM_UNVERSIONED_NOTEBOOK_PARTS = 2 # e.g.: / NUM_VERSIONED_NOTEBOOK_PARTS = 4 # e.g.: //versions/ -@dataclass +@dataclass(frozen=True) class ResourceHandle: @abc.abstractmethod def to_url(self) -> str: @@ -24,7 +25,7 @@ def to_url(self) -> str: pass -@dataclass +@dataclass(frozen=True) class ModelHandle(ResourceHandle): owner: str model: str @@ -35,6 +36,11 @@ class ModelHandle(ResourceHandle): def is_versioned(self) -> bool: return self.version is not None and self.version > 0 + def with_version(self, version: int) -> "ModelHandle": + return ModelHandle( + owner=self.owner, model=self.model, framework=self.framework, variation=self.variation, version=version + ) + def __str__(self) -> str: handle_str = f"{self.owner}/{self.model}/{self.framework}/{self.variation}" if self.is_versioned(): @@ -49,7 +55,7 @@ def to_url(self) -> str: return f"{endpoint}/models/{self.owner}/{self.model}/{self.framework}/{self.variation}" -@dataclass +@dataclass(frozen=True) class DatasetHandle(ResourceHandle): owner: str dataset: str @@ -58,6 +64,9 @@ class DatasetHandle(ResourceHandle): def is_versioned(self) -> bool: return self.version is not None and self.version > 0 + def with_version(self, version: int) -> "DatasetHandle": + return DatasetHandle(owner=self.owner, dataset=self.dataset, version=version) + def __str__(self) -> str: handle_str = f"{self.owner}/{self.dataset}" if self.is_versioned(): @@ -72,7 +81,7 @@ def to_url(self) -> str: return base_url -@dataclass +@dataclass(frozen=True) class CompetitionHandle(ResourceHandle): competition: str @@ -86,7 +95,7 @@ def to_url(self) -> str: return base_url -@dataclass +@dataclass(frozen=True) class NotebookHandle(ResourceHandle): owner: str notebook: str @@ -95,6 +104,9 @@ class NotebookHandle(ResourceHandle): def is_versioned(self) -> bool: return self.version is not None and self.version > 0 + def with_version(self, version: int) -> "NotebookHandle": + return NotebookHandle(owner=self.owner, notebook=self.notebook, version=version) + def __str__(self) -> str: handle_str = f"{self.owner}/{self.notebook}" if self.is_versioned(): @@ -113,6 +125,10 @@ class UtilityScriptHandle(NotebookHandle): pass +class PackageHandle(NotebookHandle): + pass + + def parse_dataset_handle(handle: str) -> DatasetHandle: parts = handle.split("/") @@ -217,3 +233,8 @@ def parse_notebook_handle(handle: str) -> NotebookHandle: def parse_utility_script_handle(handle: str) -> UtilityScriptHandle: notebook_handle = parse_notebook_handle(handle) return UtilityScriptHandle(**asdict(notebook_handle)) + + +def parse_package_handle(handle: str) -> PackageHandle: + notebook_handle = parse_notebook_handle(handle) + return PackageHandle(**asdict(notebook_handle)) diff --git a/src/kagglehub/http_resolver.py b/src/kagglehub/http_resolver.py index 6174ed8..f44a2a9 100644 --- a/src/kagglehub/http_resolver.py +++ b/src/kagglehub/http_resolver.py @@ -17,6 +17,7 @@ from kagglehub.clients import KaggleApiV1Client from kagglehub.exceptions import UnauthenticatedError from kagglehub.handle import CompetitionHandle, DatasetHandle, ModelHandle, NotebookHandle, ResourceHandle +from kagglehub.packages import PackageScope from kagglehub.resolver import Resolver DATASET_CURRENT_VERSION_FIELD = "currentVersionNumber" @@ -33,9 +34,9 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 # Downloading files over HTTP is supported in all environments for all handles / paths. return True - def __call__( + def _resolve( self, h: CompetitionHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False - ) -> str: + ) -> tuple[str, Optional[int]]: api_client = KaggleApiV1Client() cached_path = load_from_cache(h, path) @@ -45,7 +46,7 @@ def __call__( if not api_client.has_credentials(): if cached_path: - return cached_path + return cached_path, None raise UnauthenticatedError() out_path = get_cached_path(h, path) @@ -60,11 +61,11 @@ def __call__( ) except requests.exceptions.ConnectionError: if cached_path: - return cached_path + return cached_path, None raise if not download_needed and cached_path: - return cached_path + return cached_path, None else: # Download, extract, then delete the archive. url_path = _build_competition_download_all_url_path(h) @@ -77,19 +78,19 @@ def __call__( if cached_path: if os.path.exists(archive_path): os.remove(archive_path) - return cached_path + return cached_path, None raise if not download_needed and cached_path: if os.path.exists(archive_path): os.remove(archive_path) - return cached_path + return cached_path, None _extract_archive(archive_path, out_path) os.remove(archive_path) mark_as_complete(h, path) - return out_path + return out_path, None class DatasetHttpResolver(Resolver[DatasetHandle]): @@ -97,15 +98,17 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 # Downloading files over HTTP is supported in all environments for all handles / paths. return True - def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: + def _resolve( + self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + ) -> tuple[str, Optional[int]]: api_client = KaggleApiV1Client() if not h.is_versioned(): - h.version = _get_current_version(api_client, h) + h = h.with_version(_get_current_version(api_client, h)) dataset_path = load_from_cache(h, path) if dataset_path and not force_download: - return dataset_path # Already cached + return dataset_path, h.version # Already cached elif dataset_path and force_download: delete_from_cache(h, path) @@ -132,7 +135,7 @@ def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_downlo os.remove(archive_path) mark_as_complete(h, path) - return out_path + return out_path, h.version class ModelHttpResolver(Resolver[ModelHandle]): @@ -140,15 +143,17 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 # Downloading files over HTTP is supported in all environments for all handles / path. return True - def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: + def _resolve( + self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + ) -> tuple[str, Optional[int]]: api_client = KaggleApiV1Client() if not h.is_versioned(): - h.version = _get_current_version(api_client, h) + h = h.with_version(_get_current_version(api_client, h)) model_path = load_from_cache(h, path) if model_path and not force_download: - return model_path # Already cached + return model_path, h.version # Already cached elif model_path and force_download: delete_from_cache(h, path) @@ -192,7 +197,7 @@ def _inner_download_file(file: str) -> None: ) mark_as_complete(h, path) - return out_path + return out_path, h.version class NotebookOutputHttpResolver(Resolver[NotebookHandle]): @@ -200,15 +205,17 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 # Downloading files over HTTP is supported in all environments for all handles / paths. return True - def __call__(self, h: NotebookHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: + def _resolve( + self, h: NotebookHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + ) -> tuple[str, Optional[int]]: api_client = KaggleApiV1Client() if not h.is_versioned(): - h.version = _get_current_version(api_client, h) + h = h.with_version(_get_current_version(api_client, h)) notebook_path = load_from_cache(h, path) if notebook_path and not force_download: - return notebook_path # Already cached + return notebook_path, h.version # Already cached elif notebook_path and force_download: delete_from_cache(h, path) @@ -233,7 +240,8 @@ def __call__(self, h: NotebookHandle, path: Optional[str] = None, *, force_downl os.remove(archive_path) mark_as_complete(h, path) - return out_path + + return out_path, h.version def _list_files(self, api_client: KaggleApiV1Client, h: NotebookHandle) -> tuple[list[str], bool]: query = f"kernels/output/list/{h.owner}/{h.notebook}?page_size={MAX_NUM_FILES_DIRECT_DOWNLOAD}" @@ -265,6 +273,11 @@ def _extract_archive(archive_path: str, out_path: str) -> None: def _get_current_version(api_client: KaggleApiV1Client, h: ResourceHandle) -> int: + # Check if there's a Package in scope which has stored a version number used when it was created. + version_from_package_scope = PackageScope.get_version(h) + if version_from_package_scope is not None: + return version_from_package_scope + if isinstance(h, ModelHandle): json_response = api_client.get(_build_get_instance_url_path(h), h) if MODEL_INSTANCE_VERSION_FIELD not in json_response: diff --git a/src/kagglehub/kaggle_cache_resolver.py b/src/kagglehub/kaggle_cache_resolver.py index 1f537f6..52df8f1 100644 --- a/src/kagglehub/kaggle_cache_resolver.py +++ b/src/kagglehub/kaggle_cache_resolver.py @@ -12,6 +12,7 @@ from kagglehub.exceptions import BackendError from kagglehub.handle import CompetitionHandle, DatasetHandle, ModelHandle, NotebookHandle from kagglehub.logger import EXTRA_CONSOLE_BLOCK +from kagglehub.packages import PackageScope from kagglehub.resolver import Resolver KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME = "KAGGLE_CACHE_MOUNT_FOLDER" @@ -36,9 +37,9 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 return True return False - def __call__( + def _resolve( self, h: CompetitionHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False - ) -> str: + ) -> tuple[str, Optional[int]]: client = KaggleJwtClient() if force_download: logger.info( @@ -86,8 +87,8 @@ def __call__( f"You can acces the other files othe attached competition at '{cached_path}'" ) raise ValueError(msg) - return cached_filepath - return cached_path + return cached_filepath, None + return cached_path, None class DatasetKaggleCacheResolver(Resolver[DatasetHandle]): @@ -100,7 +101,9 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 return False - def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: + def _resolve( + self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + ) -> tuple[str, Optional[int]]: if force_download: logger.info( "Ignoring `force_download` argument when running inside the Kaggle notebook environment.", @@ -113,6 +116,11 @@ def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_downlo } if h.is_versioned(): dataset_ref["VersionNumber"] = str(h.version) + else: + # Check if there's a Package in scope which has stored a version number used when it was created. + version_from_package_scope = PackageScope.get_version(h) + if version_from_package_scope is not None: + dataset_ref["VersionNumber"] = str(version_from_package_scope) result = client.post( ATTACH_DATASOURCE_REQUEST_NAME, @@ -127,6 +135,7 @@ def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_downlo base_mount_path = os.getenv(KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME, DEFAULT_KAGGLE_CACHE_MOUNT_FOLDER) cached_path = f"{base_mount_path}/{result['mountSlug']}" + version = result.get("versionNumber") # None if missing if not os.path.exists(cached_path): # Only print this if the dataset is not already mounted. @@ -153,8 +162,8 @@ def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_downlo f"You can acces the other files othe attached dataset at '{cached_path}'" ) raise ValueError(msg) - return cached_filepath - return cached_path + return cached_filepath, version + return cached_path, version class ModelKaggleCacheResolver(Resolver[ModelHandle]): @@ -167,7 +176,9 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 return False - def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: + def _resolve( + self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + ) -> tuple[str, Optional[int]]: if force_download: logger.info( "Ignoring `force_download` argument when running inside the Kaggle notebook environment.", @@ -182,6 +193,11 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download } if h.is_versioned(): model_ref["VersionNumber"] = str(h.version) + else: + # Check if there's a Package in scope which has stored a version number used when it was created. + version_from_package_scope = PackageScope.get_version(h) + if version_from_package_scope is not None: + model_ref["VersionNumber"] = str(version_from_package_scope) result = client.post( ATTACH_DATASOURCE_REQUEST_NAME, @@ -196,6 +212,7 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download base_mount_path = os.getenv(KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME, DEFAULT_KAGGLE_CACHE_MOUNT_FOLDER) cached_path = f"{base_mount_path}/{result['mountSlug']}" + version = result.get("versionNumber") # None if missing if not os.path.exists(cached_path): # Only print this if the model is not already mounted. @@ -222,8 +239,8 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download f"You can access the other files of the attached model at '{cached_path}'" ) raise ValueError(msg) - return cached_filepath - return cached_path + return cached_filepath, version + return cached_path, version class NotebookOutputKaggleCacheResolver(Resolver[NotebookHandle]): @@ -236,7 +253,9 @@ def is_supported(self, *_, **__) -> bool: # noqa: ANN002, ANN003 return False - def __call__(self, h: NotebookHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: + def _resolve( + self, h: NotebookHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False + ) -> tuple[str, Optional[int]]: if force_download: logger.info( "Ignoring `force_download` argument when running inside the Kaggle notebook environment.", @@ -249,6 +268,11 @@ def __call__(self, h: NotebookHandle, path: Optional[str] = None, *, force_downl } if h.is_versioned(): kernel_ref["VersionNumber"] = str(h.version) + else: + # Check if there's a Package in scope which has stored a version number used when it was created. + version_from_package_scope = PackageScope.get_version(h) + if version_from_package_scope is not None: + kernel_ref["VersionNumber"] = str(version_from_package_scope) result = client.post( ATTACH_DATASOURCE_REQUEST_NAME, @@ -264,6 +288,7 @@ def __call__(self, h: NotebookHandle, path: Optional[str] = None, *, force_downl base_mount_path = os.getenv(KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME, DEFAULT_KAGGLE_CACHE_MOUNT_FOLDER) cached_path = f"{base_mount_path}/{result['mountSlug']}" + version = result.get("versionNumber") # None if missing if not os.path.exists(cached_path): # Only print this if the notebook output is not already mounted. @@ -290,5 +315,5 @@ def __call__(self, h: NotebookHandle, path: Optional[str] = None, *, force_downl f"You can access the other files of the attached notebook at '{cached_path}'" ) raise ValueError(msg) - return cached_filepath - return cached_path + return cached_filepath, version + return cached_path, version diff --git a/src/kagglehub/models.py b/src/kagglehub/models.py index 4fd0f53..3f2976f 100644 --- a/src/kagglehub/models.py +++ b/src/kagglehub/models.py @@ -34,7 +34,8 @@ def model_download( """ h = parse_model_handle(handle) logger.info(f"Downloading Model: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK}) - return registry.model_resolver(h, path, force_download=force_download) + path, _ = registry.model_resolver(h, path, force_download=force_download) + return path def model_upload( diff --git a/src/kagglehub/notebooks.py b/src/kagglehub/notebooks.py index 982a44d..d484a21 100644 --- a/src/kagglehub/notebooks.py +++ b/src/kagglehub/notebooks.py @@ -9,9 +9,7 @@ def notebook_output_download(handle: str, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: - """[WORK IN PROGRESS] - - Download notebook output files. + """Download notebook output files. Args: handle: (string) the notebook handle under https://kaggle.com/code. @@ -24,4 +22,5 @@ def notebook_output_download(handle: str, path: Optional[str] = None, *, force_d """ h = parse_notebook_handle(handle) logger.info(f"Downloading Notebook Output: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK}) - return registry.notebook_output_resolver(h, path, force_download=force_download) + path, _ = registry.notebook_output_resolver(h, path, force_download=force_download) + return path diff --git a/src/kagglehub/packages.py b/src/kagglehub/packages.py new file mode 100644 index 0000000..21556bd --- /dev/null +++ b/src/kagglehub/packages.py @@ -0,0 +1,195 @@ +import contextvars +import importlib +import inspect +import logging +import pathlib +import re +import sys +from functools import wraps +from types import ModuleType +from typing import Any, Callable, Optional + +from kagglehub import registry +from kagglehub.handle import ResourceHandle, parse_package_handle +from kagglehub.logger import EXTRA_CONSOLE_BLOCK +from kagglehub.requirements import VersionedDatasources, read_requirements + +logger = logging.getLogger(__name__) + +# Current version of the package format used by Kaggle Packages +PACKAGE_VERSION = "0.1.0" +# Name of module field referring to its package version +PACKAGE_VERSION_NAME = "__package_version__" + +# Expected name of the kagglehub requirements file +KAGGLEHUB_REQUIREMENTS_FILENAME = "kagglehub_requirements.yaml" + + +def package_import(handle: str, *, force_download: Optional[bool] = False) -> ModuleType: + """Download a Kaggle Package and import it. + + A Kaggle Package is a Kaggle Notebook which has exported code to a python package format. + + Args: + handle: (string) the notebook handle under https://kaggle.com/code. + force_download: (bool) Optional flag to force download motebook output, even if it's cached. + Returns: + The imported python package. + """ + h = parse_package_handle(handle) + + logger.info(f"Downloading Notebook Output for Package: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK}) + notebook_path, version = registry.notebook_output_resolver(h, path=None, force_download=force_download) + init_file_path = pathlib.Path(notebook_path) / "package" / "__init__.py" + if not init_file_path.exists(): + msg = f"Notebook '{h!s}' is not a Package, missing 'package/__init__.py' file." + raise ValueError(msg) + + # Unique module name based on handle + downloaded version + module_name = re.sub(r"[^a-zA-Z0-9_]", "_", f"kagglehub_package_{h.owner}_{h.notebook}_{version}") + + if module_name in sys.modules: + # If this module already exists and the user didn't re-download, just return it + if not force_download: + return sys.modules[module_name] + + # If already existing but re-downloaded, clear this module and any submodules and reload + logger.info( + f"Uninstalling existing package module {module_name} before re-installing.", extra={**EXTRA_CONSOLE_BLOCK} + ) + del sys.modules[module_name] + submodule_names = [name for name in sys.modules if name.startswith(f"{module_name}.")] + for name in submodule_names: + del sys.modules[name] + + spec = importlib.util.spec_from_file_location(module_name, init_file_path) + if spec is None or spec.loader is None: + msg = f"Could not load module from {init_file_path} as {module_name}." + raise ImportError(msg) + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + + return module + + +def get_package_asset_path(path: str) -> pathlib.Path: + """Returns a path referring to an asset file for use in Kaggle Packages. + + If within a PackageScope context, returns the path relative to it. This should be true + any time a Package has been imported, whether via `package_import` above or directly. + Otherwise, assumes we're in an interactive Kaggle Notebook and creating a Package, + where package data should be written to a staging directory which then gets copied + into the exported package when the Notebook is saved.""" + scope = PackageScope.get() + + assets_dir = scope.path / "assets" if scope else pathlib.Path("/kaggle/package_assets") + assets_dir.mkdir(parents=True, exist_ok=True) + + return assets_dir / path + + +def import_submodules(package_module: ModuleType) -> list[str]: + """Complete the import of a Kaggle Package's python module by importing all submodules. + + Only intended for use by Kaggle auto-generated package __init__.py files. + + Imports all (non-underscore-prefixed) sibling .py files as submodules, scopes members + from their `__all__` onto the parent module (similar to `from X import *`), and decorates + them with our PackageScope (see that class for more details). + + Args: + package_module: (ModuleType) The python module of the Kaggle Package. + Returns: + The names of all public members which we scoped onto the module (similar to `__all__`). + """ + package_version = getattr(package_module, PACKAGE_VERSION_NAME, None) + if package_version != PACKAGE_VERSION: + msg = f"Unsupported Kaggle Package version: {package_version}" + raise ValueError(msg) + + all_names: set[str] = set() + with PackageScope(package_module) as package_scope: + for filepath in package_scope.path.glob("[!_]*.py"): + submodule = importlib.import_module(f".{filepath.stem}", package=package_module.__name__) + package_scope.apply_to_module(submodule) + for name in submodule.__all__: + setattr(package_module, name, getattr(submodule, name)) + all_names.update(submodule.__all__) + + return sorted(all_names) + + +class PackageScope: + """Captures data about a Kaggle Package. Use as Context Manager to apply scope. + + Only intended for use by Kaggle auto-generated package __init__.py files. + + When scope is applied, certain `kagglehub` calls will utilize the Package info. + Specifically, downloading a datasource without a version specified will check + the PackageScope and use the Package's version if it finds a matching entry. + `kagglehub.get_package_asset_path` also relies on the current scope to pull + assets related to a package. + """ + + _ctx = contextvars.ContextVar("kagglehub_package_scope", default=None) + + def __init__(self, package_module: ModuleType): + """Only intended for use by Kaggle auto-generated package __init__.py files.""" + if not package_module.__file__: + msg = f"Package module '{package_module.__name__}' missing '__file__'." + raise Exception(msg) + + self.path: pathlib.Path = pathlib.Path(package_module.__file__).parent + self.datasources: VersionedDatasources = read_requirements(self.path / KAGGLEHUB_REQUIREMENTS_FILENAME) + + self._token_stack: list[contextvars.Token] = [] + + def __enter__(self): + token = PackageScope._ctx.set(self) + self._token_stack.append(token) + return self + + def __exit__(self, exc_type, exc_value, traceback): # noqa: ANN001 + token = self._token_stack.pop() + PackageScope._ctx.reset(token) + + @staticmethod + def get() -> Optional["PackageScope"]: + """Gets the currently applied PackageScope, or None if none applied.""" + return PackageScope._ctx.get() + + @staticmethod + def get_version(h: ResourceHandle) -> Optional[int]: + """Gets version number for given resource within current PackageScope (if any). + + Returns None if no PackageScope is applied, or if it didn't contain the resource.""" + scope = PackageScope.get() + + return scope.datasources.get(h) if scope else None + + def apply_to_module(self, module: ModuleType) -> None: + """Decorates all functions/methods in the module to apply our scope.""" + + def decorate(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): # noqa: ANN002, ANN003, ANN202 + with self: + return func(*args, **kwargs) + + return wrapper + + stack: list[Any] = [module] + while stack: + obj = stack.pop() + for name, member in inspect.getmembers(obj): + # Only decorate things defined within the module. + if getattr(member, "__module__", None) != module.__name__: + continue + # Recurse on a class to decorate its functions / methods. + # These denylisted entries cause infinite loops. + elif inspect.isclass(member) and name not in ["__base__", "__class__"]: + stack.append(member) + elif inspect.isfunction(member) or inspect.ismethod(member): + setattr(obj, name, decorate(member)) diff --git a/src/kagglehub/registry.py b/src/kagglehub/registry.py index 0d31d66..5009834 100644 --- a/src/kagglehub/registry.py +++ b/src/kagglehub/registry.py @@ -1,4 +1,4 @@ -from typing import Generic, TypeVar +from typing import Generic, Optional, TypeVar from kagglehub.handle import CompetitionHandle, DatasetHandle, ModelHandle, NotebookHandle, ResourceHandle from kagglehub.resolver import Resolver @@ -21,7 +21,7 @@ def __init__(self, name: str) -> None: def add_implementation(self, impl: Resolver[T]) -> None: self._impls.append(impl) - def __call__(self, *args, **kwargs) -> str: # noqa: ANN002, ANN003 + def __call__(self, *args, **kwargs) -> tuple[str, Optional[int]]: # noqa: ANN002, ANN003 fails = [] for impl in reversed(self._impls): if impl.is_supported(*args, **kwargs): diff --git a/src/kagglehub/requirements.py b/src/kagglehub/requirements.py new file mode 100644 index 0000000..3b1bc1c --- /dev/null +++ b/src/kagglehub/requirements.py @@ -0,0 +1,128 @@ +import pathlib +from typing import Any, Optional, Union + +import yaml + +from kagglehub.handle import ( + CompetitionHandle, + DatasetHandle, + ModelHandle, + NotebookHandle, + PackageHandle, + ResourceHandle, + UtilityScriptHandle, + parse_competition_handle, + parse_dataset_handle, + parse_model_handle, + parse_notebook_handle, + parse_package_handle, + parse_utility_script_handle, +) + +# Current version of the file format written here +FORMAT_VERSION = "0.1.0" + +FORMAT_VERSION_FIELD = "format_version" +DATASOURCES_FIELD = "datasources" +DATASOURCE_TYPE_FIELD = "type" +DATASOURCE_REF_FIELD = "ref" +DATASOURCE_VERSION_FIELD = "version" + +HANDLE_TYPE_NAMES = { + CompetitionHandle: "Competition", + DatasetHandle: "Dataset", + ModelHandle: "Model", + NotebookHandle: "Notebook", + UtilityScriptHandle: "UtilityScript", + PackageHandle: "Package", +} + +HANDLE_TYPE_PARSERS = { + HANDLE_TYPE_NAMES[CompetitionHandle]: parse_competition_handle, + HANDLE_TYPE_NAMES[DatasetHandle]: parse_dataset_handle, + HANDLE_TYPE_NAMES[ModelHandle]: parse_model_handle, + HANDLE_TYPE_NAMES[NotebookHandle]: parse_notebook_handle, + HANDLE_TYPE_NAMES[UtilityScriptHandle]: parse_utility_script_handle, + HANDLE_TYPE_NAMES[PackageHandle]: parse_package_handle, +} + +# Maps requested ResourceHandle (which may include version) to version used +VersionedDatasources = dict[ResourceHandle, Optional[int]] + +# Tracks datasources accessed in the current session +_accessed_datasources: VersionedDatasources = {} + + +def register_accessed_datasource(handle: ResourceHandle, version: Optional[int]) -> None: + """Record that a datasource was accessed. + + Link the user-requested handle to the version retrieved.""" + _accessed_datasources[handle] = version + + +def get_accessed_datasources() -> VersionedDatasources: + return _accessed_datasources.copy() + + +def write_requirements(filepath: str) -> None: + """Write the datasources accessed during this session to a yaml file.""" + data = { + FORMAT_VERSION_FIELD: FORMAT_VERSION, + DATASOURCES_FIELD: [_serialize_datasource(h, version) for h, version in _accessed_datasources.items()], + } + + with open(filepath, "w") as f: + yaml.dump(data, f, sort_keys=False) + + +def read_requirements(filepath: Union[str, pathlib.Path]) -> VersionedDatasources: + """Read a yaml file with datasource + version records.""" + with open(filepath) as f: + data = yaml.safe_load(f) + + format_version = data.get(FORMAT_VERSION_FIELD) + if format_version != FORMAT_VERSION: + msg = f"Unsupported requirements format version: {format_version}" + raise ValueError(msg) + + versioned_datasources: VersionedDatasources = {} + for datasource in data.get(DATASOURCES_FIELD, []): + h, version = _deserialize_datasource(datasource) + versioned_datasources[h] = version + + return versioned_datasources + + +def _serialize_datasource(h: ResourceHandle, version: Optional[int]) -> dict: + data: dict[str, Any] = { + DATASOURCE_TYPE_FIELD: HANDLE_TYPE_NAMES[type(h)], + DATASOURCE_REF_FIELD: str(h), + } + + if version is not None: + data[DATASOURCE_VERSION_FIELD] = version + + return data + + +def _deserialize_datasource(data: dict) -> tuple[ResourceHandle, Optional[int]]: + parser = HANDLE_TYPE_PARSERS[data[DATASOURCE_TYPE_FIELD]] + h = parser(data[DATASOURCE_REF_FIELD]) + version = _parse_version(data.get(DATASOURCE_VERSION_FIELD, None)) + + return h, version + + +def _parse_version(version: Any) -> Optional[int]: # noqa: ANN401 + if version is None or isinstance(version, int): + return version + + if isinstance(version, str): + try: + return int(version) + except: # noqa: E722, S110 + # Fall-through to the raise below + pass + + msg = f"Invalid version: '{version}'. Expected an integer or None." + raise ValueError(msg) diff --git a/src/kagglehub/resolver.py b/src/kagglehub/resolver.py index 1886e09..d92e9c2 100644 --- a/src/kagglehub/resolver.py +++ b/src/kagglehub/resolver.py @@ -1,7 +1,10 @@ import abc from typing import Generic, Optional, TypeVar -T = TypeVar("T") +from kagglehub.handle import ResourceHandle +from kagglehub.requirements import register_accessed_datasource + +T = TypeVar("T", bound=ResourceHandle) class Resolver(Generic[T]): @@ -9,8 +12,9 @@ class Resolver(Generic[T]): __metaclass__ = abc.ABCMeta - @abc.abstractmethod - def __call__(self, handle: T, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str: + def __call__( + self, handle: T, path: Optional[str] = None, *, force_download: Optional[bool] = False + ) -> tuple[str, Optional[int]]: """Resolves a handle into a path with the requested model files. Args: @@ -18,9 +22,29 @@ def __call__(self, handle: T, path: Optional[str] = None, *, force_download: Opt path: (string) Optional path to a file within the model bundle. force_download: (bool) Optional flag to force download a model, even if it's cached. + Returns: + String representing the path. + """ + path, version = self._resolve(handle, path, force_download=force_download) + + # Note handles are immutable, so resolve() could not have altered our reference + register_accessed_datasource(handle, version) + + return path, version + + @abc.abstractmethod + def _resolve( + self, handle: T, path: Optional[str] = None, *, force_download: Optional[bool] = False + ) -> tuple[str, Optional[int]]: + """Resolves a handle into a path with the requested model files. + + Args: + handle: (string) the model handle to resolve. + path: (string) Optional path to a file within the model bundle. + force_download: (bool) Optional flag to force download a model, even if it's cached. Returns: - A string representing the path + A tuple of: 1) string representing the path 2) version number of resolved datasource, if applicable. """ pass diff --git a/src/kagglehub/utility_scripts.py b/src/kagglehub/utility_scripts.py index d4e80ed..8112935 100644 --- a/src/kagglehub/utility_scripts.py +++ b/src/kagglehub/utility_scripts.py @@ -27,7 +27,7 @@ def utility_script_install(handle: str, *, force_download: Optional[bool] = Fals h = parse_utility_script_handle(handle) logger.info(f"Downloading Utility Script: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK}) - utility_script_path = registry.notebook_output_resolver(h, path=None, force_download=force_download) + utility_script_path, _ = registry.notebook_output_resolver(h, path=None, force_download=force_download) if not _is_notebook_utility_script(h): logger.info( diff --git a/tests/data/package-v1.zip b/tests/data/package-v1.zip new file mode 100644 index 0000000000000000000000000000000000000000..8e0fba0f655568073ade40e425d661085aa52df6 GIT binary patch literal 722 zcmWIWW@Zs#;A7xm__E46iUA1-1K9+$DojKp#AHE?p!2E*o72|;Ck4~M@J)y4^cmk;4yysaB-yjd2u#Wc6J9a!O{v7O8 z93+|}`slG@*5k*fj~;&(Dh>uRc0PJ6TwHWOP08e?qpr@i<}3|8-&3r!gVr#b`c39$ z2=Hd-Q2F#xq!8$y=|CKS)5ZKqE>6xbN(H($IIMt&_hWha*J9oxp2E_=kb;llA3pO3 z7LLY*E6#Xddi*FTEFd)0 zLr3>yxWR+2e*ZHkJ3(IV2=hOEKD@uxEBX4Fsgq_cn=ozu`h}TCmo8YaAYe*VP}G#= z3+7FUSaOaP?1LXE3#waz?wfzW&kuU&^05w4CHYbSkhR+h|r8D7y`Um S*+80@f$%(#ehbvbzyJWR((0-J literal 0 HcmV?d00001 diff --git a/tests/data/package-v2.zip b/tests/data/package-v2.zip new file mode 100644 index 0000000000000000000000000000000000000000..eeaf5c33067f66821dd3c5468af5a9c409ef3bd5 GIT binary patch literal 722 zcmWIWW@Zs#;A7xmSiQ+$DojKp#AHE?p!2E*o72|;Ck4~M@J)y4^cmk;4yysaB-yjd2u#Wc6J9a!O{v7O8 z93+|}`slG@*5k*fj~;&(Dh>uRc0PJ6TwHWOP08e?qpr@i<}3|8-&3r!gVr#b`c39$ z2=Hd-Q2F#xq!8$y=|CI+ckvoD7xN>zI61#473kXFumT?5kLBfGi+PK93QGe+3OTW`Le#A*EM}Vy;C|G7f`R1`oRW{m-221bMk5%>VTH@cve>>Fav=>23?^=zGcpwm3}raI*%`ST(T4j7%cT zi0DQRR#0@qz>-D~3rpk&cq0q|g+8*cK%ox ResponseReturnValue: @@ -36,6 +70,11 @@ def error(e: Exception): # noqa: ANN201 @app.route("/kaggle-jwt-handler/AttachDatasourceUsingJwtRequest", methods=["POST"]) def attach_datasource_using_jwt_request() -> ResponseReturnValue: + cache_mount_folder = os.getenv(KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME) + if not cache_mount_folder: + msg = f"Missing envvar '{KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME}'" + raise ValueError(msg) + data = request.get_json() if "modelRef" in data: model_ref = data["modelRef"] @@ -44,17 +83,17 @@ def attach_datasource_using_jwt_request() -> ResponseReturnValue: version_number = model_ref["VersionNumber"] mount_slug = f"{model_ref['ModelSlug']}/{model_ref['Framework']}/{model_ref['InstanceSlug']}/{version_number}" # Load the files - cache_mount_folder = os.getenv(KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME) - base_path = f"{cache_mount_folder}/{mount_slug}" - os.makedirs(base_path, exist_ok=True) - Path(f"{base_path}/config.json").touch() + base_path = Path(cache_mount_folder) / mount_slug + base_path.mkdir(parents=True, exist_ok=True) + (base_path / "config.json").touch() if version_number == LATEST_MODEL_VERSION: # The latest version has an extra file. - Path(f"{base_path}/model.keras").touch() + (base_path / "model.keras").touch() data = { "wasSuccessful": True, "result": { "mountSlug": mount_slug, + "versionNumber": version_number, }, } return jsonify(data), 200 @@ -65,17 +104,17 @@ def attach_datasource_using_jwt_request() -> ResponseReturnValue: version_number = dataset_ref["VersionNumber"] mount_slug = f"{dataset_ref['DatasetSlug']}" # Load the files - cache_mount_folder = os.getenv(KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME) - base_path = f"{cache_mount_folder}/{mount_slug}" - os.makedirs(base_path, exist_ok=True) - Path(f"{base_path}/foo.txt").touch() + base_path = Path(cache_mount_folder) / mount_slug + base_path.mkdir(parents=True, exist_ok=True) + (base_path / "foo.txt").touch() if version_number == LATEST_DATASET_VERSION: # The latest version has an extra file. - Path(f"{base_path}/bar.csv").touch() + (base_path / "bar.csv").touch() data = { "wasSuccessful": True, "result": { "mountSlug": mount_slug, + "versionNumber": version_number, }, } return jsonify(data), 200 @@ -83,11 +122,10 @@ def attach_datasource_using_jwt_request() -> ResponseReturnValue: competition_ref = data["competitionRef"] mount_slug = f"{competition_ref['CompetitionSlug']}" # Load the files - cache_mount_folder = os.getenv(KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME) - base_path = f"{cache_mount_folder}/{mount_slug}" - os.makedirs(base_path, exist_ok=True) - Path(f"{base_path}/foo.txt").touch() - Path(f"{base_path}/bar.csv").touch() + base_path = Path(cache_mount_folder) / mount_slug + base_path.mkdir(parents=True, exist_ok=True) + (base_path / "foo.txt").touch() + (base_path / "bar.csv").touch() data = { "wasSuccessful": True, "result": { @@ -102,17 +140,21 @@ def attach_datasource_using_jwt_request() -> ResponseReturnValue: version_number = kernel_ref["VersionNumber"] mount_slug = f"{kernel_ref['KernelSlug']}" # Load the files - cache_mount_folder = os.getenv(KAGGLE_CACHE_MOUNT_FOLDER_ENV_VAR_NAME) - base_path = f"{cache_mount_folder}/{mount_slug}" - os.makedirs(base_path, exist_ok=True) - Path(f"{base_path}/foo.txt").touch() - if version_number == LATEST_KERNEL_VERSION: - # The latest version has an extra file. - Path(f"{base_path}/bar.csv").touch() + base_path = Path(cache_mount_folder) / mount_slug + base_path.mkdir(parents=True, exist_ok=True) + latest_version = version_number == LATEST_KERNEL_VERSION + if mount_slug == "test-package": + _write_package_files(base_path, latest_version) + else: + (base_path / "foo.txt").touch() + if version_number == LATEST_KERNEL_VERSION: + # The latest version has an extra file. + (base_path / "bar.csv").touch() data = { "wasSuccessful": True, "result": { "mountSlug": mount_slug, + "versionNumber": version_number, }, } return jsonify(data), 200 @@ -120,6 +162,21 @@ def attach_datasource_using_jwt_request() -> ResponseReturnValue: return jsonify(data), 500 +def _write_package_files(base_path: Path, latest_version: bool) -> None: # noqa: FBT001 + package_path = base_path / "package" + package_path.mkdir(parents=True, exist_ok=True) + + (package_path / "__init__.py").write_text(PACKAGE_INIT_PY_TEXT) + (package_path / "kagglehub_requirements.yaml").write_text(PACKAGE_REQUIREMENTS_YAML_TEXT) + (package_path / "foo.py").write_text(PACKAGE_FOO_PY_TEXT) + + if latest_version: + (package_path / "bar.py").write_text(PACKAGE_BAR_PY_TEXT) + assets_path = package_path / "assets" + assets_path.mkdir(parents=True, exist_ok=True) + (package_path / "assets" / "asset.txt").write_text(PACKAGE_ASSET_TEXT) + + @contextmanager def create_env() -> Generator[Any, Any, Any]: with TemporaryDirectory() as cache_mount_folder: diff --git a/tests/server_stubs/notebook_output_download_stub.py b/tests/server_stubs/notebook_output_download_stub.py index 67b865e..0dc488c 100644 --- a/tests/server_stubs/notebook_output_download_stub.py +++ b/tests/server_stubs/notebook_output_download_stub.py @@ -37,8 +37,10 @@ def notebook_output_download(owner_slug: str, kernel_slug: str) -> ResponseRetur # First, determine if we're fetching a file or the whole notebook output file_name_query_param = request.args.get("file_path") - test_file_name = "foo.txt.zip" - if file_name_query_param: + if kernel_slug == "package-test": + version = request.args.get("version_number", type=int) + test_file_name = f"package-v{version}.zip" + elif file_name_query_param: # This mimics behavior for our file downloads, where users request a file, but # receive a zipped version of the file from GCS. test_file_name = ( @@ -46,6 +48,8 @@ def notebook_output_download(owner_slug: str, kernel_slug: str) -> ResponseRetur if file_name_query_param == AUTO_COMPRESSED_FILE_NAME else file_name_query_param ) + else: + test_file_name = "foo.txt.zip" return get_gcs_redirect_response(test_file_name) diff --git a/tests/test_http_package_import.py b/tests/test_http_package_import.py new file mode 100644 index 0000000..fbf61f6 --- /dev/null +++ b/tests/test_http_package_import.py @@ -0,0 +1,59 @@ +import os +import sys + +import kagglehub +from kagglehub.cache import NOTEBOOKS_CACHE_SUBFOLDER, get_cached_archive_path +from kagglehub.handle import parse_package_handle +from tests.fixtures import BaseTestCase + +from .server_stubs import notebook_output_download_stub as stub +from .server_stubs import serv +from .utils import create_test_cache + +INVALID_ARCHIVE_PACKAGE_HANDLE = "invalid/invalid/invalid/invalid/invalid" +VERSIONED_PACKAGE_HANDLE = "dster/package-test/versions/1" +UNVERSIONED_PACKAGE_HANDLE = "dster/package-test" + +EXPECTED_NOTEBOOK_SUBDIR = os.path.join(NOTEBOOKS_CACHE_SUBFOLDER, "dster", "package-test", "output", "versions", "1") + + +class TestHttpPackageImport(BaseTestCase): + + def tearDown(self) -> None: + # Clear any imported packages from sys.modules. + for name in list(sys.modules.keys()): + if name.startswith("kagglehub_package"): + del sys.modules[name] + + @classmethod + def setUpClass(cls): + cls.server = serv.start_server(stub.app) + + @classmethod + def tearDownClass(cls): + cls.server.shutdown() + + def test_package_versioned_succeeds(self) -> None: + with create_test_cache(): + package = kagglehub.package_import(VERSIONED_PACKAGE_HANDLE) + + self.assertIn("foo", dir(package)) + self.assertEqual("bar", package.foo()) + + archive_path = get_cached_archive_path(parse_package_handle(VERSIONED_PACKAGE_HANDLE)) + self.assertFalse(os.path.exists(archive_path)) + + def test_package_unversioned_succeeds(self) -> None: + with create_test_cache(): + package = kagglehub.package_import(UNVERSIONED_PACKAGE_HANDLE) + + self.assertIn("foo", dir(package)) + self.assertEqual("baz", package.foo()) + + archive_path = get_cached_archive_path(parse_package_handle(UNVERSIONED_PACKAGE_HANDLE)) + self.assertFalse(os.path.exists(archive_path)) + + def test_notebook_download_bad_archive(self) -> None: + with create_test_cache(): + with self.assertRaises(ValueError): + kagglehub.package_import(INVALID_ARCHIVE_PACKAGE_HANDLE) diff --git a/tests/test_http_requirements.py b/tests/test_http_requirements.py new file mode 100644 index 0000000..77a819f --- /dev/null +++ b/tests/test_http_requirements.py @@ -0,0 +1,54 @@ +from pathlib import Path +from tempfile import TemporaryDirectory + +import kagglehub +from kagglehub.handle import parse_model_handle +from kagglehub.requirements import read_requirements, write_requirements +from tests.fixtures import BaseTestCase + +from .server_stubs import model_download_stub as stub +from .server_stubs import serv +from .utils import create_test_cache + +UNVERSIONED_DATASET_HANDLE = "sarahjeffreson/featured-spotify-artiststracks-with-metadata" +VERSIONED_MODEL_HANDLE = "metaresearch/llama-2/pyTorch/13b/3" +UNVERSIONED_MODEL_HANDLE = "metaresearch/llama-2/pyTorch/13b" + + +class TestHttpRequirements(BaseTestCase): + + def setUp(self) -> None: + # Clear out our `requirements` tracking between tests + kagglehub.requirements._accessed_datasources = {} + + @classmethod + def setUpClass(cls): + cls.server = serv.start_server(stub.app) + + @classmethod + def tearDownClass(cls): + cls.server.shutdown() + + def test_two_models(self) -> None: + with create_test_cache(): + kagglehub.model_download(UNVERSIONED_MODEL_HANDLE) + kagglehub.model_download(VERSIONED_MODEL_HANDLE) + + with TemporaryDirectory() as d: + requirements_path = str(Path(d) / "requirements.yaml") + write_requirements(requirements_path) + datasources = read_requirements(requirements_path) + + self.assertEqual(2, len(datasources)) + # Check the versions of each accessed datasource + self.assertEqual(3, datasources[parse_model_handle(UNVERSIONED_MODEL_HANDLE)]) + self.assertEqual(3, datasources[parse_model_handle(VERSIONED_MODEL_HANDLE)]) + + def test_no_datasources(self) -> None: + with create_test_cache(): + with TemporaryDirectory() as d: + requirements_path = str(Path(d) / "requirements.yaml") + write_requirements(requirements_path) + datasources = read_requirements(requirements_path) + + self.assertEqual(0, len(datasources)) diff --git a/tests/test_kaggle_cache_package_import.py b/tests/test_kaggle_cache_package_import.py new file mode 100644 index 0000000..46811fc --- /dev/null +++ b/tests/test_kaggle_cache_package_import.py @@ -0,0 +1,78 @@ +import os +import sys +from unittest import mock + +import requests + +import kagglehub +from kagglehub.config import DISABLE_KAGGLE_CACHE_ENV_VAR_NAME +from kagglehub.env import KAGGLE_DATA_PROXY_URL_ENV_VAR_NAME +from tests.fixtures import BaseTestCase + +from .server_stubs import jwt_stub as stub +from .server_stubs import serv + +INVALID_ARCHIVE_PACKAGE_HANDLE = "invalid/invalid/invalid/invalid/invalid" +VERSIONED_PACKAGE_HANDLE = "alexisbcook/test-package/versions/1" +UNVERSIONED_PACKAGE_HANDLE = "alexisbcook/test-package" + + +# Test cases for package_import and get_packet_asset_path. +class TestKaggleCachePackageImport(BaseTestCase): + + def tearDown(self) -> None: + # Clear any imported packages from sys.modules. + for name in list(sys.modules.keys()): + if name.startswith("kagglehub_package"): + del sys.modules[name] + + @classmethod + def setUpClass(cls): + cls.server = serv.start_server(stub.app, KAGGLE_DATA_PROXY_URL_ENV_VAR_NAME, "http://localhost:7778") + + @classmethod + def tearDownClass(cls): + cls.server.shutdown() + + def test_unversioned_package_import(self) -> None: + with stub.create_env(): + package = kagglehub.package_import(UNVERSIONED_PACKAGE_HANDLE) + self.assertEqual("kaggle", package.foo()) + self.assertEqual("abcd", package.bar()) + + def test_versioned_package_import(self) -> None: + with stub.create_env(): + package = kagglehub.package_import(VERSIONED_PACKAGE_HANDLE) + self.assertEqual("kaggle", package.foo()) + self.assertFalse(hasattr(package, "bar")) + + def test_kaggle_resolver_skipped(self) -> None: + with mock.patch.dict(os.environ, {DISABLE_KAGGLE_CACHE_ENV_VAR_NAME: "true"}): + with stub.create_env(): + # Assert that a ConnectionError is set (uses HTTP server which is not set) + with self.assertRaises(requests.exceptions.ConnectionError): + kagglehub.package_import(UNVERSIONED_PACKAGE_HANDLE) + + def test_versioned_package_import_bad_handle_raises(self) -> None: + with self.assertRaises(ValueError): + kagglehub.package_import("bad handle") + + def test_versioned_package_import_returns_same(self) -> None: + with stub.create_env(): + # Importing the same package a second time returns the same exact package. + package = kagglehub.package_import(VERSIONED_PACKAGE_HANDLE) + package2 = kagglehub.package_import(VERSIONED_PACKAGE_HANDLE) + self.assertEqual(package, package2) + + def test_versioned_package_import_force_download_returns_different(self) -> None: + with stub.create_env(): + # Re-importing with force_download re-installs anew. + package = kagglehub.package_import(VERSIONED_PACKAGE_HANDLE) + package_forced = kagglehub.package_import(VERSIONED_PACKAGE_HANDLE, force_download=True) + self.assertNotEqual(package, package_forced) + + def test_versioned_package_import_with_force_download_explicitly_false(self) -> None: + with stub.create_env(): + package = kagglehub.package_import(VERSIONED_PACKAGE_HANDLE, force_download=False) + self.assertEqual("kaggle", package.foo()) + self.assertFalse(hasattr(package, "bar")) diff --git a/tests/test_kaggle_cache_requirements.py b/tests/test_kaggle_cache_requirements.py new file mode 100644 index 0000000..7473d60 --- /dev/null +++ b/tests/test_kaggle_cache_requirements.py @@ -0,0 +1,70 @@ +from pathlib import Path +from tempfile import TemporaryDirectory + +import kagglehub +from kagglehub.env import KAGGLE_DATA_PROXY_URL_ENV_VAR_NAME +from kagglehub.handle import parse_dataset_handle, parse_model_handle +from kagglehub.requirements import read_requirements, write_requirements +from tests.fixtures import BaseTestCase + +from .server_stubs import jwt_stub as stub +from .server_stubs import serv + +UNVERSIONED_DATASET_HANDLE = "sarahjeffreson/featured-spotify-artiststracks-with-metadata" +VERSIONED_MODEL_HANDLE = "metaresearch/llama-2/pyTorch/13b/1" +UNVERSIONED_MODEL_HANDLE = "metaresearch/llama-2/pyTorch/13b" + + +# Test cases for requirements.py submodule. +class TestKaggleCacheRequirements(BaseTestCase): + + def setUp(self) -> None: + # Clear out our `requirements` tracking between tests + kagglehub.requirements._accessed_datasources = {} + + @classmethod + def setUpClass(cls): + cls.server = serv.start_server(stub.app, KAGGLE_DATA_PROXY_URL_ENV_VAR_NAME, "http://localhost:7778") + + @classmethod + def tearDownClass(cls): + cls.server.shutdown() + + def test_dataset_and_model(self) -> None: + with stub.create_env(): + kagglehub.dataset_download(UNVERSIONED_DATASET_HANDLE) + kagglehub.model_download(VERSIONED_MODEL_HANDLE) + + with TemporaryDirectory() as d: + requirements_path = str(Path(d) / "requirements.yaml") + write_requirements(requirements_path) + datasources = read_requirements(requirements_path) + + self.assertEqual(2, len(datasources)) + # Check the versions of each accessed datasource + self.assertEqual(2, datasources[parse_dataset_handle(UNVERSIONED_DATASET_HANDLE)]) + self.assertEqual(1, datasources[parse_model_handle(VERSIONED_MODEL_HANDLE)]) + + def test_two_models(self) -> None: + with stub.create_env(): + kagglehub.model_download(UNVERSIONED_MODEL_HANDLE) + kagglehub.model_download(VERSIONED_MODEL_HANDLE) + + with TemporaryDirectory() as d: + requirements_path = str(Path(d) / "requirements.yaml") + write_requirements(requirements_path) + datasources = read_requirements(requirements_path) + + self.assertEqual(2, len(datasources)) + # Check the versions of each accessed datasource + self.assertEqual(2, datasources[parse_model_handle(UNVERSIONED_MODEL_HANDLE)]) + self.assertEqual(1, datasources[parse_model_handle(VERSIONED_MODEL_HANDLE)]) + + def test_no_datasources(self) -> None: + with stub.create_env(): + with TemporaryDirectory() as d: + requirements_path = str(Path(d) / "requirements.yaml") + write_requirements(requirements_path) + datasources = read_requirements(requirements_path) + + self.assertEqual(0, len(datasources)) diff --git a/tests/test_registry.py b/tests/test_registry.py index 458b29f..88b3ef4 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -1,11 +1,11 @@ -from typing import Any, Callable +from typing import Any, Callable, Optional from kagglehub import registry from kagglehub.handle import ResourceHandle from kagglehub.resolver import Resolver from tests.fixtures import BaseTestCase -SOME_VALUE: str = "test" +SOME_VALUE: tuple[str, Optional[int]] = ("test", 1) class FakeHandle(ResourceHandle): @@ -17,19 +17,19 @@ class FakeImpl(Resolver[FakeHandle]): def __init__( self, is_supported_fn: Callable[[FakeHandle], bool], - call_fn: Callable[[FakeHandle], str], + resolve_fn: Callable[[FakeHandle], tuple[str, Optional[int]]], ): self._is_supported_fn = is_supported_fn - self._call_fn = call_fn + self._resolve_fn = resolve_fn def is_supported(self, *args: Any, **kwargs: Any) -> bool: # noqa: ANN401 return self._is_supported_fn(*args, **kwargs) - def __call__(self, *args: Any, **kwargs: Any) -> str: # noqa: ANN401 - return self._call_fn(*args, **kwargs) + def _resolve(self, *args: Any, **kwargs: Any) -> tuple[str, Optional[int]]: # noqa: ANN401 + return self._resolve_fn(*args, **kwargs) -def fail_fn(*_, **__) -> str: # noqa: ANN002, ANN003 +def fail_fn(*_, **__) -> tuple[str, Optional[int]]: # noqa: ANN002, ANN003 msg = "fail_fn should not be called" raise AssertionError(msg)