Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-order execution plan to enable more operations to run in-place #405

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 148 additions & 17 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,21 @@ impl Graph {
.sum()
}

/// Return the sequence of operators from the current graph that would be
/// executed in order to compute `outputs` given `inputs`, without actually
/// running the model.
///
/// The result does not include nodes from any subgraphs that an operator
/// may run.
pub fn execution_plan(
&self,
inputs: &[NodeId],
outputs: &[NodeId],
) -> Result<Vec<NodeId>, RunError> {
let plan = self.get_cached_plan(inputs, outputs, false /* is_subgraph */)?;
Ok(plan.plan().to_vec())
}

/// Compute a set of output values given a set of inputs, using the
/// processing steps and constant values defined by the graph.
pub fn run(
Expand All @@ -909,7 +924,8 @@ impl Graph {
outputs: &[NodeId],
opts: Option<RunOptions>,
) -> Result<Vec<Output>, RunError> {
let plan = self.get_cached_plan(&inputs, outputs, false /* is_subgraph */)?;
let input_ids: Vec<_> = inputs.iter().map(|(node_id, _)| *node_id).collect();
let plan = self.get_cached_plan(&input_ids, outputs, false /* is_subgraph */)?;
threading::thread_pool().run(|| {
self.run_plan(
inputs,
Expand All @@ -934,13 +950,14 @@ impl Graph {
pool: Option<&TensorPool>,
opts: Option<RunOptions>,
) -> Result<Vec<Output>, RunError> {
let plan = self.get_cached_plan(&inputs, outputs, true /* is_subgraph */)?;
let input_ids: Vec<_> = inputs.iter().map(|(node_id, _)| *node_id).collect();
let plan = self.get_cached_plan(&input_ids, outputs, true /* is_subgraph */)?;
self.run_plan(inputs, plan.plan(), outputs, Some(captures), pool, opts)
}

fn get_cached_plan(
&self,
inputs: &[(NodeId, InputOrOutput)],
inputs: &[NodeId],
outputs: &[NodeId],
is_subgraph: bool,
) -> Result<Arc<CachedPlan>, RunError> {
Expand All @@ -950,9 +967,8 @@ impl Graph {
// Note that we only hold the plan lock while creating the plan,
// not while executing the model.
let mut cached_plan = self.cached_plan.lock().unwrap();
let input_ids: Vec<_> = inputs.iter().map(|(node_id, _)| *node_id).collect();
let plan = match cached_plan.as_ref() {
Some(plan) if plan.matches(&input_ids, outputs) => plan.clone(),
Some(plan) if plan.matches(inputs, outputs) => plan.clone(),
_ => {
let plan = self.create_plan(
inputs,
Expand All @@ -962,7 +978,7 @@ impl Graph {
captures_available: is_subgraph,
},
)?;
*cached_plan = Some(Arc::new(CachedPlan::new(&input_ids, outputs, plan)));
*cached_plan = Some(Arc::new(CachedPlan::new(inputs, outputs, plan)));
cached_plan.clone().unwrap()
}
};
Expand Down Expand Up @@ -1326,15 +1342,15 @@ impl Graph {
outputs: &[NodeId],
opts: Option<RunOptions>,
) -> Result<Vec<(NodeId, Output)>, RunError> {
let input_ids: Vec<_> = inputs.iter().map(|(id, _)| id).copied().collect();
let plan = self.create_plan(
&inputs,
&input_ids,
outputs,
PlanOptions {
allow_missing_inputs: true,
captures_available: false,
},
)?;
let input_ids: Vec<_> = inputs.iter().map(|(id, _)| id).copied().collect();
let (pruned_plan, pruned_plan_output_ids) = self.prune_plan(&plan, &input_ids, outputs);
let outputs = threading::thread_pool().run(|| {
self.run_plan(
Expand Down Expand Up @@ -1476,15 +1492,15 @@ impl Graph {
/// omitted from the plan.
fn create_plan(
&self,
inputs: &[(NodeId, InputOrOutput)],
inputs: &[NodeId],
outputs: &[NodeId],
options: PlanOptions,
) -> Result<Vec<NodeId>, RunError> {
if !all_unique(outputs, |x, y| x == y) {
return Err(RunError::PlanningError("output IDs are not unique".into()));
}

if !all_unique(inputs, |(x_id, _), (y_id, _)| x_id == y_id) {
if !all_unique(inputs, |x_id, y_id| x_id == y_id) {
return Err(RunError::PlanningError("input IDs are not unique".into()));
}

Expand Down Expand Up @@ -1529,9 +1545,89 @@ impl Graph {
Ok(())
}

/// Return a sequential plan to generate `outputs`. The plan is
/// a vec of `(op_node_id, operator)` tuples.
/// Take the current execution plan and re-order it for more
/// efficient execution.
fn sort_plan(self, mut resolved_values: FxHashSet<NodeId>) -> Vec<NodeId> {
// Build map of value node to operators that depend on the value.
let mut dependent_ops: FxHashMap<NodeId, Vec<(NodeId, &OperatorNode)>> =
FxHashMap::default();
for (op_node_id, op_node) in &self.plan {
for input_id in self.graph.operator_dependencies(op_node) {
if let Some(deps) = dependent_ops.get_mut(&input_id) {
deps.push((*op_node_id, op_node));
} else {
dependent_ops.insert(input_id, [(*op_node_id, *op_node)].into());
}
}
}

let mut output_plan = Vec::with_capacity(self.plan.len());

// Initialize frontier with all operators that can be executed
// from initially-available values.
let mut frontier: Vec<(NodeId, &OperatorNode)> = Vec::new();
for (op_node_id, op_node) in &self.plan {
if self
.graph
.operator_dependencies(op_node)
.all(|id| resolved_values.contains(&id))
{
frontier.push((*op_node_id, op_node));
}
}

debug_assert!(!frontier.is_empty(), "initial frontier is empty");

// Loop while we still have operators to compute.
while !frontier.is_empty() {
// Choose an operator to execute next and add it to the plan.
//
// We run non-in-place operators first, so that operators
// which can run in-place are more likely to have their
// inputs available for in-place execution.
let op_pos = frontier
.iter()
.position(|(_id, op)| !op.operator().can_run_in_place())
.unwrap_or(0);
let (next_op_id, op_node) = frontier.remove(op_pos);
output_plan.push(next_op_id);

// Mark the operator's outputs as computed.
resolved_values.extend(op_node.output_ids().iter().filter_map(|id| *id));

// Visit operators that depend on current op outputs. Add
// to frontier set if all dependencies have been resolved.
for output_id in op_node.output_ids() {
let Some(output_id) = output_id else {
continue;
};
let Some(deps) = dependent_ops.get(output_id) else {
continue;
};
for (candidate_op_id, candidate_op) in deps {
if frontier.iter().any(|(op_id, _)| op_id == candidate_op_id) {
continue;
}

if self
.graph
.operator_dependencies(candidate_op)
.all(|id| resolved_values.contains(&id))
{
frontier.push((*candidate_op_id, candidate_op));
}
}
}
}

output_plan
}

/// Return a sequential plan to generate `outputs`.
fn plan(mut self, outputs: &[NodeId]) -> Result<Vec<NodeId>, RunError> {
let initial_resolved_values = self.resolved_values.clone();

// Build initial plan by traversing graph backwards from outputs.
for output_id in outputs.iter() {
if self.resolved_values.contains(output_id) {
// Value is either a constant node or is produced by
Expand All @@ -1546,15 +1642,25 @@ impl Graph {
return Err(RunError::PlanningError(msg));
}
}
Ok(self.plan.into_iter().map(|(node_id, _)| node_id).collect())

// When doing partial evaluation, just return the initial plan.
// This avoids having to handle missing inputs when sorting the
// plan.
if self.options.allow_missing_inputs || self.plan.is_empty() {
return Ok(self.plan.into_iter().map(|(op_id, _)| op_id).collect());
}

// Re-order initial plan to get a more efficient execution
// order.
let sorted_plan = self.sort_plan(initial_resolved_values);

Ok(sorted_plan)
}
}

// Set of values that are available after executing the plan
let resolved_values: FxHashSet<NodeId> = self.init_resolved_values(
inputs.iter().map(|(node_id, _)| *node_id),
options.captures_available,
);
let resolved_values: FxHashSet<NodeId> =
self.init_resolved_values(inputs.iter().copied(), options.captures_available);

let builder = PlanBuilder {
graph: self,
Expand Down Expand Up @@ -1837,6 +1943,31 @@ mod tests {
Ok(())
}

#[test]
fn test_runs_non_in_place_ops_first() -> Result<(), Box<dyn Error>> {
let mut g = Graph::new();

let input_a_id = g.add_value(Some("input_a"), None);
let input_b_id = g.add_value(Some("input_b"), None);

let (add_op, add_out) = g.add_simple_op("add", Add {}, &[input_a_id, input_b_id]);
let (shape_op, shape_out) = g.add_simple_op("shape", Shape {}, &[input_a_id]);

// The execution plan could run operators in either order and produce
// the correct output. Since the `Add` op has the _potential_ to run in
// place (if the input is passed as an owned value) and the `Shape` op
// does not, the Shape op should be run first.
let plan = g.execution_plan(&[input_a_id, input_b_id], &[add_out, shape_out])?;
assert_eq!(plan, &[shape_op, add_op]);

// Make sure the results are the same if the order of outputs is
// swapped.
let plan = g.execution_plan(&[input_a_id, input_b_id], &[shape_out, add_out])?;
assert_eq!(plan, &[shape_op, add_op]);

Ok(())
}

// Perform a graph run where one of the outputs is also an input for other
// steps of the run.
#[test]
Expand Down
Loading