From 1bc21f82e3dac98d81c38c3b159d45422c525495 Mon Sep 17 00:00:00 2001 From: Fabioomega Date: Mon, 5 Aug 2024 16:15:35 -0300 Subject: [PATCH 01/11] Added support for pytorch.compile on >=Pytorch 2.4 version --- doctr/file_utils.py | 30 ++++++++- doctr/models/classification/zoo.py | 7 +- doctr/models/detection/_utils/__init__.py | 9 ++- doctr/models/detection/_utils/cv2_fallback.py | 21 ++++++ .../detection/_utils/pytorch_compile.py | 67 +++++++++++++++++++ doctr/models/detection/core.py | 5 +- .../differentiable_binarization/base.py | 13 ++-- doctr/models/detection/fast/base.py | 15 +++-- doctr/models/detection/linknet/base.py | 11 +-- doctr/models/detection/zoo.py | 7 +- doctr/models/recognition/zoo.py | 7 +- 11 files changed, 167 insertions(+), 25 deletions(-) create mode 100644 doctr/models/detection/_utils/cv2_fallback.py create mode 100644 doctr/models/detection/_utils/pytorch_compile.py diff --git a/doctr/file_utils.py b/doctr/file_utils.py index 68e9dfffac..10a0886c0f 100644 --- a/doctr/file_utils.py +++ b/doctr/file_utils.py @@ -14,13 +14,14 @@ CLASS_NAME: str = "words" -__all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"] +__all__ = ["is_tf_available", "is_torch_available", "does_torch_have_compile_capability", "is_triton_available", "requires_package", "CLASS_NAME"] ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() +USE_TRITON = os.environ.get("USE_TRITON", "NO").upper() if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: @@ -76,6 +77,25 @@ " is installed and that either USE_TF or USE_TORCH is enabled." ) +if _torch_available: + import torch + _torch_has_compile = hasattr(torch, "compile") + _torch_has_triton = False + + if _torch_has_compile and hasattr(torch.library, 'custom_op'): + from torch.utils._triton import has_triton + if USE_TRITON in ENV_VARS_TRUE_AND_AUTO_VALUES: + if has_triton(): + logging.info("Triton detected!") + _torch_has_triton = True + elif USE_TRITON == 'AUTO': + logging.info("Triton was not found! Continuing without it!") + else: + logging.warn("Triton was not found even tough it was requested by the user!") +else: + _torch_has_compile = False + _torch_has_triton = False + def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover """ @@ -104,3 +124,11 @@ def is_torch_available(): def is_tf_available(): """Whether TensorFlow is installed.""" return _tf_available + +def does_torch_have_compile_capability(): + """Whether Pytorch has compile support.""" + return _torch_has_compile + +def is_triton_available(): + """Whether Triton is installed.""" + return _torch_has_triton diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index 9368bb225d..9b536c5688 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -5,7 +5,7 @@ from typing import Any, List -from doctr.file_utils import is_tf_available +from doctr.file_utils import is_tf_available, is_triton_available from .. import classification from ..preprocessor import PreProcessor @@ -44,6 +44,11 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4) input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:] + + if is_triton_available(): + import torch + _model = torch.compile(_model, fullgraph=True, dynamic=False) + predictor = OrientationPredictor( PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model ) diff --git a/doctr/models/detection/_utils/__init__.py b/doctr/models/detection/_utils/__init__.py index fbeba301bc..4281dcacea 100644 --- a/doctr/models/detection/_utils/__init__.py +++ b/doctr/models/detection/_utils/__init__.py @@ -1,7 +1,14 @@ -from doctr.file_utils import is_tf_available +from doctr.file_utils import is_tf_available, is_triton_available from .base import * if is_tf_available(): from .tensorflow import * + from .cv2_fallback import * else: from .pytorch import * + import torch + if is_triton_available: + from .pytorch_compile import * + else: + from .cv2_fallback import * + del torch \ No newline at end of file diff --git a/doctr/models/detection/_utils/cv2_fallback.py b/doctr/models/detection/_utils/cv2_fallback.py new file mode 100644 index 0000000000..ed017f62e2 --- /dev/null +++ b/doctr/models/detection/_utils/cv2_fallback.py @@ -0,0 +1,21 @@ +import cv2 +from typing import Sequence +import numpy as np +from typing import Tuple + +__all__ = ['boundingRect', 'minAreaRect', 'fillPoly', 'morphologyEx'] + +def boundingRect(array: cv2.typing.MatLike) -> Sequence[int]: + return cv2.boundingRect(array) + +def minAreaRect(mat: cv2.typing.MatLike) -> Tuple[Sequence[float], Sequence[float], float]: + return cv2.minAreaRect(mat) + +def fillPoly(img: cv2.typing.MatLike, pts: Sequence[cv2.typing.MatLike], color: cv2.typing.Scalar) -> None: + return cv2.fillPoly(img, pts, color) + +def morphologyEx(src: np.ndarray, op: int, kernel: np.ndarray) -> np.ndarray: + return cv2.morphologyEx(src, op, kernel) + +# def boxPoints(box: cv2.typing.RotatedRect) -> np.ndarray: +# return cv2.boxPoints(box) \ No newline at end of file diff --git a/doctr/models/detection/_utils/pytorch_compile.py b/doctr/models/detection/_utils/pytorch_compile.py new file mode 100644 index 0000000000..870e788932 --- /dev/null +++ b/doctr/models/detection/_utils/pytorch_compile.py @@ -0,0 +1,67 @@ +import torch +from torch import Tensor +import cv2 +from typing import List, Sequence, Tuple +import numpy as np +import torch._dynamo.config + +__all__ = [ 'boundingRect', 'minAreaRect', 'fillPoly', 'morphologyEx'] + +torch._dynamo.config.cache_size_limit = 30 + +def morphologyEx(src: np.ndarray, op: int, kernel: np.ndarray) -> np.ndarray: + return _morphologyEx(torch.from_numpy(src), op, torch.from_numpy(kernel)).numpy() +# Register a custom_op for the morphologyEx +@torch.library.custom_op("cv2::morphologyEx", mutates_args=()) +def _morphologyEx(src: torch.Tensor, op: int, kernel: torch.Tensor) -> torch.Tensor: + return torch.from_numpy(cv2.morphologyEx(src.numpy(), op, kernel.numpy())) +# Register the FakeTensor as having the same size as the src +@_morphologyEx.register_fake +def _(src, op, kernel): + return src + +def boundingRect(array: cv2.typing.MatLike) -> Sequence[int]: + return tuple(_boundingRect(Tensor(array))) + +@torch.library.custom_op('cv2::boundingRect', mutates_args=()) +def _boundingRect(array: Tensor) -> Tensor: + return torch.LongTensor(cv2.boundingRect(array.numpy())) + +@_boundingRect.register_fake +def _(array): + return torch.empty((1, 4)) + +def minAreaRect(mat: cv2.typing.MatLike) -> Tuple[Sequence[float], Sequence[float], float]: + packed = _minAreaRect(torch.from_numpy(mat)) + k = list(map(lambda x: x.numpy(), packed.split_with_sizes((2, 2, 1)))) + k[-1] = k[-1].item() + return k + +@torch.library.custom_op('cv2::minAreaRect', mutates_args=()) +def _minAreaRect(mat: Tensor) -> Tensor: + point, size, rot = cv2.minAreaRect(mat.numpy()) + return torch.FloatTensor([point[0], point[1], size[0], size[1], rot]) + +@_minAreaRect.register_fake +def _(mat): + return torch.empty([5]) + +def fillPoly(img: cv2.typing.MatLike, pts: Sequence[cv2.typing.MatLike], color: cv2.typing.Scalar) -> None: + _fillPoly(torch.from_numpy(img), torch.from_numpy(np.array(pts)), color) + +@torch.library.custom_op('cv2::fillPoly', mutates_args=({'img'})) +def _fillPoly(img: Tensor, pts: Tensor, color: float) -> None: + cv2.fillPoly(img.numpy(), [p.numpy() for p in pts], color) + +# def boxPoints(box: cv2.typing.RotatedRect) -> cv2.typing.MatLike: +# point, size, rot = box +# return _boxPoints(torch.FloatTensor([point[0], point[1], size[0], size[1], rot])).numpy() + +# @torch.library.custom_op('cv2::boxPoints', mutates_args=()) +# def _boxPoints(box: Tensor) -> Tensor: +# b = box.tolist() +# return torch.from_numpy(cv2.boxPoints(((b[0], b[1]), (b[2], b[3]), b[4]))) + +# @_boxPoints.register_fake +# def _(box): +# return torch.empty([4, 2]) \ No newline at end of file diff --git a/doctr/models/detection/core.py b/doctr/models/detection/core.py index 63fa786151..3dc46a8daf 100644 --- a/doctr/models/detection/core.py +++ b/doctr/models/detection/core.py @@ -9,6 +9,7 @@ import numpy as np from doctr.utils.repr import NestedObject +from ._utils import morphologyEx, fillPoly __all__ = ["DetectionPostProcessor"] @@ -57,7 +58,7 @@ def box_score(pred: np.ndarray, points: np.ndarray, assume_straight_pages: bool else: mask: np.ndarray = np.zeros((h, w), np.int32) - cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload] + fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload] product = pred * mask return np.sum(product) / np.count_nonzero(product) @@ -89,7 +90,7 @@ def __call__( # Erosion + dilation on the binary map bin_map = [ [ - cv2.morphologyEx(bmap[..., idx], cv2.MORPH_OPEN, self._opening_kernel) + morphologyEx(bmap[..., idx], cv2.MORPH_OPEN, self._opening_kernel) for idx in range(proba_map.shape[-1]) ] for bmap in (proba_map >= self.bin_thresh).astype(np.uint8) diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 21eceb7940..daf7c60d61 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -13,6 +13,7 @@ from shapely.geometry import Polygon from ..core import DetectionPostProcessor +from .._utils import boundingRect, minAreaRect, fillPoly __all__ = ["DBPostProcessor"] @@ -56,7 +57,7 @@ def polygon_to_box( """ if not self.assume_straight_pages: # Compute the rectangle polygon enclosing the raw polygon - rect = cv2.minAreaRect(points) + rect = minAreaRect(points) points = cv2.boxPoints(rect) # Add 1 pixel to correct cv2 approx area = (rect[1][0] + 1) * (1 + rect[1][1]) @@ -83,9 +84,9 @@ def polygon_to_box( if len(expanded_points) < 1: return None # type: ignore[return-value] return ( - cv2.boundingRect(expanded_points) # type: ignore[return-value] + boundingRect(expanded_points) # type: ignore[return-value] if self.assume_straight_pages - else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0) + else np.roll(cv2.boxPoints(minAreaRect(expanded_points)), -1, axis=0) ) def bitmap_to_boxes( @@ -118,7 +119,7 @@ def bitmap_to_boxes( continue # Compute objectness if self.assume_straight_pages: - x, y, w, h = cv2.boundingRect(contour) + x, y, w, h = boundingRect(contour) points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) score = self.box_score(pred, points, assume_straight_pages=True) else: @@ -235,7 +236,7 @@ def draw_thresh_map( padded_polygon: np.ndarray = np.array(padding.Execute(distance)[0]) # Fill the mask with 1 on the new padded polygon - cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) # type: ignore[call-overload] + fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) # type: ignore[call-overload] # Get min/max to recover polygon after distance computation xmin = padded_polygon[:, 0].min() @@ -354,7 +355,7 @@ def build_target( if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue - cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] + fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] # Draw on both thresh map and thresh mask poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map( diff --git a/doctr/models/detection/fast/base.py b/doctr/models/detection/fast/base.py index f98981a82e..386269acd4 100644 --- a/doctr/models/detection/fast/base.py +++ b/doctr/models/detection/fast/base.py @@ -15,6 +15,7 @@ from doctr.models.core import BaseModel from ..core import DetectionPostProcessor +from .._utils import boundingRect, minAreaRect, fillPoly __all__ = ["_FAST", "FASTPostProcessor"] @@ -54,7 +55,7 @@ def polygon_to_box( """ if not self.assume_straight_pages: # Compute the rectangle polygon enclosing the raw polygon - rect = cv2.minAreaRect(points) + rect = minAreaRect(points) points = cv2.boxPoints(rect) # Add 1 pixel to correct cv2 approx area = (rect[1][0] + 1) * (1 + rect[1][1]) @@ -81,9 +82,9 @@ def polygon_to_box( if len(expanded_points) < 1: return None # type: ignore[return-value] return ( - cv2.boundingRect(expanded_points) # type: ignore[return-value] + boundingRect(expanded_points) # type: ignore[return-value] if self.assume_straight_pages - else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0) + else np.roll(cv2.boxPoints(minAreaRect(expanded_points)), -1, axis=0) ) def bitmap_to_boxes( @@ -109,13 +110,13 @@ def bitmap_to_boxes( boxes: List[Union[np.ndarray, List[float]]] = [] # get contours from connected components on the bitmap contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - for contour in contours: + for _, contour in enumerate(contours): # Check whether smallest enclosing bounding box is not too small if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index] continue # Compute objectness if self.assume_straight_pages: - x, y, w, h = cv2.boundingRect(contour) + x, y, w, h = boundingRect(contour) points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) score = self.box_score(pred, points, assume_straight_pages=True) else: @@ -244,9 +245,9 @@ def build_target( if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue - cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] + fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] # draw the original polygon on the segmentation target - cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0) # type: ignore[call-overload] + fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0) # type: ignore[call-overload] # Don't forget to switch back to channel last if Tensorflow is used if channels_last: diff --git a/doctr/models/detection/linknet/base.py b/doctr/models/detection/linknet/base.py index d677048c0e..af031032ff 100644 --- a/doctr/models/detection/linknet/base.py +++ b/doctr/models/detection/linknet/base.py @@ -15,6 +15,7 @@ from doctr.models.core import BaseModel from ..core import DetectionPostProcessor +from .._utils import boundingRect, minAreaRect, fillPoly __all__ = ["_LinkNet", "LinkNetPostProcessor"] @@ -54,7 +55,7 @@ def polygon_to_box( """ if not self.assume_straight_pages: # Compute the rectangle polygon enclosing the raw polygon - rect = cv2.minAreaRect(points) + rect = minAreaRect(points) points = cv2.boxPoints(rect) # Add 1 pixel to correct cv2 approx area = (rect[1][0] + 1) * (1 + rect[1][1]) @@ -81,9 +82,9 @@ def polygon_to_box( if len(expanded_points) < 1: return None # type: ignore[return-value] return ( - cv2.boundingRect(expanded_points) # type: ignore[return-value] + boundingRect(expanded_points) # type: ignore[return-value] if self.assume_straight_pages - else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0) + else np.roll(cv2.boxPoints(minAreaRect(expanded_points)), -1, axis=0) ) def bitmap_to_boxes( @@ -115,7 +116,7 @@ def bitmap_to_boxes( continue # Compute objectness if self.assume_straight_pages: - x, y, w, h = cv2.boundingRect(contour) + x, y, w, h = boundingRect(contour) points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) score = self.box_score(pred, points, assume_straight_pages=True) else: @@ -247,7 +248,7 @@ def build_target( if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue - cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] + fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] # Don't forget to switch back to channel last if Tensorflow is used if channels_last: diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index e89009fd8d..00888e9248 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -5,7 +5,7 @@ from typing import Any, List -from doctr.file_utils import is_tf_available, is_torch_available +from doctr.file_utils import is_tf_available, is_torch_available, is_triton_available from .. import detection from ..detection.fast import reparameterize @@ -68,6 +68,11 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 2) + + if is_triton_available(): + import torch + _model = torch.compile(_model, fullgraph=False) + predictor = DetectionPredictor( PreProcessor(_model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs), _model, diff --git a/doctr/models/recognition/zoo.py b/doctr/models/recognition/zoo.py index 0393240431..43af7cab1a 100644 --- a/doctr/models/recognition/zoo.py +++ b/doctr/models/recognition/zoo.py @@ -5,7 +5,7 @@ from typing import Any, List -from doctr.file_utils import is_tf_available +from doctr.file_utils import is_tf_available, is_triton_available from doctr.models.preprocessor import PreProcessor from .. import recognition @@ -47,6 +47,11 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 128) input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:] + + if is_triton_available(): + import torch + _model = torch.compile(_model) + predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model) return predictor From 479c8978bbb95f0bcdd4323ba1454b303c13dcad Mon Sep 17 00:00:00 2001 From: Fabioomega Date: Mon, 5 Aug 2024 16:20:01 -0300 Subject: [PATCH 02/11] Added support for controlling compiling behavior via predictors --- doctr/models/classification/zoo.py | 3 ++- doctr/models/detection/zoo.py | 3 ++- doctr/models/recognition/zoo.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index 9b536c5688..5b273cf7f8 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -43,9 +43,10 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4) + kwargs["compile"] = kwargs.get("compile", True) input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:] - if is_triton_available(): + if is_triton_available() and kwargs["compile"]: import torch _model = torch.compile(_model, fullgraph=True, dynamic=False) diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index 00888e9248..1dd34ab521 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -68,8 +68,9 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 2) + kwargs["compile"] = kwargs.get("compile", True) - if is_triton_available(): + if is_triton_available() and kwargs["compile"]: import torch _model = torch.compile(_model, fullgraph=False) diff --git a/doctr/models/recognition/zoo.py b/doctr/models/recognition/zoo.py index 43af7cab1a..d8aadb8ae9 100644 --- a/doctr/models/recognition/zoo.py +++ b/doctr/models/recognition/zoo.py @@ -46,9 +46,10 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 128) + kwargs["compile"] = kwargs.get("compile", True) input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:] - if is_triton_available(): + if is_triton_available() and kwargs["compile"]: import torch _model = torch.compile(_model) From f8a5ee6683f9a9995082884f235ba3e1bd8b162b Mon Sep 17 00:00:00 2001 From: Fabioomega Date: Mon, 5 Aug 2024 16:22:19 -0300 Subject: [PATCH 03/11] Added fullgraph support for parseq --- doctr/models/recognition/parseq/pytorch.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/doctr/models/recognition/parseq/pytorch.py b/doctr/models/recognition/parseq/pytorch.py index 8fff062da9..b79cf2777c 100644 --- a/doctr/models/recognition/parseq/pytorch.py +++ b/doctr/models/recognition/parseq/pytorch.py @@ -180,7 +180,7 @@ def generate_permutations(self, seqlen: torch.Tensor) -> torch.Tensor: # Borrowed from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py # with small modifications - max_num_chars = int(seqlen.max().item()) # get longest sequence length in batch + max_num_chars = int(seqlen.max().numpy().item()) # get longest sequence length in batch perms = [torch.arange(max_num_chars, device=seqlen.device)] max_perms = math.factorial(max_num_chars) // 2 @@ -266,7 +266,8 @@ def decode_autoregressive(self, features: torch.Tensor, max_len: Optional[int] = ).int() pos_logits = [] - for i in range(max_length): + i = 0 + while i < max_length: # Decode one token at a time without providing information about the future tokens tgt_out = self.decode( ys[:, : i + 1], @@ -283,8 +284,9 @@ def decode_autoregressive(self, features: torch.Tensor, max_len: Optional[int] = # Stop decoding if all sequences have reached the EOS token # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export - if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all(): - break + i += (not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all())*max_length + + i += 1 logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1) @@ -322,7 +324,7 @@ def forward( # Build target tensor _gt, _seq_len = self.build_target(target) gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long).to(x.device), torch.tensor(_seq_len).to(x.device) - gt = gt[:, : int(seq_len.max().item()) + 2] # slice up to the max length of the batch + 2 (SOS + EOS) + gt = gt[:, : int(seq_len.max().numpy().item()) + 2] # slice up to the max length of the batch + 2 (SOS + EOS) if self.training: # Generate permutations for the target sequences @@ -338,7 +340,7 @@ def forward( loss = torch.tensor(0.0, device=features.device) loss_numel: Union[int, float] = 0 - n = (gt_out != self.vocab_size + 2).sum().item() + n = (gt_out != self.vocab_size + 2).sum().numpy().item() for i, perm in enumerate(tgt_perms): _, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len) # combine both masks @@ -351,7 +353,7 @@ def forward( # remove the [EOS] tokens for the succeeding perms if i == 1: gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out) - n = (gt_out != self.vocab_size + 2).sum().item() + n = (gt_out != self.vocab_size + 2).sum().numpy().item() loss /= loss_numel @@ -406,7 +408,7 @@ def __call__( ] # compute probabilties for each word up to the EOS token probs = [ - preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values) + preds_prob[i, : len(word)].clip(0, 1).mean().tolist() if word else 0.0 for i, word in enumerate(word_values) ] return list(zip(word_values, probs)) From 60ca8eb640af9fc18a415d2072f31662e93f5695 Mon Sep 17 00:00:00 2001 From: Fabioomega Date: Mon, 5 Aug 2024 19:52:47 -0300 Subject: [PATCH 04/11] Fixed typo --- doctr/models/detection/_utils/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/doctr/models/detection/_utils/__init__.py b/doctr/models/detection/_utils/__init__.py index 4281dcacea..b83fd6fcef 100644 --- a/doctr/models/detection/_utils/__init__.py +++ b/doctr/models/detection/_utils/__init__.py @@ -6,9 +6,7 @@ from .cv2_fallback import * else: from .pytorch import * - import torch if is_triton_available: from .pytorch_compile import * else: - from .cv2_fallback import * - del torch \ No newline at end of file + from .cv2_fallback import * \ No newline at end of file From 22dd3d9db21a2f5abd5379bef78a2f917fa6495c Mon Sep 17 00:00:00 2001 From: Fabioomega Date: Thu, 8 Aug 2024 14:55:03 -0300 Subject: [PATCH 05/11] Fixed typo again --- doctr/models/detection/_utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/models/detection/_utils/__init__.py b/doctr/models/detection/_utils/__init__.py index b83fd6fcef..b70cb8dab2 100644 --- a/doctr/models/detection/_utils/__init__.py +++ b/doctr/models/detection/_utils/__init__.py @@ -6,7 +6,7 @@ from .cv2_fallback import * else: from .pytorch import * - if is_triton_available: + if is_triton_available(): from .pytorch_compile import * else: from .cv2_fallback import * \ No newline at end of file From 38e5a2d301087de0e51073e890cef4f03ba4c631 Mon Sep 17 00:00:00 2001 From: Fabioomega Date: Thu, 8 Aug 2024 15:22:55 -0300 Subject: [PATCH 06/11] Removed compile kwarg --- doctr/models/classification/zoo.py | 3 +-- doctr/models/detection/zoo.py | 3 +-- doctr/models/recognition/zoo.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index 5b273cf7f8..9b536c5688 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -43,10 +43,9 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4) - kwargs["compile"] = kwargs.get("compile", True) input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:] - if is_triton_available() and kwargs["compile"]: + if is_triton_available(): import torch _model = torch.compile(_model, fullgraph=True, dynamic=False) diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index 1dd34ab521..00888e9248 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -68,9 +68,8 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 2) - kwargs["compile"] = kwargs.get("compile", True) - if is_triton_available() and kwargs["compile"]: + if is_triton_available(): import torch _model = torch.compile(_model, fullgraph=False) diff --git a/doctr/models/recognition/zoo.py b/doctr/models/recognition/zoo.py index d8aadb8ae9..43af7cab1a 100644 --- a/doctr/models/recognition/zoo.py +++ b/doctr/models/recognition/zoo.py @@ -46,10 +46,9 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 128) - kwargs["compile"] = kwargs.get("compile", True) input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:] - if is_triton_available() and kwargs["compile"]: + if is_triton_available(): import torch _model = torch.compile(_model) From 5822d4a45e86ac9285505dbcb37fbd432e13aad0 Mon Sep 17 00:00:00 2001 From: Fabioomega Date: Tue, 13 Aug 2024 17:35:24 -0300 Subject: [PATCH 07/11] Fixed crash in boundingRect when assume_straight_pages=True --- doctr/models/detection/_utils/pytorch_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doctr/models/detection/_utils/pytorch_compile.py b/doctr/models/detection/_utils/pytorch_compile.py index 870e788932..5a11666eac 100644 --- a/doctr/models/detection/_utils/pytorch_compile.py +++ b/doctr/models/detection/_utils/pytorch_compile.py @@ -21,7 +21,7 @@ def _(src, op, kernel): return src def boundingRect(array: cv2.typing.MatLike) -> Sequence[int]: - return tuple(_boundingRect(Tensor(array))) + return tuple(_boundingRect(Tensor(array)).numpy().tolist()) @torch.library.custom_op('cv2::boundingRect', mutates_args=()) def _boundingRect(array: Tensor) -> Tensor: From 0ded837b245ea2c8ed8e931aa8b5925ec586378d Mon Sep 17 00:00:00 2001 From: Fabioomega Date: Tue, 13 Aug 2024 17:42:59 -0300 Subject: [PATCH 08/11] Added compile and compile_kwargs and removed explicit mention of triton for compatibility with future backends --- doctr/file_utils.py | 21 ++++++--------------- doctr/models/classification/zoo.py | 11 ++++++++--- doctr/models/detection/zoo.py | 11 ++++++++--- doctr/models/recognition/zoo.py | 11 ++++++++--- doctr/models/zoo.py | 15 ++++++++++++++- 5 files changed, 44 insertions(+), 25 deletions(-) diff --git a/doctr/file_utils.py b/doctr/file_utils.py index 10a0886c0f..a96529d6fb 100644 --- a/doctr/file_utils.py +++ b/doctr/file_utils.py @@ -14,15 +14,13 @@ CLASS_NAME: str = "words" -__all__ = ["is_tf_available", "is_torch_available", "does_torch_have_compile_capability", "is_triton_available", "requires_package", "CLASS_NAME"] +__all__ = ["is_tf_available", "is_torch_available", "does_torch_have_compile_capability", "is_pytorch_backend_available", "requires_package", "CLASS_NAME"] ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() -USE_TRITON = os.environ.get("USE_TRITON", "NO").upper() - if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: _torch_available = importlib.util.find_spec("torch") is not None @@ -80,21 +78,14 @@ if _torch_available: import torch _torch_has_compile = hasattr(torch, "compile") - _torch_has_triton = False + _torch_has_backend = False if _torch_has_compile and hasattr(torch.library, 'custom_op'): from torch.utils._triton import has_triton - if USE_TRITON in ENV_VARS_TRUE_AND_AUTO_VALUES: - if has_triton(): - logging.info("Triton detected!") - _torch_has_triton = True - elif USE_TRITON == 'AUTO': - logging.info("Triton was not found! Continuing without it!") - else: - logging.warn("Triton was not found even tough it was requested by the user!") + _torch_has_backend = has_triton() else: _torch_has_compile = False - _torch_has_triton = False + _torch_has_backend = False def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover @@ -129,6 +120,6 @@ def does_torch_have_compile_capability(): """Whether Pytorch has compile support.""" return _torch_has_compile -def is_triton_available(): +def is_pytorch_backend_available(): """Whether Triton is installed.""" - return _torch_has_triton + return _torch_has_backend diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py index 9b536c5688..a0ed9a6a62 100644 --- a/doctr/models/classification/zoo.py +++ b/doctr/models/classification/zoo.py @@ -5,7 +5,7 @@ from typing import Any, List -from doctr.file_utils import is_tf_available, is_triton_available +from doctr.file_utils import is_tf_available, is_pytorch_backend_available from .. import classification from ..preprocessor import PreProcessor @@ -43,11 +43,16 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4) + kwargs["compile"] = kwargs.get("compile", False) + kwargs["compile_kwargs"] = kwargs.get("compile_kwargs", {'fullgraph': True, 'dynamic': False}) input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:] - if is_triton_available(): + if is_pytorch_backend_available() and kwargs["compile"]: import torch - _model = torch.compile(_model, fullgraph=True, dynamic=False) + _model = torch.compile(_model, **kwargs["compile_kwargs"]) + + del kwargs["compile"] + del kwargs["compile_kwargs"] predictor = OrientationPredictor( PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py index 00888e9248..1777b0052f 100644 --- a/doctr/models/detection/zoo.py +++ b/doctr/models/detection/zoo.py @@ -5,7 +5,7 @@ from typing import Any, List -from doctr.file_utils import is_tf_available, is_torch_available, is_triton_available +from doctr.file_utils import is_tf_available, is_torch_available, is_pytorch_backend_available from .. import detection from ..detection.fast import reparameterize @@ -68,10 +68,15 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 2) + kwargs["compile"] = kwargs.get("compile", False) + kwargs["compile_kwargs"] = kwargs.get("compile_kwargs", {}) - if is_triton_available(): + if is_pytorch_backend_available() and kwargs["compile"]: import torch - _model = torch.compile(_model, fullgraph=False) + _model = torch.compile(_model, **kwargs["compile_kwargs"]) + + del kwargs["compile"] + del kwargs["compile_kwargs"] predictor = DetectionPredictor( PreProcessor(_model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs), diff --git a/doctr/models/recognition/zoo.py b/doctr/models/recognition/zoo.py index 43af7cab1a..e78258983a 100644 --- a/doctr/models/recognition/zoo.py +++ b/doctr/models/recognition/zoo.py @@ -5,7 +5,7 @@ from typing import Any, List -from doctr.file_utils import is_tf_available, is_triton_available +from doctr.file_utils import is_tf_available, is_pytorch_backend_available from doctr.models.preprocessor import PreProcessor from .. import recognition @@ -46,11 +46,16 @@ def _predictor(arch: Any, pretrained: bool, **kwargs: Any) -> RecognitionPredict kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"]) kwargs["std"] = kwargs.get("std", _model.cfg["std"]) kwargs["batch_size"] = kwargs.get("batch_size", 128) + kwargs["compile"] = kwargs.get("compile", False) + kwargs["compile_kwargs"] = kwargs.get("compile_kwargs", {}) input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:] - if is_triton_available(): + if is_pytorch_backend_available() and kwargs["compile"]: import torch - _model = torch.compile(_model) + _model = torch.compile(_model, **kwargs["compile_kwargs"]) + + del kwargs["compile"] + del kwargs["compile_kwargs"] predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model) diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py index eff5fe14c4..8fd5237239 100644 --- a/doctr/models/zoo.py +++ b/doctr/models/zoo.py @@ -3,7 +3,7 @@ # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. -from typing import Any +from typing import Any, Dict from .detection.zoo import detection_predictor from .kie_predictor import KIEPredictor @@ -26,6 +26,8 @@ def _predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, + compile: bool = False, + compile_kwargs: Dict[str, Any] = {}, **kwargs, ) -> OCRPredictor: # Detection @@ -37,6 +39,8 @@ def _predictor( assume_straight_pages=assume_straight_pages, preserve_aspect_ratio=preserve_aspect_ratio, symmetric_pad=symmetric_pad, + compile=compile, + compile_kwargs=compile_kwargs, ) # Recognition @@ -45,6 +49,8 @@ def _predictor( pretrained=pretrained, pretrained_backbone=pretrained_backbone, batch_size=reco_bs, + compile=compile, + compile_kwargs=compile_kwargs, ) return OCRPredictor( @@ -72,6 +78,8 @@ def ocr_predictor( detect_orientation: bool = False, straighten_pages: bool = False, detect_language: bool = False, + compile: bool = False, + compile_kwargs: Dict[str, Any] = {}, **kwargs: Any, ) -> OCRPredictor: """End-to-end OCR architecture using one model for localization, and another for text recognition. @@ -105,6 +113,9 @@ def ocr_predictor( Doing so will improve performances for documents with page-uniform rotations. detect_language: if True, the language prediction will be added to the predictions for each page. Doing so will slightly deteriorate the overall latency. + compile: if True, the predictor will try to compile the model. + May cause slowdowns on first use. Only available for pytorch. + compile_kwargs: the arguments to be passed if compile is enabled. kwargs: keyword args of `OCRPredictor` Returns: @@ -123,6 +134,8 @@ def ocr_predictor( detect_orientation=detect_orientation, straighten_pages=straighten_pages, detect_language=detect_language, + compile=compile, + compile_kwargs=compile_kwargs, **kwargs, ) From f7ae74d94043b88fc82fa7e3038141bac074d37d Mon Sep 17 00:00:00 2001 From: Fabioomega Date: Tue, 13 Aug 2024 17:55:18 -0300 Subject: [PATCH 09/11] Fixed some remaining name changes from is_triton_available to is_pytorch_backend_available --- doctr/models/detection/_utils/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doctr/models/detection/_utils/__init__.py b/doctr/models/detection/_utils/__init__.py index b70cb8dab2..18fe259f74 100644 --- a/doctr/models/detection/_utils/__init__.py +++ b/doctr/models/detection/_utils/__init__.py @@ -1,4 +1,4 @@ -from doctr.file_utils import is_tf_available, is_triton_available +from doctr.file_utils import is_tf_available, is_pytorch_backend_available from .base import * if is_tf_available(): @@ -6,7 +6,7 @@ from .cv2_fallback import * else: from .pytorch import * - if is_triton_available(): + if is_pytorch_backend_available(): from .pytorch_compile import * else: from .cv2_fallback import * \ No newline at end of file From 1e6869b9b1ba64aeefe17efdc084f45ff0a6674e Mon Sep 17 00:00:00 2001 From: Fabioomega Date: Fri, 30 Aug 2024 19:46:22 -0300 Subject: [PATCH 10/11] Removed changes to the postprocessing step and reverted back --- doctr/models/detection/_utils/__init__.py | 9 +-- doctr/models/detection/_utils/cv2_fallback.py | 21 ------ .../detection/_utils/pytorch_compile.py | 67 ------------------- doctr/models/detection/core.py | 5 +- .../differentiable_binarization/base.py | 13 ++-- doctr/models/detection/fast/base.py | 13 ++-- doctr/models/detection/linknet/base.py | 11 ++- 7 files changed, 21 insertions(+), 118 deletions(-) delete mode 100644 doctr/models/detection/_utils/cv2_fallback.py delete mode 100644 doctr/models/detection/_utils/pytorch_compile.py diff --git a/doctr/models/detection/_utils/__init__.py b/doctr/models/detection/_utils/__init__.py index 18fe259f74..1ce2b036b8 100644 --- a/doctr/models/detection/_utils/__init__.py +++ b/doctr/models/detection/_utils/__init__.py @@ -1,12 +1,7 @@ -from doctr.file_utils import is_tf_available, is_pytorch_backend_available +from doctr.file_utils import is_tf_available from .base import * if is_tf_available(): from .tensorflow import * - from .cv2_fallback import * else: - from .pytorch import * - if is_pytorch_backend_available(): - from .pytorch_compile import * - else: - from .cv2_fallback import * \ No newline at end of file + from .pytorch import * \ No newline at end of file diff --git a/doctr/models/detection/_utils/cv2_fallback.py b/doctr/models/detection/_utils/cv2_fallback.py deleted file mode 100644 index ed017f62e2..0000000000 --- a/doctr/models/detection/_utils/cv2_fallback.py +++ /dev/null @@ -1,21 +0,0 @@ -import cv2 -from typing import Sequence -import numpy as np -from typing import Tuple - -__all__ = ['boundingRect', 'minAreaRect', 'fillPoly', 'morphologyEx'] - -def boundingRect(array: cv2.typing.MatLike) -> Sequence[int]: - return cv2.boundingRect(array) - -def minAreaRect(mat: cv2.typing.MatLike) -> Tuple[Sequence[float], Sequence[float], float]: - return cv2.minAreaRect(mat) - -def fillPoly(img: cv2.typing.MatLike, pts: Sequence[cv2.typing.MatLike], color: cv2.typing.Scalar) -> None: - return cv2.fillPoly(img, pts, color) - -def morphologyEx(src: np.ndarray, op: int, kernel: np.ndarray) -> np.ndarray: - return cv2.morphologyEx(src, op, kernel) - -# def boxPoints(box: cv2.typing.RotatedRect) -> np.ndarray: -# return cv2.boxPoints(box) \ No newline at end of file diff --git a/doctr/models/detection/_utils/pytorch_compile.py b/doctr/models/detection/_utils/pytorch_compile.py deleted file mode 100644 index 5a11666eac..0000000000 --- a/doctr/models/detection/_utils/pytorch_compile.py +++ /dev/null @@ -1,67 +0,0 @@ -import torch -from torch import Tensor -import cv2 -from typing import List, Sequence, Tuple -import numpy as np -import torch._dynamo.config - -__all__ = [ 'boundingRect', 'minAreaRect', 'fillPoly', 'morphologyEx'] - -torch._dynamo.config.cache_size_limit = 30 - -def morphologyEx(src: np.ndarray, op: int, kernel: np.ndarray) -> np.ndarray: - return _morphologyEx(torch.from_numpy(src), op, torch.from_numpy(kernel)).numpy() -# Register a custom_op for the morphologyEx -@torch.library.custom_op("cv2::morphologyEx", mutates_args=()) -def _morphologyEx(src: torch.Tensor, op: int, kernel: torch.Tensor) -> torch.Tensor: - return torch.from_numpy(cv2.morphologyEx(src.numpy(), op, kernel.numpy())) -# Register the FakeTensor as having the same size as the src -@_morphologyEx.register_fake -def _(src, op, kernel): - return src - -def boundingRect(array: cv2.typing.MatLike) -> Sequence[int]: - return tuple(_boundingRect(Tensor(array)).numpy().tolist()) - -@torch.library.custom_op('cv2::boundingRect', mutates_args=()) -def _boundingRect(array: Tensor) -> Tensor: - return torch.LongTensor(cv2.boundingRect(array.numpy())) - -@_boundingRect.register_fake -def _(array): - return torch.empty((1, 4)) - -def minAreaRect(mat: cv2.typing.MatLike) -> Tuple[Sequence[float], Sequence[float], float]: - packed = _minAreaRect(torch.from_numpy(mat)) - k = list(map(lambda x: x.numpy(), packed.split_with_sizes((2, 2, 1)))) - k[-1] = k[-1].item() - return k - -@torch.library.custom_op('cv2::minAreaRect', mutates_args=()) -def _minAreaRect(mat: Tensor) -> Tensor: - point, size, rot = cv2.minAreaRect(mat.numpy()) - return torch.FloatTensor([point[0], point[1], size[0], size[1], rot]) - -@_minAreaRect.register_fake -def _(mat): - return torch.empty([5]) - -def fillPoly(img: cv2.typing.MatLike, pts: Sequence[cv2.typing.MatLike], color: cv2.typing.Scalar) -> None: - _fillPoly(torch.from_numpy(img), torch.from_numpy(np.array(pts)), color) - -@torch.library.custom_op('cv2::fillPoly', mutates_args=({'img'})) -def _fillPoly(img: Tensor, pts: Tensor, color: float) -> None: - cv2.fillPoly(img.numpy(), [p.numpy() for p in pts], color) - -# def boxPoints(box: cv2.typing.RotatedRect) -> cv2.typing.MatLike: -# point, size, rot = box -# return _boxPoints(torch.FloatTensor([point[0], point[1], size[0], size[1], rot])).numpy() - -# @torch.library.custom_op('cv2::boxPoints', mutates_args=()) -# def _boxPoints(box: Tensor) -> Tensor: -# b = box.tolist() -# return torch.from_numpy(cv2.boxPoints(((b[0], b[1]), (b[2], b[3]), b[4]))) - -# @_boxPoints.register_fake -# def _(box): -# return torch.empty([4, 2]) \ No newline at end of file diff --git a/doctr/models/detection/core.py b/doctr/models/detection/core.py index 3dc46a8daf..63fa786151 100644 --- a/doctr/models/detection/core.py +++ b/doctr/models/detection/core.py @@ -9,7 +9,6 @@ import numpy as np from doctr.utils.repr import NestedObject -from ._utils import morphologyEx, fillPoly __all__ = ["DetectionPostProcessor"] @@ -58,7 +57,7 @@ def box_score(pred: np.ndarray, points: np.ndarray, assume_straight_pages: bool else: mask: np.ndarray = np.zeros((h, w), np.int32) - fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload] + cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) # type: ignore[call-overload] product = pred * mask return np.sum(product) / np.count_nonzero(product) @@ -90,7 +89,7 @@ def __call__( # Erosion + dilation on the binary map bin_map = [ [ - morphologyEx(bmap[..., idx], cv2.MORPH_OPEN, self._opening_kernel) + cv2.morphologyEx(bmap[..., idx], cv2.MORPH_OPEN, self._opening_kernel) for idx in range(proba_map.shape[-1]) ] for bmap in (proba_map >= self.bin_thresh).astype(np.uint8) diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index daf7c60d61..21eceb7940 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -13,7 +13,6 @@ from shapely.geometry import Polygon from ..core import DetectionPostProcessor -from .._utils import boundingRect, minAreaRect, fillPoly __all__ = ["DBPostProcessor"] @@ -57,7 +56,7 @@ def polygon_to_box( """ if not self.assume_straight_pages: # Compute the rectangle polygon enclosing the raw polygon - rect = minAreaRect(points) + rect = cv2.minAreaRect(points) points = cv2.boxPoints(rect) # Add 1 pixel to correct cv2 approx area = (rect[1][0] + 1) * (1 + rect[1][1]) @@ -84,9 +83,9 @@ def polygon_to_box( if len(expanded_points) < 1: return None # type: ignore[return-value] return ( - boundingRect(expanded_points) # type: ignore[return-value] + cv2.boundingRect(expanded_points) # type: ignore[return-value] if self.assume_straight_pages - else np.roll(cv2.boxPoints(minAreaRect(expanded_points)), -1, axis=0) + else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0) ) def bitmap_to_boxes( @@ -119,7 +118,7 @@ def bitmap_to_boxes( continue # Compute objectness if self.assume_straight_pages: - x, y, w, h = boundingRect(contour) + x, y, w, h = cv2.boundingRect(contour) points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) score = self.box_score(pred, points, assume_straight_pages=True) else: @@ -236,7 +235,7 @@ def draw_thresh_map( padded_polygon: np.ndarray = np.array(padding.Execute(distance)[0]) # Fill the mask with 1 on the new padded polygon - fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) # type: ignore[call-overload] + cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) # type: ignore[call-overload] # Get min/max to recover polygon after distance computation xmin = padded_polygon[:, 0].min() @@ -355,7 +354,7 @@ def build_target( if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue - fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] + cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] # Draw on both thresh map and thresh mask poly, thresh_target[idx, class_idx], thresh_mask[idx, class_idx] = self.draw_thresh_map( diff --git a/doctr/models/detection/fast/base.py b/doctr/models/detection/fast/base.py index 386269acd4..409dfcebe9 100644 --- a/doctr/models/detection/fast/base.py +++ b/doctr/models/detection/fast/base.py @@ -15,7 +15,6 @@ from doctr.models.core import BaseModel from ..core import DetectionPostProcessor -from .._utils import boundingRect, minAreaRect, fillPoly __all__ = ["_FAST", "FASTPostProcessor"] @@ -55,7 +54,7 @@ def polygon_to_box( """ if not self.assume_straight_pages: # Compute the rectangle polygon enclosing the raw polygon - rect = minAreaRect(points) + rect = cv2.minAreaRect(points) points = cv2.boxPoints(rect) # Add 1 pixel to correct cv2 approx area = (rect[1][0] + 1) * (1 + rect[1][1]) @@ -82,9 +81,9 @@ def polygon_to_box( if len(expanded_points) < 1: return None # type: ignore[return-value] return ( - boundingRect(expanded_points) # type: ignore[return-value] + cv2.boundingRect(expanded_points) # type: ignore[return-value] if self.assume_straight_pages - else np.roll(cv2.boxPoints(minAreaRect(expanded_points)), -1, axis=0) + else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0) ) def bitmap_to_boxes( @@ -116,7 +115,7 @@ def bitmap_to_boxes( continue # Compute objectness if self.assume_straight_pages: - x, y, w, h = boundingRect(contour) + x, y, w, h = cv2.boundingRect(contour) points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) score = self.box_score(pred, points, assume_straight_pages=True) else: @@ -245,9 +244,9 @@ def build_target( if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue - fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] + cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] # draw the original polygon on the segmentation target - fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0) # type: ignore[call-overload] + cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0) # type: ignore[call-overload] # Don't forget to switch back to channel last if Tensorflow is used if channels_last: diff --git a/doctr/models/detection/linknet/base.py b/doctr/models/detection/linknet/base.py index af031032ff..371e090605 100644 --- a/doctr/models/detection/linknet/base.py +++ b/doctr/models/detection/linknet/base.py @@ -15,7 +15,6 @@ from doctr.models.core import BaseModel from ..core import DetectionPostProcessor -from .._utils import boundingRect, minAreaRect, fillPoly __all__ = ["_LinkNet", "LinkNetPostProcessor"] @@ -55,7 +54,7 @@ def polygon_to_box( """ if not self.assume_straight_pages: # Compute the rectangle polygon enclosing the raw polygon - rect = minAreaRect(points) + rect = cv2.minAreaRect,(points) points = cv2.boxPoints(rect) # Add 1 pixel to correct cv2 approx area = (rect[1][0] + 1) * (1 + rect[1][1]) @@ -82,9 +81,9 @@ def polygon_to_box( if len(expanded_points) < 1: return None # type: ignore[return-value] return ( - boundingRect(expanded_points) # type: ignore[return-value] + cv2.boundingRect(expanded_points) # type: ignore[return-value] if self.assume_straight_pages - else np.roll(cv2.boxPoints(minAreaRect(expanded_points)), -1, axis=0) + else np.roll(cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0) ) def bitmap_to_boxes( @@ -116,7 +115,7 @@ def bitmap_to_boxes( continue # Compute objectness if self.assume_straight_pages: - x, y, w, h = boundingRect(contour) + x, y, w, h = cv2.boundingRect(contour) points: np.ndarray = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) score = self.box_score(pred, points, assume_straight_pages=True) else: @@ -248,7 +247,7 @@ def build_target( if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid: seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False continue - fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] + cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload] # Don't forget to switch back to channel last if Tensorflow is used if channels_last: From 9850455ecd81eb798606b685f4a85cd495fa15d3 Mon Sep 17 00:00:00 2001 From: Fabioomega Date: Fri, 30 Aug 2024 20:43:45 -0300 Subject: [PATCH 11/11] Added test cases for classification, detection and recognition --- .../pytorch/test_models_classification_pt.py | 47 +++++++++++++++++++ tests/pytorch/test_models_detection_pt.py | 45 +++++++++++++++++- tests/pytorch/test_models_recognition_pt.py | 44 +++++++++++++++++ 3 files changed, 135 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py index d2dbe5087a..464ba2663c 100644 --- a/tests/pytorch/test_models_classification_pt.py +++ b/tests/pytorch/test_models_classification_pt.py @@ -7,6 +7,7 @@ import pytest import torch +from doctr.file_utils import is_pytorch_backend_available, does_torch_have_compile_capability from doctr.models import classification from doctr.models.classification.predictor import OrientationPredictor from doctr.models.utils import export_model_to_onnx @@ -193,3 +194,49 @@ def test_models_onnx_export(arch_name, input_shape, output_size): assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) except AssertionError: pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") + +@pytest.mark.skipif(not does_torch_have_compile_capability(), reason="requires pytorch >= 2.0.0") +@pytest.mark.skipif(not is_pytorch_backend_available(), reason="requires pytorch backend to be available") +@pytest.mark.parametrize("fullgraph", [True, False]) +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["vgg16_bn_r", (3, 32, 32)], + ["resnet18", (3, 32, 32)], + ["resnet31", (3, 32, 32)], + ["resnet34", (3, 32, 32)], + ["resnet34_wide", (3, 32, 32)], + ["resnet50", (3, 32, 32)], + ["magc_resnet31", (3, 32, 32)], + ["mobilenet_v3_small", (3, 32, 32)], + ["mobilenet_v3_large", (3, 32, 32)], + ["mobilenet_v3_small_crop_orientation", (3, 256, 256)], + ["mobilenet_v3_small_page_orientation", (3, 512, 512)], + ["vit_s", (3, 32, 32)], + ["vit_b", (3, 32, 32)], + ["textnet_tiny", (3, 32, 32)], + ["textnet_small", (3, 32, 32)], + ["textnet_base", (3, 32, 32)], + ], +) +def test_models_pytorch_compile(arch_name, input_shape, fullgraph): + # General Check that the model can be compiled + try: + assert torch.compile(classification.__dict__[arch_name](pretrained=True).eval(), fullgraph=fullgraph) + except: + pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nModel is failing pytorch compilation") + # Model + batch_size = 2 + model = classification.__dict__[arch_name](pretrained=True, exportable=True).eval() + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) + pt_logits = model(dummy_input)["logits"].detach().cpu().numpy() + + compiled_model = torch.compile(model, fullgraph=fullgraph) + pt_logits_compiled = compiled_model(dummy_input)["logits"].detach().cpu().numpy() + + assert pt_logits_compiled.shape == pt_logits.shape + # Check that the output is close to the "original" output + try: + assert np.allclose(pt_logits, pt_logits_compiled, atol=1e-4) + except AssertionError: + pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nMax element-wise difference: {np.max(np.abs(pt_logits - pt_logits_compiled))}") \ No newline at end of file diff --git a/tests/pytorch/test_models_detection_pt.py b/tests/pytorch/test_models_detection_pt.py index 247cdb2880..77b67a2ff0 100644 --- a/tests/pytorch/test_models_detection_pt.py +++ b/tests/pytorch/test_models_detection_pt.py @@ -7,7 +7,7 @@ import pytest import torch -from doctr.file_utils import CLASS_NAME +from doctr.file_utils import CLASS_NAME, is_pytorch_backend_available, does_torch_have_compile_capability from doctr.models import detection from doctr.models.detection._utils import dilate, erode from doctr.models.detection.fast.pytorch import reparameterize @@ -186,3 +186,46 @@ def test_models_onnx_export(arch_name, input_shape, output_size): assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) except AssertionError: pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") + +@pytest.mark.skipif(not does_torch_have_compile_capability(), reason="requires pytorch >= 2.0.0") +@pytest.mark.skipif(not is_pytorch_backend_available(), reason="requires pytorch backend to be available") +@pytest.mark.parametrize("fullgraph", [True, False]) +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["db_resnet34", (3, 512, 512)], + ["db_resnet50", (3, 512, 512)], + ["db_mobilenet_v3_large", (3, 512, 512)], + ["linknet_resnet18", (3, 512, 512)], + ["linknet_resnet34", (3, 512, 512)], + ["linknet_resnet50", (3, 512, 512)], + ["fast_tiny", (3, 512, 512)], + ["fast_small", (3, 512, 512)], + ["fast_base", (3, 512, 512)], + ["fast_tiny_rep", (3, 512, 512)], # Reparameterized model + ], +) +def test_models_pytorch_compile(arch_name, input_shape, fullgraph): + # General Check that the model can be compiled + try: + assert torch.compile(detection.__dict__[arch_name](pretrained=True).eval(), fullgraph=fullgraph) + except: + pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nModel is failing pytorch compilation") + # Model + batch_size = 2 + if arch_name == "fast_tiny_rep": + model = reparameterize(detection.fast_tiny(pretrained=True, exportable=True).eval()) + else: + model = detection.__dict__[arch_name](pretrained=True, exportable=True).eval() + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) + pt_logits = model(dummy_input)["logits"].detach().cpu().numpy() + + compiled_model = torch.compile(model, fullgraph=fullgraph) + pt_logits_compiled = compiled_model(dummy_input)["logits"].detach().cpu().numpy() + + assert pt_logits_compiled.shape == pt_logits.shape + # Check that the output is close to the "original" output + try: + assert np.allclose(pt_logits, pt_logits_compiled, atol=1e-4) + except AssertionError: + pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nMax element-wise difference: {np.max(np.abs(pt_logits - pt_logits_compiled))}") \ No newline at end of file diff --git a/tests/pytorch/test_models_recognition_pt.py b/tests/pytorch/test_models_recognition_pt.py index e4df34060b..43e518ef9c 100644 --- a/tests/pytorch/test_models_recognition_pt.py +++ b/tests/pytorch/test_models_recognition_pt.py @@ -7,6 +7,7 @@ import pytest import torch +from doctr.file_utils import is_pytorch_backend_available, does_torch_have_compile_capability from doctr.models import recognition from doctr.models.recognition.crnn.pytorch import CTCPostProcessor from doctr.models.recognition.master.pytorch import MASTERPostProcessor @@ -154,3 +155,46 @@ def test_models_onnx_export(arch_name, input_shape): assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) except AssertionError: pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") + +@pytest.mark.skipif(not does_torch_have_compile_capability(), reason="requires pytorch >= 2.0.0") +@pytest.mark.skipif(not is_pytorch_backend_available(), reason="requires pytorch backend to be available") +@pytest.mark.parametrize("fullgraph", [True, False]) +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["crnn_vgg16_bn", (3, 32, 128)], + ["crnn_mobilenet_v3_small", (3, 32, 128)], + ["crnn_mobilenet_v3_large", (3, 32, 128)], + pytest.param( + "sar_resnet31", + (3, 32, 128), + marks=pytest.mark.skipif(system_available_memory < 16, reason="not enough memory"), + ), + pytest.param( + "master", (3, 32, 128), marks=pytest.mark.skipif(system_available_memory < 16, reason="not enough memory") + ), + ["vitstr_small", (3, 32, 128)], # testing one vitstr version is enough + ["parseq", (3, 32, 128)], + ], +) +def test_models_pytorch_compile(arch_name, input_shape, fullgraph): + # General Check that the model can be compiled + try: + assert torch.compile(recognition.__dict__[arch_name](pretrained=True).eval(), fullgraph=fullgraph) + except: + pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nModel is failing pytorch compilation") + # Model + batch_size = 2 + model = recognition.__dict__[arch_name](pretrained=True, exportable=True).eval() + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) + pt_logits = model(dummy_input)["logits"].detach().cpu().numpy() + + compiled_model = torch.compile(model, fullgraph=fullgraph) + pt_logits_compiled = compiled_model(dummy_input)["logits"].detach().cpu().numpy() + + assert pt_logits_compiled.shape == pt_logits.shape + # Check that the output is close to the "original" output + try: + assert np.allclose(pt_logits, pt_logits_compiled, atol=1e-4) + except AssertionError: + pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nMax element-wise difference: {np.max(np.abs(pt_logits - pt_logits_compiled))}") \ No newline at end of file