Skip to content

Commit

Permalink
predictors change kwargs to args
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Nov 4, 2024
1 parent 676a457 commit 7e09517
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 6 deletions.
22 changes: 20 additions & 2 deletions onnxtr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _orientation_predictor(

def crop_orientation_predictor(
arch: Any = "mobilenet_v3_small_crop_orientation",
batch_size: int = 512,
load_in_8_bit: bool = False,
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
Expand All @@ -66,6 +67,7 @@ def crop_orientation_predictor(
Args:
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
batch_size: number of samples the model processes in parallel
load_in_8_bit: load the 8-bit quantized version of the model
engine_cfg: configuration of inference engine
**kwargs: keyword arguments to be passed to the OrientationPredictor
Expand All @@ -75,11 +77,19 @@ def crop_orientation_predictor(
OrientationPredictor
"""
model_type = "crop"
return _orientation_predictor(arch, model_type, load_in_8_bit, engine_cfg, **kwargs)
return _orientation_predictor(
arch=arch,
batch_size=batch_size,
model_type=model_type,
load_in_8_bit=load_in_8_bit,
engine_cfg=engine_cfg,
**kwargs,
)


def page_orientation_predictor(
arch: Any = "mobilenet_v3_small_page_orientation",
batch_size: int = 2,
load_in_8_bit: bool = False,
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
Expand All @@ -95,6 +105,7 @@ def page_orientation_predictor(
Args:
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
batch_size: number of samples the model processes in parallel
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments to be passed to the OrientationPredictor
Expand All @@ -104,4 +115,11 @@ def page_orientation_predictor(
OrientationPredictor
"""
model_type = "page"
return _orientation_predictor(arch, model_type, load_in_8_bit, engine_cfg, **kwargs)
return _orientation_predictor(
arch=arch,
batch_size=batch_size,
model_type=model_type,
load_in_8_bit=load_in_8_bit,
engine_cfg=engine_cfg,
**kwargs,
)
18 changes: 17 additions & 1 deletion onnxtr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def _predictor(
def detection_predictor(
arch: Any = "fast_base",
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
batch_size: int = 2,
load_in_8_bit: bool = False,
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
Expand All @@ -75,6 +78,10 @@ def detection_predictor(
----
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
assume_straight_pages: If True, fit straight boxes to the page
preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
running the detection model on it
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
batch_size: number of samples the model processes in parallel
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: optional keyword arguments passed to the architecture
Expand All @@ -83,4 +90,13 @@ def detection_predictor(
-------
Detection predictor
"""
return _predictor(arch, assume_straight_pages, load_in_8_bit, engine_cfg=engine_cfg, **kwargs)
return _predictor(
arch=arch,
assume_straight_pages=assume_straight_pages,
preserve_aspect_ratio=preserve_aspect_ratio,
symmetric_pad=symmetric_pad,
batch_size=batch_size,
load_in_8_bit=load_in_8_bit,
engine_cfg=engine_cfg,
**kwargs,
)
1 change: 1 addition & 0 deletions onnxtr/models/preprocessor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class PreProcessor(NestedObject):
batch_size: the size of page batches
mean: mean value of the training distribution by channel
std: standard deviation of the training distribution by channel
**kwargs: additional arguments for the resizing operation
"""

_children_names: List[str] = ["resize", "normalize"]
Expand Down
18 changes: 16 additions & 2 deletions onnxtr/models/recognition/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ def _predictor(


def recognition_predictor(
arch: Any = "crnn_vgg16_bn", load_in_8_bit: bool = False, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any
arch: Any = "crnn_vgg16_bn",
symmetric_pad: bool = False,
batch_size: int = 128,
load_in_8_bit: bool = False,
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> RecognitionPredictor:
"""Text recognition architecture.
Expand All @@ -64,6 +69,8 @@ def recognition_predictor(
Args:
----
arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
batch_size: number of samples the model processes in parallel
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration of inference engine
**kwargs: optional parameters to be passed to the architecture
Expand All @@ -72,4 +79,11 @@ def recognition_predictor(
-------
Recognition predictor
"""
return _predictor(arch, load_in_8_bit, engine_cfg, **kwargs)
return _predictor(
arch=arch,
symmetric_pad=symmetric_pad,
batch_size=batch_size,
load_in_8_bit=load_in_8_bit,
engine_cfg=engine_cfg,
**kwargs,
)
4 changes: 3 additions & 1 deletion tests/common/test_models_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def test_detection_models(arch_name, input_shape, output_size, out_prob, quantiz
)
def test_detection_zoo(arch_name, quantized):
# Model
predictor = detection.zoo.detection_predictor(arch_name, load_in_8_bit=quantized)
predictor = detection.zoo.detection_predictor(
arch_name, load_in_8_bit=quantized, preserve_aspect_ratio=False, symmetric_pad=False
)
# object check
assert isinstance(predictor, DetectionPredictor)
input_array = np.random.rand(2, 3, 1024, 1024).astype(np.float32)
Expand Down

0 comments on commit 7e09517

Please sign in to comment.