diff --git a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py index 90db2c6b1f5..5ab5935290d 100644 --- a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py +++ b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py @@ -11,36 +11,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -from cugraph.structure import graph_primtypes_wrapper -from cugraph.structure.graph_primtypes_wrapper import Direction -from cugraph.structure.number_map import NumberMap -from cugraph.structure.symmetrize import symmetrize -import cudf +import gc +from typing import Union import warnings -import dask_cudf + +import cudf import cupy as cp import dask -from typing import Union +import dask_cudf +from dask import delayed +from dask.distributed import wait, default_client import numpy as np -import gc from pylibcugraph import ( MGGraph, ResourceHandle, GraphProperties, + get_two_hop_neighbors as pylibcugraph_get_two_hop_neighbors, + select_random_vertices as pylibcugraph_select_random_vertices, ) -from dask.distributed import wait, default_client +from cugraph.structure import graph_primtypes_wrapper +from cugraph.structure.graph_primtypes_wrapper import Direction +from cugraph.structure.number_map import NumberMap +from cugraph.structure.symmetrize import symmetrize from cugraph.dask.common.part_utils import ( get_persisted_df_worker_map, persist_dask_df_equal_parts_per_worker, ) -from cugraph.dask.common.input_utils import get_distributed_data -from pylibcugraph import ( - get_two_hop_neighbors as pylibcugraph_get_two_hop_neighbors, - select_random_vertices as pylibcugraph_select_random_vertices, -) +from cugraph.dask import get_n_workers import cugraph.dask.comms.comms as Comms -from dask import delayed class simpleDistributedGraphImpl: @@ -784,6 +783,15 @@ def get_two_hop_neighbors(self, start_vertices=None): the second vertex id of a pair, if an external vertex id is defined by only one column """ + _client = default_client() + + def _call_plc_two_hop_neighbors(sID, mg_graph_x, start_vertices): + return pylibcugraph_get_two_hop_neighbors( + resource_handle=ResourceHandle(Comms.get_handle(sID).getHandle()), + graph=mg_graph_x, + start_vertices=start_vertices, + do_expensive_check=False, + ) if isinstance(start_vertices, int): start_vertices = [start_vertices] @@ -805,20 +813,13 @@ def get_two_hop_neighbors(self, start_vertices=None): ) start_vertices = start_vertices.astype(start_vertices_type) - start_vertices = get_distributed_data(start_vertices) - wait(start_vertices) - start_vertices = start_vertices.worker_to_parts - - def _call_plc_two_hop_neighbors(sID, mg_graph_x, start_vertices): - return pylibcugraph_get_two_hop_neighbors( - resource_handle=ResourceHandle(Comms.get_handle(sID).getHandle()), - graph=mg_graph_x, - start_vertices=start_vertices, - do_expensive_check=False, + n_workers = get_n_workers() + start_vertices = start_vertices.repartition(npartitions=n_workers) + start_vertices = persist_dask_df_equal_parts_per_worker( + start_vertices, _client ) + start_vertices = get_persisted_df_worker_map(start_vertices, _client) - _client = default_client() - if start_vertices is not None: result = [ _client.submit( _call_plc_two_hop_neighbors, @@ -828,7 +829,7 @@ def _call_plc_two_hop_neighbors(sID, mg_graph_x, start_vertices): workers=[w], allow_other_workers=False, ) - for w in Comms.get_workers() + for w in start_vertices.keys() ] else: result = [ @@ -855,7 +856,6 @@ def convert_to_cudf(cp_arrays): df["second"] = second return df - _client = default_client() cudf_result = [ _client.submit(convert_to_cudf, cp_arrays) for cp_arrays in result ]