Skip to content

Commit

Permalink
add lightweight sparse tensor wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed Aug 11, 2023
1 parent c55151c commit f47c163
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 20 deletions.
50 changes: 49 additions & 1 deletion python/cugraph-dgl/cugraph_dgl/nn/conv/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

from cugraph.utilities.utils import import_optional

torch = import_optional("torch")
from pyg_lib.ops import index_sort

# torch = import_optional("torch")
import torch

nn = import_optional("torch.nn")
ops_torch = import_optional("pylibcugraphops.pytorch")

Expand Down Expand Up @@ -48,3 +54,45 @@ def pad_offsets(self, offsets: torch.Tensor, size: int) -> torch.Tensor:
self._cached_offsets_fg[offsets.numel() : size] = offsets[-1]

return self._cached_offsets_fg[:size]


class SparseGraph(object):
r"""A thin wrapper to facilitate sparse format conversion.
Parameters
----------
src_ids: torch.Tensor
Tensor of source indices.
dst_ids: torch.Tensor
Tensor of destination indices.
shape: Tuple of int
Shape of the adjacency matrix.
is_sort: bool
Whether `dst_ids` has been sorted in an ascending order.
"""

def __init__(
self,
src_ids: torch.Tensor,
dst_ids: torch.Tensor,
shape: Tuple[int, int],
is_sorted: bool = False,
):
self.row = src_ids.contiguous()
self.col = dst_ids.contiguous()
self.num_src_nodes, self.num_dst_nodes = shape
self.rowptr = None
self.colptr = None
self.perm = None

if not is_sorted:
self.col, self.perm = index_sort(self.col, max_value=self.num_dst_nodes)
self.row = self.row[self.perm]

def to_csc(self) -> Tuple[torch.Tensor, torch.Tensor, int]:
if self.colptr is None:
self.colptr = torch._convert_indices_from_coo_to_csr(
self.col, self.num_dst_nodes, out_int32=self.col.dtype == torch.int32
)

return (self.row, self.colptr, self.num_src_nodes)
46 changes: 27 additions & 19 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/sageconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
cugraph-ops"""
# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments
from __future__ import annotations
from typing import Optional
from typing import Optional, Union

from cugraph_dgl.nn.conv.base import BaseConv
from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph
from cugraph.utilities.utils import import_optional

dgl = import_optional("dgl")
Expand Down Expand Up @@ -98,7 +98,7 @@ def reset_parameters(self):

def forward(
self,
g: dgl.DGLHeteroGraph,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
feat: torch.Tensor,
max_in_degree: Optional[int] = None,
) -> torch.Tensor:
Expand All @@ -122,26 +122,34 @@ def forward(
torch.Tensor
Output node features. Shape: :math:`(|V|, D_{out})`.
"""
offsets, indices, _ = g.adj_tensors("csc")

if g.is_block:
if max_in_degree is None:
max_in_degree = g.in_degrees().max().item()

if max_in_degree < self.MAX_IN_DEGREE_MFG:
_graph = ops_torch.SampledCSC(
offsets, indices, max_in_degree, g.num_src_nodes()
)
if isinstance(g, dgl.DGLHeteroGraph):
offsets, indices, _ = g.adj_tensors("csc")

if g.is_block:
if max_in_degree is None:
max_in_degree = g.in_degrees().max().item()

if max_in_degree < self.MAX_IN_DEGREE_MFG:
_graph = ops_torch.SampledCSC(
offsets, indices, max_in_degree, g.num_src_nodes()
)
else:
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
_graph = ops_torch.StaticCSC(offsets_fg, indices)
else:
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
_graph = ops_torch.StaticCSC(offsets_fg, indices)
_graph = ops_torch.StaticCSC(offsets, indices)

num_dst_nodes = g.num_dst_nodes()
else:
_graph = ops_torch.StaticCSC(offsets, indices)
assert isinstance(g, SparseGraph)
indices, offsets, num_src_nodes = g.to_csc()
_graph = ops_torch.CSC(
offsets=offsets, indices=indices, num_src_nodes=num_src_nodes
)
num_dst_nodes = g.num_dst_nodes

feat = self.feat_drop(feat)
h = ops_torch.operators.agg_concat_n2n(feat, _graph, self.aggr)[
: g.num_dst_nodes()
]
h = ops_torch.operators.agg_concat_n2n(feat, _graph, self.aggr)[:num_dst_nodes]
h = self.linear(h)

return h

0 comments on commit f47c163

Please sign in to comment.