Skip to content

Commit

Permalink
Merge pull request #405 from robertknight/sort-plan
Browse files Browse the repository at this point in the history
Re-order execution plan to enable more operations to run in-place
  • Loading branch information
robertknight authored Nov 13, 2024
2 parents 2faf5e3 + 937d3d0 commit 24da48c
Showing 1 changed file with 148 additions and 17 deletions.
165 changes: 148 additions & 17 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,21 @@ impl Graph {
.sum()
}

/// Return the sequence of operators from the current graph that would be
/// executed in order to compute `outputs` given `inputs`, without actually
/// running the model.
///
/// The result does not include nodes from any subgraphs that an operator
/// may run.
pub fn execution_plan(
&self,
inputs: &[NodeId],
outputs: &[NodeId],
) -> Result<Vec<NodeId>, RunError> {
let plan = self.get_cached_plan(inputs, outputs, false /* is_subgraph */)?;
Ok(plan.plan().to_vec())
}

/// Compute a set of output values given a set of inputs, using the
/// processing steps and constant values defined by the graph.
pub fn run(
Expand All @@ -909,7 +924,8 @@ impl Graph {
outputs: &[NodeId],
opts: Option<RunOptions>,
) -> Result<Vec<Output>, RunError> {
let plan = self.get_cached_plan(&inputs, outputs, false /* is_subgraph */)?;
let input_ids: Vec<_> = inputs.iter().map(|(node_id, _)| *node_id).collect();
let plan = self.get_cached_plan(&input_ids, outputs, false /* is_subgraph */)?;
threading::thread_pool().run(|| {
self.run_plan(
inputs,
Expand All @@ -934,13 +950,14 @@ impl Graph {
pool: Option<&TensorPool>,
opts: Option<RunOptions>,
) -> Result<Vec<Output>, RunError> {
let plan = self.get_cached_plan(&inputs, outputs, true /* is_subgraph */)?;
let input_ids: Vec<_> = inputs.iter().map(|(node_id, _)| *node_id).collect();
let plan = self.get_cached_plan(&input_ids, outputs, true /* is_subgraph */)?;
self.run_plan(inputs, plan.plan(), outputs, Some(captures), pool, opts)
}

fn get_cached_plan(
&self,
inputs: &[(NodeId, InputOrOutput)],
inputs: &[NodeId],
outputs: &[NodeId],
is_subgraph: bool,
) -> Result<Arc<CachedPlan>, RunError> {
Expand All @@ -950,9 +967,8 @@ impl Graph {
// Note that we only hold the plan lock while creating the plan,
// not while executing the model.
let mut cached_plan = self.cached_plan.lock().unwrap();
let input_ids: Vec<_> = inputs.iter().map(|(node_id, _)| *node_id).collect();
let plan = match cached_plan.as_ref() {
Some(plan) if plan.matches(&input_ids, outputs) => plan.clone(),
Some(plan) if plan.matches(inputs, outputs) => plan.clone(),
_ => {
let plan = self.create_plan(
inputs,
Expand All @@ -962,7 +978,7 @@ impl Graph {
captures_available: is_subgraph,
},
)?;
*cached_plan = Some(Arc::new(CachedPlan::new(&input_ids, outputs, plan)));
*cached_plan = Some(Arc::new(CachedPlan::new(inputs, outputs, plan)));
cached_plan.clone().unwrap()
}
};
Expand Down Expand Up @@ -1326,15 +1342,15 @@ impl Graph {
outputs: &[NodeId],
opts: Option<RunOptions>,
) -> Result<Vec<(NodeId, Output)>, RunError> {
let input_ids: Vec<_> = inputs.iter().map(|(id, _)| id).copied().collect();
let plan = self.create_plan(
&inputs,
&input_ids,
outputs,
PlanOptions {
allow_missing_inputs: true,
captures_available: false,
},
)?;
let input_ids: Vec<_> = inputs.iter().map(|(id, _)| id).copied().collect();
let (pruned_plan, pruned_plan_output_ids) = self.prune_plan(&plan, &input_ids, outputs);
let outputs = threading::thread_pool().run(|| {
self.run_plan(
Expand Down Expand Up @@ -1476,15 +1492,15 @@ impl Graph {
/// omitted from the plan.
fn create_plan(
&self,
inputs: &[(NodeId, InputOrOutput)],
inputs: &[NodeId],
outputs: &[NodeId],
options: PlanOptions,
) -> Result<Vec<NodeId>, RunError> {
if !all_unique(outputs, |x, y| x == y) {
return Err(RunError::PlanningError("output IDs are not unique".into()));
}

if !all_unique(inputs, |(x_id, _), (y_id, _)| x_id == y_id) {
if !all_unique(inputs, |x_id, y_id| x_id == y_id) {
return Err(RunError::PlanningError("input IDs are not unique".into()));
}

Expand Down Expand Up @@ -1529,9 +1545,89 @@ impl Graph {
Ok(())
}

/// Return a sequential plan to generate `outputs`. The plan is
/// a vec of `(op_node_id, operator)` tuples.
/// Take the current execution plan and re-order it for more
/// efficient execution.
fn sort_plan(self, mut resolved_values: FxHashSet<NodeId>) -> Vec<NodeId> {
// Build map of value node to operators that depend on the value.
let mut dependent_ops: FxHashMap<NodeId, Vec<(NodeId, &OperatorNode)>> =
FxHashMap::default();
for (op_node_id, op_node) in &self.plan {
for input_id in self.graph.operator_dependencies(op_node) {
if let Some(deps) = dependent_ops.get_mut(&input_id) {
deps.push((*op_node_id, op_node));
} else {
dependent_ops.insert(input_id, [(*op_node_id, *op_node)].into());
}
}
}

let mut output_plan = Vec::with_capacity(self.plan.len());

// Initialize frontier with all operators that can be executed
// from initially-available values.
let mut frontier: Vec<(NodeId, &OperatorNode)> = Vec::new();
for (op_node_id, op_node) in &self.plan {
if self
.graph
.operator_dependencies(op_node)
.all(|id| resolved_values.contains(&id))
{
frontier.push((*op_node_id, op_node));
}
}

debug_assert!(!frontier.is_empty(), "initial frontier is empty");

// Loop while we still have operators to compute.
while !frontier.is_empty() {
// Choose an operator to execute next and add it to the plan.
//
// We run non-in-place operators first, so that operators
// which can run in-place are more likely to have their
// inputs available for in-place execution.
let op_pos = frontier
.iter()
.position(|(_id, op)| !op.operator().can_run_in_place())
.unwrap_or(0);
let (next_op_id, op_node) = frontier.remove(op_pos);
output_plan.push(next_op_id);

// Mark the operator's outputs as computed.
resolved_values.extend(op_node.output_ids().iter().filter_map(|id| *id));

// Visit operators that depend on current op outputs. Add
// to frontier set if all dependencies have been resolved.
for output_id in op_node.output_ids() {
let Some(output_id) = output_id else {
continue;
};
let Some(deps) = dependent_ops.get(output_id) else {
continue;
};
for (candidate_op_id, candidate_op) in deps {
if frontier.iter().any(|(op_id, _)| op_id == candidate_op_id) {
continue;
}

if self
.graph
.operator_dependencies(candidate_op)
.all(|id| resolved_values.contains(&id))
{
frontier.push((*candidate_op_id, candidate_op));
}
}
}
}

output_plan
}

/// Return a sequential plan to generate `outputs`.
fn plan(mut self, outputs: &[NodeId]) -> Result<Vec<NodeId>, RunError> {
let initial_resolved_values = self.resolved_values.clone();

// Build initial plan by traversing graph backwards from outputs.
for output_id in outputs.iter() {
if self.resolved_values.contains(output_id) {
// Value is either a constant node or is produced by
Expand All @@ -1546,15 +1642,25 @@ impl Graph {
return Err(RunError::PlanningError(msg));
}
}
Ok(self.plan.into_iter().map(|(node_id, _)| node_id).collect())

// When doing partial evaluation, just return the initial plan.
// This avoids having to handle missing inputs when sorting the
// plan.
if self.options.allow_missing_inputs || self.plan.is_empty() {
return Ok(self.plan.into_iter().map(|(op_id, _)| op_id).collect());
}

// Re-order initial plan to get a more efficient execution
// order.
let sorted_plan = self.sort_plan(initial_resolved_values);

Ok(sorted_plan)
}
}

// Set of values that are available after executing the plan
let resolved_values: FxHashSet<NodeId> = self.init_resolved_values(
inputs.iter().map(|(node_id, _)| *node_id),
options.captures_available,
);
let resolved_values: FxHashSet<NodeId> =
self.init_resolved_values(inputs.iter().copied(), options.captures_available);

let builder = PlanBuilder {
graph: self,
Expand Down Expand Up @@ -1837,6 +1943,31 @@ mod tests {
Ok(())
}

#[test]
fn test_runs_non_in_place_ops_first() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();

let input_a_id = g.add_value(Some("input_a"), None);
let input_b_id = g.add_value(Some("input_b"), None);

let (add_op, add_out) = g.add_simple_op("add", Add {}, &[input_a_id, input_b_id]);
let (shape_op, shape_out) = g.add_simple_op("shape", Shape {}, &[input_a_id]);

// The execution plan could run operators in either order and produce
// the correct output. Since the `Add` op has the _potential_ to run in
// place (if the input is passed as an owned value) and the `Shape` op
// does not, the Shape op should be run first.
let plan = g.execution_plan(&[input_a_id, input_b_id], &[add_out, shape_out])?;
assert_eq!(plan, &[shape_op, add_op]);

// Make sure the results are the same if the order of outputs is
// swapped.
let plan = g.execution_plan(&[input_a_id, input_b_id], &[shape_out, add_out])?;
assert_eq!(plan, &[shape_op, add_op]);

Ok(())
}

// Perform a graph run where one of the outputs is also an input for other
// steps of the run.
#[test]
Expand Down

0 comments on commit 24da48c

Please sign in to comment.