From 890ae43bf83825cc49c92eed59ae445c8c337d7f Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Fri, 28 Jun 2024 08:14:10 +0200 Subject: [PATCH] [Fix] fix default cuda config (#21) --- README.md | 10 +++++----- onnxtr/contrib/base.py | 5 +---- .../models/classification/models/mobilenet.py | 8 ++++---- onnxtr/models/classification/zoo.py | 10 +++++----- .../models/differentiable_binarization.py | 10 +++++----- onnxtr/models/detection/models/fast.py | 10 +++++----- onnxtr/models/detection/models/linknet.py | 10 +++++----- onnxtr/models/detection/zoo.py | 8 ++++---- onnxtr/models/engine.py | 6 +++--- onnxtr/models/predictor/base.py | 2 +- onnxtr/models/predictor/predictor.py | 4 ++-- onnxtr/models/recognition/models/crnn.py | 10 +++++----- onnxtr/models/recognition/models/master.py | 6 +++--- onnxtr/models/recognition/models/parseq.py | 6 +++--- onnxtr/models/recognition/models/sar.py | 6 +++--- onnxtr/models/recognition/models/vitstr.py | 8 ++++---- onnxtr/models/recognition/zoo.py | 6 +++--- onnxtr/models/zoo.py | 18 +++++++++--------- 18 files changed, 70 insertions(+), 73 deletions(-) diff --git a/README.md b/README.md index 895281f..53c5d9e 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![codecov](https://codecov.io/gh/felixdittrich92/OnnxTR/graph/badge.svg?token=WVFRCQBOLI)](https://codecov.io/gh/felixdittrich92/OnnxTR) [![Codacy Badge](https://app.codacy.com/project/badge/Grade/4fff4d764bb14fb8b4f4afeb9587231b)](https://app.codacy.com/gh/felixdittrich92/OnnxTR/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade) [![CodeFactor](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr/badge)](https://www.codefactor.io/repository/github/felixdittrich92/onnxtr) -[![Pypi](https://img.shields.io/badge/pypi-v0.3.0-blue.svg)](https://pypi.org/project/OnnxTR/) +[![Pypi](https://img.shields.io/badge/pypi-v0.3.1-blue.svg)](https://pypi.org/project/OnnxTR/) > :warning: Please note that this is a wrapper around the [doctr](https://github.com/mindee/doctr) library to provide a Onnx pipeline for docTR. For feature requests, which are not directly related to the Onnx pipeline, please refer to the base project. @@ -77,8 +77,8 @@ from onnxtr.models import ocr_predictor, EngineConfig model = ocr_predictor( det_arch='fast_base', # detection architecture reco_arch='vitstr_base', # recognition architecture - det_bs=4, # detection batch size - reco_bs=1024, # recognition batch size + det_bs=2, # detection batch size + reco_bs=512, # recognition batch size assume_straight_pages=True, # set to `False` if the pages are not straight (rotation, perspective, etc.) (default: True) straighten_pages=False, # set to `True` if the pages should be straightened before final processing (default: False) # Preprocessing related parameters @@ -151,7 +151,7 @@ general_options.enable_cpu_mem_arena = False # NOTE: The following would force to run only on the GPU if no GPU is available it will raise an error # List of strings e.g. ["CUDAExecutionProvider", "CPUExecutionProvider"] or a list of tuples with the provider and its options e.g. # [("CUDAExecutionProvider", {"device_id": 0}), ("CPUExecutionProvider", {"arena_extend_strategy": "kSameAsRequested"})] -providers = [("CUDAExecutionProvider", {"device_id": 0})] # For available providers see: https://onnxruntime.ai/docs/execution-providers/ +providers = [("CUDAExecutionProvider", {"device_id": 0, "cudnn_conv_algo_search": "DEFAULT"})] # For available providers see: https://onnxruntime.ai/docs/execution-providers/ engine_config = EngineConfig( session_options=general_options, @@ -183,7 +183,7 @@ model = ocr_predictor(det_arch=det_model, reco_arch=reco_model) ## Models architectures -Credits where it's due: this repository is implementing, among others, architectures from published research papers. +Credits where it's due: this repository provides ONNX models for the following architectures, converted from the docTR models: ### Text Detection diff --git a/onnxtr/contrib/base.py b/onnxtr/contrib/base.py index 08eb449..8990bad 100644 --- a/onnxtr/contrib/base.py +++ b/onnxtr/contrib/base.py @@ -6,8 +6,8 @@ from typing import Any, List, Optional import numpy as np +import onnxruntime as ort -from onnxtr.file_utils import requires_package from onnxtr.utils.data import download_from_url @@ -44,9 +44,6 @@ def _init_model(self, url: Optional[str] = None, model_path: Optional[str] = Non ------- Any: the ONNX loaded model """ - requires_package("onnxruntime", "`.contrib` module requires `onnxruntime` to be installed.") - import onnxruntime as ort - if not url and not model_path: raise ValueError("You must provide either a url or a model_path") onnx_model_path = model_path if model_path else str(download_from_url(url, cache_subdir="models", **kwargs)) # type: ignore[arg-type] diff --git a/onnxtr/models/classification/models/mobilenet.py b/onnxtr/models/classification/models/mobilenet.py index 583f146..fbd94d2 100644 --- a/onnxtr/models/classification/models/mobilenet.py +++ b/onnxtr/models/classification/models/mobilenet.py @@ -51,7 +51,7 @@ class MobileNetV3(Engine): def __init__( self, model_path: str, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: @@ -69,7 +69,7 @@ def _mobilenet_v3( arch: str, model_path: str, load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> MobileNetV3: # Patch the url @@ -81,7 +81,7 @@ def _mobilenet_v3( def mobilenet_v3_small_crop_orientation( model_path: str = default_cfgs["mobilenet_v3_small_crop_orientation"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> MobileNetV3: """MobileNetV3-Small architecture as described in @@ -111,7 +111,7 @@ def mobilenet_v3_small_crop_orientation( def mobilenet_v3_small_page_orientation( model_path: str = default_cfgs["mobilenet_v3_small_page_orientation"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> MobileNetV3: """MobileNetV3-Small architecture as described in diff --git a/onnxtr/models/classification/zoo.py b/onnxtr/models/classification/zoo.py index ad0b50b..cf8b8be 100644 --- a/onnxtr/models/classification/zoo.py +++ b/onnxtr/models/classification/zoo.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, List +from typing import Any, List, Optional from onnxtr.models.engine import EngineConfig @@ -17,7 +17,7 @@ def _orientation_predictor( - arch: str, load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any + arch: str, load_in_8_bit: bool = False, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any ) -> OrientationPredictor: if arch not in ORIENTATION_ARCHS: raise ValueError(f"unknown architecture '{arch}'") @@ -26,7 +26,7 @@ def _orientation_predictor( _model = classification.__dict__[arch](load_in_8_bit=load_in_8_bit, engine_cfg=engine_cfg) kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) - kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4) + kwargs["batch_size"] = kwargs.get("batch_size", 512 if "crop" in arch else 2) input_shape = _model.cfg["input_shape"][1:] predictor = OrientationPredictor( PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), @@ -38,7 +38,7 @@ def _orientation_predictor( def crop_orientation_predictor( arch: Any = "mobilenet_v3_small_crop_orientation", load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> OrientationPredictor: """Crop orientation classification architecture. @@ -66,7 +66,7 @@ def crop_orientation_predictor( def page_orientation_predictor( arch: Any = "mobilenet_v3_small_page_orientation", load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> OrientationPredictor: """Page orientation classification architecture. diff --git a/onnxtr/models/detection/models/differentiable_binarization.py b/onnxtr/models/detection/models/differentiable_binarization.py index 1747bcb..af774bc 100644 --- a/onnxtr/models/detection/models/differentiable_binarization.py +++ b/onnxtr/models/detection/models/differentiable_binarization.py @@ -56,7 +56,7 @@ class DBNet(Engine): def __init__( self, model_path: str, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, bin_thresh: float = 0.3, box_thresh: float = 0.1, assume_straight_pages: bool = True, @@ -93,7 +93,7 @@ def _dbnet( arch: str, model_path: str, load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> DBNet: # Patch the url @@ -105,7 +105,7 @@ def _dbnet( def db_resnet34( model_path: str = default_cfgs["db_resnet34"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> DBNet: """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" @@ -134,7 +134,7 @@ def db_resnet34( def db_resnet50( model_path: str = default_cfgs["db_resnet50"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> DBNet: """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" @@ -163,7 +163,7 @@ def db_resnet50( def db_mobilenet_v3_large( model_path: str = default_cfgs["db_mobilenet_v3_large"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> DBNet: """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" diff --git a/onnxtr/models/detection/models/fast.py b/onnxtr/models/detection/models/fast.py index d8ee844..527ecb6 100644 --- a/onnxtr/models/detection/models/fast.py +++ b/onnxtr/models/detection/models/fast.py @@ -54,7 +54,7 @@ class FAST(Engine): def __init__( self, model_path: str, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, bin_thresh: float = 0.1, box_thresh: float = 0.1, assume_straight_pages: bool = True, @@ -92,7 +92,7 @@ def _fast( arch: str, model_path: str, load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> FAST: if load_in_8_bit: @@ -104,7 +104,7 @@ def _fast( def fast_tiny( model_path: str = default_cfgs["fast_tiny"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> FAST: """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" @@ -133,7 +133,7 @@ def fast_tiny( def fast_small( model_path: str = default_cfgs["fast_small"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> FAST: """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" @@ -162,7 +162,7 @@ def fast_small( def fast_base( model_path: str = default_cfgs["fast_base"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> FAST: """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" diff --git a/onnxtr/models/detection/models/linknet.py b/onnxtr/models/detection/models/linknet.py index 852d1be..bc24f68 100644 --- a/onnxtr/models/detection/models/linknet.py +++ b/onnxtr/models/detection/models/linknet.py @@ -56,7 +56,7 @@ class LinkNet(Engine): def __init__( self, model_path: str, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, bin_thresh: float = 0.1, box_thresh: float = 0.1, assume_straight_pages: bool = True, @@ -94,7 +94,7 @@ def _linknet( arch: str, model_path: str, load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> LinkNet: # Patch the url @@ -106,7 +106,7 @@ def _linknet( def linknet_resnet18( model_path: str = default_cfgs["linknet_resnet18"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> LinkNet: """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" @@ -135,7 +135,7 @@ def linknet_resnet18( def linknet_resnet34( model_path: str = default_cfgs["linknet_resnet34"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> LinkNet: """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" @@ -164,7 +164,7 @@ def linknet_resnet34( def linknet_resnet50( model_path: str = default_cfgs["linknet_resnet50"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> LinkNet: """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" diff --git a/onnxtr/models/detection/zoo.py b/onnxtr/models/detection/zoo.py index cda0eed..ca42cfe 100644 --- a/onnxtr/models/detection/zoo.py +++ b/onnxtr/models/detection/zoo.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any +from typing import Any, Optional from .. import detection from ..engine import EngineConfig @@ -29,7 +29,7 @@ def _predictor( arch: Any, assume_straight_pages: bool = True, load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> DetectionPredictor: if isinstance(arch, str): @@ -48,7 +48,7 @@ def _predictor( kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) - kwargs["batch_size"] = kwargs.get("batch_size", 4) + kwargs["batch_size"] = kwargs.get("batch_size", 2) predictor = DetectionPredictor( PreProcessor(_model.cfg["input_shape"][1:], **kwargs), _model, @@ -60,7 +60,7 @@ def detection_predictor( arch: Any = "fast_base", assume_straight_pages: bool = True, load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> DetectionPredictor: """Text detection architecture. diff --git a/onnxtr/models/engine.py b/onnxtr/models/engine.py index b8035aa..215d4bf 100644 --- a/onnxtr/models/engine.py +++ b/onnxtr/models/engine.py @@ -49,7 +49,7 @@ def _init_providers(self) -> List[Tuple[str, Dict[str, Any]]]: { "device_id": 0, "arena_extend_strategy": "kNextPowerOfTwo", - "cudnn_conv_algo_search": "EXHAUSTIVE", + "cudnn_conv_algo_search": "DEFAULT", "do_copy_in_default_stream": True, }, ), @@ -87,8 +87,8 @@ class Engine: **kwargs: additional arguments to be passed to `download_from_url` """ - def __init__(self, url: str, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any) -> None: - engine_cfg = engine_cfg or EngineConfig() + def __init__(self, url: str, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any) -> None: + engine_cfg = engine_cfg if isinstance(engine_cfg, EngineConfig) else EngineConfig() archive_path = download_from_url(url, cache_subdir="models", **kwargs) if "http" in url else url self.session_options = engine_cfg.session_options self.providers = engine_cfg.providers diff --git a/onnxtr/models/predictor/base.py b/onnxtr/models/predictor/base.py index d31deb7..bb2ff55 100644 --- a/onnxtr/models/predictor/base.py +++ b/onnxtr/models/predictor/base.py @@ -50,7 +50,7 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, load_in_8_bit: bool = False, - clf_engine_cfg: EngineConfig = EngineConfig(), + clf_engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> None: self.assume_straight_pages = assume_straight_pages diff --git a/onnxtr/models/predictor/predictor.py b/onnxtr/models/predictor/predictor.py index 20c8b79..2fb5964 100644 --- a/onnxtr/models/predictor/predictor.py +++ b/onnxtr/models/predictor/predictor.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, List +from typing import Any, List, Optional import numpy as np @@ -52,7 +52,7 @@ def __init__( symmetric_pad: bool = True, detect_orientation: bool = False, detect_language: bool = False, - clf_engine_cfg: EngineConfig = EngineConfig(), + clf_engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> None: self.det_predictor = det_predictor diff --git a/onnxtr/models/recognition/models/crnn.py b/onnxtr/models/recognition/models/crnn.py index 3ce0181..30fde10 100644 --- a/onnxtr/models/recognition/models/crnn.py +++ b/onnxtr/models/recognition/models/crnn.py @@ -124,7 +124,7 @@ def __init__( self, model_path: str, vocab: str, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: @@ -154,7 +154,7 @@ def _crnn( arch: str, model_path: str, load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> CRNN: kwargs["vocab"] = kwargs.get("vocab", default_cfgs[arch]["vocab"]) @@ -172,7 +172,7 @@ def _crnn( def crnn_vgg16_bn( model_path: str = default_cfgs["crnn_vgg16_bn"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> CRNN: """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based @@ -201,7 +201,7 @@ def crnn_vgg16_bn( def crnn_mobilenet_v3_small( model_path: str = default_cfgs["crnn_mobilenet_v3_small"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> CRNN: """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based @@ -230,7 +230,7 @@ def crnn_mobilenet_v3_small( def crnn_mobilenet_v3_large( model_path: str = default_cfgs["crnn_mobilenet_v3_large"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> CRNN: """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based diff --git a/onnxtr/models/recognition/models/master.py b/onnxtr/models/recognition/models/master.py index 10cadbc..91c399c 100644 --- a/onnxtr/models/recognition/models/master.py +++ b/onnxtr/models/recognition/models/master.py @@ -45,7 +45,7 @@ def __init__( self, model_path: str, vocab: str, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: @@ -116,7 +116,7 @@ def _master( arch: str, model_path: str, load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> MASTER: # Patch the config @@ -134,7 +134,7 @@ def _master( def master( model_path: str = default_cfgs["master"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> MASTER: """MASTER as described in paper: `_. diff --git a/onnxtr/models/recognition/models/parseq.py b/onnxtr/models/recognition/models/parseq.py index 8ccdec6..4253fca 100644 --- a/onnxtr/models/recognition/models/parseq.py +++ b/onnxtr/models/recognition/models/parseq.py @@ -44,7 +44,7 @@ def __init__( self, model_path: str, vocab: str, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: @@ -104,7 +104,7 @@ def _parseq( arch: str, model_path: str, load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> PARSeq: # Patch the config @@ -123,7 +123,7 @@ def _parseq( def parseq( model_path: str = default_cfgs["parseq"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> PARSeq: """PARSeq architecture from diff --git a/onnxtr/models/recognition/models/sar.py b/onnxtr/models/recognition/models/sar.py index 7ea2415..efbc61c 100644 --- a/onnxtr/models/recognition/models/sar.py +++ b/onnxtr/models/recognition/models/sar.py @@ -44,7 +44,7 @@ def __init__( self, model_path: str, vocab: str, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: @@ -103,7 +103,7 @@ def _sar( arch: str, model_path: str, load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> SAR: # Patch the config @@ -122,7 +122,7 @@ def _sar( def sar_resnet31( model_path: str = default_cfgs["sar_resnet31"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> SAR: """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong diff --git a/onnxtr/models/recognition/models/vitstr.py b/onnxtr/models/recognition/models/vitstr.py index ab37d3e..f91c807 100644 --- a/onnxtr/models/recognition/models/vitstr.py +++ b/onnxtr/models/recognition/models/vitstr.py @@ -52,7 +52,7 @@ def __init__( self, model_path: str, vocab: str, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, cfg: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: @@ -114,7 +114,7 @@ def _vitstr( arch: str, model_path: str, load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> ViTSTR: # Patch the config @@ -133,7 +133,7 @@ def _vitstr( def vitstr_small( model_path: str = default_cfgs["vitstr_small"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> ViTSTR: """ViTSTR-Small as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition" @@ -162,7 +162,7 @@ def vitstr_small( def vitstr_base( model_path: str = default_cfgs["vitstr_base"]["url"], load_in_8_bit: bool = False, - engine_cfg: EngineConfig = EngineConfig(), + engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> ViTSTR: """ViTSTR-Base as described in `"Vision Transformer for Fast and Efficient Scene Text Recognition" diff --git a/onnxtr/models/recognition/zoo.py b/onnxtr/models/recognition/zoo.py index d237290..58e1706 100644 --- a/onnxtr/models/recognition/zoo.py +++ b/onnxtr/models/recognition/zoo.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any, List +from typing import Any, List, Optional from .. import recognition from ..engine import EngineConfig @@ -26,7 +26,7 @@ def _predictor( - arch: Any, load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any + arch: Any, load_in_8_bit: bool = False, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any ) -> RecognitionPredictor: if isinstance(arch, str): if arch not in ARCHS: @@ -50,7 +50,7 @@ def _predictor( def recognition_predictor( - arch: Any = "crnn_vgg16_bn", load_in_8_bit: bool = False, engine_cfg: EngineConfig = EngineConfig(), **kwargs: Any + arch: Any = "crnn_vgg16_bn", load_in_8_bit: bool = False, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any ) -> RecognitionPredictor: """Text recognition architecture. diff --git a/onnxtr/models/zoo.py b/onnxtr/models/zoo.py index f681306..5700bb8 100644 --- a/onnxtr/models/zoo.py +++ b/onnxtr/models/zoo.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any +from typing import Any, Optional from .detection.zoo import detection_predictor from .engine import EngineConfig @@ -19,15 +19,15 @@ def _predictor( assume_straight_pages: bool = True, preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, - det_bs: int = 4, - reco_bs: int = 1024, + det_bs: int = 2, + reco_bs: int = 512, detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, load_in_8_bit: bool = False, - det_engine_cfg: EngineConfig = EngineConfig(), - reco_engine_cfg: EngineConfig = EngineConfig(), - clf_engine_cfg: EngineConfig = EngineConfig(), + det_engine_cfg: Optional[EngineConfig] = None, + reco_engine_cfg: Optional[EngineConfig] = None, + clf_engine_cfg: Optional[EngineConfig] = None, **kwargs, ) -> OCRPredictor: # Detection @@ -74,9 +74,9 @@ def ocr_predictor( straighten_pages: bool = False, detect_language: bool = False, load_in_8_bit: bool = False, - det_engine_cfg: EngineConfig = EngineConfig(), - reco_engine_cfg: EngineConfig = EngineConfig(), - clf_engine_cfg: EngineConfig = EngineConfig(), + det_engine_cfg: Optional[EngineConfig] = None, + reco_engine_cfg: Optional[EngineConfig] = None, + clf_engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, ) -> OCRPredictor: """End-to-end OCR architecture using one model for localization, and another for text recognition.