Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MVP support for Kaggle Packages #196

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions integration_tests/test_package_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import sys
import unittest

from kagglehub import package_import

from .utils import create_test_cache, unauthenticated

UNVERSIONED_HANDLE = "dster/package-test"
VERSIONED_HANDLE = "dster/package-test/versions/1"


class TestPackageImport(unittest.TestCase):

def tearDown(self) -> None:
# Clear any imported packages from sys.modules.
for name in list(sys.modules.keys()):
if name.startswith("kagglehub_package"):
del sys.modules[name]

def test_package_versioned_succeeds(self) -> None:
with create_test_cache():
package = package_import(VERSIONED_HANDLE)

self.assertIn("foo", dir(package))
self.assertEqual("bar", package.foo())

def test_package_unversioned_succeeds(self) -> None:
with create_test_cache():
package = package_import(UNVERSIONED_HANDLE)

self.assertIn("foo", dir(package))
self.assertEqual("baz", package.foo())

def test_download_private_package_succeeds(self) -> None:
with create_test_cache():
package = package_import("integrationtester/kagglehub-test-private-package")

self.assertIn("foo", dir(package))
self.assertEqual("bar", package.foo())

def test_public_package_with_unauthenticated_succeeds(self) -> None:
with create_test_cache():
with unauthenticated():
package = package_import(UNVERSIONED_HANDLE)

self.assertIn("foo", dir(package))
self.assertEqual("baz", package.foo())
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"tqdm",
"packaging",
"model_signing",
"pyyaml",
]

[project.urls]
Expand Down
3 changes: 2 additions & 1 deletion src/kagglehub/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.6"
__version__ = "0.3.7"

import kagglehub.logger # configures the library logger.
from kagglehub import colab_cache_resolver, http_resolver, kaggle_cache_resolver, registry
Expand All @@ -7,6 +7,7 @@
from kagglehub.datasets import KaggleDatasetAdapter, dataset_download, dataset_upload, load_dataset
from kagglehub.models import model_download, model_upload
from kagglehub.notebooks import notebook_output_download
from kagglehub.packages import get_package_asset_path, package_import
from kagglehub.utility_scripts import utility_script_install

registry.model_resolver.add_implementation(http_resolver.ModelHttpResolver())
Expand Down
2 changes: 1 addition & 1 deletion src/kagglehub/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ class ColabClient:
MOUNT_PATH = "/kagglehub/models/mount"
MODEL_MOUNT_PATH = "/kagglehub/models/mount"
DATASET_MOUNT_PATH = "/kagglehub/datasets/mount"
# TBE_RUNTIME_ADDR serves requests made from `is_supported` and `__call__`
# TBE_RUNTIME_ADDR serves requests made from `is_supported` and `_resolve`
# of ModelColabCacheResolver.
TBE_RUNTIME_ADDR_ENV_VAR_NAME = "TBE_RUNTIME_ADDR"

Expand Down
61 changes: 47 additions & 14 deletions src/kagglehub/colab_cache_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from kagglehub.exceptions import BackendError, NotFoundError
from kagglehub.handle import DatasetHandle, ModelHandle
from kagglehub.logger import EXTRA_CONSOLE_BLOCK
from kagglehub.packages import PackageScope
from kagglehub.resolver import Resolver

COLAB_CACHE_MOUNT_FOLDER_ENV_VAR_NAME = "COLAB_CACHE_MOUNT_FOLDER"
Expand All @@ -29,17 +30,20 @@ def is_supported(self, handle: ModelHandle, *_, **__) -> bool: # noqa: ANN002,
"variation": handle.variation,
}

if handle.is_versioned():
version = _get_model_version(handle)
if version:
# Colab treats version as int in the request
data["version"] = handle.version # type: ignore
data["version"] = version # type: ignore

try:
api_client.post(data, ColabClient.IS_MODEL_SUPPORTED_PATH, handle)
except NotFoundError:
return False
return True

def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
def _resolve(
self, h: ModelHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Colab notebook environment.",
Expand All @@ -53,9 +57,11 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download
"framework": h.framework,
"variation": h.variation,
}
if h.is_versioned():

version = _get_model_version(h)
if version:
# Colab treats version as int in the request
data["version"] = h.version # type: ignore
data["version"] = version # type: ignore

response = api_client.post(data, ColabClient.MODEL_MOUNT_PATH, h)

Expand Down Expand Up @@ -85,8 +91,8 @@ def __call__(self, h: ModelHandle, path: Optional[str] = None, *, force_download
f"You can access the other files of the attached model at '{cached_path}'"
)
raise ValueError(msg)
return cached_filepath
return cached_path
return cached_filepath, version
return cached_path, version


class DatasetColabCacheResolver(Resolver[DatasetHandle]):
Expand All @@ -100,17 +106,20 @@ def is_supported(self, handle: DatasetHandle, *_, **__) -> bool: # noqa: ANN002
"dataset": handle.dataset,
}

if handle.is_versioned():
version = _get_dataset_version(handle)
if version:
# Colab treats version as int in the request
data["version"] = handle.version # type: ignore
data["version"] = version # type: ignore

try:
api_client.post(data, ColabClient.IS_DATASET_SUPPORTED_PATH, handle)
except NotFoundError:
return False
return True

def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False) -> str:
def _resolve(
self, h: DatasetHandle, path: Optional[str] = None, *, force_download: Optional[bool] = False
) -> tuple[str, Optional[int]]:
if force_download:
logger.info(
"Ignoring `force_download` argument when running inside the Colab notebook environment.",
Expand All @@ -122,9 +131,11 @@ def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_downlo
"owner": h.owner,
"dataset": h.dataset,
}
if h.is_versioned():

version = _get_dataset_version(h)
if version:
# Colab treats version as int in the request
data["version"] = h.version # type: ignore
data["version"] = version # type: ignore

response = api_client.post(data, ColabClient.DATASET_MOUNT_PATH, h)

Expand Down Expand Up @@ -154,5 +165,27 @@ def __call__(self, h: DatasetHandle, path: Optional[str] = None, *, force_downlo
f"You can access the other files of the attached dataset at '{cached_path}'"
)
raise ValueError(msg)
return cached_filepath
return cached_path
return cached_filepath, version
return cached_path, version


def _get_model_version(h: ModelHandle) -> Optional[int]:
if h.is_versioned():
return h.version

version_from_package_scope = PackageScope.get_version(h)
if version_from_package_scope is not None:
return version_from_package_scope

return None


def _get_dataset_version(h: DatasetHandle) -> Optional[int]:
if h.is_versioned():
return h.version

version_from_package_scope = PackageScope.get_version(h)
if version_from_package_scope is not None:
return version_from_package_scope

return None
3 changes: 2 additions & 1 deletion src/kagglehub/competition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ def competition_download(handle: str, path: Optional[str] = None, *, force_downl

h = parse_competition_handle(handle)
logger.info(f"Downloading competition: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK})
return registry.competition_resolver(h, path, force_download=force_download)
path, _ = registry.competition_resolver(h, path, force_download=force_download)
return path
3 changes: 2 additions & 1 deletion src/kagglehub/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def dataset_download(handle: str, path: Optional[str] = None, *, force_download:

h = parse_dataset_handle(handle)
logger.info(f"Downloading Dataset: {h.to_url()} ...", extra={**EXTRA_CONSOLE_BLOCK})
return registry.dataset_resolver(h, path, force_download=force_download)
path, _ = registry.dataset_resolver(h, path, force_download=force_download)
return path


def dataset_upload(
Expand Down
31 changes: 26 additions & 5 deletions src/kagglehub/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
NUM_VERSIONED_MODEL_PARTS = 5 # e.g.: <owner>/<model>/<framework>/<variation>/<version>
NUM_UNVERSIONED_MODEL_PARTS = 4 # e.g.: <owner>/<model>/<framework>/<variation>

NUM_VERSIONED_NOTEBOOK_PARTS = 4 # e.g.: <owner>/<notebook>/versions/<version>
NUM_UNVERSIONED_NOTEBOOK_PARTS = 2 # e.g.: <owner>/<notebook>
NUM_VERSIONED_NOTEBOOK_PARTS = 4 # e.g.: <owner>/<notebook>/versions/<version>


@dataclass
@dataclass(frozen=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was frozen=True suggested automatically by the linter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is required to make it a hashable type to use as key in the VersionedDatasources dictionary.

class ResourceHandle:
@abc.abstractmethod
def to_url(self) -> str:
"""Returns URL to the resource detail page."""
pass


@dataclass
@dataclass(frozen=True)
class ModelHandle(ResourceHandle):
owner: str
model: str
Expand All @@ -35,6 +36,11 @@ class ModelHandle(ResourceHandle):
def is_versioned(self) -> bool:
return self.version is not None and self.version > 0

def with_version(self, version: int) -> "ModelHandle":
return ModelHandle(
owner=self.owner, model=self.model, framework=self.framework, variation=self.variation, version=version
)

def __str__(self) -> str:
handle_str = f"{self.owner}/{self.model}/{self.framework}/{self.variation}"
if self.is_versioned():
Expand All @@ -49,7 +55,7 @@ def to_url(self) -> str:
return f"{endpoint}/models/{self.owner}/{self.model}/{self.framework}/{self.variation}"


@dataclass
@dataclass(frozen=True)
class DatasetHandle(ResourceHandle):
owner: str
dataset: str
Expand All @@ -58,6 +64,9 @@ class DatasetHandle(ResourceHandle):
def is_versioned(self) -> bool:
return self.version is not None and self.version > 0

def with_version(self, version: int) -> "DatasetHandle":
return DatasetHandle(owner=self.owner, dataset=self.dataset, version=version)

def __str__(self) -> str:
handle_str = f"{self.owner}/{self.dataset}"
if self.is_versioned():
Expand All @@ -72,7 +81,7 @@ def to_url(self) -> str:
return base_url


@dataclass
@dataclass(frozen=True)
class CompetitionHandle(ResourceHandle):
competition: str

Expand All @@ -86,7 +95,7 @@ def to_url(self) -> str:
return base_url


@dataclass
@dataclass(frozen=True)
class NotebookHandle(ResourceHandle):
owner: str
notebook: str
Expand All @@ -95,6 +104,9 @@ class NotebookHandle(ResourceHandle):
def is_versioned(self) -> bool:
return self.version is not None and self.version > 0

def with_version(self, version: int) -> "NotebookHandle":
return NotebookHandle(owner=self.owner, notebook=self.notebook, version=version)

def __str__(self) -> str:
handle_str = f"{self.owner}/{self.notebook}"
if self.is_versioned():
Expand All @@ -113,6 +125,10 @@ class UtilityScriptHandle(NotebookHandle):
pass


class PackageHandle(NotebookHandle):
pass


def parse_dataset_handle(handle: str) -> DatasetHandle:
parts = handle.split("/")

Expand Down Expand Up @@ -217,3 +233,8 @@ def parse_notebook_handle(handle: str) -> NotebookHandle:
def parse_utility_script_handle(handle: str) -> UtilityScriptHandle:
notebook_handle = parse_notebook_handle(handle)
return UtilityScriptHandle(**asdict(notebook_handle))


def parse_package_handle(handle: str) -> PackageHandle:
notebook_handle = parse_notebook_handle(handle)
return PackageHandle(**asdict(notebook_handle))
Loading