-
Notifications
You must be signed in to change notification settings - Fork 304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow models to use a lightweight sparse structure #3782
Changes from 4 commits
f47c163
2d7bdac
c5815be
7102fe0
15878e5
6c7b147
2c2a129
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,14 +11,15 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Optional, Tuple | ||
|
||
from cugraph.utilities.utils import import_optional | ||
|
||
torch = import_optional("torch") | ||
nn = import_optional("torch.nn") | ||
ops_torch = import_optional("pylibcugraphops.pytorch") | ||
|
||
|
||
class BaseConv(nn.Module): | ||
class BaseConv(torch.nn.Module): | ||
r"""An abstract base class for cugraph-ops nn module.""" | ||
|
||
def __init__(self): | ||
|
@@ -48,3 +49,85 @@ def pad_offsets(self, offsets: torch.Tensor, size: int) -> torch.Tensor: | |
self._cached_offsets_fg[offsets.numel() : size] = offsets[-1] | ||
|
||
return self._cached_offsets_fg[:size] | ||
|
||
|
||
class SparseGraph(object): | ||
r"""A god-class to store different sparse formats needed by cugraph-ops | ||
and facilitate sparse format conversions. | ||
|
||
Parameters | ||
---------- | ||
size: tuple of int | ||
Size of the adjacency matrix: (num_src_nodes, num_dst_nodes). | ||
|
||
src_ids: torch.Tensor | ||
Source indices of the edges. | ||
|
||
dst_ids: torch.Tensor, optional | ||
Destination indices of the edges. | ||
|
||
csrc_ids: torch.Tensor, optional | ||
Compressed source indices. It is a monotonically increasing array of | ||
size (num_src_nodes + 1,). For the k-th source node, its neighborhood | ||
consists of the destinations between `dst_indices[csrc_indices[k]]` and | ||
`dst_indices[csrc_indices[k+1]]`. | ||
Comment on lines
+80
to
+84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a question regarding when Lets look at below case: cdst_ids (Compressed Destinations Indices):0,2,5,7 src_indices: 1,2,2,3,4,4,5 I believe following will work (please correct me): num_src_nodes = 6
num_dst_nodes = 3 And i guess below will fail ((please correct me): num_src_nodes = 6
num_dst_nodes = 5 # Modified it to a higher value to ensure alignment for output nodes that are missing Question: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be illegal when
In your example with num_src_nodes = 6, num_dst_nodes = 3, this translates to a COO of With num_src_nodes = 6, num_dst_nodes = 5, the constructor should have failed, unless There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Yup, this is what i was expecting. We will just make sure that the changes @seunghwak from cugraph sampling ensures that all the MFGs line up. |
||
|
||
cdst_ids: torch.Tensor, optional | ||
Compressed destination indices. It is a monotonically increasing array of | ||
size (num_dst_nodes + 1,). For the k-th destination node, its neighborhood | ||
consists of the sources between `src_indices[cdst_indices[k]]` and | ||
`src_indices[cdst_indices[k+1]]`. | ||
|
||
dst_ids_is_sorted: bool | ||
Whether `dst_ids` has been sorted in an ascending order. | ||
|
||
Notes | ||
----- | ||
COO-format requires `src_ids` and `dst_ids`. | ||
CSC-format requires `cdst_ids` and `src_ids`. | ||
CSR-format requires `csrc_ids` and `dst_ids`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should force user to provide the format requirement to prevent confusion. Like add a format variable something like Then we can raise errors according to the input what the user provided. Also, i dont like |
||
|
||
For MFGs (sampled graphs), the node ids must have been renumbered. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
size: Tuple[int, int], | ||
src_ids: torch.Tensor, | ||
dst_ids: Optional[torch.Tensor] = None, | ||
csrc_ids: Optional[torch.Tensor] = None, | ||
cdst_ids: Optional[torch.Tensor] = None, | ||
dst_ids_is_sorted: bool = False, | ||
): | ||
if dst_ids is None and cdst_ids is None: | ||
raise ValueError("One of 'dst_ids' and 'cdst_ids' must be given.") | ||
|
||
if src_ids is not None: | ||
src_ids = src_ids.contiguous() | ||
if dst_ids is not None: | ||
dst_ids = dst_ids.contiguous() | ||
if csrc_ids is not None: | ||
csrc_ids = csrc_ids.contiguous() | ||
if cdst_ids is not None: | ||
cdst_ids = cdst_ids.contiguous() | ||
|
||
self._src_ids = src_ids | ||
self._dst_ids = dst_ids | ||
self._csrc_ids = csrc_ids | ||
self._cdst_ids = cdst_ids | ||
self.num_src_nodes, self.num_dst_nodes = size | ||
|
||
# Force create CSC format. | ||
if self._cdst_ids is None: | ||
if not dst_ids_is_sorted: | ||
self._dst_ids, self._perm = torch.sort(self._dst_ids) | ||
self._src_ids = self._src_ids[self._perm] | ||
self._cdst_ids = torch._convert_indices_from_coo_to_csr( | ||
self._dst_ids, | ||
self.num_dst_nodes, | ||
out_int32=self._dst_ids.dtype == torch.int32, | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think we should remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if we are forcing We can probably borrow the convention from We wont follow their default of See formats docs . There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we discussed via slack, we will provide |
||
def csc(self) -> Tuple[torch.Tensor, torch.Tensor]: | ||
r"""Return CSC format.""" | ||
return (self._cdst_ids, self._src_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should expose this in https://github.com/rapidsai/cugraph/blob/7102fe0a657946f3262894180acd10ef7587b121/python/cugraph-dgl/cugraph_dgl/nn/conv/__init__.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a user-facing class. It is only used to handle the case where we fall back to full graph variant. In addition, with the recent cugraph-ops refactoring disabling the MFG-variant, we might totally remove this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I am sorry, I left the review at wrong line. I meant SparseGraph class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved