-
Notifications
You must be signed in to change notification settings - Fork 304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow models to use a lightweight sparse structure #3782
Changes from all commits
f47c163
2d7bdac
c5815be
7102fe0
15878e5
6c7b147
2c2a129
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,14 +11,15 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Optional, Tuple, Union | ||
|
||
from cugraph.utilities.utils import import_optional | ||
|
||
torch = import_optional("torch") | ||
nn = import_optional("torch.nn") | ||
ops_torch = import_optional("pylibcugraphops.pytorch") | ||
|
||
|
||
class BaseConv(nn.Module): | ||
class BaseConv(torch.nn.Module): | ||
r"""An abstract base class for cugraph-ops nn module.""" | ||
|
||
def __init__(self): | ||
|
@@ -48,3 +49,170 @@ def pad_offsets(self, offsets: torch.Tensor, size: int) -> torch.Tensor: | |
self._cached_offsets_fg[offsets.numel() : size] = offsets[-1] | ||
|
||
return self._cached_offsets_fg[:size] | ||
|
||
|
||
def compress_ids(ids: torch.Tensor, size: int) -> torch.Tensor: | ||
return torch._convert_indices_from_coo_to_csr( | ||
ids, size, out_int32=ids.dtype == torch.int32 | ||
) | ||
|
||
|
||
def decompress_ids(c_ids: torch.Tensor) -> torch.Tensor: | ||
ids = torch.arange(c_ids.numel() - 1, dtype=c_ids.dtype, device=c_ids.device) | ||
return ids.repeat_interleave(c_ids[1:] - c_ids[:-1]) | ||
|
||
|
||
class SparseGraph(object): | ||
r"""A god-class to store different sparse formats needed by cugraph-ops | ||
and facilitate sparse format conversions. | ||
|
||
Parameters | ||
---------- | ||
size: tuple of int | ||
Size of the adjacency matrix: (num_src_nodes, num_dst_nodes). | ||
|
||
src_ids: torch.Tensor | ||
Source indices of the edges. | ||
|
||
dst_ids: torch.Tensor, optional | ||
Destination indices of the edges. | ||
|
||
csrc_ids: torch.Tensor, optional | ||
Compressed source indices. It is a monotonically increasing array of | ||
size (num_src_nodes + 1,). For the k-th source node, its neighborhood | ||
consists of the destinations between `dst_indices[csrc_indices[k]]` and | ||
`dst_indices[csrc_indices[k+1]]`. | ||
Comment on lines
+80
to
+84
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a question regarding when Lets look at below case: cdst_ids (Compressed Destinations Indices):0,2,5,7 src_indices: 1,2,2,3,4,4,5 I believe following will work (please correct me): num_src_nodes = 6
num_dst_nodes = 3 And i guess below will fail ((please correct me): num_src_nodes = 6
num_dst_nodes = 5 # Modified it to a higher value to ensure alignment for output nodes that are missing Question: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be illegal when
In your example with num_src_nodes = 6, num_dst_nodes = 3, this translates to a COO of With num_src_nodes = 6, num_dst_nodes = 5, the constructor should have failed, unless There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Yup, this is what i was expecting. We will just make sure that the changes @seunghwak from cugraph sampling ensures that all the MFGs line up. |
||
|
||
cdst_ids: torch.Tensor, optional | ||
Compressed destination indices. It is a monotonically increasing array of | ||
size (num_dst_nodes + 1,). For the k-th destination node, its neighborhood | ||
consists of the sources between `src_indices[cdst_indices[k]]` and | ||
`src_indices[cdst_indices[k+1]]`. | ||
|
||
dst_ids_is_sorted: bool | ||
Whether `dst_ids` has been sorted in an ascending order. When sorted, | ||
creating CSC layout is much faster. | ||
|
||
formats: str or tuple of str, optional | ||
The desired sparse formats to create for the graph. | ||
|
||
reduce_memory: bool, optional | ||
When set, the tensors are not required by the desired formats will be | ||
set to `None`. | ||
|
||
Notes | ||
----- | ||
For MFGs (sampled graphs), the node ids must have been renumbered. | ||
""" | ||
|
||
supported_formats = {"coo": ("src_ids", "dst_ids"), "csc": ("cdst_ids", "src_ids")} | ||
|
||
all_tensors = set(["src_ids", "dst_ids", "csrc_ids", "cdst_ids"]) | ||
|
||
def __init__( | ||
self, | ||
size: Tuple[int, int], | ||
src_ids: torch.Tensor, | ||
dst_ids: Optional[torch.Tensor] = None, | ||
csrc_ids: Optional[torch.Tensor] = None, | ||
cdst_ids: Optional[torch.Tensor] = None, | ||
dst_ids_is_sorted: bool = False, | ||
formats: Optional[Union[str, Tuple[str]]] = None, | ||
reduce_memory: bool = True, | ||
): | ||
self._num_src_nodes, self._num_dst_nodes = size | ||
self._dst_ids_is_sorted = dst_ids_is_sorted | ||
|
||
if dst_ids is None and cdst_ids is None: | ||
raise ValueError("One of 'dst_ids' and 'cdst_ids' must be given.") | ||
|
||
if src_ids is not None: | ||
src_ids = src_ids.contiguous() | ||
|
||
if dst_ids is not None: | ||
dst_ids = dst_ids.contiguous() | ||
|
||
if csrc_ids is not None: | ||
if csrc_ids.numel() != self._num_src_nodes + 1: | ||
raise RuntimeError( | ||
f"Size mismatch for 'csrc_ids': expected ({size[0]+1},), " | ||
f"but got {tuple(csrc_ids.size())}" | ||
) | ||
csrc_ids = csrc_ids.contiguous() | ||
|
||
if cdst_ids is not None: | ||
if cdst_ids.numel() != self._num_dst_nodes + 1: | ||
raise RuntimeError( | ||
f"Size mismatch for 'cdst_ids': expected ({size[1]+1},), " | ||
f"but got {tuple(cdst_ids.size())}" | ||
) | ||
cdst_ids = cdst_ids.contiguous() | ||
|
||
self._src_ids = src_ids | ||
self._dst_ids = dst_ids | ||
self._csrc_ids = csrc_ids | ||
self._cdst_ids = cdst_ids | ||
self._perm = None | ||
|
||
if isinstance(formats, str): | ||
formats = (formats,) | ||
|
||
if formats is not None: | ||
for format_ in formats: | ||
assert format_ in SparseGraph.supported_formats | ||
self.__getattribute__(f"_create_{format_}")() | ||
self._formats = formats | ||
|
||
self._reduce_memory = reduce_memory | ||
if reduce_memory: | ||
self.reduce_memory() | ||
|
||
def reduce_memory(self): | ||
"""Remove the tensors that are not necessary to create the desired sparse | ||
formats to reduce memory footprint.""" | ||
|
||
self._perm = None | ||
if self._formats is None: | ||
return | ||
|
||
tensors_needed = [] | ||
for f in self._formats: | ||
tensors_needed += SparseGraph.supported_formats[f] | ||
for t in SparseGraph.all_tensors.difference(set(tensors_needed)): | ||
self.__dict__[t] = None | ||
|
||
def _create_coo(self): | ||
if self._dst_ids is None: | ||
self._dst_ids = decompress_ids(self._cdst_ids) | ||
|
||
def _create_csc(self): | ||
if self._cdst_ids is None: | ||
if not self._dst_ids_is_sorted: | ||
self._dst_ids, self._perm = torch.sort(self._dst_ids) | ||
self._src_ids = self._src_ids[self._perm] | ||
self._cdst_ids = compress_ids(self._dst_ids, self._num_dst_nodes) | ||
|
||
def num_src_nodes(self): | ||
return self._num_src_nodes | ||
|
||
def num_dst_nodes(self): | ||
return self._num_dst_nodes | ||
|
||
def formats(self): | ||
return self._formats | ||
|
||
def coo(self) -> Tuple[torch.Tensor, torch.Tensor]: | ||
if "coo" not in self.formats(): | ||
raise RuntimeError( | ||
"The SparseGraph did not create a COO layout. " | ||
"Set 'formats' to include 'coo' when creating the graph." | ||
) | ||
return (self._src_ids, self._dst_ids) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think we should remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if we are forcing We can probably borrow the convention from We wont follow their default of See formats docs . There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we discussed via slack, we will provide |
||
def csc(self) -> Tuple[torch.Tensor, torch.Tensor]: | ||
if "csc" not in self.formats(): | ||
raise RuntimeError( | ||
"The SparseGraph did not create a CSC layout. " | ||
"Set 'formats' to include 'csc' when creating the graph." | ||
) | ||
return (self._cdst_ids, self._src_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should expose this in https://github.com/rapidsai/cugraph/blob/7102fe0a657946f3262894180acd10ef7587b121/python/cugraph-dgl/cugraph_dgl/nn/conv/__init__.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a user-facing class. It is only used to handle the case where we fall back to full graph variant. In addition, with the recent cugraph-ops refactoring disabling the MFG-variant, we might totally remove this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I am sorry, I left the review at wrong line. I meant SparseGraph class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved