diff --git a/releasenotes/notes/0.13/fix-removed-nodes-attr-d1829e1f4462d96a.yaml b/releasenotes/notes/0.13/fix-removed-nodes-attr-d1829e1f4462d96a.yaml new file mode 100644 index 000000000..7900c2139 --- /dev/null +++ b/releasenotes/notes/0.13/fix-removed-nodes-attr-d1829e1f4462d96a.yaml @@ -0,0 +1,7 @@ +--- +fixes: + - | + Fixed an issue with several :class:`~.PyDiGraph` and :class:`~.PyGraph` + methods that removed nodes where previously when calling + these methods the :attr:`.PyDiGraph.node_removed` attribute would not be + updated to reflect that nodes were removed. diff --git a/src/digraph.rs b/src/digraph.rs index a48f2d8b6..6c177e775 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -29,7 +29,7 @@ use rustworkx_core::dictmap::*; use pyo3::exceptions::PyIndexError; use pyo3::gc::PyVisit; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyDict, PyList, PyLong, PyString, PyTuple}; +use pyo3::types::{PyBool, PyDict, PyList, PyString, PyTuple}; use pyo3::PyTraverseError; use pyo3::Python; @@ -44,7 +44,7 @@ use petgraph::prelude::*; use petgraph::visit::{ EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, - Visitable, + NodeIndexable, Visitable, }; use super::dot_utils::build_dot; @@ -318,7 +318,6 @@ impl PyDiGraph { }; edges.push(edge); } - let out_dict = PyDict::new(py); let nodes_lst: PyObject = PyList::new(py, nodes).into(); let edges_lst: PyObject = PyList::new(py, edges).into(); @@ -398,55 +397,22 @@ impl PyDiGraph { .downcast::() .unwrap(); - // use a pointer to iter the node list - let mut pointer = 0; - let mut next_node_idx: usize = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap() - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - // list of temporary nodes that will be removed later to re-create holes let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap(); let mut tmp_nodes: Vec = Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len()); - for i in 0..nodes_lst.len() + 1 { - if i < next_node_idx { + for item in nodes_lst { + let item = item.downcast::().unwrap(); + let next_index: usize = item.get_item(0).unwrap().extract().unwrap(); + let weight: PyObject = item.get_item(1).unwrap().extract().unwrap(); + while next_index > self.graph.node_bound() { // node does not exist let tmp_node = self.graph.add_node(py.None()); tmp_nodes.push(tmp_node); - } else { - // add node to the graph, and update the next available node index - let item = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap(); - - let node_w = item.get_item(1).unwrap().extract().unwrap(); - self.graph.add_node(node_w); - pointer += 1; - if pointer < nodes_lst.len() { - next_node_idx = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap() - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - } } + // add node to the graph, and update the next available node index + self.graph.add_node(weight); } // Remove any temporary nodes we added for tmp_node in tmp_nodes { @@ -463,20 +429,8 @@ impl PyDiGraph { self.graph.add_edge(tmp_node, tmp_node, py.None()); } else { let triple = item.downcast::().unwrap(); - let edge_p: usize = triple - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - let edge_c: usize = triple - .get_item(1) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); + let edge_p: usize = triple.get_item(0).unwrap().extract().unwrap(); + let edge_c: usize = triple.get_item(1).unwrap().extract().unwrap(); let edge_w = triple.get_item(2).unwrap().extract().unwrap(); self.graph .add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w); @@ -1760,8 +1714,8 @@ impl PyDiGraph { /// the graph. #[pyo3(text_signature = "(self, index_list, /)")] pub fn remove_nodes_from(&mut self, index_list: Vec) -> PyResult<()> { - for node in index_list.iter().map(|x| NodeIndex::new(*x)) { - self.graph.remove_node(node); + for node in index_list { + self.remove_node(node)?; } Ok(()) } @@ -2389,7 +2343,7 @@ impl PyDiGraph { // If no nodes are copied bail here since there is nothing left // to do. if out_map.is_empty() { - self.graph.remove_node(node_index); + self.remove_node(node_index.index())?; // Return a new empty map to clear allocation from out_map return Ok(NodeMap { node_map: DictMap::new(), @@ -2450,7 +2404,7 @@ impl PyDiGraph { self._add_edge(source_out, target, weight)?; } // Remove node - self.graph.remove_node(node_index); + self.remove_node(node_index.index())?; Ok(NodeMap { node_map: out_map }) } @@ -2559,7 +2513,7 @@ impl PyDiGraph { // Remove nodes that will be replaced. for index in indices_to_remove { - self.graph.remove_node(index); + self.remove_node(index.index())?; } // If `weight_combo_fn` was specified, merge edges according @@ -2912,7 +2866,10 @@ impl PyDiGraph { fn __delitem__(&mut self, idx: usize) -> PyResult<()> { match self.graph.remove_node(NodeIndex::new(idx)) { - Some(_) => Ok(()), + Some(_) => { + self.node_removed = true; + Ok(()) + } None => Err(PyIndexError::new_err("No node found for index")), } } diff --git a/src/graph.rs b/src/graph.rs index 75165cc4c..04c90c4c7 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -26,7 +26,7 @@ use rustworkx_core::dictmap::*; use pyo3::exceptions::PyIndexError; use pyo3::gc::PyVisit; use pyo3::prelude::*; -use pyo3::types::{PyBool, PyDict, PyList, PyLong, PyString, PyTuple}; +use pyo3::types::{PyBool, PyDict, PyList, PyString, PyTuple}; use pyo3::PyTraverseError; use pyo3::Python; @@ -47,6 +47,7 @@ use petgraph::graph::{EdgeIndex, NodeIndex}; use petgraph::prelude::*; use petgraph::visit::{ EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, + NodeIndexable, }; /// A class for creating undirected graphs @@ -284,56 +285,24 @@ impl PyGraph { .downcast::() .unwrap(); - // use a pointer to iter the node list - let mut pointer = 0; - let mut next_node_idx: usize = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap() - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - // list of temporary nodes that will be removed later to re-create holes let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap(); let mut tmp_nodes: Vec = Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len()); - for i in 0..nodes_lst.len() + 1 { - if i < next_node_idx { + for item in nodes_lst { + let item = item.downcast::().unwrap(); + let next_index: usize = item.get_item(0).unwrap().extract().unwrap(); + let weight: PyObject = item.get_item(1).unwrap().extract().unwrap(); + while next_index > self.graph.node_bound() { // node does not exist let tmp_node = self.graph.add_node(py.None()); tmp_nodes.push(tmp_node); - } else { - // add node to the graph, and update the next available node index - let item = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap(); - - let node_w = item.get_item(1).unwrap().extract().unwrap(); - self.graph.add_node(node_w); - pointer += 1; - if pointer < nodes_lst.len() { - next_node_idx = nodes_lst - .get_item(pointer) - .unwrap() - .downcast::() - .unwrap() - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - } } + // add node to the graph, and update the next available node index + self.graph.add_node(weight); } + // Remove any temporary nodes we added for tmp_node in tmp_nodes { self.graph.remove_node(tmp_node); } @@ -348,20 +317,8 @@ impl PyGraph { self.graph.add_edge(tmp_node, tmp_node, py.None()); } else { let triple = item.downcast::().unwrap(); - let edge_p: usize = triple - .get_item(0) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); - let edge_c: usize = triple - .get_item(1) - .unwrap() - .downcast::() - .unwrap() - .extract() - .unwrap(); + let edge_p: usize = triple.get_item(0).unwrap().extract().unwrap(); + let edge_c: usize = triple.get_item(1).unwrap().extract().unwrap(); let edge_w = triple.get_item(2).unwrap().extract().unwrap(); self.graph .add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w); @@ -1062,8 +1019,8 @@ impl PyGraph { /// the graph #[pyo3(text_signature = "(self, index_list, /)")] pub fn remove_nodes_from(&mut self, index_list: Vec) -> PyResult<()> { - for node in index_list.iter().map(|x| NodeIndex::new(*x)) { - self.graph.remove_node(node); + for node in index_list { + self.remove_node(node)?; } Ok(()) } @@ -1695,7 +1652,7 @@ impl PyGraph { // Remove nodes that will be replaced. for index in indices_to_remove { - self.graph.remove_node(index); + self.remove_node(index.index())?; } // If `weight_combo_fn` was specified, merge edges according @@ -1846,7 +1803,10 @@ impl PyGraph { fn __delitem__(&mut self, idx: usize) -> PyResult<()> { match self.graph.remove_node(NodeIndex::new(idx)) { - Some(_) => Ok(()), + Some(_) => { + self.node_removed = true; + Ok(()) + } None => Err(PyIndexError::new_err("No node found for index")), } } diff --git a/tests/rustworkx_tests/digraph/test_deepcopy.py b/tests/rustworkx_tests/digraph/test_deepcopy.py index bd296a5a5..542251273 100644 --- a/tests/rustworkx_tests/digraph/test_deepcopy.py +++ b/tests/rustworkx_tests/digraph/test_deepcopy.py @@ -70,3 +70,29 @@ def test_deepcopy_different_objects(self): self.assertIsNot( graph_a.get_edge_data(node_a, node_b), graph_b.get_edge_data(node_a, node_b) ) + + def test_deepcopy_multinode_hole_in_middle(self): + graph = rustworkx.PyDiGraph() + graph.add_nodes_from(range(20)) + graph.remove_nodes_from([10, 11, 12, 13, 14]) + graph.add_edges_from_no_data( + [ + (4, 5), + (16, 18), + (2, 19), + (0, 15), + (15, 16), + (16, 17), + (6, 17), + (8, 18), + (17, 1), + (17, 7), + (18, 3), + (18, 9), + (19, 16), + ] + ) + copied_graph = copy.deepcopy(graph) + self.assertEqual( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 16, 17, 18, 19], copied_graph.node_indices() + ) diff --git a/tests/rustworkx_tests/graph/test_deepcopy.py b/tests/rustworkx_tests/graph/test_deepcopy.py index 074d03211..6941d2dcb 100644 --- a/tests/rustworkx_tests/graph/test_deepcopy.py +++ b/tests/rustworkx_tests/graph/test_deepcopy.py @@ -48,3 +48,29 @@ def test_deepcopy_attrs(self): graph = rustworkx.PyGraph(attrs="abc") graph_copy = copy.deepcopy(graph) self.assertEqual(graph.attrs, graph_copy.attrs) + + def test_deepcopy_multinode_hole_in_middle(self): + graph = rustworkx.PyGraph() + graph.add_nodes_from(range(20)) + graph.remove_nodes_from([10, 11, 12, 13, 14]) + graph.add_edges_from_no_data( + [ + (4, 5), + (16, 18), + (2, 19), + (0, 15), + (15, 16), + (16, 17), + (6, 17), + (8, 18), + (17, 1), + (17, 7), + (18, 3), + (18, 9), + (19, 16), + ] + ) + copied_graph = copy.deepcopy(graph) + self.assertEqual( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 16, 17, 18, 19], copied_graph.node_indices() + )