Skip to content

Commit

Permalink
Distributed Sampling in cuGraph-PyG (#4384)
Browse files Browse the repository at this point in the history
Distributed sampling in cuGraph-PyG.  Also renames the existing API to clarify that it is dask based.
Adds a dependency on `tensordict` for `cuGraph-PyG` which supports the new `TensorDictFeatureStore`.
Also no longer installs `torch-cluster` and `torch-spline-conv` in CI for testing since that results in an `ImportError` and neither of those packages are needed.

Requires PyG 2.5.  Should be merged after #4335 

Merge after #4355 

Closes #4248 
Closes #4249 
Closes #3383 
Closes #3942 
Closes #3836 
Closes #4202 
Closes #4051 
Closes #4326 
Closes #4252 
Partially addresses #3805

Authors:
  - Alex Barghi (https://github.com/alexbarghi-nv)
  - Seunghwa Kang (https://github.com/seunghwak)
  - Tingyu Wang (https://github.com/tingyu66)
  - Ralph Liu (https://github.com/nv-rliu)

Approvers:
  - Tingyu Wang (https://github.com/tingyu66)
  - Brad Rees (https://github.com/BradReesWork)
  - Jake Awe (https://github.com/AyodeAwe)

URL: #4384
  • Loading branch information
alexbarghi-nv authored May 30, 2024
1 parent 563c06e commit 797a036
Show file tree
Hide file tree
Showing 47 changed files with 2,465 additions and 229 deletions.
5 changes: 4 additions & 1 deletion ci/run_cugraph_pyg_pytests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ set -euo pipefail
# Support invoking run_cugraph_pyg_pytests.sh outside the script directory
cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../python/cugraph-pyg/cugraph_pyg

pytest --cache-clear --ignore=tests/mg "$@" .
pytest --cache-clear --benchmark-disable "$@" .

# Used to skip certain examples in CI due to memory limitations
export CI_RUN=1

# Test examples
for e in "$(pwd)"/examples/*.py; do
Expand Down
2 changes: 1 addition & 1 deletion ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ if hasArg "--run-python-tests"; then
conda list
cd ${CUGRAPH_ROOT}/python/cugraph-pyg/cugraph_pyg
# rmat is not tested because of MG testing
pytest --cache-clear --junitxml=${CUGRAPH_ROOT}/junit-cugraph-pytests.xml -v --cov-config=.coveragerc --cov=cugraph_pyg --cov-report=xml:${WORKSPACE}/python/cugraph_pyg/cugraph-coverage.xml --cov-report term --ignore=raft --ignore=tests/mg --ignore=tests/int --ignore=tests/generators --benchmark-disable
pytest -sv -m sg --cache-clear --junitxml=${CUGRAPH_ROOT}/junit-cugraph-pytests.xml -v --cov-config=.coveragerc --cov=cugraph_pyg --cov-report=xml:${WORKSPACE}/python/cugraph_pyg/cugraph-coverage.xml --cov-report term --ignore=raft --benchmark-disable
echo "Ran Python pytest for cugraph_pyg : return code was: $?, test script exit code is now: $EXITCODE"

echo "Python pytest for cugraph-service (single-GPU only)..."
Expand Down
7 changes: 4 additions & 3 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -215,13 +215,14 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then

# Install pyg dependencies (which requires pip)

pip install ogb
pip install \
ogb \
tensordict

pip install \
pyg_lib \
torch_scatter \
torch_sparse \
torch_cluster \
torch_spline_conv \
-f ${PYG_URL}

rapids-print-env
Expand Down
8 changes: 5 additions & 3 deletions ci/test_wheel_cugraph-pyg.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ python -m pip install $(ls ./dist/${python_package_name}*.whl)[test]
# RAPIDS_DATASET_ROOT_DIR is used by test scripts
export RAPIDS_DATASET_ROOT_DIR="$(realpath datasets)"

# Used to skip certain examples in CI due to memory limitations
export CI_RUN=1

if [[ "${CUDA_VERSION}" == "11.8.0" ]]; then
PYTORCH_URL="https://download.pytorch.org/whl/cu118"
PYG_URL="https://data.pyg.org/whl/torch-2.1.0+cu118.html"
Expand All @@ -39,15 +42,14 @@ rapids-retry python -m pip install \
pyg_lib \
torch_scatter \
torch_sparse \
torch_cluster \
torch_spline_conv \
tensordict \
-f ${PYG_URL}

rapids-logger "pytest cugraph-pyg (single GPU)"
pushd python/cugraph-pyg/cugraph_pyg
python -m pytest \
--cache-clear \
--ignore=tests/mg \
--benchmark-disable \
tests
# Test examples
for e in "$(pwd)"/examples/*.py; do
Expand Down
1 change: 1 addition & 0 deletions conda/recipes/cugraph-pyg/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ requirements:
- cupy >=12.0.0
- cugraph ={{ version }}
- pylibcugraphops ={{ minor_version }}
- tensordict >=0.1.2
- pyg >=2.5,<2.6

tests:
Expand Down
1 change: 1 addition & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ dependencies:
- cugraph==24.6.*
- pytorch>=2.0
- pytorch-cuda==11.8
- tensordict>=0.1.2
- pyg>=2.5,<2.6

depends_on_rmm:
Expand Down
33 changes: 31 additions & 2 deletions docs/cugraph/source/api_docs/cugraph-pyg/cugraph_pyg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,37 @@ cugraph-pyg

.. currentmodule:: cugraph_pyg

Graph Storage
-------------
.. autosummary::
:toctree: ../api/cugraph-pyg/

.. cugraph_pyg.data.cugraph_store.EXPERIMENTAL__CuGraphStore
.. cugraph_pyg.sampler.cugraph_sampler.EXPERIMENTAL__CuGraphSampler
cugraph_pyg.data.dask_graph_store.DaskGraphStore
cugraph_pyg.data.graph_store.GraphStore

Feature Storage
---------------
.. autosummary::
:toctree: ../api/cugraph-pyg/

cugraph_pyg.data.feature_store.TensorDictFeatureStore

Data Loaders
------------
.. autosummary::
:toctree: ../api/cugraph-pyg/

cugraph_pyg.loader.dask_node_loader.DaskNeighborLoader
cugraph_pyg.loader.dask_node_loader.BulkSampleLoader
cugraph_pyg.loader.node_loader.NodeLoader
cugraph_pyg.loader.neighbor_loader.NeighborLoader

Samplers
--------
.. autosummary::
:toctree: ../api/cugraph-pyg/

cugraph_pyg.sampler.sampler.BaseSampler
cugraph_pyg.sampler.sampler.SampleReader
cugraph_pyg.sampler.sampler.HomogeneousSampleReader
cugraph_pyg.sampler.sampler.SampleIterator
1 change: 1 addition & 0 deletions python/cugraph-pyg/conda/cugraph_pyg_dev_cuda-118.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ dependencies:
- pytorch-cuda==11.8
- pytorch>=2.0
- scipy
- tensordict>=0.1.2
name: cugraph_pyg_dev_cuda-118
13 changes: 11 additions & 2 deletions python/cugraph-pyg/cugraph_pyg/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -11,4 +11,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from cugraph_pyg.data.cugraph_store import CuGraphStore
import warnings

from cugraph_pyg.data.dask_graph_store import DaskGraphStore
from cugraph_pyg.data.graph_store import GraphStore
from cugraph_pyg.data.feature_store import TensorDictFeatureStore


def CuGraphStore(*args, **kwargs):
warnings.warn("CuGraphStore has been renamed to DaskGraphStore", FutureWarning)
return DaskGraphStore(*args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def cast(cls, *args, **kwargs):
return cls(*args, **kwargs)


class CuGraphStore:
class DaskGraphStore:
"""
Duck-typed version of PyG's GraphStore and FeatureStore.
"""
Expand All @@ -221,7 +221,7 @@ def __init__(
order: str = "CSR",
):
"""
Constructs a new CuGraphStore from the provided
Constructs a new DaskGraphStore from the provided
arguments.
Parameters
Expand Down
129 changes: 129 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/data/feature_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

from typing import Optional, Tuple, List

from cugraph.utilities.utils import import_optional, MissingModule

torch = import_optional("torch")
torch_geometric = import_optional("torch_geometric")
tensordict = import_optional("tensordict")


class TensorDictFeatureStore(
object
if isinstance(torch_geometric, MissingModule)
else torch_geometric.data.FeatureStore
):
"""
A basic implementation of the PyG FeatureStore interface that stores
feature data in a single TensorDict. This type of feature store is
not distributed, so each node will have to load the entire graph's
features into memory.
"""

def __init__(self):
super().__init__()

self.__features = {}

def _put_tensor(
self,
tensor: "torch_geometric.typing.FeatureTensorType",
attr: "torch_geometric.data.feature_store.TensorAttr",
) -> bool:
if attr.group_name in self.__features:
td = self.__features[attr.group_name]
batch_size = td.batch_size[0]

if attr.is_set("index"):
if attr.attr_name in td.keys():
if attr.index.shape[0] != batch_size:
raise ValueError(
"Leading size of index tensor "
"does not match existing tensors for group name "
f"{attr.group_name}; Expected {batch_size}, "
f"got {attr.index.shape[0]}"
)
td[attr.attr_name][attr.index] = tensor
return True
else:
warnings.warn(
"Ignoring index parameter "
f"(attribute does not exist for group {attr.group_name})"
)

if tensor.shape[0] != batch_size:
raise ValueError(
"Leading size of input tensor does not match "
f"existing tensors for group name {attr.group_name};"
f" Expected {batch_size}, got {tensor.shape[0]}"
)
else:
batch_size = tensor.shape[0]
self.__features[attr.group_name] = tensordict.TensorDict(
{}, batch_size=batch_size
)

self.__features[attr.group_name][attr.attr_name] = tensor
return True

def _get_tensor(
self, attr: "torch_geometric.data.feature_store.TensorAttr"
) -> Optional["torch_geometric.typing.FeatureTensorType"]:
if attr.group_name not in self.__features:
return None

if attr.attr_name not in self.__features[attr.group_name].keys():
return None

tensor = self.__features[attr.group_name][attr.attr_name]
return (
tensor
if (attr.index is None or (not attr.is_set("index")))
else tensor[attr.index]
)

def _remove_tensor(
self, attr: "torch_geometric.data.feature_store.TensorAttr"
) -> bool:
if attr.group_name not in self.__features:
return False

if attr.attr_name not in self.__features[attr.group_name].keys():
return False

del self.__features[attr.group_name][attr.attr_name]
return True

def _get_tensor_size(
self, attr: "torch_geometric.data.feature_store.TensorAttr"
) -> Tuple:
return self._get_tensor(attr).size()

def get_all_tensor_attrs(
self,
) -> List["torch_geometric.data.feature_store.TensorAttr"]:
attrs = []
for group_name, td in self.__features.items():
for attr_name in td.keys():
attrs.append(
torch_geometric.data.feature_store.TensorAttr(
group_name,
attr_name,
)
)

return attrs
Loading

0 comments on commit 797a036

Please sign in to comment.