Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sync] Change Resize kwargs to args for each zoo predictor #45

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions onnxtr/models/classification/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _orientation_predictor(

def crop_orientation_predictor(
arch: Any = "mobilenet_v3_small_crop_orientation",
batch_size: int = 512,
load_in_8_bit: bool = False,
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
Expand All @@ -66,6 +67,7 @@ def crop_orientation_predictor(
Args:
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation')
batch_size: number of samples the model processes in parallel
load_in_8_bit: load the 8-bit quantized version of the model
engine_cfg: configuration of inference engine
**kwargs: keyword arguments to be passed to the OrientationPredictor
Expand All @@ -75,11 +77,19 @@ def crop_orientation_predictor(
OrientationPredictor
"""
model_type = "crop"
return _orientation_predictor(arch, model_type, load_in_8_bit, engine_cfg, **kwargs)
return _orientation_predictor(
arch=arch,
batch_size=batch_size,
model_type=model_type,
load_in_8_bit=load_in_8_bit,
engine_cfg=engine_cfg,
**kwargs,
)


def page_orientation_predictor(
arch: Any = "mobilenet_v3_small_page_orientation",
batch_size: int = 2,
load_in_8_bit: bool = False,
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
Expand All @@ -95,6 +105,7 @@ def page_orientation_predictor(
Args:
----
arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation')
batch_size: number of samples the model processes in parallel
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: keyword arguments to be passed to the OrientationPredictor
Expand All @@ -104,4 +115,11 @@ def page_orientation_predictor(
OrientationPredictor
"""
model_type = "page"
return _orientation_predictor(arch, model_type, load_in_8_bit, engine_cfg, **kwargs)
return _orientation_predictor(
arch=arch,
batch_size=batch_size,
model_type=model_type,
load_in_8_bit=load_in_8_bit,
engine_cfg=engine_cfg,
**kwargs,
)
18 changes: 17 additions & 1 deletion onnxtr/models/detection/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def _predictor(
def detection_predictor(
arch: Any = "fast_base",
assume_straight_pages: bool = True,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
batch_size: int = 2,
load_in_8_bit: bool = False,
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
Expand All @@ -75,6 +78,10 @@ def detection_predictor(
----
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
assume_straight_pages: If True, fit straight boxes to the page
preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before
running the detection model on it
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
batch_size: number of samples the model processes in parallel
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration for the inference engine
**kwargs: optional keyword arguments passed to the architecture
Expand All @@ -83,4 +90,13 @@ def detection_predictor(
-------
Detection predictor
"""
return _predictor(arch, assume_straight_pages, load_in_8_bit, engine_cfg=engine_cfg, **kwargs)
return _predictor(
arch=arch,
assume_straight_pages=assume_straight_pages,
preserve_aspect_ratio=preserve_aspect_ratio,
symmetric_pad=symmetric_pad,
batch_size=batch_size,
load_in_8_bit=load_in_8_bit,
engine_cfg=engine_cfg,
**kwargs,
)
1 change: 1 addition & 0 deletions onnxtr/models/preprocessor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class PreProcessor(NestedObject):
batch_size: the size of page batches
mean: mean value of the training distribution by channel
std: standard deviation of the training distribution by channel
**kwargs: additional arguments for the resizing operation
"""

_children_names: List[str] = ["resize", "normalize"]
Expand Down
18 changes: 16 additions & 2 deletions onnxtr/models/recognition/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ def _predictor(


def recognition_predictor(
arch: Any = "crnn_vgg16_bn", load_in_8_bit: bool = False, engine_cfg: Optional[EngineConfig] = None, **kwargs: Any
arch: Any = "crnn_vgg16_bn",
symmetric_pad: bool = False,
batch_size: int = 128,
load_in_8_bit: bool = False,
engine_cfg: Optional[EngineConfig] = None,
**kwargs: Any,
) -> RecognitionPredictor:
"""Text recognition architecture.

Expand All @@ -64,6 +69,8 @@ def recognition_predictor(
Args:
----
arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn')
symmetric_pad: if True, pad the image symmetrically instead of padding at the bottom-right
batch_size: number of samples the model processes in parallel
load_in_8_bit: whether to load the the 8-bit quantized model, defaults to False
engine_cfg: configuration of inference engine
**kwargs: optional parameters to be passed to the architecture
Expand All @@ -72,4 +79,11 @@ def recognition_predictor(
-------
Recognition predictor
"""
return _predictor(arch, load_in_8_bit, engine_cfg, **kwargs)
return _predictor(
arch=arch,
symmetric_pad=symmetric_pad,
batch_size=batch_size,
load_in_8_bit=load_in_8_bit,
engine_cfg=engine_cfg,
**kwargs,
)
4 changes: 3 additions & 1 deletion tests/common/test_models_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def test_detection_models(arch_name, input_shape, output_size, out_prob, quantiz
)
def test_detection_zoo(arch_name, quantized):
# Model
predictor = detection.zoo.detection_predictor(arch_name, load_in_8_bit=quantized)
predictor = detection.zoo.detection_predictor(
arch_name, load_in_8_bit=quantized, preserve_aspect_ratio=False, symmetric_pad=False
)
# object check
assert isinstance(predictor, DetectionPredictor)
input_array = np.random.rand(2, 3, 1024, 1024).astype(np.float32)
Expand Down
Loading