Skip to content

Commit

Permalink
ORTModule GraphTransitionManager (microsoft#19007)
Browse files Browse the repository at this point in the history
### Problem

Currently, the codebase contains some logics pertaining to model
re-export checks and graph_builder reinitialization checks. Ideally,
these operations should function akin to a state machine. However, upon
inspecting the implementation, it becomes apparent that certain states
are checked or set in various scattered locations. This fragmentation
makes it challenging to comprehend when a re-export or re-initialization
will be triggered. For optimal clarity and maintainability, it is
advisable to consolidate these states into a cohesive component, rather
than dispersing them within the current graph execution manager.

Furthermore, the process of model exports and post-export processing for
stage 3 support or memory-efficient gradient management introduces
considerable complexity. To enhance the codebase's structure, it would
be beneficial to extract these intricate functionalities into a
dedicated component, divorcing them from the current graph execution
manager.

As part of the effort to improve the codebase, it's essential to address
inconsistencies in handling input/output flatten/unflatten operations.
Currently, there are several functions performing these operations
recursively, each with slightly different implementations. This
inconsistency leads to varying support for input/output data types and
structures in different parts of the code. To rectify this, the proposed
pull request simplifies these operations into a set of primitive
functions, ensuring uniformity. This not only streamlines the code but
also facilitates the maintenance of consistency when introducing bug
fixes or supporting new data types. One thing to mention here: input
output handling is deeply bound to the graph transition mentioned above,
so it is difficult to make this change separately.

While acknowledging the complexity of these logics, it is reassuring
that the codebase benefits from an extensive suite of unit tests that
cover all possible branches. Despite the intricacies, ensuring the
passage of all tests has been a time-intensive but necessary aspect of
this development effort.

### Design


Introduce `GraphTransitionManager` and put all model export and
post-export processing logics in it.
1. Re-export check
2. Do export
3. Re-post-export process check
4. Do post-export process
5. Return `PostExportProcessedModelInfo`, which contains all the
information we need, to pass to ORT to build gradient graph (currently
we do the same for training or evaluating, but ideally we should not do
it for evaluating, let's keep this behavior as it is now, and make the
change later).
    ```
          # Input names for the pre-gradient-build graph.
# This may be different with the one in ExportedGraph since we may
modify the graph inputs as needed
# for example when memory efficient gradient management is enabled.
self.onnx_graph_input_names: list[str] = onnx_graph_input_names
  
          # A subset of onnx_graph_input_names.
# Input names that require gradients for the pre-gradient-build graph.
self.onnx_graph_input_names_require_grad: list[str] =
onnx_graph_input_names_require_grad
  
# Create symbolic names for each dimension of the graph input (e.g.
onnx_graph_input_names).
# The key is the input name, the value is a dict of {dim_index:
symbolic_dim_name}
# e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0:
"input2_dim0"}}
self.onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]] =
onnx_graph_input_dynamic_axes_map
  
self.buffer_for_ort_runs: dict[str, torch.Tensor] = OrderedDict()
          self.onnx_graph_input_names_user_defined = (
onnx_graph_input_names_user_defined # The ONNX graph input names
excluding the parameters, buffers.
          )
  
# The ONNX graph input names excluding the parameters, buffers.
self.onnx_graph_input_names_require_grad_user_defined =
onnx_graph_input_names_require_grad_user_defined
  
self._post_export_processed_model: onnx.ModelProto | None =
post_export_processed_model
  
# A function to access the input data from the args and kwargs.
# If it is not None, the length is same as onnx_graph_input_names.
# For i-th input name, we can use the i-th function to get the input
data from args and kwargs.
          self.data_accessor: list[callable] | None = data_accessor
  
          # Used for unflattening the outputs from the ORT forward run.
self.module_forward_output_schema: ORTModelInputOutputSchemaType | None
= module_forward_output_schema```




The `GraphTransitionManager` instance is a property of
`GraphExecutionManager` (e.g. `TrainingManager` or ``InferenceManager),
1. Use
'self._graph_transition_manager.use_cache_or_reconstruct_post_processed_model(inputs,
kwargs)' to check whether the PyTorch module need a re-export or
re-post-export-process.
2. Use
`self._graph_transition_manager._post_export_processed_model_info.construct_inputs`
to construct the list of inputs used for ORT runs.
3. Use
`self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs)`
to restore the outputs in original PyTorch output structure.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored Jul 3, 2024
1 parent 116398c commit 4932e04
Show file tree
Hide file tree
Showing 16 changed files with 1,642 additions and 1,071 deletions.

Large diffs are not rendered by default.

1,058 changes: 1,058 additions & 0 deletions orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@

from onnxruntime.capi import _pybind_state as C

from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_algorithms, _utils
from . import _are_deterministic_algorithms_enabled, _use_deterministic_algorithms, _utils
from ._execution_agent import InferenceAgent
from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo
from ._io import unflatten_user_output
from ._logger import ORTModuleInitPhase, TrackTime
from ._utils import save_tuning_results, set_tuning_results
from .options import DebugOptions, _SkipCheck
Expand Down Expand Up @@ -109,15 +108,19 @@ def forward(self, *inputs, **kwargs):
build_graph = False
if (
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
or not self._onnx_models.exported_model
or not self._graph_transition_manager._exported_model_info
):
self.time_tracker.start(ORTModuleInitPhase.EndToEnd)

# Exporting module to ONNX for the first time
build_graph = self._export_model(*inputs, **kwargs)

(
build_graph,
post_export_processed_model_info,
) = self._graph_transition_manager.get_post_processed_model(inputs, kwargs)
if build_graph:
# If model was exported, then initialize the graph builder.
self._initialize_graph_builder()
# TODO(): do we need call it for inferencing mode???
self._initialize_graph_builder(post_export_processed_model_info)

# Build the inference graph
if build_graph:
Expand All @@ -134,7 +137,7 @@ def forward(self, *inputs, **kwargs):
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False
or not self._execution_agent
):
module_device = _utils.get_device_from_module(self._original_module)
module_device = _utils.get_device_from_module_and_inputs(self._original_module, inputs, kwargs)

create_execution_session = (
build_graph
Expand All @@ -144,7 +147,7 @@ def forward(self, *inputs, **kwargs):
_use_deterministic_algorithms(torch.are_deterministic_algorithms_enabled())

if self._device != module_device:
self._device = module_device
self._graph_transition_manager._device = module_device

if create_execution_session:
# Create execution session creates the inference_session
Expand All @@ -160,23 +163,15 @@ def forward(self, *inputs, **kwargs):
if self._runtime_options.enable_zero_stage3_support:
self._append_pull_weight_trigger_as_input(kwargs, self._device)

prepared_input_list = _io._combine_input_buffers_initializers(
self._graph_initializers,
self._graph_info.user_input_names,
self._input_info,
self._flattened_module.named_buffers(),
inputs,
kwargs,
self._device,
self._runtime_inspector,
self._zero_stage3_param_map,
prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs(
inputs, kwargs, True, self._device
)

user_outputs, _ = InferenceManager.execution_session_run_forward(
self._execution_agent,
self._onnx_models.optimized_model,
self._device,
*prepared_input_list,
*prepared_input_map.values(),
)

if (
Expand All @@ -188,7 +183,8 @@ def forward(self, *inputs, **kwargs):
self._execution_agent._inference_session, False, self._runtime_options.tuning_results_path
)

return unflatten_user_output(self._module_output_schema, user_outputs)
return self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs)

except ORTModuleFallbackException as e:
# Exceptions subject to fallback are handled here
self._fallback_manager.handle_exception(exception=e, log_level=self._debug_options.logging.log_level)
Expand Down
Loading

0 comments on commit 4932e04

Please sign in to comment.