Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Update API to torch.onnx.export(..., dynamo=True) #3223

Open
wants to merge 7 commits into
base: 2.6-RC-TEST
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed _static/img/onnx/custom_addandround.png
Binary file not shown.
Binary file removed _static/img/onnx/custom_aten_gelu_model.png
Binary file not shown.
6 changes: 5 additions & 1 deletion beginner_source/onnx/README.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,9 @@ ONNX
https://pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html

3. onnx_registry_tutorial.py
Extending the ONNX Registry
Extending the ONNX exporter operator support
https://pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html

4. export_control_flow_model_to_onnx_tutorial.py
Export a model with control flow to ONNX
https://pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html
171 changes: 171 additions & 0 deletions beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
"""
`Introduction to ONNX <intro_onnx.html>`_ ||
`Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_ ||
`Extending the ONNX exporter operator support <onnx_registry_tutorial.html>`_ ||
**`Export a model with control flow to ONNX**

Export a model with control flow to ONNX
==========================================

**Author**: `Xavier Dupré <https://github.com/xadupre>`_.
"""


###############################################################################
# Overview
# --------
#
# This tutorial demonstrates how to handle control flow logic while exporting
# a PyTorch model to ONNX. It highlights the challenges of exporting
# conditional statements directly and provides solutions to circumvent them.
#
# Conditional logic cannot be exported into ONNX unless they refactored
# to use :func:`torch.cond`. Let's start with a simple model
# implementing a test.

import torch

###############################################################################
# Define the Models
# --------
#
# Two models are defined:
#
# ForwardWithControlFlowTest: A model with a forward method containing an
# if-else conditional.
#
# ModelWithControlFlowTest: A model that incorporates ForwardWithControlFlowTest
# as part of a simple multi-layer perceptron (MLP). The models are tested with
# a random input tensor to confirm they execute as expected.

class ForwardWithControlFlowTest(torch.nn.Module):
def forward(self, x):
if x.sum():
return x * 2
return -x


class ModelWithControlFlowTest(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(3, 2),
torch.nn.Linear(2, 1),
ForwardWithControlFlowTest(),
)

def forward(self, x):
out = self.mlp(x)
return out


model = ModelWithControlFlowTest()


###############################################################################
# Exporting the Model: First Attempt
# --------
#
# Exporting this model using torch.export.export fails because the control
# flow logic in the forward pass creates a graph break that the exporter cannot
# handle. This behavior is expected, as conditional logic not written using
# torch.cond is unsupported.
#
# A try-except block is used to capture the expected failure during the export
# process. If the export unexpectedly succeeds, an AssertionError is raised.

x = torch.randn(3)
model(x)

try:
torch.export.export(model, (x,), strict=False)
raise AssertionError("This export should failed unless PyTorch now supports this model.")
except Exception as e:
print(e)

###############################################################################
# Using torch.onnx.export with JIT Tracing
# --------
#
# When exporting the model using torch.onnx.export with the dynamo=True
# argument, the exporter defaults to using JIT tracing. This fallback allows
# the model to export, but the resulting ONNX graph may not faithfully represent
# the original model logic due to the limitations of tracing.


onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
Copy link

@justinchuby justinchuby Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<
    ir_version=10,
    opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18},
    producer_name='pytorch',
    producer_version='2.6.0',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"input_1"<FLOAT,[3]>
    ),
    outputs=(
        %"mul"<FLOAT,[1]>
    ),
    initializers=(
        %"model.mlp.0.bias"<FLOAT,[2]>,
        %"model.mlp.0.weight"<FLOAT,[2,3]>,
        %"model.mlp.1.bias"<FLOAT,[1]>,
        %"model.mlp.1.weight"<FLOAT,[1,2]>
    ),
) {
    0 |  # node_Transpose_0
         %"val_0"<?,?> ⬅️ ::Transpose(%"model.mlp.0.weight") {perm=[1, 0]}
    1 |  # node_MatMul_1
         %"val_1"<?,?> ⬅️ ::MatMul(%"input_1", %"val_0")
    2 |  # node_Add_2
         %"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"model.mlp.0.bias")
    3 |  # node_Transpose_3
         %"val_2"<?,?> ⬅️ ::Transpose(%"model.mlp.1.weight") {perm=[1, 0]}
    4 |  # node_MatMul_4
         %"val_3"<?,?> ⬅️ ::MatMul(%"linear", %"val_2")
    5 |  # node_Add_5
         %"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"model.mlp.1.bias")
    6 |  # node_Constant_6
         %"val_4"<?,?> ⬅️ ::Constant() {value=Tensor<INT64,[]>(array(2), name=None)}
    7 |  # node_Cast_7
         %"scalar_tensor_default"<FLOAT,[]> ⬅️ ::Cast(%"val_4") {to=FLOAT}
    8 |  # node_Mul_8
         %"mul"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default")
    return %"mul"<FLOAT,[1]>
}



###############################################################################
# Suggested Patch: Refactoring with torch.cond
# --------
#
# To make the control flow exportable, the tutorial demonstrates replacing the
# forward method in ForwardWithControlFlowTest with a refactored version that
# uses torch.cond.
#
# Details of the Refactoring:
#
# Two helper functions (identity2 and neg) represent the branches of the conditional logic:
# * torch.cond is used to specify the condition and the two branches along with the input arguments.
# * The updated forward method is then dynamically assigned to the ForwardWithControlFlowTest instance within the model. A list of submodules is printed to confirm the replacement.

def new_forward(x):
def identity2(x):
return x * 2

def neg(x):
return -x

return torch.cond(x.sum() > 0, identity2, neg, (x,))


print("the list of submodules")
for name, mod in model.named_modules():
print(name, type(mod))
if isinstance(mod, ForwardWithControlFlowTest):
mod.forward = new_forward

###############################################################################
# Let's see what the fx graph looks like.

print(torch.export.export(model, (x,), strict=False))
Copy link

@justinchuby justinchuby Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_mlp_0_weight: "f32[2, 3]", p_mlp_0_bias: "f32[2]", p_mlp_1_weight: "f32[1, 2]", p_mlp_1_bias: "f32[1]", x: "f32[3]"):
            linear: "f32[2]" = torch.ops.aten.linear.default(x, p_mlp_0_weight, p_mlp_0_bias);  x = p_mlp_0_weight = p_mlp_0_bias = None
            linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, p_mlp_1_weight, p_mlp_1_bias);  linear = p_mlp_1_weight = p_mlp_1_bias = None
            
            sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1)
            gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None
            
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [linear_1]);  gt = true_graph_0 = false_graph_0 = linear_1 = None
            getitem: "f32[1]" = cond[0];  cond = None
            return (getitem,)
            
        class true_graph_0(torch.nn.Module):
            def forward(self, linear_1: "f32[1]"):
                mul: "f32[1]" = torch.ops.aten.mul.Tensor(linear_1, 2);  linear_1 = None
                return (mul,)
                
        class false_graph_0(torch.nn.Module):
            def forward(self, linear_1: "f32[1]"):
                neg: "f32[1]" = torch.ops.aten.neg.default(linear_1);  linear_1 = None
                return (neg,)


###############################################################################
# Let's export again.

onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<
    ir_version=10,
    opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18, 'pkg.torch.__subgraph__': 1},
    producer_name='pytorch',
    producer_version='2.6.0',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"x"<FLOAT,[3]>
    ),
    outputs=(
        %"getitem"<FLOAT,[1]>
    ),
    initializers=(
        %"mlp.0.weight"<FLOAT,[2,3]>,
        %"mlp.0.bias"<FLOAT,[2]>,
        %"mlp.1.weight"<FLOAT,[1,2]>,
        %"mlp.1.bias"<FLOAT,[1]>
    ),
) {
     0 |  # node_Transpose_0
          %"val_0"<?,?> ⬅️ ::Transpose(%"mlp.0.weight") {perm=[1, 0]}
     1 |  # node_MatMul_1
          %"val_1"<?,?> ⬅️ ::MatMul(%"x", %"val_0")
     2 |  # node_Add_2
          %"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias")
     3 |  # node_Transpose_3
          %"val_2"<?,?> ⬅️ ::Transpose(%"mlp.1.weight") {perm=[1, 0]}
     4 |  # node_MatMul_4
          %"val_3"<?,?> ⬅️ ::MatMul(%"linear", %"val_2")
     5 |  # node_Add_5
          %"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias")
     6 |  # node_ReduceSum_6
          %"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {noop_with_empty_axes=0, keepdims=False}
     7 |  # node_Constant_7
          %"val_4"<?,?> ⬅️ ::Constant() {value=Tensor<INT64,[]>(array(0), name=None)}
     8 |  # node_Cast_8
          %"scalar_tensor_default"<FLOAT,[]> ⬅️ ::Cast(%"val_4") {to=FLOAT}
     9 |  # node_Greater_9
          %"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default")
    10 |  # node_If_10
          %"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
              graph(
                  name=true_graph_0,
                  inputs=(

                  ),
                  outputs=(
                      %"mul_true_graph_0"<?,?>
                  ),
              ) {
                  0 |  # node_true_graph_0_0
                       %"mul_true_graph_0"<?,?> ⬅️ pkg.torch.__subgraph__::true_graph_0(%"linear_1")
                  return %"mul_true_graph_0"<?,?>
              }, else_branch=
              graph(
                  name=false_graph_0,
                  inputs=(

                  ),
                  outputs=(
                      %"neg_false_graph_0"<?,?>
                  ),
              ) {
                  0 |  # node_false_graph_0_0
                       %"neg_false_graph_0"<?,?> ⬅️ pkg.torch.__subgraph__::false_graph_0(%"linear_1")
                  return %"neg_false_graph_0"<?,?>
              }}
    return %"getitem"<FLOAT,[1]>
}

<
    opset_imports={'': 18},
>
def pkg.torch.__subgraph__::false_graph_0(
    inputs=(
        %"linear_1"<FLOAT,[1]>
    ),
    outputs=(
        %"neg"<FLOAT,[1]>
    ),
) {
    0 |  # node_Neg_0
         %"neg"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
    return %"neg"<FLOAT,[1]>
}

<
    opset_imports={'': 18},
>
def pkg.torch.__subgraph__::true_graph_0(
    inputs=(
        %"linear_1"<FLOAT,[1]>
    ),
    outputs=(
        %"mul"<FLOAT,[1]>
    ),
) {
    0 |  # node_Constant_0
         %"val_0"<?,?> ⬅️ ::Constant() {value=Tensor<INT64,[]>(array(2), name=None)}
    1 |  # node_Cast_1
         %"scalar_tensor_default"<FLOAT,[]> ⬅️ ::Cast(%"val_0") {to=FLOAT}
    2 |  # node_Mul_2
         %"mul"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default")
    return %"mul"<FLOAT,[1]>
}



Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should include the result of the print here I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by result of the print? More than model proto?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at other examples in the tutorials, output was not included. The generation should take care of that otherwise we would have to update them everytime pytorch is released. I guess they do something very similar to what sphinx-gallery does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should be printed in the page. The code will be executed.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Strange I was looking at the registry page and there was no print out

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

###############################################################################
# We can optimize the model and get rid of the model local functions created to capture the control flow branches.

onnx_program.optimize()
print(onnx_program.model)
Copy link

@justinchuby justinchuby Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<
    ir_version=10,
    opset_imports={'pkg.onnxscript.torch_lib.common': 1, '': 18, 'pkg.torch.__subgraph__': 1},
    producer_name='pytorch',
    producer_version='2.6.0',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"x"<FLOAT,[3]>
    ),
    outputs=(
        %"getitem"<FLOAT,[1]>
    ),
    initializers=(
        %"mlp.0.bias"<FLOAT,[2]>,
        %"mlp.1.bias"<FLOAT,[1]>
    ),
) {
     0 |  # node_Constant_11
          %"val_0"<FLOAT,[3,2]> ⬅️ ::Constant() {value=Tensor<FLOAT,[3,2]>(array([[ 0.32409453,  0.09968598],
                 [ 0.23967852, -0.04969374],
                 [-0.09462868,  0.34749857]], dtype=float32), name='val_0')}
     1 |  # node_MatMul_1
          %"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0")
     2 |  # node_Add_2
          %"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias")
     3 |  # node_Constant_12
          %"val_2"<FLOAT,[2,1]> ⬅️ ::Constant() {value=Tensor<FLOAT,[2,1]>(array([[0.19137527],
                 [0.29681835]], dtype=float32), name='val_2')}
     4 |  # node_MatMul_4
          %"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2")
     5 |  # node_Add_5
          %"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias")
     6 |  # node_ReduceSum_6
          %"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {noop_with_empty_axes=0, keepdims=False}
     7 |  # node_Constant_13
          %"scalar_tensor_default"<FLOAT,[]> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')}
     8 |  # node_Greater_9
          %"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default")
     9 |  # node_If_10
          %"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
              graph(
                  name=true_graph_0,
                  inputs=(

                  ),
                  outputs=(
                      %"mul_true_graph_0"<FLOAT,[1]>
                  ),
              ) {
                  0 |  # node_Constant_1
                       %"scalar_tensor_default_2"<FLOAT,[]> ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')}
                  1 |  # node_Mul_2
                       %"mul_true_graph_0"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2")
                  return %"mul_true_graph_0"<FLOAT,[1]>
              }, else_branch=
              graph(
                  name=false_graph_0,
                  inputs=(

                  ),
                  outputs=(
                      %"neg_false_graph_0"<FLOAT,[1]>
                  ),
              ) {
                  0 |  # node_Neg_0
                       %"neg_false_graph_0"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
                  return %"neg_false_graph_0"<FLOAT,[1]>
              }}
    return %"getitem"<FLOAT,[1]>
}


###############################################################################
# Conclusion
# --------
# This tutorial demonstrates the challenges of exporting models with conditional
# logic to ONNX and presents a practical solution using torch.cond.
# While the default exporters may fail or produce imperfect graphs, refactoring the
# model's logic ensures compatibility and generates a faithful ONNX representation.
#
# By understanding these techniques, we can overcome common pitfalls when
# working with control flow in PyTorch models and ensure smooth integration with ONNX workflows.
#
# Further reading
# ---------------
#
# The list below refers to tutorials that ranges from basic examples to advanced scenarios,
# not necessarily in the order they are listed.
# Feel free to jump directly to specific topics of your interest or
# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
#
# .. include:: /beginner_source/onnx/onnx_toc.txt
#
# .. toctree::
# :hidden:
22 changes: 16 additions & 6 deletions beginner_source/onnx/export_simple_model_to_onnx_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
"""
`Introduction to ONNX <intro_onnx.html>`_ ||
**Exporting a PyTorch model to ONNX** ||
`Extending the ONNX Registry <onnx_registry_tutorial.html>`_
`Extending the ONNX exporter operator support <onnx_registry_tutorial.html>`_ ||
`Export a model with control flow to ONNX <export_control_flow_model_to_onnx_tutorial.html>`_

Export a PyTorch model to ONNX
==============================

**Author**: `Ti-Tai Wang <https://github.com/titaiwangms>`_ and `Xavier Dupré <https://github.com/xadupre>`_
**Author**: `Ti-Tai Wang <https://github.com/titaiwangms>`_ and Thiago Crepaldi <https://github.com/thiagocrepaldi>`_.

.. note::
As of PyTorch 2.1, there are two versions of ONNX Exporter.
As of PyTorch 2.5, there are two versions of ONNX Exporter.

* ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
* ``torch.onnx.export(..., dynamo=True)`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
* ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0

"""
Expand All @@ -21,7 +22,7 @@
# In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_,
# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images.
# In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the
# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter.
# ONNX format using TorchDynamo and the ``torch.onnx.export(..., dynamo=True)`` ONNX exporter.
#
# While PyTorch is great for iterating on the development of models, the model can be deployed to production
# using different formats, including `ONNX <https://onnx.ai/>`_ (Open Neural Network Exchange)!
Expand Down Expand Up @@ -90,7 +91,16 @@ def forward(self, x):

torch_model = MyModel()
torch_input = torch.randn(1, 1, 32, 32)
onnx_program = torch.onnx.dynamo_export(torch_model, torch_input)
onnx_program = torch.onnx.export(torch_model, torch_input, dynamo=True)

######################################################################
# 3.5. (Optional) Optimize the ONNX model
# ---------------------------------------
#
# The ONNX model can be optimized with constant folding, and elimination of redundant nodes.
# The optimization is done in-place, so the original ONNX model is modified.

onnx_program.optimize()

######################################################################
# As we can see, we didn't need any code change to the model.
Expand Down
13 changes: 8 additions & 5 deletions beginner_source/onnx/intro_onnx.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""
**Introduction to ONNX** ||
`Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_ ||
`Extending the ONNX Registry <onnx_registry_tutorial.html>`_
`Extending the ONNX exporter operator support <onnx_registry_tutorial.html>`_ ||
`Export a model with control flow to ONNX <export_control_flow_model_to_onnx_tutorial.html>`_

Introduction to ONNX
====================

Authors:
`Ti-Tai Wang <https://github.com/titaiwangms>`_ and `Xavier Dupré <https://github.com/xadupre>`_
`Ti-Tai Wang <https://github.com/titaiwangms>`_ and Thiago Crepaldi <https://github.com/thiagocrepaldi>`_.

`Open Neural Network eXchange (ONNX) <https://onnx.ai/>`_ is an open standard
format for representing machine learning models. The ``torch.onnx`` module provides APIs to
Expand All @@ -19,8 +20,10 @@
including Microsoft's `ONNX Runtime <https://www.onnxruntime.ai>`_.

.. note::
Currently, there are two flavors of ONNX exporter APIs,
but this tutorial will focus on the ``torch.onnx.dynamo_export``.
Currently, the users can choose either through `TorchScript https://pytorch.org/docs/stable/jit.html`_ or

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you

`ExportedProgram https://pytorch.org/docs/stable/export.html`_ to export the model to ONNX by the
boolean parameter dynamo in `torch.onnx.export <https://pytorch.org/docs/stable/generated/torch.onnx.export.html>`_.
In this tutorial, we will focus on the ExportedProgram approach.

The TorchDynamo engine is leveraged to hook into Python's frame evaluation API and dynamically rewrite its
bytecode into an `FX graph <https://pytorch.org/docs/stable/fx.html>`_.
Expand All @@ -33,7 +36,7 @@
Dependencies
------------

PyTorch 2.1.0 or newer is required.
PyTorch 2.5.0 or newer is required.

The ONNX exporter depends on extra Python packages:

Expand Down
Loading