diff --git a/src/graph.rs b/src/graph.rs index 40f3ce6b..c3609e2a 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -901,6 +901,21 @@ impl Graph { .sum() } + /// Return the sequence of operators from the current graph that would be + /// executed in order to compute `outputs` given `inputs`, without actually + /// running the model. + /// + /// The result does not include nodes from any subgraphs that an operator + /// may run. + pub fn execution_plan( + &self, + inputs: &[NodeId], + outputs: &[NodeId], + ) -> Result, RunError> { + let plan = self.get_cached_plan(inputs, outputs, false /* is_subgraph */)?; + Ok(plan.plan().to_vec()) + } + /// Compute a set of output values given a set of inputs, using the /// processing steps and constant values defined by the graph. pub fn run( @@ -909,7 +924,8 @@ impl Graph { outputs: &[NodeId], opts: Option, ) -> Result, RunError> { - let plan = self.get_cached_plan(&inputs, outputs, false /* is_subgraph */)?; + let input_ids: Vec<_> = inputs.iter().map(|(node_id, _)| *node_id).collect(); + let plan = self.get_cached_plan(&input_ids, outputs, false /* is_subgraph */)?; threading::thread_pool().run(|| { self.run_plan( inputs, @@ -934,13 +950,14 @@ impl Graph { pool: Option<&TensorPool>, opts: Option, ) -> Result, RunError> { - let plan = self.get_cached_plan(&inputs, outputs, true /* is_subgraph */)?; + let input_ids: Vec<_> = inputs.iter().map(|(node_id, _)| *node_id).collect(); + let plan = self.get_cached_plan(&input_ids, outputs, true /* is_subgraph */)?; self.run_plan(inputs, plan.plan(), outputs, Some(captures), pool, opts) } fn get_cached_plan( &self, - inputs: &[(NodeId, InputOrOutput)], + inputs: &[NodeId], outputs: &[NodeId], is_subgraph: bool, ) -> Result, RunError> { @@ -950,9 +967,8 @@ impl Graph { // Note that we only hold the plan lock while creating the plan, // not while executing the model. let mut cached_plan = self.cached_plan.lock().unwrap(); - let input_ids: Vec<_> = inputs.iter().map(|(node_id, _)| *node_id).collect(); let plan = match cached_plan.as_ref() { - Some(plan) if plan.matches(&input_ids, outputs) => plan.clone(), + Some(plan) if plan.matches(inputs, outputs) => plan.clone(), _ => { let plan = self.create_plan( inputs, @@ -962,7 +978,7 @@ impl Graph { captures_available: is_subgraph, }, )?; - *cached_plan = Some(Arc::new(CachedPlan::new(&input_ids, outputs, plan))); + *cached_plan = Some(Arc::new(CachedPlan::new(inputs, outputs, plan))); cached_plan.clone().unwrap() } }; @@ -1326,15 +1342,15 @@ impl Graph { outputs: &[NodeId], opts: Option, ) -> Result, RunError> { + let input_ids: Vec<_> = inputs.iter().map(|(id, _)| id).copied().collect(); let plan = self.create_plan( - &inputs, + &input_ids, outputs, PlanOptions { allow_missing_inputs: true, captures_available: false, }, )?; - let input_ids: Vec<_> = inputs.iter().map(|(id, _)| id).copied().collect(); let (pruned_plan, pruned_plan_output_ids) = self.prune_plan(&plan, &input_ids, outputs); let outputs = threading::thread_pool().run(|| { self.run_plan( @@ -1476,7 +1492,7 @@ impl Graph { /// omitted from the plan. fn create_plan( &self, - inputs: &[(NodeId, InputOrOutput)], + inputs: &[NodeId], outputs: &[NodeId], options: PlanOptions, ) -> Result, RunError> { @@ -1484,7 +1500,7 @@ impl Graph { return Err(RunError::PlanningError("output IDs are not unique".into())); } - if !all_unique(inputs, |(x_id, _), (y_id, _)| x_id == y_id) { + if !all_unique(inputs, |x_id, y_id| x_id == y_id) { return Err(RunError::PlanningError("input IDs are not unique".into())); } @@ -1529,9 +1545,89 @@ impl Graph { Ok(()) } - /// Return a sequential plan to generate `outputs`. The plan is - /// a vec of `(op_node_id, operator)` tuples. + /// Take the current execution plan and re-order it for more + /// efficient execution. + fn sort_plan(self, mut resolved_values: FxHashSet) -> Vec { + // Build map of value node to operators that depend on the value. + let mut dependent_ops: FxHashMap> = + FxHashMap::default(); + for (op_node_id, op_node) in &self.plan { + for input_id in self.graph.operator_dependencies(op_node) { + if let Some(deps) = dependent_ops.get_mut(&input_id) { + deps.push((*op_node_id, op_node)); + } else { + dependent_ops.insert(input_id, [(*op_node_id, *op_node)].into()); + } + } + } + + let mut output_plan = Vec::with_capacity(self.plan.len()); + + // Initialize frontier with all operators that can be executed + // from initially-available values. + let mut frontier: Vec<(NodeId, &OperatorNode)> = Vec::new(); + for (op_node_id, op_node) in &self.plan { + if self + .graph + .operator_dependencies(op_node) + .all(|id| resolved_values.contains(&id)) + { + frontier.push((*op_node_id, op_node)); + } + } + + debug_assert!(!frontier.is_empty(), "initial frontier is empty"); + + // Loop while we still have operators to compute. + while !frontier.is_empty() { + // Choose an operator to execute next and add it to the plan. + // + // We run non-in-place operators first, so that operators + // which can run in-place are more likely to have their + // inputs available for in-place execution. + let op_pos = frontier + .iter() + .position(|(_id, op)| !op.operator().can_run_in_place()) + .unwrap_or(0); + let (next_op_id, op_node) = frontier.remove(op_pos); + output_plan.push(next_op_id); + + // Mark the operator's outputs as computed. + resolved_values.extend(op_node.output_ids().iter().filter_map(|id| *id)); + + // Visit operators that depend on current op outputs. Add + // to frontier set if all dependencies have been resolved. + for output_id in op_node.output_ids() { + let Some(output_id) = output_id else { + continue; + }; + let Some(deps) = dependent_ops.get(output_id) else { + continue; + }; + for (candidate_op_id, candidate_op) in deps { + if frontier.iter().any(|(op_id, _)| op_id == candidate_op_id) { + continue; + } + + if self + .graph + .operator_dependencies(candidate_op) + .all(|id| resolved_values.contains(&id)) + { + frontier.push((*candidate_op_id, candidate_op)); + } + } + } + } + + output_plan + } + + /// Return a sequential plan to generate `outputs`. fn plan(mut self, outputs: &[NodeId]) -> Result, RunError> { + let initial_resolved_values = self.resolved_values.clone(); + + // Build initial plan by traversing graph backwards from outputs. for output_id in outputs.iter() { if self.resolved_values.contains(output_id) { // Value is either a constant node or is produced by @@ -1546,15 +1642,25 @@ impl Graph { return Err(RunError::PlanningError(msg)); } } - Ok(self.plan.into_iter().map(|(node_id, _)| node_id).collect()) + + // When doing partial evaluation, just return the initial plan. + // This avoids having to handle missing inputs when sorting the + // plan. + if self.options.allow_missing_inputs || self.plan.is_empty() { + return Ok(self.plan.into_iter().map(|(op_id, _)| op_id).collect()); + } + + // Re-order initial plan to get a more efficient execution + // order. + let sorted_plan = self.sort_plan(initial_resolved_values); + + Ok(sorted_plan) } } // Set of values that are available after executing the plan - let resolved_values: FxHashSet = self.init_resolved_values( - inputs.iter().map(|(node_id, _)| *node_id), - options.captures_available, - ); + let resolved_values: FxHashSet = + self.init_resolved_values(inputs.iter().copied(), options.captures_available); let builder = PlanBuilder { graph: self, @@ -1837,6 +1943,31 @@ mod tests { Ok(()) } + #[test] + fn test_runs_non_in_place_ops_first() -> Result<(), Box> { + let mut g = Graph::new(); + + let input_a_id = g.add_value(Some("input_a"), None); + let input_b_id = g.add_value(Some("input_b"), None); + + let (add_op, add_out) = g.add_simple_op("add", Add {}, &[input_a_id, input_b_id]); + let (shape_op, shape_out) = g.add_simple_op("shape", Shape {}, &[input_a_id]); + + // The execution plan could run operators in either order and produce + // the correct output. Since the `Add` op has the _potential_ to run in + // place (if the input is passed as an owned value) and the `Shape` op + // does not, the Shape op should be run first. + let plan = g.execution_plan(&[input_a_id, input_b_id], &[add_out, shape_out])?; + assert_eq!(plan, &[shape_op, add_op]); + + // Make sure the results are the same if the order of outputs is + // swapped. + let plan = g.execution_plan(&[input_a_id, input_b_id], &[shape_out, add_out])?; + assert_eq!(plan, &[shape_op, add_op]); + + Ok(()) + } + // Perform a graph run where one of the outputs is also an input for other // steps of the run. #[test]