Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HuggingFace] Add HF hub integration #27

Merged
merged 8 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ requirements:
- shapely >=1.6.0, <3.0.0
- langdetect >=1.0.9, <2.0.0
- rapidfuzz >=3.0.0, <4.0.0
- huggingface_hub >=0.20.0, <1.0.0
- defusedxml >=0.7.0
- anyascii >=0.3.2
- tqdm >=4.30.0
Expand Down
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,39 @@
model = ocr_predictor(det_arch=det_model, reco_arch=reco_model)
```

<details>
<summary>Loading models from HuggingFace Hub</summary>

Check notice on line 185 in README.md

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

README.md#L185

Element: summary

You can also load models from the HuggingFace Hub:

```python
from onnxtr.io import DocumentFile
from onnxtr.models import ocr_predictor, from_hub

img = DocumentFile.from_images(['<image_path>'])
# Load your model from the hub
model = from_hub('onnxtr/my-model')

# Pass it to the predictor
# If your model is a recognition model:
predictor = ocr_predictor(det_arch='db_mobilenet_v3_large',
reco_arch=model)

# If your model is a detection model:
predictor = ocr_predictor(det_arch=model,
reco_arch='crnn_mobilenet_v3_small')

# Get your predictions
res = predictor(img)
```

You can find the available models on the HuggingFace Hub [here](https://huggingface.co/models?search=onnxtr).
[Collection](https://huggingface.co/collections/Felix92/onnxtr-66bf213a9f88f7346c90e842)

[multilingualParseq](https://huggingface.co/Felix92/onnxtr-parseq-multilingual-v1)

</details>

## Models architectures

Credits where it's due: this repository provides ONNX models for the following architectures, converted from the docTR models:
Expand Down
1 change: 1 addition & 0 deletions onnxtr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .detection import *
from .recognition import *
from .zoo import *
from .factory import *

Check notice on line 6 in onnxtr/models/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

onnxtr/models/__init__.py#L6

'.factory.*' imported but unused (F401)
1 change: 1 addition & 0 deletions onnxtr/models/classification/models/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)

self.cfg = cfg

def __call__(
Expand Down
2 changes: 2 additions & 0 deletions onnxtr/models/detection/models/differentiable_binarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)

self.cfg = cfg
self.assume_straight_pages = assume_straight_pages

self.postprocessor = GeneralDetectionPostProcessor(
assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
)
Expand Down
1 change: 1 addition & 0 deletions onnxtr/models/detection/models/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)

self.cfg = cfg
self.assume_straight_pages = assume_straight_pages

Expand Down
1 change: 1 addition & 0 deletions onnxtr/models/detection/models/linknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)

self.cfg = cfg
self.assume_straight_pages = assume_straight_pages

Expand Down
2 changes: 2 additions & 0 deletions onnxtr/models/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class Engine:
def __init__(self, url: str, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any) -> None:
engine_cfg = engine_cfg if isinstance(engine_cfg, EngineConfig) else EngineConfig()
archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url
# Store model path for each model
self.model_path = archive_path
self.session_options = engine_cfg.session_options
self.providers = engine_cfg.providers
self.runtime = InferenceSession(archive_path, providers=self.providers, sess_options=self.session_options)
Expand Down
1 change: 1 addition & 0 deletions onnxtr/models/factory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .hub import *

Check notice on line 1 in onnxtr/models/factory/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

onnxtr/models/factory/__init__.py#L1

'.hub.*' imported but unused (F401)
224 changes: 224 additions & 0 deletions onnxtr/models/factory/hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# Copyright (C) 2021-2024, Mindee | Felix Dittrich.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

# Inspired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/hub.py

import json
import logging
import os
import shutil
import subprocess

Check warning on line 12 in onnxtr/models/factory/hub.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

onnxtr/models/factory/hub.py#L12

Consider possible security implications associated with the subprocess module.
import textwrap
from pathlib import Path
from typing import Any, Optional

from huggingface_hub import (
HfApi,
Repository,
get_token,
get_token_permission,
hf_hub_download,
login,
)

from onnxtr import models
from onnxtr.models.engine import EngineConfig

__all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]


AVAILABLE_ARCHS = {
"classification": models.classification.zoo.ORIENTATION_ARCHS,
"detection": models.detection.zoo.ARCHS,
"recognition": models.recognition.zoo.ARCHS,
}


def login_to_hub() -> None: # pragma: no cover
"""Login to huggingface hub"""
access_token = get_token()
if access_token is not None and get_token_permission(access_token):
logging.info("Huggingface Hub token found and valid")
login(token=access_token, write_permission=True)
else:
login()
# check if git lfs is installed
try:
subprocess.call(["git", "lfs", "version"])

Check failure on line 49 in onnxtr/models/factory/hub.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

onnxtr/models/factory/hub.py#L49

Python possesses many mechanisms to invoke an external executable.

Check warning on line 49 in onnxtr/models/factory/hub.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

onnxtr/models/factory/hub.py#L49

subprocess call - check for execution of untrusted input.
except FileNotFoundError:
raise OSError(
"Looks like you do not have git-lfs installed, please install. \
You can install from https://git-lfs.github.com/. \
Then run `git lfs install` (you only have to do this once)."
)


def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task: str) -> None:
"""Save model and config to disk for pushing to huggingface hub

Args:
----
model: Onnx model to be saved
save_dir: directory to save model and config
arch: architecture name
task: task name
"""
save_directory = Path(save_dir)
shutil.copy2(model.model_path, save_directory / "model.onnx")

config_path = save_directory / "config.json"

# add model configuration
model_config = model.cfg
model_config["arch"] = arch
model_config["task"] = task

with config_path.open("w") as f:
json.dump(model_config, f, indent=2, ensure_ascii=False)


def push_to_hf_hub(
model: Any, model_name: str, task: str, override: bool = False, **kwargs
) -> None: # pragma: no cover
"""Save model and its configuration on HF hub

>>> from onnxtr.models import login_to_hub, push_to_hf_hub
>>> from onnxtr.models.recognition import crnn_mobilenet_v3_small
>>> login_to_hub()
>>> model = crnn_mobilenet_v3_small()
>>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')

Args:
----
model: Onnx model to be saved
model_name: name of the model which is also the repository name
task: task name
override: whether to override the existing model / repo on HF hub
**kwargs: keyword arguments for push_to_hf_hub
"""
run_config = kwargs.get("run_config", None)
arch = kwargs.get("arch", None)

if run_config is None and arch is None:
raise ValueError("run_config or arch must be specified")
if task not in ["classification", "detection", "recognition"]:
raise ValueError("task must be one of classification, detection, recognition")

# default readme
readme = textwrap.dedent(
f"""
---
language:
- en
- fr
license: apache-2.0
---

<p align="center">
<img src="https://github.com/felixdittrich92/OnnxTR/raw/main/docs/images/logo.jpg" width="40%">
</p>

**Optical Character Recognition made seamless & accessible to anyone, powered by Onnxruntime**

## Task: {task}

https://github.com/felixdittrich92/OnnxTR

### Example usage:

```python
>>> from onnxtr.io import DocumentFile
>>> from onnxtr.models import ocr_predictor, from_hub

>>> img = DocumentFile.from_images(['<image_path>'])
>>> # Load your model from the hub
>>> model = from_hub('onnxtr/my-model')

>>> # Pass it to the predictor
>>> # If your model is a recognition model:
>>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large',
>>> reco_arch=model)

>>> # If your model is a detection model:
>>> predictor = ocr_predictor(det_arch=model,
>>> reco_arch='crnn_mobilenet_v3_small')

>>> # Get your predictions
>>> res = predictor(img)
```
"""
)

# add run configuration to readme if available
if run_config is not None:
arch = run_config.arch
readme += textwrap.dedent(
f"""### Run Configuration
\n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}"""
)

if arch not in AVAILABLE_ARCHS[task]:
raise ValueError(
f"Architecture: {arch} for task: {task} not found.\
\nAvailable architectures: {AVAILABLE_ARCHS}"
)

commit_message = f"Add {model_name} model"

local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name)
repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=override)
repo = Repository(local_dir=local_cache_dir, clone_from=repo_url)

with repo.commit(commit_message):
_save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task)
readme_path = Path(repo.local_dir) / "README.md"
readme_path.write_text(readme)

repo.git_push()


def from_hub(repo_id: str, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any):
"""Instantiate & load a pretrained model from HF hub.

>>> from onnxtr.models import from_hub
>>> model = from_hub("onnxtr/my-model")

Args:
----
repo_id: HuggingFace model hub repo
engine_cfg: configuration for the inference engine (optional)
kwargs: kwargs of `hf_hub_download`

Returns:
-------
Model loaded with the checkpoint
"""
# Get the config
with open(hf_hub_download(repo_id, filename="config.json", **kwargs), "rb") as f:
cfg = json.load(f)
model_path = hf_hub_download(repo_id, filename="model.onnx", **kwargs)

arch = cfg["arch"]
task = cfg["task"]
cfg.pop("arch")
cfg.pop("task")

if task == "classification":
model = models.classification.__dict__[arch](model_path, classes=cfg["classes"], engine_cfg=engine_cfg)
elif task == "detection":
model = models.detection.__dict__[arch](model_path, engine_cfg=engine_cfg)
elif task == "recognition":
model = models.recognition.__dict__[arch](
model_path, input_shape=cfg["input_shape"], vocab=cfg["vocab"], engine_cfg=engine_cfg
)

# convert all values which are lists to tuples
for key, value in cfg.items():
if isinstance(value, list):
cfg[key] = tuple(value)
# update model cfg
model.cfg = cfg

return model
2 changes: 2 additions & 0 deletions onnxtr/models/recognition/models/crnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,10 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)

self.vocab = vocab
self.cfg = cfg

self.postprocessor = CRNNPostProcessor(self.vocab)

def __call__(
Expand Down
1 change: 1 addition & 0 deletions onnxtr/models/recognition/models/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(

self.vocab = vocab
self.cfg = cfg

self.postprocessor = MASTERPostProcessor(vocab=self.vocab)

def __call__(
Expand Down
2 changes: 2 additions & 0 deletions onnxtr/models/recognition/models/parseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)

self.vocab = vocab
self.cfg = cfg

self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)

def __call__(
Expand Down
2 changes: 2 additions & 0 deletions onnxtr/models/recognition/models/sar.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)

self.vocab = vocab
self.cfg = cfg

self.postprocessor = SARPostProcessor(self.vocab)

def __call__(
Expand Down
1 change: 1 addition & 0 deletions onnxtr/models/recognition/models/vitstr.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(url=model_path, engine_cfg=engine_cfg, **kwargs)

self.vocab = vocab
self.cfg = cfg

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies = [
"shapely>=1.6.0,<3.0.0",
"rapidfuzz>=3.0.0,<4.0.0",
"langdetect>=1.0.9,<2.0.0",
"huggingface-hub>=0.23.0,<1.0.0",
"Pillow>=9.2.0",
"defusedxml>=0.7.0",
"anyascii>=0.3.2",
Expand Down Expand Up @@ -126,6 +127,7 @@ module = [
"weasyprint.*",
"pypdfium2.*",
"langdetect.*",
"huggingface_hub.*",
"rapidfuzz.*",
"anyascii.*",
"tqdm.*",
Expand Down
Loading
Loading