diff --git a/doctr/file_utils.py b/doctr/file_utils.py index 68e9dfffac..a96529d6fb 100644 --- a/doctr/file_utils.py +++ b/doctr/file_utils.py @@ -14,7 +14,7 @@ CLASS_NAME: str = "words" -__all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"] +__all__ = ["is_tf_available", "is_torch_available", "does_torch_have_compile_capability", "is_pytorch_backend_available", "requires_package", "CLASS_NAME"] ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) @@ -22,7 +22,6 @@ USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() - if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: _torch_available = importlib.util.find_spec("torch") is not None if _torch_available: @@ -76,6 +75,18 @@ " is installed and that either USE_TF or USE_TORCH is enabled." ) +if _torch_available: + import torch + _torch_has_compile = hasattr(torch, "compile") + _torch_has_backend = False + + if _torch_has_compile and hasattr(torch.library, 'custom_op'): + from torch.utils._triton import has_triton + _torch_has_backend = has_triton() +else: + _torch_has_compile = False + _torch_has_backend = False + def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover """ @@ -104,3 +115,11 @@ def is_torch_available(): def is_tf_available(): """Whether TensorFlow is installed.""" return _tf_available + +def does_torch_have_compile_capability(): + """Whether Pytorch has compile support.""" + return _torch_has_compile + +def is_pytorch_backend_available(): + """Whether Triton is installed.""" + return _torch_has_backend diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index 9368bb225d..a0ed9a6a62 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -5,7 +5,7 @@ from typing import Any, List -from doctr.file_utils import is_tf_available +from doctr.file_utils import is_tf_available, is_pytorch_backend_available from .. import classification from ..preprocessor import PreProcessor @@ -43,7 +43,17 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4) + kwargs["compile"] = kwargs.get("compile", False) + kwargs["compile_kwargs"] = kwargs.get("compile_kwargs", {'fullgraph': True, 'dynamic': False}) input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:] + + if is_pytorch_backend_available() and kwargs["compile"]: + import torch + _model = torch.compile(_model, **kwargs["compile_kwargs"]) + + del kwargs["compile"] + del kwargs["compile_kwargs"] + predictor = OrientationPredictor( PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model ) diff --git a/doctr/models/detection/_utils/__init__.py b/doctr/models/detection/_utils/__init__.py index fbeba301bc..1ce2b036b8 100644 --- a/doctr/models/detection/_utils/__init__.py +++ b/doctr/models/detection/_utils/__init__.py @@ -4,4 +4,4 @@ if is_tf_available(): from .tensorflow import * else: - from .pytorch import * + from .pytorch import * \ No newline at end of file diff --git a/doctr/models/detection/fast/base.py b/doctr/models/detection/fast/base.py index f98981a82e..409dfcebe9 100644 --- a/doctr/models/detection/fast/base.py +++ b/doctr/models/detection/fast/base.py @@ -109,7 +109,7 @@ def bitmap_to_boxes( boxes: List[Union[np.ndarray, List[float]]] = [] # get contours from connected components on the bitmap contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - for contour in contours: + for _, contour in enumerate(contours): # Check whether smallest enclosing bounding box is not too small if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index] continue diff --git a/doctr/models/detection/linknet/base.py b/doctr/models/detection/linknet/base.py index d677048c0e..371e090605 100644 --- a/doctr/models/detection/linknet/base.py +++ b/doctr/models/detection/linknet/base.py @@ -54,7 +54,7 @@ def polygon_to_box( """ if not self.assume_straight_pages: # Compute the rectangle polygon enclosing the raw polygon - rect = cv2.minAreaRect(points) + rect = cv2.minAreaRect,(points) points = cv2.boxPoints(rect) # Add 1 pixel to correct cv2 approx area = (rect[1][0] + 1) * (1 + rect[1][1]) diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index e89009fd8d..1777b0052f 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -5,7 +5,7 @@ from typing import Any, List -from doctr.file_utils import is_tf_available, is_torch_available +from doctr.file_utils import is_tf_available, is_torch_available, is_pytorch_backend_available from .. import detection from ..detection.fast import reparameterize @@ -68,6 +68,16 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 2) + kwargs["compile"] = kwargs.get("compile", False) + kwargs["compile_kwargs"] = kwargs.get("compile_kwargs", {}) + + if is_pytorch_backend_available() and kwargs["compile"]: + import torch + _model = torch.compile(_model, **kwargs["compile_kwargs"]) + + del kwargs["compile"] + del kwargs["compile_kwargs"] + predictor = DetectionPredictor( PreProcessor(_model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs), _model, diff --git a/doctr/models/recognition/parseq/pytorch.py b/doctr/models/recognition/parseq/pytorch.py index 8fff062da9..b79cf2777c 100644 --- a/doctr/models/recognition/parseq/pytorch.py +++ b/doctr/models/recognition/parseq/pytorch.py @@ -180,7 +180,7 @@ def generate_permutations(self, seqlen: torch.Tensor) -> torch.Tensor: # Borrowed from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py # with small modifications - max_num_chars = int(seqlen.max().item()) # get longest sequence length in batch + max_num_chars = int(seqlen.max().numpy().item()) # get longest sequence length in batch perms = [torch.arange(max_num_chars, device=seqlen.device)] max_perms = math.factorial(max_num_chars) // 2 @@ -266,7 +266,8 @@ def decode_autoregressive(self, features: torch.Tensor, max_len: Optional[int] = ).int() pos_logits = [] - for i in range(max_length): + i = 0 + while i < max_length: # Decode one token at a time without providing information about the future tokens tgt_out = self.decode( ys[:, : i + 1], @@ -283,8 +284,9 @@ def decode_autoregressive(self, features: torch.Tensor, max_len: Optional[int] = # Stop decoding if all sequences have reached the EOS token # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export - if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all(): - break + i += (not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all())*max_length + + i += 1 logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1) @@ -322,7 +324,7 @@ def forward( # Build target tensor _gt, _seq_len = self.build_target(target) gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long).to(x.device), torch.tensor(_seq_len).to(x.device) - gt = gt[:, : int(seq_len.max().item()) + 2] # slice up to the max length of the batch + 2 (SOS + EOS) + gt = gt[:, : int(seq_len.max().numpy().item()) + 2] # slice up to the max length of the batch + 2 (SOS + EOS) if self.training: # Generate permutations for the target sequences @@ -338,7 +340,7 @@ def forward( loss = torch.tensor(0.0, device=features.device) loss_numel: Union[int, float] = 0 - n = (gt_out != self.vocab_size + 2).sum().item() + n = (gt_out != self.vocab_size + 2).sum().numpy().item() for i, perm in enumerate(tgt_perms): _, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len) # combine both masks @@ -351,7 +353,7 @@ def forward( # remove the [EOS] tokens for the succeeding perms if i == 1: gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out) - n = (gt_out != self.vocab_size + 2).sum().item() + n = (gt_out != self.vocab_size + 2).sum().numpy().item() loss /= loss_numel @@ -406,7 +408,7 @@ def __call__( ] # compute probabilties for each word up to the EOS token probs = [ - preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values) + preds_prob[i, : len(word)].clip(0, 1).mean().tolist() if word else 0.0 for i, word in enumerate(word_values) ] return list(zip(word_values, probs)) diff --git a/doctr/models/recognition/zoo.py b/doctr/models/recognition/zoo.py index 0393240431..e78258983a 100644 --- a/doctr/models/recognition/zoo.py +++ b/doctr/models/recognition/zoo.py @@ -5,7 +5,7 @@ from typing import Any, List -from doctr.file_utils import is_tf_available +from doctr.file_utils import is_tf_available, is_pytorch_backend_available from doctr.models.preprocessor import PreProcessor from .. import recognition @@ -46,7 +46,17 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 128) + kwargs["compile"] = kwargs.get("compile", False) + kwargs["compile_kwargs"] = kwargs.get("compile_kwargs", {}) input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:] + + if is_pytorch_backend_available() and kwargs["compile"]: + import torch + _model = torch.compile(_model, **kwargs["compile_kwargs"]) + + del kwargs["compile"] + del kwargs["compile_kwargs"] + predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model) return predictor diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py index eff5fe14c4..8fd5237239 100644 --- a/doctr/models/zoo.py +++ b/doctr/models/zoo.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any +from typing import Any, Dict from .detection.zoo import detection_predictor from .kie_predictor import KIEPredictor @@ -26,6 +26,8 @@ def _predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, + compile: bool = False, + compile_kwargs: Dict[str, Any] = {}, **kwargs, ) -> OCRPredictor: # Detection @@ -37,6 +39,8 @@ def _predictor( assume_straight_pages=assume_straight_pages, preserve_aspect_ratio=preserve_aspect_ratio, symmetric_pad=symmetric_pad, + compile=compile, + compile_kwargs=compile_kwargs, ) # Recognition @@ -45,6 +49,8 @@ def _predictor( pretrained=pretrained, pretrained_backbone=pretrained_backbone, batch_size=reco_bs, + compile=compile, + compile_kwargs=compile_kwargs, ) return OCRPredictor( @@ -72,6 +78,8 @@ def ocr_predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, + compile: bool = False, + compile_kwargs: Dict[str, Any] = {}, **kwargs: Any, ) -> OCRPredictor: """End-to-end OCR architecture using one model for localization, and another for text recognition. @@ -105,6 +113,9 @@ def ocr_predictor( Doing so will improve performances for documents with page-uniform rotations. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. + compile: if True, the predictor will try to compile the model. + May cause slowdowns on first use. Only available for pytorch. + compile_kwargs: the arguments to be passed if compile is enabled. kwargs: keyword args of `OCRPredictor` Returns: @@ -123,6 +134,8 @@ def ocr_predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, + compile=compile, + compile_kwargs=compile_kwargs, **kwargs, ) diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py index d2dbe5087a..464ba2663c 100644 --- a/tests/pytorch/test_models_classification_pt.py +++ b/tests/pytorch/test_models_classification_pt.py @@ -7,6 +7,7 @@ import pytest import torch +from doctr.file_utils import is_pytorch_backend_available, does_torch_have_compile_capability from doctr.models import classification from doctr.models.classification.predictor import OrientationPredictor from doctr.models.utils import export_model_to_onnx @@ -193,3 +194,49 @@ def test_models_onnx_export(arch_name, input_shape, output_size): assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) except AssertionError: pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") + +@pytest.mark.skipif(not does_torch_have_compile_capability(), reason="requires pytorch >= 2.0.0") +@pytest.mark.skipif(not is_pytorch_backend_available(), reason="requires pytorch backend to be available") +@pytest.mark.parametrize("fullgraph", [True, False]) +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["vgg16_bn_r", (3, 32, 32)], + ["resnet18", (3, 32, 32)], + ["resnet31", (3, 32, 32)], + ["resnet34", (3, 32, 32)], + ["resnet34_wide", (3, 32, 32)], + ["resnet50", (3, 32, 32)], + ["magc_resnet31", (3, 32, 32)], + ["mobilenet_v3_small", (3, 32, 32)], + ["mobilenet_v3_large", (3, 32, 32)], + ["mobilenet_v3_small_crop_orientation", (3, 256, 256)], + ["mobilenet_v3_small_page_orientation", (3, 512, 512)], + ["vit_s", (3, 32, 32)], + ["vit_b", (3, 32, 32)], + ["textnet_tiny", (3, 32, 32)], + ["textnet_small", (3, 32, 32)], + ["textnet_base", (3, 32, 32)], + ], +) +def test_models_pytorch_compile(arch_name, input_shape, fullgraph): + # General Check that the model can be compiled + try: + assert torch.compile(classification.__dict__[arch_name](pretrained=True).eval(), fullgraph=fullgraph) + except: + pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nModel is failing pytorch compilation") + # Model + batch_size = 2 + model = classification.__dict__[arch_name](pretrained=True, exportable=True).eval() + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) + pt_logits = model(dummy_input)["logits"].detach().cpu().numpy() + + compiled_model = torch.compile(model, fullgraph=fullgraph) + pt_logits_compiled = compiled_model(dummy_input)["logits"].detach().cpu().numpy() + + assert pt_logits_compiled.shape == pt_logits.shape + # Check that the output is close to the "original" output + try: + assert np.allclose(pt_logits, pt_logits_compiled, atol=1e-4) + except AssertionError: + pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nMax element-wise difference: {np.max(np.abs(pt_logits - pt_logits_compiled))}") \ No newline at end of file diff --git a/tests/pytorch/test_models_detection_pt.py b/tests/pytorch/test_models_detection_pt.py index 247cdb2880..77b67a2ff0 100644 --- a/tests/pytorch/test_models_detection_pt.py +++ b/tests/pytorch/test_models_detection_pt.py @@ -7,7 +7,7 @@ import pytest import torch -from doctr.file_utils import CLASS_NAME +from doctr.file_utils import CLASS_NAME, is_pytorch_backend_available, does_torch_have_compile_capability from doctr.models import detection from doctr.models.detection._utils import dilate, erode from doctr.models.detection.fast.pytorch import reparameterize @@ -186,3 +186,46 @@ def test_models_onnx_export(arch_name, input_shape, output_size): assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) except AssertionError: pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") + +@pytest.mark.skipif(not does_torch_have_compile_capability(), reason="requires pytorch >= 2.0.0") +@pytest.mark.skipif(not is_pytorch_backend_available(), reason="requires pytorch backend to be available") +@pytest.mark.parametrize("fullgraph", [True, False]) +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["db_resnet34", (3, 512, 512)], + ["db_resnet50", (3, 512, 512)], + ["db_mobilenet_v3_large", (3, 512, 512)], + ["linknet_resnet18", (3, 512, 512)], + ["linknet_resnet34", (3, 512, 512)], + ["linknet_resnet50", (3, 512, 512)], + ["fast_tiny", (3, 512, 512)], + ["fast_small", (3, 512, 512)], + ["fast_base", (3, 512, 512)], + ["fast_tiny_rep", (3, 512, 512)], # Reparameterized model + ], +) +def test_models_pytorch_compile(arch_name, input_shape, fullgraph): + # General Check that the model can be compiled + try: + assert torch.compile(detection.__dict__[arch_name](pretrained=True).eval(), fullgraph=fullgraph) + except: + pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nModel is failing pytorch compilation") + # Model + batch_size = 2 + if arch_name == "fast_tiny_rep": + model = reparameterize(detection.fast_tiny(pretrained=True, exportable=True).eval()) + else: + model = detection.__dict__[arch_name](pretrained=True, exportable=True).eval() + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) + pt_logits = model(dummy_input)["logits"].detach().cpu().numpy() + + compiled_model = torch.compile(model, fullgraph=fullgraph) + pt_logits_compiled = compiled_model(dummy_input)["logits"].detach().cpu().numpy() + + assert pt_logits_compiled.shape == pt_logits.shape + # Check that the output is close to the "original" output + try: + assert np.allclose(pt_logits, pt_logits_compiled, atol=1e-4) + except AssertionError: + pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nMax element-wise difference: {np.max(np.abs(pt_logits - pt_logits_compiled))}") \ No newline at end of file diff --git a/tests/pytorch/test_models_recognition_pt.py b/tests/pytorch/test_models_recognition_pt.py index e4df34060b..43e518ef9c 100644 --- a/tests/pytorch/test_models_recognition_pt.py +++ b/tests/pytorch/test_models_recognition_pt.py @@ -7,6 +7,7 @@ import pytest import torch +from doctr.file_utils import is_pytorch_backend_available, does_torch_have_compile_capability from doctr.models import recognition from doctr.models.recognition.crnn.pytorch import CTCPostProcessor from doctr.models.recognition.master.pytorch import MASTERPostProcessor @@ -154,3 +155,46 @@ def test_models_onnx_export(arch_name, input_shape): assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) except AssertionError: pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") + +@pytest.mark.skipif(not does_torch_have_compile_capability(), reason="requires pytorch >= 2.0.0") +@pytest.mark.skipif(not is_pytorch_backend_available(), reason="requires pytorch backend to be available") +@pytest.mark.parametrize("fullgraph", [True, False]) +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["crnn_vgg16_bn", (3, 32, 128)], + ["crnn_mobilenet_v3_small", (3, 32, 128)], + ["crnn_mobilenet_v3_large", (3, 32, 128)], + pytest.param( + "sar_resnet31", + (3, 32, 128), + marks=pytest.mark.skipif(system_available_memory < 16, reason="not enough memory"), + ), + pytest.param( + "master", (3, 32, 128), marks=pytest.mark.skipif(system_available_memory < 16, reason="not enough memory") + ), + ["vitstr_small", (3, 32, 128)], # testing one vitstr version is enough + ["parseq", (3, 32, 128)], + ], +) +def test_models_pytorch_compile(arch_name, input_shape, fullgraph): + # General Check that the model can be compiled + try: + assert torch.compile(recognition.__dict__[arch_name](pretrained=True).eval(), fullgraph=fullgraph) + except: + pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nModel is failing pytorch compilation") + # Model + batch_size = 2 + model = recognition.__dict__[arch_name](pretrained=True, exportable=True).eval() + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) + pt_logits = model(dummy_input)["logits"].detach().cpu().numpy() + + compiled_model = torch.compile(model, fullgraph=fullgraph) + pt_logits_compiled = compiled_model(dummy_input)["logits"].detach().cpu().numpy() + + assert pt_logits_compiled.shape == pt_logits.shape + # Check that the output is close to the "original" output + try: + assert np.allclose(pt_logits, pt_logits_compiled, atol=1e-4) + except AssertionError: + pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nMax element-wise difference: {np.max(np.abs(pt_logits - pt_logits_compiled))}") \ No newline at end of file