-
Notifications
You must be signed in to change notification settings - Fork 445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch.compile support for pytorch 2.4 #1690
base: main
Are you sure you want to change the base?
Changes from all commits
1bc21f8
479c897
f8a5ee6
60ca8eb
22dd3d9
38e5a2d
5822d4a
0ded837
f7ae74d
1e6869b
9850455
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,15 +14,14 @@ | |
CLASS_NAME: str = "words" | ||
|
||
|
||
__all__ = ["is_tf_available", "is_torch_available", "requires_package", "CLASS_NAME"] | ||
__all__ = ["is_tf_available", "is_torch_available", "does_torch_have_compile_capability", "is_pytorch_backend_available", "requires_package", "CLASS_NAME"] | ||
|
||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} | ||
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) | ||
|
||
USE_TF = os.environ.get("USE_TF", "AUTO").upper() | ||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() | ||
|
||
|
||
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: | ||
_torch_available = importlib.util.find_spec("torch") is not None | ||
if _torch_available: | ||
|
@@ -76,6 +75,18 @@ | |
" is installed and that either USE_TF or USE_TORCH is enabled." | ||
) | ||
|
||
if _torch_available: | ||
import torch | ||
_torch_has_compile = hasattr(torch, "compile") | ||
_torch_has_backend = False | ||
|
||
if _torch_has_compile and hasattr(torch.library, 'custom_op'): | ||
from torch.utils._triton import has_triton | ||
_torch_has_backend = has_triton() | ||
else: | ||
_torch_has_compile = False | ||
_torch_has_backend = False | ||
|
||
|
||
def requires_package(name: str, extra_message: Optional[str] = None) -> None: # pragma: no cover | ||
""" | ||
|
@@ -104,3 +115,11 @@ def is_torch_available(): | |
def is_tf_available(): | ||
"""Whether TensorFlow is installed.""" | ||
return _tf_available | ||
|
||
def does_torch_have_compile_capability(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove :) |
||
"""Whether Pytorch has compile support.""" | ||
return _torch_has_compile | ||
|
||
def is_pytorch_backend_available(): | ||
"""Whether Triton is installed.""" | ||
return _torch_has_backend |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can be reverted complete |
Fabioomega marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,4 @@ | |
if is_tf_available(): | ||
from .tensorflow import * | ||
else: | ||
from .pytorch import * | ||
from .pytorch import * |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert :) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert :) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert :) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert :) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -180,7 +180,7 @@ | |
# Borrowed from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py | ||
# with small modifications | ||
|
||
max_num_chars = int(seqlen.max().item()) # get longest sequence length in batch | ||
max_num_chars = int(seqlen.max().numpy().item()) # get longest sequence length in batch | ||
perms = [torch.arange(max_num_chars, device=seqlen.device)] | ||
|
||
max_perms = math.factorial(max_num_chars) // 2 | ||
|
@@ -266,7 +266,8 @@ | |
).int() | ||
|
||
pos_logits = [] | ||
for i in range(max_length): | ||
i = 0 | ||
while i < max_length: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remember there was a issue with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed it because it was some unecessary complication related to breaks in torch.compile. Changing to a while loop and changing the logic a bit helped. Hopefully it works for the onnx also |
||
# Decode one token at a time without providing information about the future tokens | ||
tgt_out = self.decode( | ||
ys[:, : i + 1], | ||
|
@@ -283,8 +284,9 @@ | |
|
||
# Stop decoding if all sequences have reached the EOS token | ||
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export | ||
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all(): | ||
break | ||
i += (not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all())*max_length | ||
|
||
i += 1 | ||
|
||
logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1) | ||
|
||
|
@@ -322,7 +324,7 @@ | |
# Build target tensor | ||
_gt, _seq_len = self.build_target(target) | ||
gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long).to(x.device), torch.tensor(_seq_len).to(x.device) | ||
gt = gt[:, : int(seq_len.max().item()) + 2] # slice up to the max length of the batch + 2 (SOS + EOS) | ||
gt = gt[:, : int(seq_len.max().numpy().item()) + 2] # slice up to the max length of the batch + 2 (SOS + EOS) | ||
|
||
if self.training: | ||
# Generate permutations for the target sequences | ||
|
@@ -338,7 +340,7 @@ | |
|
||
loss = torch.tensor(0.0, device=features.device) | ||
loss_numel: Union[int, float] = 0 | ||
n = (gt_out != self.vocab_size + 2).sum().item() | ||
n = (gt_out != self.vocab_size + 2).sum().numpy().item() | ||
for i, perm in enumerate(tgt_perms): | ||
_, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len) | ||
# combine both masks | ||
|
@@ -351,7 +353,7 @@ | |
# remove the [EOS] tokens for the succeeding perms | ||
if i == 1: | ||
gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out) | ||
n = (gt_out != self.vocab_size + 2).sum().item() | ||
n = (gt_out != self.vocab_size + 2).sum().numpy().item() | ||
|
||
loss /= loss_numel | ||
|
||
|
@@ -406,7 +408,7 @@ | |
] | ||
# compute probabilties for each word up to the EOS token | ||
probs = [ | ||
preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values) | ||
preds_prob[i, : len(word)].clip(0, 1).mean().tolist() if word else 0.0 for i, word in enumerate(word_values) | ||
] | ||
|
||
return list(zip(word_values, probs)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert :) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hame some questions about that! Wasn't the original ideia to add a new argument to enable compilation? Did I misunderstood? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That was the first thought as your code looked like changes to the pipeline/models were needed. However, we then saw that these were not needed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A full sample would look like then for example:
The only required change here would be to allow also: doctr/doctr/models/recognition/zoo.py Line 39 in 9045dcf
doctr/doctr/models/detection/zoo.py Line 59 in 9045dcf
doctr/doctr/models/classification/zoo.py Line 45 in 9045dcf
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
import pytest | ||
import torch | ||
|
||
from doctr.file_utils import CLASS_NAME | ||
from doctr.file_utils import CLASS_NAME, is_pytorch_backend_available, does_torch_have_compile_capability | ||
from doctr.models import detection | ||
from doctr.models.detection._utils import dilate, erode | ||
from doctr.models.detection.fast.pytorch import reparameterize | ||
|
@@ -186,3 +186,46 @@ | |
assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) | ||
except AssertionError: | ||
pytest.skip(f"Output of {arch_name}:\nMax element-wise difference: {np.max(np.abs(pt_logits - ort_outs[0]))}") | ||
|
||
@pytest.mark.skipif(not does_torch_have_compile_capability(), reason="requires pytorch >= 2.0.0") | ||
@pytest.mark.skipif(not is_pytorch_backend_available(), reason="requires pytorch backend to be available") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the first two |
||
@pytest.mark.parametrize("fullgraph", [True, False]) | ||
@pytest.mark.parametrize( | ||
"arch_name, input_shape", | ||
[ | ||
["db_resnet34", (3, 512, 512)], | ||
["db_resnet50", (3, 512, 512)], | ||
["db_mobilenet_v3_large", (3, 512, 512)], | ||
["linknet_resnet18", (3, 512, 512)], | ||
["linknet_resnet34", (3, 512, 512)], | ||
["linknet_resnet50", (3, 512, 512)], | ||
["fast_tiny", (3, 512, 512)], | ||
["fast_small", (3, 512, 512)], | ||
["fast_base", (3, 512, 512)], | ||
["fast_tiny_rep", (3, 512, 512)], # Reparameterized model | ||
], | ||
) | ||
def test_models_pytorch_compile(arch_name, input_shape, fullgraph): | ||
# General Check that the model can be compiled | ||
try: | ||
assert torch.compile(detection.__dict__[arch_name](pretrained=True).eval(), fullgraph=fullgraph) | ||
except: | ||
pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nModel is failing pytorch compilation") | ||
# Model | ||
batch_size = 2 | ||
if arch_name == "fast_tiny_rep": | ||
model = reparameterize(detection.fast_tiny(pretrained=True, exportable=True).eval()) | ||
else: | ||
model = detection.__dict__[arch_name](pretrained=True, exportable=True).eval() | ||
dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) | ||
pt_logits = model(dummy_input)["logits"].detach().cpu().numpy() | ||
|
||
compiled_model = torch.compile(model, fullgraph=fullgraph) | ||
pt_logits_compiled = compiled_model(dummy_input)["logits"].detach().cpu().numpy() | ||
|
||
assert pt_logits_compiled.shape == pt_logits.shape | ||
# Check that the output is close to the "original" output | ||
try: | ||
assert np.allclose(pt_logits, pt_logits_compiled, atol=1e-4) | ||
except AssertionError: | ||
pytest.skip(f"Output of {arch_name}:\n-fullgraph: {fullgraph}\nMax element-wise difference: {np.max(np.abs(pt_logits - pt_logits_compiled))}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Fabioomega We can remove this.
2 options:
We pin the lower boundary to >= 2.0.0 here
doctr/pyproject.toml
Line 61 in 9045dcf
doctr/pyproject.toml
Line 105 in 9045dcf
and
torchvision>=0.15.0
or we mention in the docs that this requires >= 2.0.0 for compile and >=2.4.0 for compile + fullgraph
@odulcy-mindee wdyt ?
We are already at 2.4.0 so i would prefer the >=2.0.0 pin (in this case only to mention >=2.4.0 for fullgraph (triton) support)