Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow models to use a lightweight sparse structure #3782

Merged
merged 7 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import SparseGraph
from .gatconv import GATConv
from .relgraphconv import RelGraphConv
from .sageconv import SAGEConv
from .transformerconv import TransformerConv

__all__ = [
"SparseGraph",
"GATConv",
"RelGraphConv",
"SAGEConv",
Expand Down
172 changes: 170 additions & 2 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union

from cugraph.utilities.utils import import_optional

torch = import_optional("torch")
nn = import_optional("torch.nn")
ops_torch = import_optional("pylibcugraphops.pytorch")


class BaseConv(nn.Module):
class BaseConv(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a user-facing class. It is only used to handle the case where we fall back to full graph variant. In addition, with the recent cugraph-ops refactoring disabling the MFG-variant, we might totally remove this class.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I am sorry, I left the review at wrong line. I meant SparseGraph class

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved

r"""An abstract base class for cugraph-ops nn module."""

def __init__(self):
Expand Down Expand Up @@ -48,3 +49,170 @@ def pad_offsets(self, offsets: torch.Tensor, size: int) -> torch.Tensor:
self._cached_offsets_fg[offsets.numel() : size] = offsets[-1]

return self._cached_offsets_fg[:size]


def compress_ids(ids: torch.Tensor, size: int) -> torch.Tensor:
return torch._convert_indices_from_coo_to_csr(
ids, size, out_int32=ids.dtype == torch.int32
)


def decompress_ids(c_ids: torch.Tensor) -> torch.Tensor:
ids = torch.arange(c_ids.numel() - 1, dtype=c_ids.dtype, device=c_ids.device)
return ids.repeat_interleave(c_ids[1:] - c_ids[:-1])


class SparseGraph(object):
r"""A god-class to store different sparse formats needed by cugraph-ops
and facilitate sparse format conversions.

Parameters
----------
size: tuple of int
Size of the adjacency matrix: (num_src_nodes, num_dst_nodes).

src_ids: torch.Tensor
Source indices of the edges.

dst_ids: torch.Tensor, optional
Destination indices of the edges.

csrc_ids: torch.Tensor, optional
Compressed source indices. It is a monotonically increasing array of
size (num_src_nodes + 1,). For the k-th source node, its neighborhood
consists of the destinations between `dst_indices[csrc_indices[k]]` and
`dst_indices[csrc_indices[k+1]]`.
Comment on lines +80 to +84
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question regarding when num_dst_nodes > len(cdst_ids)-1 for this case.

Lets look at below case:

cdst_ids (Compressed Destinations Indices):0,2,5,7
src_indices: 1,2,2,3,4,4,5

I believe following will work (please correct me):

num_src_nodes = 6
num_dst_nodes = 3 

And i guess below will fail ((please correct me):

num_src_nodes = 6
num_dst_nodes = 5 # Modified it to a higher value to ensure alignment for output nodes that are missing

Question:
So this will have to handled by ensuring correct creation because we want to handle alignment problem b/w blocks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be illegal when num_dst_nodes != len(cdst_ids)-1. I will improve the error handling in this case. For example, pyg does lots of assertions to check the size. We should throw proper exceptions.

cdst_ids (Compressed Destinations Indices):0,2,5,7
src_indices: 1,2,2,3,4,4,5

In your example with num_src_nodes = 6, num_dst_nodes = 3, this translates to a COO of
(1,2,2,3,4,4,5)
(0,0,1,1,1,2,2)

With num_src_nodes = 6, num_dst_nodes = 5, the constructor should have failed, unless cdst_ids is augmented (cdst_ids = 0,2,5,7,7,7).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Yup, this is what i was expecting. We will just make sure that the changes @seunghwak from cugraph sampling ensures that all the MFGs line up.


cdst_ids: torch.Tensor, optional
Compressed destination indices. It is a monotonically increasing array of
size (num_dst_nodes + 1,). For the k-th destination node, its neighborhood
consists of the sources between `src_indices[cdst_indices[k]]` and
`src_indices[cdst_indices[k+1]]`.

dst_ids_is_sorted: bool
Whether `dst_ids` has been sorted in an ascending order. When sorted,
creating CSC layout is much faster.

formats: str or tuple of str, optional
The desired sparse formats to create for the graph.

reduce_memory: bool, optional
When set, the tensors are not required by the desired formats will be
set to `None`.

Notes
-----
For MFGs (sampled graphs), the node ids must have been renumbered.
"""

supported_formats = {"coo": ("src_ids", "dst_ids"), "csc": ("cdst_ids", "src_ids")}

all_tensors = set(["src_ids", "dst_ids", "csrc_ids", "cdst_ids"])

def __init__(
self,
size: Tuple[int, int],
src_ids: torch.Tensor,
dst_ids: Optional[torch.Tensor] = None,
csrc_ids: Optional[torch.Tensor] = None,
cdst_ids: Optional[torch.Tensor] = None,
dst_ids_is_sorted: bool = False,
formats: Optional[Union[str, Tuple[str]]] = None,
reduce_memory: bool = True,
):
self._num_src_nodes, self._num_dst_nodes = size
self._dst_ids_is_sorted = dst_ids_is_sorted

if dst_ids is None and cdst_ids is None:
raise ValueError("One of 'dst_ids' and 'cdst_ids' must be given.")

if src_ids is not None:
src_ids = src_ids.contiguous()

if dst_ids is not None:
dst_ids = dst_ids.contiguous()

if csrc_ids is not None:
if csrc_ids.numel() != self._num_src_nodes + 1:
raise RuntimeError(
f"Size mismatch for 'csrc_ids': expected ({size[0]+1},), "
f"but got {tuple(csrc_ids.size())}"
)
csrc_ids = csrc_ids.contiguous()

if cdst_ids is not None:
if cdst_ids.numel() != self._num_dst_nodes + 1:
raise RuntimeError(
f"Size mismatch for 'cdst_ids': expected ({size[1]+1},), "
f"but got {tuple(cdst_ids.size())}"
)
cdst_ids = cdst_ids.contiguous()

self._src_ids = src_ids
self._dst_ids = dst_ids
self._csrc_ids = csrc_ids
self._cdst_ids = cdst_ids
self._perm = None

if isinstance(formats, str):
formats = (formats,)

if formats is not None:
for format_ in formats:
assert format_ in SparseGraph.supported_formats
self.__getattribute__(f"_create_{format_}")()
self._formats = formats

self._reduce_memory = reduce_memory
if reduce_memory:
self.reduce_memory()

def reduce_memory(self):
"""Remove the tensors that are not necessary to create the desired sparse
formats to reduce memory footprint."""

self._perm = None
if self._formats is None:
return

tensors_needed = []
for f in self._formats:
tensors_needed += SparseGraph.supported_formats[f]
for t in SparseGraph.all_tensors.difference(set(tensors_needed)):
self.__dict__[t] = None

def _create_coo(self):
if self._dst_ids is None:
self._dst_ids = decompress_ids(self._cdst_ids)

def _create_csc(self):
if self._cdst_ids is None:
if not self._dst_ids_is_sorted:
self._dst_ids, self._perm = torch.sort(self._dst_ids)
self._src_ids = self._src_ids[self._perm]
self._cdst_ids = compress_ids(self._dst_ids, self._num_dst_nodes)

def num_src_nodes(self):
return self._num_src_nodes

def num_dst_nodes(self):
return self._num_dst_nodes

def formats(self):
return self._formats

def coo(self) -> Tuple[torch.Tensor, torch.Tensor]:
if "coo" not in self.formats():
raise RuntimeError(
"The SparseGraph did not create a COO layout. "
"Set 'formats' to include 'coo' when creating the graph."
)
return (self._src_ids, self._dst_ids)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should remove dst_ids if we are forcing CSC conversion because that will mean memory overhead of maintaining it always ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we are forcing csc conversions for now but in future we may want to expand to other formats right, I think we might want to either have this configurable via a class variable.

We can probably borrow the convention from formats.

We wont follow their default of 'coo' -> 'csr' -> 'csc', but have our own version.

See formats docs .

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed via slack, we will provide input_format and output_format to help specify which tensor is needed.

def csc(self) -> Tuple[torch.Tensor, torch.Tensor]:
if "csc" not in self.formats():
raise RuntimeError(
"The SparseGraph did not create a CSC layout. "
"Set 'formats' to include 'csc' when creating the graph."
)
return (self._cdst_ids, self._src_ids)
55 changes: 32 additions & 23 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/sageconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
cugraph-ops"""
# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments
from __future__ import annotations
from typing import Optional
from typing import Optional, Union

from cugraph_dgl.nn.conv.base import BaseConv
from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph
from cugraph.utilities.utils import import_optional

dgl = import_optional("dgl")
Expand Down Expand Up @@ -98,45 +98,54 @@ def reset_parameters(self):

def forward(
self,
g: dgl.DGLHeteroGraph,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
feat: torch.Tensor,
max_in_degree: Optional[int] = None,
) -> torch.Tensor:
r"""Forward computation.

Parameters
----------
g : DGLGraph
g : DGLGraph or SparseGraph
The graph.
feat : torch.Tensor
Node features. Shape: :math:`(|V|, D_{in})`.
max_in_degree : int
Maximum in-degree of destination nodes. It is only effective when
:attr:`g` is a :class:`DGLBlock`, i.e., bipartite graph. When
:attr:`g` is generated from a neighbor sampler, the value should be
set to the corresponding :attr:`fanout`. If not given,
:attr:`max_in_degree` will be calculated on-the-fly.
Maximum in-degree of destination nodes. When :attr:`g` is generated
from a neighbor sampler, the value should be set to the corresponding
:attr:`fanout`. This option is used to invoke the MFG-variant of
cugraph-ops kernel.
tingyu66 marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
torch.Tensor
Output node features. Shape: :math:`(|V|, D_{out})`.
"""
offsets, indices, _ = g.adj_tensors("csc")

if g.is_block:
if max_in_degree is None:
max_in_degree = g.in_degrees().max().item()

if max_in_degree < self.MAX_IN_DEGREE_MFG:
_graph = ops_torch.SampledCSC(
offsets, indices, max_in_degree, g.num_src_nodes()
)
else:
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
_graph = ops_torch.StaticCSC(offsets_fg, indices)
if max_in_degree is None:
max_in_degree = -1

if isinstance(g, SparseGraph):
assert "csc" in g.formats()
offsets, indices = g.csc()
_graph = ops_torch.CSC(
offsets=offsets,
indices=indices,
num_src_nodes=g.num_src_nodes(),
dst_max_in_degree=max_in_degree,
)
elif isinstance(g, dgl.DGLHeteroGraph):
offsets, indices, _ = g.adj_tensors("csc")
_graph = ops_torch.CSC(
offsets=offsets,
indices=indices,
num_src_nodes=g.num_src_nodes(),
dst_max_in_degree=max_in_degree,
)
else:
_graph = ops_torch.StaticCSC(offsets, indices)
raise TypeError(
f"The graph has to be either a 'SparseGraph' or "
f"'dgl.DGLHeteroGraph', but got '{type(g)}'."
)

feat = self.feat_drop(feat)
h = ops_torch.operators.agg_concat_n2n(feat, _graph, self.aggr)[
Expand Down
24 changes: 24 additions & 0 deletions python/cugraph-dgl/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

import pytest

import torch

from cugraph.testing.mg_utils import (
start_dask_client,
stop_dask_client,
Expand All @@ -31,3 +33,25 @@ def dask_client():
yield dask_client

stop_dask_client(dask_client, dask_cluster)


class SparseGraphData1:
size = (6, 5)
nnz = 6
src_ids = torch.IntTensor([0, 1, 2, 3, 2, 5]).cuda()
dst_ids = torch.IntTensor([1, 2, 3, 4, 0, 3]).cuda()

# CSR
src_ids_sorted_by_src = torch.IntTensor([0, 1, 2, 2, 3, 5]).cuda()
dst_ids_sorted_by_src = torch.IntTensor([1, 2, 0, 3, 4, 3]).cuda()
csrc_ids = torch.IntTensor([0, 1, 2, 4, 5, 5, 6]).cuda()

# CSC
src_ids_sorted_by_dst = torch.IntTensor([2, 0, 1, 5, 2, 3]).cuda()
dst_ids_sorted_by_dst = torch.IntTensor([0, 1, 2, 3, 3, 4]).cuda()
cdst_ids = torch.IntTensor([0, 1, 2, 3, 5, 6]).cuda()


@pytest.fixture
def sparse_graph_1():
return SparseGraphData1()
27 changes: 19 additions & 8 deletions python/cugraph-dgl/tests/nn/test_sageconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@

import pytest

try:
import cugraph_dgl
except ModuleNotFoundError:
pytest.skip("cugraph_dgl not available", allow_module_level=True)

from cugraph.utilities.utils import import_optional
from cugraph_dgl.nn.conv.base import SparseGraph
from cugraph_dgl.nn import SAGEConv as CuGraphSAGEConv
from .common import create_graph1

torch = import_optional("torch")
Expand All @@ -30,20 +27,31 @@
@pytest.mark.parametrize("idtype_int", [False, True])
@pytest.mark.parametrize("max_in_degree", [None, 8])
@pytest.mark.parametrize("to_block", [False, True])
def test_SAGEConv_equality(bias, idtype_int, max_in_degree, to_block):
@pytest.mark.parametrize("sparse_format", ["coo", "csc", None])
def test_SAGEConv_equality(bias, idtype_int, max_in_degree, to_block, sparse_format):
SAGEConv = dgl.nn.SAGEConv
CuGraphSAGEConv = cugraph_dgl.nn.SAGEConv
device = "cuda"

in_feat, out_feat = 5, 2
kwargs = {"aggregator_type": "mean", "bias": bias}
g = create_graph1().to(device)

if idtype_int:
g = g.int()
if to_block:
g = dgl.to_block(g)

size = (g.num_src_nodes(), g.num_dst_nodes())
feat = torch.rand(g.num_src_nodes(), in_feat).to(device)

if sparse_format == "coo":
sg = SparseGraph(
size=size, src_ids=g.edges()[0], dst_ids=g.edges()[1], formats="csc"
)
elif sparse_format == "csc":
offsets, indices, _ = g.adj_tensors("csc")
sg = SparseGraph(size=size, src_ids=indices, cdst_ids=offsets, formats="csc")

torch.manual_seed(0)
conv1 = SAGEConv(in_feat, out_feat, **kwargs).to(device)

Expand All @@ -57,7 +65,10 @@ def test_SAGEConv_equality(bias, idtype_int, max_in_degree, to_block):
conv2.linear.bias.data[:] = conv1.fc_self.bias.data

out1 = conv1(g, feat)
out2 = conv2(g, feat, max_in_degree=max_in_degree)
if sparse_format is not None:
out2 = conv2(sg, feat, max_in_degree=max_in_degree)
else:
out2 = conv2(g, feat, max_in_degree=max_in_degree)
assert torch.allclose(out1, out2, atol=1e-06)

grad_out = torch.rand_like(out1)
Expand Down
Loading