diff --git a/onnxtr/models/classification/zoo.py b/onnxtr/models/classification/zoo.py index 313d11c..40644f4 100644 --- a/onnxtr/models/classification/zoo.py +++ b/onnxtr/models/classification/zoo.py @@ -51,6 +51,7 @@ def _orientation_predictor( def crop_orientation_predictor( arch: Any = "mobilenet_v3_small_crop_orientation", + batch_size: int = 512, load_in_8_bit: bool = False, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, @@ -66,6 +67,7 @@ def crop_orientation_predictor( Args: ---- arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation') + batch_size: number of samples the model processes in parallel load_in_8_bit: load the 8-bit quantized version of the model engine_cfg: configuration of inference engine **kwargs: keyword arguments to be passed to the OrientationPredictor @@ -75,11 +77,19 @@ def crop_orientation_predictor( OrientationPredictor """ model_type = "crop" - return _orientation_predictor(arch, model_type, load_in_8_bit, engine_cfg, **kwargs) + return _orientation_predictor( + arch=arch, + batch_size=batch_size, + model_type=model_type, + load_in_8_bit=load_in_8_bit, + engine_cfg=engine_cfg, + **kwargs, + ) def page_orientation_predictor( arch: Any = "mobilenet_v3_small_page_orientation", + batch_size: int = 2, load_in_8_bit: bool = False, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, @@ -95,6 +105,7 @@ def page_orientation_predictor( Args: ---- arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation') + batch_size: number of samples the model processes in parallel load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: keyword arguments to be passed to the OrientationPredictor @@ -104,4 +115,11 @@ def page_orientation_predictor( OrientationPredictor """ model_type = "page" - return _orientation_predictor(arch, model_type, load_in_8_bit, engine_cfg, **kwargs) + return _orientation_predictor( + arch=arch, + batch_size=batch_size, + model_type=model_type, + load_in_8_bit=load_in_8_bit, + engine_cfg=engine_cfg, + **kwargs, + ) diff --git a/onnxtr/models/detection/zoo.py b/onnxtr/models/detection/zoo.py index ca42cfe..38e7e56 100644 --- a/onnxtr/models/detection/zoo.py +++ b/onnxtr/models/detection/zoo.py @@ -59,6 +59,9 @@ def _predictor( def detection_predictor( arch: Any = "fast_base", assume_straight_pages: bool = True, + preserve_aspect_ratio: bool = True, + symmetric_pad: bool = True, + batch_size: int = 2, load_in_8_bit: bool = False, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any, @@ -75,6 +78,10 @@ def detection_predictor( ---- arch: name of the architecture or model itself to use (e.g. 'db_resnet50') assume_straight_pages: If True, fit straight boxes to the page + preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before + running the detection model on it + symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right + batch_size: number of samples the model processes in parallel load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration for the inference engine **kwargs: optional keyword arguments passed to the architecture @@ -83,4 +90,13 @@ def detection_predictor( ------- Detection predictor """ - return _predictor(arch, assume_straight_pages, load_in_8_bit, engine_cfg=engine_cfg, **kwargs) + return _predictor( + arch=arch, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + symmetric_pad=symmetric_pad, + batch_size=batch_size, + load_in_8_bit=load_in_8_bit, + engine_cfg=engine_cfg, + **kwargs, + ) diff --git a/onnxtr/models/preprocessor/base.py b/onnxtr/models/preprocessor/base.py index 546a56a..a1c31d8 100644 --- a/onnxtr/models/preprocessor/base.py +++ b/onnxtr/models/preprocessor/base.py @@ -25,6 +25,7 @@ class PreProcessor(NestedObject): batch_size: the size of page batches mean: mean value of the training distribution by channel std: standard deviation of the training distribution by channel + **kwargs: additional arguments for the resizing operation """ _children_names: List[str] = ["resize", "normalize"] diff --git a/onnxtr/models/recognition/zoo.py b/onnxtr/models/recognition/zoo.py index 58e1706..8eb8bd3 100644 --- a/onnxtr/models/recognition/zoo.py +++ b/onnxtr/models/recognition/zoo.py @@ -50,7 +50,12 @@ def _predictor( def recognition_predictor( - arch: Any = "crnn_vgg16_bn", load_in_8_bit: bool = False, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any + arch: Any = "crnn_vgg16_bn", + symmetric_pad: bool = False, + batch_size: int = 128, + load_in_8_bit: bool = False, + engine_cfg: Optional[EngineConfig] = None, + **kwargs: Any, ) -> RecognitionPredictor: """Text recognition architecture. @@ -64,6 +69,8 @@ def recognition_predictor( Args: ---- arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn') + symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right + batch_size: number of samples the model processes in parallel load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False engine_cfg: configuration of inference engine **kwargs: optional parameters to be passed to the architecture @@ -72,4 +79,11 @@ def recognition_predictor( ------- Recognition predictor """ - return _predictor(arch, load_in_8_bit, engine_cfg, **kwargs) + return _predictor( + arch=arch, + symmetric_pad=symmetric_pad, + batch_size=batch_size, + load_in_8_bit=load_in_8_bit, + engine_cfg=engine_cfg, + **kwargs, + ) diff --git a/tests/common/test_models_detection.py b/tests/common/test_models_detection.py index 3b7ce0b..dffb361 100644 --- a/tests/common/test_models_detection.py +++ b/tests/common/test_models_detection.py @@ -116,7 +116,9 @@ def test_detection_models(arch_name, input_shape, output_size, out_prob, quantiz ) def test_detection_zoo(arch_name, quantized): # Model - predictor = detection.zoo.detection_predictor(arch_name, load_in_8_bit=quantized) + predictor = detection.zoo.detection_predictor( + arch_name, load_in_8_bit=quantized, preserve_aspect_ratio=False, symmetric_pad=False + ) # object check assert isinstance(predictor, DetectionPredictor) input_array = np.random.rand(2, 3, 1024, 1024).astype(np.float32)