Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.compile support for pytorch 2.4 #1690

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
23 changes: 21 additions & 2 deletions doctr/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
CLASS_NAME: str = "words"


__all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"]
__all__ = ["is_tf_available", "is_torch_available", "does_torch_have_compile_capability", "is_pytorch_backend_available", "requires_package", "CLASS_NAME"]

ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})

USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()


if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
_torch_available = importlib.util.find_spec("torch") is not None
if _torch_available:
Expand Down Expand Up @@ -76,6 +75,18 @@
" is installed and that either USE_TF or USE_TORCH is enabled."
)

if _torch_available:
Copy link
Contributor

@felixdittrich92 felixdittrich92 Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fabioomega We can remove this.

2 options:

We pin the lower boundary to >= 2.0.0 here

"torch>=1.12.0,<3.0.0",

"torch>=1.12.0,<3.0.0",

and torchvision>=0.15.0

or we mention in the docs that this requires >= 2.0.0 for compile and >=2.4.0 for compile + fullgraph

@odulcy-mindee wdyt ?
We are already at 2.4.0 so i would prefer the >=2.0.0 pin (in this case only to mention >=2.4.0 for fullgraph (triton) support)

import torch
_torch_has_compile = hasattr(torch, "compile")
_torch_has_backend = False

if _torch_has_compile and hasattr(torch.library, 'custom_op'):
from torch.utils._triton import has_triton
_torch_has_backend = has_triton()
else:
_torch_has_compile = False
_torch_has_backend = False


def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover
"""
Expand Down Expand Up @@ -104,3 +115,11 @@ def is_torch_available():
def is_tf_available():
"""Whether TensorFlow is installed."""
return _tf_available

def does_torch_have_compile_capability():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove :)

"""Whether Pytorch has compile support."""
return _torch_has_compile

def is_pytorch_backend_available():
"""Whether Triton is installed."""
return _torch_has_backend
12 changes: 11 additions & 1 deletion doctr/models/classification/zoo.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be reverted complete

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Any, List

from doctr.file_utils import is_tf_available
from doctr.file_utils import is_tf_available, is_pytorch_backend_available

from .. import classification
from ..preprocessor import PreProcessor
Expand Down Expand Up @@ -43,7 +43,17 @@ def _orientation_predictor(arch: str, pretrained: bool, **kwargs: Any) -> Orient
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 128 if "crop" in arch else 4)
kwargs["compile"] = kwargs.get("compile", False)
kwargs["compile_kwargs"] = kwargs.get("compile_kwargs", {'fullgraph': True, 'dynamic': False})
input_shape = _model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:]

if is_pytorch_backend_available() and kwargs["compile"]:
import torch
_model = torch.compile(_model, **kwargs["compile_kwargs"])

del kwargs["compile"]
del kwargs["compile_kwargs"]

predictor = OrientationPredictor(
PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), _model
)
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/_utils/__init__.py
Fabioomega marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
if is_tf_available():
from .tensorflow import *
else:
from .pytorch import *
from .pytorch import *
2 changes: 1 addition & 1 deletion doctr/models/detection/fast/base.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert :)

Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def bitmap_to_boxes(
boxes: List[Union[np.ndarray, List[float]]] = []
# get contours from connected components on the bitmap
contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
for _, contour in enumerate(contours):
# Check whether smallest enclosing bounding box is not too small
if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < 2): # type: ignore[index]
continue
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/detection/linknet/base.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert :)

Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def polygon_to_box(
"""
if not self.assume_straight_pages:
# Compute the rectangle polygon enclosing the raw polygon
rect = cv2.minAreaRect(points)
rect = cv2.minAreaRect,(points)
points = cv2.boxPoints(rect)
# Add 1 pixel to correct cv2 approx
area = (rect[1][0] + 1) * (1 + rect[1][1])
Expand Down
12 changes: 11 additions & 1 deletion doctr/models/detection/zoo.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert :)

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Any, List

from doctr.file_utils import is_tf_available, is_torch_available
from doctr.file_utils import is_tf_available, is_torch_available, is_pytorch_backend_available

from .. import detection
from ..detection.fast import reparameterize
Expand Down Expand Up @@ -68,6 +68,16 @@
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 2)
kwargs["compile"] = kwargs.get("compile", False)
kwargs["compile_kwargs"] = kwargs.get("compile_kwargs", {})

if is_pytorch_backend_available() and kwargs["compile"]:
import torch
_model = torch.compile(_model, **kwargs["compile_kwargs"])

Check notice on line 77 in doctr/models/detection/zoo.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/detection/zoo.py#L77

Trailing whitespace
del kwargs["compile"]
del kwargs["compile_kwargs"]

predictor = DetectionPredictor(
PreProcessor(_model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs),
_model,
Expand Down
18 changes: 10 additions & 8 deletions doctr/models/recognition/parseq/pytorch.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert :)

Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
# Borrowed from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
# with small modifications

max_num_chars = int(seqlen.max().item()) # get longest sequence length in batch
max_num_chars = int(seqlen.max().numpy().item()) # get longest sequence length in batch
perms = [torch.arange(max_num_chars, device=seqlen.device)]

max_perms = math.factorial(max_num_chars) // 2
Expand Down Expand Up @@ -266,7 +266,8 @@
).int()

pos_logits = []
for i in range(max_length):
i = 0
while i < max_length:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember there was a issue with while loops by exporting to onnx so we have to be careful here (needs to be checked)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it because it was some unecessary complication related to breaks in torch.compile. Changing to a while loop and changing the logic a bit helped. Hopefully it works for the onnx also

# Decode one token at a time without providing information about the future tokens
tgt_out = self.decode(
ys[:, : i + 1],
Expand All @@ -283,8 +284,9 @@

# Stop decoding if all sequences have reached the EOS token
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
break
i += (not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all())*max_length

Check notice on line 288 in doctr/models/recognition/parseq/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/recognition/parseq/pytorch.py#L288

Trailing whitespace
i += 1

logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)

Expand Down Expand Up @@ -322,7 +324,7 @@
# Build target tensor
_gt, _seq_len = self.build_target(target)
gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long).to(x.device), torch.tensor(_seq_len).to(x.device)
gt = gt[:, : int(seq_len.max().item()) + 2] # slice up to the max length of the batch + 2 (SOS + EOS)
gt = gt[:, : int(seq_len.max().numpy().item()) + 2] # slice up to the max length of the batch + 2 (SOS + EOS)

if self.training:
# Generate permutations for the target sequences
Expand All @@ -338,7 +340,7 @@

loss = torch.tensor(0.0, device=features.device)
loss_numel: Union[int, float] = 0
n = (gt_out != self.vocab_size + 2).sum().item()
n = (gt_out != self.vocab_size + 2).sum().numpy().item()
for i, perm in enumerate(tgt_perms):
_, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
# combine both masks
Expand All @@ -351,7 +353,7 @@
# remove the [EOS] tokens for the succeeding perms
if i == 1:
gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out)
n = (gt_out != self.vocab_size + 2).sum().item()
n = (gt_out != self.vocab_size + 2).sum().numpy().item()

loss /= loss_numel

Expand Down Expand Up @@ -406,7 +408,7 @@
]
# compute probabilties for each word up to the EOS token
probs = [
preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values)
preds_prob[i, : len(word)].clip(0, 1).mean().tolist() if word else 0.0 for i, word in enumerate(word_values)
]

return list(zip(word_values, probs))
Expand Down
12 changes: 11 additions & 1 deletion doctr/models/recognition/zoo.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert :)

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing import Any, List

from doctr.file_utils import is_tf_available
from doctr.file_utils import is_tf_available, is_pytorch_backend_available
from doctr.models.preprocessor import PreProcessor

from .. import recognition
Expand Down Expand Up @@ -46,7 +46,17 @@
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 128)
kwargs["compile"] = kwargs.get("compile", False)
kwargs["compile_kwargs"] = kwargs.get("compile_kwargs", {})
input_shape = _model.cfg["input_shape"][:2] if is_tf_available() else _model.cfg["input_shape"][-2:]

if is_pytorch_backend_available() and kwargs["compile"]:
import torch
_model = torch.compile(_model, **kwargs["compile_kwargs"])

Check notice on line 56 in doctr/models/recognition/zoo.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/recognition/zoo.py#L56

Trailing whitespace
del kwargs["compile"]
del kwargs["compile_kwargs"]

predictor = RecognitionPredictor(PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), _model)

return predictor
Expand Down
15 changes: 14 additions & 1 deletion doctr/models/zoo.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hame some questions about that! Wasn't the original ideia to add a new argument to enable compilation? Did I misunderstood?

Copy link
Contributor

@felixT2K felixT2K Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was the first thought as your code looked like changes to the pipeline/models were needed. However, we then saw that these were not needed.
Which is why we only add tests here and a section on how to use it. The compilation therefore remains on the user side, which is at the same time much more flexible. :)
Additional this avoids to add a arg which at the end only does -> model = torch.compile(model, ..) and is backend depending (PyTorch).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A full sample would look like then for example:

import requests
import torch
from doctr.models import ocr_predictor, parseq, fast_base
from doctr.io import DocumentFile

bytes_data = requests.get(
    "https://i1.rgstatic.net/publication/231831562_Another_Boring_Day_in_Paradise_Rock_and_Roll_and_the_Empowerment_of_Everyday_Life/links/57d02a2408ae601b39a05636/largepreview.png"
).content

doc = DocumentFile.from_images([bytes_data])

rec_model = torch.compile(parseq(pretrained=True))
det_model = torch.compile(fast_base(pretrained=True))
predictor = ocr_predictor(det_arch=det_model, reco_arch=rec_model, pretrained=True)

res = predictor(doc)
res.show()

The only required change here would be to allow also:
torch._dynamo.eval_frame.OptimizedModule in

arch, (recognition.CRNN, recognition.SAR, recognition.MASTER, recognition.ViTSTR, recognition.PARSeq)
and
if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
and
if not isinstance(arch, classification.MobileNetV3):

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any
from typing import Any, Dict

from .detection.zoo import detection_predictor
from .kie_predictor import KIEPredictor
Expand All @@ -26,6 +26,8 @@
detect_orientation: bool = False,
straighten_pages: bool = False,
detect_language: bool = False,
compile: bool = False,

Check warning on line 29 in doctr/models/zoo.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/zoo.py#L29

Redefining built-in 'compile'
compile_kwargs: Dict[str, Any] = {},
**kwargs,
) -> OCRPredictor:
# Detection
Expand All @@ -37,6 +39,8 @@
assume_straight_pages=assume_straight_pages,
preserve_aspect_ratio=preserve_aspect_ratio,
symmetric_pad=symmetric_pad,
compile=compile,
compile_kwargs=compile_kwargs,
)

# Recognition
Expand All @@ -45,6 +49,8 @@
pretrained=pretrained,
pretrained_backbone=pretrained_backbone,
batch_size=reco_bs,
compile=compile,
compile_kwargs=compile_kwargs,
)

return OCRPredictor(
Expand Down Expand Up @@ -72,6 +78,8 @@
detect_orientation: bool = False,
straighten_pages: bool = False,
detect_language: bool = False,
compile: bool = False,
compile_kwargs: Dict[str, Any] = {},
**kwargs: Any,
) -> OCRPredictor:
"""End-to-end OCR architecture using one model for localization, and another for text recognition.
Expand Down Expand Up @@ -105,6 +113,9 @@
Doing so will improve performances for documents with page-uniform rotations.
detect_language: if True, the language prediction will be added to the predictions for each
page. Doing so will slightly deteriorate the overall latency.
compile: if True, the predictor will try to compile the model.
May cause slowdowns on first use. Only available for pytorch.
compile_kwargs: the arguments to be passed if compile is enabled.
kwargs: keyword args of `OCRPredictor`

Returns:
Expand All @@ -123,6 +134,8 @@
detect_orientation=detect_orientation,
straighten_pages=straighten_pages,
detect_language=detect_language,
compile=compile,
compile_kwargs=compile_kwargs,
**kwargs,
)

Expand Down
47 changes: 47 additions & 0 deletions tests/pytorch/test_models_classification_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import torch

from doctr.file_utils import is_pytorch_backend_available, does_torch_have_compile_capability
from doctr.models import classification
from doctr.models.classification.predictor import OrientationPredictor
from doctr.models.utils import export_model_to_onnx
Expand Down Expand Up @@ -193,3 +194,49 @@
assert np.allclose(pt_logits, ort_outs[0], atol=1e-4)
except AssertionError:
pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}")

@pytest.mark.skipif(not does_torch_have_compile_capability(), reason="requires pytorch >= 2.0.0")
@pytest.mark.skipif(not is_pytorch_backend_available(), reason="requires pytorch backend to be available")
@pytest.mark.parametrize("fullgraph", [True, False])
@pytest.mark.parametrize(
"arch_name, input_shape",
[
["vgg16_bn_r", (3, 32, 32)],
["resnet18", (3, 32, 32)],
["resnet31", (3, 32, 32)],
["resnet34", (3, 32, 32)],
["resnet34_wide", (3, 32, 32)],
["resnet50", (3, 32, 32)],
["magc_resnet31", (3, 32, 32)],
["mobilenet_v3_small", (3, 32, 32)],
["mobilenet_v3_large", (3, 32, 32)],
["mobilenet_v3_small_crop_orientation", (3, 256, 256)],
["mobilenet_v3_small_page_orientation", (3, 512, 512)],
["vit_s", (3, 32, 32)],
["vit_b", (3, 32, 32)],
["textnet_tiny", (3, 32, 32)],
["textnet_small", (3, 32, 32)],
["textnet_base", (3, 32, 32)],
],
)
def test_models_pytorch_compile(arch_name, input_shape, fullgraph):
# General Check that the model can be compiled
try:
assert torch.compile(classification.__dict__[arch_name](pretrained=True).eval(), fullgraph=fullgraph)
except:

Check warning on line 226 in tests/pytorch/test_models_classification_pt.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

tests/pytorch/test_models_classification_pt.py#L226

No exception type(s) specified
pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nModel is failing pytorch compilation")
# Model
batch_size = 2
model = classification.__dict__[arch_name](pretrained=True, exportable=True).eval()
dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32)
pt_logits = model(dummy_input)["logits"].detach().cpu().numpy()

compiled_model = torch.compile(model, fullgraph=fullgraph)
pt_logits_compiled = compiled_model(dummy_input)["logits"].detach().cpu().numpy()

assert pt_logits_compiled.shape == pt_logits.shape
# Check that the output is close to the "original" output
try:
assert np.allclose(pt_logits, pt_logits_compiled, atol=1e-4)
except AssertionError:
pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nMax element-wise difference: {np.max(np.abs(pt_logits - pt_logits_compiled))}")
45 changes: 44 additions & 1 deletion tests/pytorch/test_models_detection_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import torch

from doctr.file_utils import CLASS_NAME
from doctr.file_utils import CLASS_NAME, is_pytorch_backend_available, does_torch_have_compile_capability
from doctr.models import detection
from doctr.models.detection._utils import dilate, erode
from doctr.models.detection.fast.pytorch import reparameterize
Expand Down Expand Up @@ -186,3 +186,46 @@
assert np.allclose(pt_logits, ort_outs[0], atol=1e-4)
except AssertionError:
pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}")

@pytest.mark.skipif(not does_torch_have_compile_capability(), reason="requires pytorch >= 2.0.0")
@pytest.mark.skipif(not is_pytorch_backend_available(), reason="requires pytorch backend to be available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the first two skipif Ci runs always on latest pytorch - same for the other tests 👍

@pytest.mark.parametrize("fullgraph", [True, False])
@pytest.mark.parametrize(
"arch_name, input_shape",
[
["db_resnet34", (3, 512, 512)],
["db_resnet50", (3, 512, 512)],
["db_mobilenet_v3_large", (3, 512, 512)],
["linknet_resnet18", (3, 512, 512)],
["linknet_resnet34", (3, 512, 512)],
["linknet_resnet50", (3, 512, 512)],
["fast_tiny", (3, 512, 512)],
["fast_small", (3, 512, 512)],
["fast_base", (3, 512, 512)],
["fast_tiny_rep", (3, 512, 512)], # Reparameterized model
],
)
def test_models_pytorch_compile(arch_name, input_shape, fullgraph):
# General Check that the model can be compiled
try:
assert torch.compile(detection.__dict__[arch_name](pretrained=True).eval(), fullgraph=fullgraph)
except:

Check warning on line 212 in tests/pytorch/test_models_detection_pt.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

tests/pytorch/test_models_detection_pt.py#L212

No exception type(s) specified
pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nModel is failing pytorch compilation")
# Model
batch_size = 2
if arch_name == "fast_tiny_rep":
model = reparameterize(detection.fast_tiny(pretrained=True, exportable=True).eval())
else:
model = detection.__dict__[arch_name](pretrained=True, exportable=True).eval()
dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32)
pt_logits = model(dummy_input)["logits"].detach().cpu().numpy()

compiled_model = torch.compile(model, fullgraph=fullgraph)
pt_logits_compiled = compiled_model(dummy_input)["logits"].detach().cpu().numpy()

assert pt_logits_compiled.shape == pt_logits.shape
# Check that the output is close to the "original" output
try:
assert np.allclose(pt_logits, pt_logits_compiled, atol=1e-4)
except AssertionError:
pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nMax element-wise difference: {np.max(np.abs(pt_logits - pt_logits_compiled))}")
Loading