Skip to content

Commit

Permalink
Infinite loop countermeasure setzer22#21
Browse files Browse the repository at this point in the history
  • Loading branch information
kkngsm committed Aug 5, 2022
1 parent 6e3f9ce commit d3313cb
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 30 deletions.
32 changes: 25 additions & 7 deletions egui_node_graph/src/editor_ui.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,29 @@ where
/* Draw connections */
if let Some((_, ref locator)) = self.connection_in_progress {
let port_type = self.graph.any_param_type(*locator).unwrap();
let connection_color = port_type.data_type_color(&self.user_state);
let start_pos = port_locations[locator];
let (src_pos, dst_pos) = match locator {
AnyParameterId::Output(_) => (start_pos, cursor_pos),
AnyParameterId::Input(_) => (cursor_pos, start_pos),
};

let connection_color;
let src_pos;
let dst_pos;
match locator {
param_id @ AnyParameterId::Output(_) => {
connection_color = port_type.data_type_color(
&self.user_state,
PortConnection::ConnectionCursor(*param_id),
);
src_pos = start_pos;
dst_pos = cursor_pos;
}
param_id @ AnyParameterId::Input(_) => {
connection_color = port_type.data_type_color(
&self.user_state,
PortConnection::ConnectionCursor(*param_id),
);
src_pos = cursor_pos;
dst_pos = start_pos;
}
}
draw_connection(ui.painter(), src_pos, dst_pos, connection_color);
}

Expand All @@ -187,7 +204,8 @@ where
.graph
.any_param_type(AnyParameterId::Output(output))
.unwrap();
let connection_color = port_type.data_type_color(&self.user_state);
let connection_color = port_type
.data_type_color(&self.user_state, PortConnection::Connection(input, output));
let src_pos = port_locations[&AnyParameterId::Output(output)];
let dst_pos = port_locations[&AnyParameterId::Input(input)];
draw_connection(ui.painter(), src_pos, dst_pos, connection_color);
Expand Down Expand Up @@ -467,7 +485,7 @@ where
let port_color = if resp.hovered() {
Color32::WHITE
} else {
port_type.data_type_color(user_state)
port_type.data_type_color(user_state, PortConnection::Port(param_id))
};
ui.painter()
.circle(port_rect.center(), 5.0, port_color, Stroke::none());
Expand Down
7 changes: 7 additions & 0 deletions egui_node_graph/src/id_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,10 @@ impl AnyParameterId {
}
}
}

#[derive(Clone, Copy, Debug)]
pub enum PortConnection {
Port(AnyParameterId),
ConnectionCursor(AnyParameterId),
Connection(InputId, OutputId),
}
6 changes: 5 additions & 1 deletion egui_node_graph/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ pub trait WidgetValueTrait {
/// to the user.
pub trait DataTypeTrait<UserState>: PartialEq + Eq {
/// The associated port color of this datatype
fn data_type_color(&self, user_state: &UserState) -> egui::Color32;
fn data_type_color(
&self,
user_state: &UserState,
port_connection: PortConnection,
) -> egui::Color32;

/// The name of this datatype. Return type is specified as Cow<str> because
/// some implementations will need to allocate a new string to provide an
Expand Down
76 changes: 54 additions & 22 deletions egui_node_graph_example/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,27 @@ pub enum MyResponse {
#[derive(Default)]
pub struct MyGraphState {
pub active_node: Option<NodeId>,

pub infinity_loop: Option<OutputId>,
}

// =========== Then, you need to implement some traits ============

// A trait for the data types, to tell the library how to display them
impl DataTypeTrait<MyGraphState> for MyDataType {
fn data_type_color(&self, _user_state: &MyGraphState) -> egui::Color32 {
fn data_type_color(
&self,
user_state: &MyGraphState,
port_connection: PortConnection,
) -> egui::Color32 {
// Turns the color of connections in infinite loops red
if let Some(loop_id) = user_state.infinity_loop {
if let PortConnection::Connection(_, output_id) = port_connection {
if output_id == loop_id {
return egui::Color32::from_rgb(255, 0, 0);
}
}
}
match self {
MyDataType::Scalar => egui::Color32::from_rgb(38, 109, 211),
MyDataType::Vec2 => egui::Color32::from_rgb(238, 207, 109),
Expand Down Expand Up @@ -339,14 +353,14 @@ pub struct NodeGraphExample {
// The `GraphEditorState` is the top-level object. You "register" all your
// custom types by specifying it as its generic parameters.
state: MyEditorState,
visitor: BFSVisitor,
visitor: BFS,
}

impl Default for NodeGraphExample {
fn default() -> Self {
Self {
state: GraphEditorState::new(1.0, MyGraphState::default()),
visitor: BFSVisitor::default(),
visitor: BFS::default(),
}
}
}
Expand Down Expand Up @@ -381,10 +395,14 @@ impl eframe::App for NodeGraphExample {

if let Some(node) = self.state.user_state.active_node {
if self.state.graph.nodes.contains_key(node) {
let text = match self.visitor.compute(&self.state.graph, node) {
Ok(value) => format!("The result is: {:?}", value),
Err(err) => format!("Execution error: {}", err),
};
let text =
match self
.visitor
.compute(&self.state.graph, node, &mut self.state.user_state)
{
Ok(value) => format!("The result is: {:?}", value),
Err(err) => format!("Execution error: {}", err),
};
ctx.debug_painter().text(
egui::pos2(10.0, 35.0),
egui::Align2::LEFT_TOP,
Expand All @@ -404,7 +422,7 @@ use std::collections::{BTreeSet, HashMap, VecDeque};
type OutputsCache = HashMap<OutputId, MyValueType>;

#[derive(Default)]
struct BFSVisitor {
struct BFS {
explored: BTreeSet<NodeId>,
queue: VecDeque<NodeId>,

Expand All @@ -415,28 +433,34 @@ struct BFSVisitor {
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct InputPortId(pub usize);
/// Evaluate all dependencies of this node by BFS
impl BFSVisitor {
impl BFS {
/// Explore the nodes, evaluate, and return the result of the evaluation.
pub fn compute(&mut self, graph: &MyGraph, node_id: NodeId) -> anyhow::Result<MyValueType> {
pub fn compute(
&mut self,
graph: &MyGraph,
node_id: NodeId,
user_state: &mut MyGraphState,
) -> anyhow::Result<MyValueType> {
user_state.infinity_loop = None;
self.clear_cache();
self.explore(graph, node_id);
self.evaluate_all(graph)?;
self.evaluate_all(graph, user_state)?;
let ans = self.get_output(graph, node_id)?;
Ok(ans)
}

/// Return the result of the evaluation. This example returns the first output.
pub fn get_output(&self, graph: &MyGraph, node_id: NodeId) -> anyhow::Result<MyValueType> {
let output_id = graph[node_id].output_ids().next().unwrap();
// If there are multiple outputs:
// If there are multiple outputs:
// let output_id = graph[node_id].get_output(param_name);
let output = self
.outputs_cache
.get(&output_id)
.ok_or(anyhow::format_err!("It may be in an infinite loop."))?;
Ok(*output)
}

/// Always run before exploring.
pub fn clear_cache(&mut self) {
self.queue.clear();
Expand Down Expand Up @@ -466,12 +490,17 @@ impl BFSVisitor {

/// Evaluate based on the calculation sequence
/// Note that the order of computation is stored in reverse
fn evaluate_all(&mut self, graph: &MyGraph) -> anyhow::Result<()> {
fn evaluate_all(
&mut self,
graph: &MyGraph,
user_state: &mut MyGraphState,
) -> anyhow::Result<()> {
for node_id in self.sequence.iter().rev().copied() {
let mut evaluator = Evaluator {
graph,
node_id,
outputs_cache: &mut self.outputs_cache,
user_state,
};
evaluator.evaluate(graph)?;
}
Expand All @@ -484,6 +513,8 @@ struct Evaluator<'a> {
pub outputs_cache: &'a mut OutputsCache,
pub graph: &'a MyGraph,
pub node_id: NodeId,

pub user_state: &'a mut MyGraphState,
}
impl Evaluator<'_> {
fn input_vector(&mut self, param_name: &str) -> anyhow::Result<egui::Vec2> {
Expand All @@ -497,17 +528,18 @@ impl Evaluator<'_> {
fn get_input(&mut self, param_name: &str) -> anyhow::Result<MyValueType> {
let input_id = self.graph[self.node_id].get_input(param_name)?;
// The output of another node is connected.
let value = if let Some(output_id) = self.graph.connection(input_id) {
if let Some(output_id) = self.graph.connection(input_id) {
// Now that we know the value is cached, return it
*self
.outputs_cache
.get(&output_id)
.ok_or(anyhow::format_err!("It may be in an infinite loop."))?
if let Some(value) = self.outputs_cache.get(&output_id) {
Ok(*value)
} else {
self.user_state.infinity_loop = Some(output_id);
Err(anyhow::format_err!("It may be in an infinite loop."))
}
} else {
// No existing connection, take the inline value instead.
self.graph[input_id].value
};
Ok(value)
Ok(self.graph[input_id].value)
}
}

fn output_vector(&mut self, name: &str, value: egui::Vec2) -> Result<(), EguiGraphError> {
Expand Down

0 comments on commit d3313cb

Please sign in to comment.