Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ORTModule GraphTransitionManager (microsoft#19007)
### Problem Currently, the codebase contains some logics pertaining to model re-export checks and graph_builder reinitialization checks. Ideally, these operations should function akin to a state machine. However, upon inspecting the implementation, it becomes apparent that certain states are checked or set in various scattered locations. This fragmentation makes it challenging to comprehend when a re-export or re-initialization will be triggered. For optimal clarity and maintainability, it is advisable to consolidate these states into a cohesive component, rather than dispersing them within the current graph execution manager. Furthermore, the process of model exports and post-export processing for stage 3 support or memory-efficient gradient management introduces considerable complexity. To enhance the codebase's structure, it would be beneficial to extract these intricate functionalities into a dedicated component, divorcing them from the current graph execution manager. As part of the effort to improve the codebase, it's essential to address inconsistencies in handling input/output flatten/unflatten operations. Currently, there are several functions performing these operations recursively, each with slightly different implementations. This inconsistency leads to varying support for input/output data types and structures in different parts of the code. To rectify this, the proposed pull request simplifies these operations into a set of primitive functions, ensuring uniformity. This not only streamlines the code but also facilitates the maintenance of consistency when introducing bug fixes or supporting new data types. One thing to mention here: input output handling is deeply bound to the graph transition mentioned above, so it is difficult to make this change separately. While acknowledging the complexity of these logics, it is reassuring that the codebase benefits from an extensive suite of unit tests that cover all possible branches. Despite the intricacies, ensuring the passage of all tests has been a time-intensive but necessary aspect of this development effort. ### Design Introduce `GraphTransitionManager` and put all model export and post-export processing logics in it. 1. Re-export check 2. Do export 3. Re-post-export process check 4. Do post-export process 5. Return `PostExportProcessedModelInfo`, which contains all the information we need, to pass to ORT to build gradient graph (currently we do the same for training or evaluating, but ideally we should not do it for evaluating, let's keep this behavior as it is now, and make the change later). ``` # Input names for the pre-gradient-build graph. # This may be different with the one in ExportedGraph since we may modify the graph inputs as needed # for example when memory efficient gradient management is enabled. self.onnx_graph_input_names: list[str] = onnx_graph_input_names # A subset of onnx_graph_input_names. # Input names that require gradients for the pre-gradient-build graph. self.onnx_graph_input_names_require_grad: list[str] = onnx_graph_input_names_require_grad # Create symbolic names for each dimension of the graph input (e.g. onnx_graph_input_names). # The key is the input name, the value is a dict of {dim_index: symbolic_dim_name} # e.g. {"input1": {0: "input1_dim0", 1: "input1_dim1"}, "input2": {0: "input2_dim0"}} self.onnx_graph_input_dynamic_axes_map: dict[str, dict[int, str]] = onnx_graph_input_dynamic_axes_map self.buffer_for_ort_runs: dict[str, torch.Tensor] = OrderedDict() self.onnx_graph_input_names_user_defined = ( onnx_graph_input_names_user_defined # The ONNX graph input names excluding the parameters, buffers. ) # The ONNX graph input names excluding the parameters, buffers. self.onnx_graph_input_names_require_grad_user_defined = onnx_graph_input_names_require_grad_user_defined self._post_export_processed_model: onnx.ModelProto | None = post_export_processed_model # A function to access the input data from the args and kwargs. # If it is not None, the length is same as onnx_graph_input_names. # For i-th input name, we can use the i-th function to get the input data from args and kwargs. self.data_accessor: list[callable] | None = data_accessor # Used for unflattening the outputs from the ORT forward run. self.module_forward_output_schema: ORTModelInputOutputSchemaType | None = module_forward_output_schema``` The `GraphTransitionManager` instance is a property of `GraphExecutionManager` (e.g. `TrainingManager` or ``InferenceManager), 1. Use 'self._graph_transition_manager.use_cache_or_reconstruct_post_processed_model(inputs, kwargs)' to check whether the PyTorch module need a re-export or re-post-export-process. 2. Use `self._graph_transition_manager._post_export_processed_model_info.construct_inputs` to construct the list of inputs used for ORT runs. 3. Use `self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs)` to restore the outputs in original PyTorch output structure. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
- Loading branch information