From 184f44782cb5c970048f8465f6bda7a2f846cb05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 18 Apr 2024 09:25:52 +0200 Subject: [PATCH] Update for Curated Transformers 2.0 API changes (#37) --- requirements.txt | 4 ++-- setup.cfg | 4 ++-- spacy_curated_transformers/models/architectures.py | 4 ++-- spacy_curated_transformers/models/hf_loader.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0e3c0a2..a6b2919 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -curated-transformers>=2.0.0.dev3,<3.0.0 -curated-tokenizers>=2.0.0.dev0,<3.0.0 +curated-transformers>=2.0.0,<3.0.0 +curated-tokenizers>=2.0.0,<3.0.0 fsspec>=2023.5.0 spacy>=4.0.0.dev2,<5.0.0 thinc>=9.0.0.dev4,<9.1.0 diff --git a/setup.cfg b/setup.cfg index 047f92e..291e53e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,8 +14,8 @@ zip_safe = true include_package_data = true python_requires = >=3.9 install_requires = - curated-transformers>=2.0.0.dev3,<3.0.0 - curated-tokenizers>=2.0.0.dev0,<3.0.0 + curated-transformers>=2.0.0,<3.0.0 + curated-tokenizers>=2.0.0,<3.0.0 fsspec>=2023.5.0 spacy>=4.0.0.dev2,<5.0.0 thinc>=9.0.0.dev4,<9.1.0 diff --git a/spacy_curated_transformers/models/architectures.py b/spacy_curated_transformers/models/architectures.py index 2345991..b4e9b14 100644 --- a/spacy_curated_transformers/models/architectures.py +++ b/spacy_curated_transformers/models/architectures.py @@ -12,7 +12,7 @@ BERTEncoder, CamemBERTEncoder, EncoderModule, - FromHFHub, + FromHF, ModelOutput, RoBERTaConfig, RoBERTaEncoder, @@ -1313,7 +1313,7 @@ def build_pytorch_checkpoint_loader_v2(*, path: Path) -> Callable[ def load(model, X=None, Y=None): device = get_torch_default_device() encoder = model.shims[0]._model.curated_encoder - assert isinstance(encoder, FromHFHub) + assert isinstance(encoder, FromHF) fs = LocalFileSystem() encoder.from_fsspec_(fs=fs, model_path=path, device=device) return model diff --git a/spacy_curated_transformers/models/hf_loader.py b/spacy_curated_transformers/models/hf_loader.py index 0748c15..ee4c0fd 100644 --- a/spacy_curated_transformers/models/hf_loader.py +++ b/spacy_curated_transformers/models/hf_loader.py @@ -1,6 +1,6 @@ from typing import Callable, List, Optional -from curated_transformers.models import FromHFHub +from curated_transformers.models import FromHF from spacy.tokens import Doc from .types import TorchTransformerModelT @@ -25,7 +25,7 @@ def build_hf_transformer_encoder_loader_v1( def load(model, X=None, Y=None): encoder = model.shims[0]._model.curated_encoder - assert isinstance(encoder, FromHFHub) + assert isinstance(encoder, FromHF) device = model.shims[0].device encoder.from_hf_hub_(name=name, revision=revision, device=device) return model