Skip to content

Commit

Permalink
Add 8-Bit models
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed May 13, 2024
1 parent 16c2825 commit 9f45e7d
Show file tree
Hide file tree
Showing 22 changed files with 382 additions and 76 deletions.
15 changes: 9 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ model = ocr_predictor(
resolve_lines=True, # whether words should be automatically grouped into lines (default: True)
resolve_blocks=True, # whether lines should be automatically grouped into blocks (default: True)
paragraph_break=0.035, # relative length of the minimum space separating paragraphs (default: 0.035)
# OnnxTR specific parameters
load_in_8_bit=False, # set to `True` to load 8-bit quantized models instead of the full precision onces (default: False)
)
# PDF
doc = DocumentFile.from_pdf("path/to/your/doc.pdf")
Expand Down Expand Up @@ -170,9 +172,9 @@ predictor.list_archs()
'linknet_resnet18',
'linknet_resnet34',
'linknet_resnet50',
'fast_tiny',
'fast_small',
'fast_base'
'fast_tiny', # No 8-bit support
'fast_small', # No 8-bit support
'fast_base' # No 8-bit support
],
'recognition archs':
[
Expand Down Expand Up @@ -202,21 +204,22 @@ NOTE:
### Benchmarks

The CPU benchmarks was measured on a `i7-14700K Intel CPU`.

The GPU benchmarks was measured on a `RTX 4080 Nvidia GPU`.

Benchmarking performed on the FUNSD dataset and the CORD dataset.
Benchmarking performed on the FUNSD dataset and CORD dataset.

docTR / OnnxTR models used for the benchmarks are `fast_base` for detection and `crnn_vgg16_bn` for recognition.

The smallest combination in OnnxTR (docTR) of `db_mobilenet_v3_large` and `crnn_mobilenet_v3_small` takes as comparison `~0.17s / Page` on the FUNSD dataset and `~0.12s / Page` on the CORD dataset.
The smallest combination in OnnxTR (docTR) of `db_mobilenet_v3_large` and `crnn_mobilenet_v3_small` takes as comparison `~0.17s / Page` on the FUNSD dataset and `~0.12s / Page` on the CORD dataset in **full precision**.

- CPU benchmarks:

|Library |FUNSD (199 pages) |CORD (900 pages) |
|--------------------------------|-------------------------------|-------------------------------|
|docTR (CPU) - v0.8.1 | ~1.29s / Page | ~0.60s / Page |
|**OnnxTR (CPU)** - v0.1.2 | ~0.57s / Page | **~0.25s / Page** |
|OnnxTR (CPU) 8-bit - v0.1.2 | in progress | in progress |
|**OnnxTR (CPU) 8-bit** - v0.1.2 | **~0.38s / Page** | **~0.14s / Page** |
|EasyOCR (CPU) - v1.7.1 | ~1.96s / Page | ~1.75s / Page |
|**PyTesseract (CPU)** - v0.3.10 | **~0.50s / Page** | ~0.52s / Page |
|Surya (line) (CPU) - v0.4.4 | ~48.76s / Page | ~35.49s / Page |
Expand Down
19 changes: 15 additions & 4 deletions onnxtr/models/classification/models/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
"input_shape": (3, 256, 256),
"classes": [0, -90, 180, 90],
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/mobilenet_v3_small_crop_orientation-5620cf7e.onnx",
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/mobilenet_v3_small_crop_orientation_static_8_bit-4cfaa621.onnx",
},
"mobilenet_v3_small_page_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 512, 512),
"classes": [0, -90, 180, 90],
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/mobilenet_v3_small_page_orientation-d3f76d79.onnx",
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/mobilenet_v3_small_page_orientation_static_8_bit-3e5ef3dc.onnx",
},
}

Expand Down Expand Up @@ -64,14 +66,19 @@ def __call__(
def _mobilenet_v3(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
**kwargs: Any,
) -> MobileNetV3:
# Patch the url
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
_cfg = deepcopy(default_cfgs[arch])
return MobileNetV3(model_path, cfg=_cfg, **kwargs)


def mobilenet_v3_small_crop_orientation(
model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"], **kwargs: Any
model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"],
load_in_8_bit: bool = False,
**kwargs: Any,
) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
Expand All @@ -86,17 +93,20 @@ def mobilenet_v3_small_crop_orientation(
Args:
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
-------
MobileNetV3
"""
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, **kwargs)
return _mobilenet_v3("mobilenet_v3_small_crop_orientation", model_path, load_in_8_bit, **kwargs)


def mobilenet_v3_small_page_orientation(
model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"], **kwargs: Any
model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"],
load_in_8_bit: bool = False,
**kwargs: Any,
) -> MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
Expand All @@ -111,10 +121,11 @@ def mobilenet_v3_small_page_orientation(
Args:
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
-------
MobileNetV3
"""
return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, **kwargs)
return _mobilenet_v3("mobilenet_v3_small_page_orientation", model_path, load_in_8_bit, **kwargs)
1 change: 1 addition & 0 deletions onnxtr/models/classification/predictor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class OrientationPredictor(NestedObject):
----
pre_processor: transform inputs for easier batched model inference
model: core classification architecture (backbone + classification head)
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
"""

_children_names: List[str] = ["pre_processor", "model"]
Expand Down
17 changes: 10 additions & 7 deletions onnxtr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,25 @@
ORIENTATION_ARCHS: List[str] = ["mobilenet_v3_small_crop_orientation", "mobilenet_v3_small_page_orientation"]


def _orientation_predictor(arch: str, **kwargs: Any) -> OrientationPredictor:
def _orientation_predictor(arch: str, load_in_8_bit: bool = False, **kwargs: Any) -> OrientationPredictor:
if arch not in ORIENTATION_ARCHS:
raise ValueError(f"unknown architecture '{arch}'")

# Load directly classifier from backbone
_model = classification.__dict__[arch]()
_model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit)
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
input_shape = _model.cfg["input_shape"][1:]
predictor = OrientationPredictor(
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs),
_model,
)
return predictor


def crop_orientation_predictor(
arch: Any = "mobilenet_v3_small_crop_orientation", **kwargs: Any
arch: Any = "mobilenet_v3_small_crop_orientation", load_in_8_bit: bool = False, **kwargs: Any
) -> OrientationPredictor:
"""Crop orientation classification architecture.
Expand All @@ -44,17 +45,18 @@ def crop_orientation_predictor(
Args:
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
load_in_8_bit: load the 8-bit quantized version of the model
**kwargs: keyword arguments to be passed to the OrientationPredictor
Returns:
-------
OrientationPredictor
"""
return _orientation_predictor(arch, **kwargs)
return _orientation_predictor(arch, load_in_8_bit, **kwargs)


def page_orientation_predictor(
arch: Any = "mobilenet_v3_small_page_orientation", **kwargs: Any
arch: Any = "mobilenet_v3_small_page_orientation", load_in_8_bit: bool = False, **kwargs: Any
) -> OrientationPredictor:
"""Page orientation classification architecture.
Expand All @@ -67,10 +69,11 @@ def page_orientation_predictor(
Args:
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
**kwargs: keyword arguments to be passed to the OrientationPredictor
Returns:
-------
OrientationPredictor
"""
return _orientation_predictor(arch, **kwargs)
return _orientation_predictor(arch, load_in_8_bit, **kwargs)
27 changes: 21 additions & 6 deletions onnxtr/models/detection/models/differentiable_binarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,21 @@
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_resnet50-69ba0015.onnx",
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_resnet50_static_8_bit-09a6104f.onnx",
},
"db_resnet34": {
"input_shape": (3, 1024, 1024),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_resnet34-b4873198.onnx",
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_resnet34_static_8_bit-027e2c7f.onnx",
},
"db_mobilenet_v3_large": {
"input_shape": (3, 1024, 1024),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.0.1/db_mobilenet_v3_large-1866973f.onnx",
"url_8_bit": "https://github.com/felixdittrich92/OnnxTR/releases/download/v0.1.2/db_mobilenet_v3_large_static_8_bit-51659bb9.onnx",
},
}

Expand Down Expand Up @@ -87,13 +90,18 @@ def __call__(
def _dbnet(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
**kwargs: Any,
) -> DBNet:
# Patch the url
model_path = default_cfgs[arch]["url_8_bit"] if load_in_8_bit and "http" in model_path else model_path
# Build the model
return DBNet(model_path, cfg=default_cfgs[arch], **kwargs)


def db_resnet34(model_path: str = default_cfgs["db_resnet34"]["url"], **kwargs: Any) -> DBNet:
def db_resnet34(
model_path: str = default_cfgs["db_resnet34"]["url"], load_in_8_bit: bool = False, **kwargs: Any
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-34 backbone.
Expand All @@ -106,16 +114,19 @@ def db_resnet34(model_path: str = default_cfgs["db_resnet34"]["url"], **kwargs:
Args:
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
return _dbnet("db_resnet34", model_path, **kwargs)
return _dbnet("db_resnet34", model_path, load_in_8_bit, **kwargs)


def db_resnet50(model_path: str = default_cfgs["db_resnet50"]["url"], **kwargs: Any) -> DBNet:
def db_resnet50(
model_path: str = default_cfgs["db_resnet50"]["url"], load_in_8_bit: bool = False, **kwargs: Any
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a ResNet-50 backbone.
Expand All @@ -128,16 +139,19 @@ def db_resnet50(model_path: str = default_cfgs["db_resnet50"]["url"], **kwargs:
Args:
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
return _dbnet("db_resnet50", model_path, **kwargs)
return _dbnet("db_resnet50", model_path, load_in_8_bit, **kwargs)


def db_mobilenet_v3_large(model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], **kwargs: Any) -> DBNet:
def db_mobilenet_v3_large(
model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], load_in_8_bit: bool = False, **kwargs: Any
) -> DBNet:
"""DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization"
<https://arxiv.org/pdf/1911.08947.pdf>`_, using a MobileNet V3 Large backbone.
Expand All @@ -150,10 +164,11 @@ def db_mobilenet_v3_large(model_path: str = default_cfgs["db_mobilenet_v3_large"
Args:
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
return _dbnet("db_mobilenet_v3_large", model_path, **kwargs)
return _dbnet("db_mobilenet_v3_large", model_path, load_in_8_bit, **kwargs)
19 changes: 13 additions & 6 deletions onnxtr/models/detection/models/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

import logging
from typing import Any, Dict, Optional

import numpy as np
Expand Down Expand Up @@ -88,13 +89,16 @@ def __call__(
def _fast(
arch: str,
model_path: str,
load_in_8_bit: bool = False,
**kwargs: Any,
) -> FAST:
if load_in_8_bit:
logging.warning("FAST models do not support 8-bit quantization yet. Loading full precision model...")
# Build the model
return FAST(model_path, cfg=default_cfgs[arch], **kwargs)


def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], **kwargs: Any) -> FAST:
def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST:
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone.
Expand All @@ -107,16 +111,17 @@ def fast_tiny(model_path: str = default_cfgs["fast_tiny"]["url"], **kwargs: Any)
Args:
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
return _fast("fast_tiny", model_path, **kwargs)
return _fast("fast_tiny", model_path, load_in_8_bit, **kwargs)


def fast_small(model_path: str = default_cfgs["fast_small"]["url"], **kwargs: Any) -> FAST:
def fast_small(model_path: str = default_cfgs["fast_small"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST:
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone.
Expand All @@ -129,16 +134,17 @@ def fast_small(model_path: str = default_cfgs["fast_small"]["url"], **kwargs: An
Args:
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
return _fast("fast_small", model_path, **kwargs)
return _fast("fast_small", model_path, load_in_8_bit, **kwargs)


def fast_base(model_path: str = default_cfgs["fast_base"]["url"], **kwargs: Any) -> FAST:
def fast_base(model_path: str = default_cfgs["fast_base"]["url"], load_in_8_bit: bool = False, **kwargs: Any) -> FAST:
"""FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
<https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone.
Expand All @@ -151,10 +157,11 @@ def fast_base(model_path: str = default_cfgs["fast_base"]["url"], **kwargs: Any)
Args:
----
model_path: path to onnx model file, defaults to url in default_cfgs
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
**kwargs: keyword arguments of the DBNet architecture
Returns:
-------
text detection architecture
"""
return _fast("fast_base", model_path, **kwargs)
return _fast("fast_base", model_path, load_in_8_bit, **kwargs)
Loading

0 comments on commit 9f45e7d

Please sign in to comment.