This repository has been archived by the owner on Sep 21, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_graph.py
49 lines (42 loc) · 1.77 KB
/
load_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import dgl
import torch as th
def load_reddit():
from dgl.data import RedditDataset
# load reddit data
data = RedditDataset(self_loop=True)
g = data[0]
g.ndata['features'] = g.ndata['feat']
g.ndata['labels'] = g.ndata['label']
return g, data.num_classes
def load_ogb(name):
from ogb.nodeproppred import DglNodePropPredDataset
print('load', name)
data = DglNodePropPredDataset(name=name)
print('finish loading', name)
splitted_idx = data.get_idx_split()
graph, labels = data[0]
labels = labels[:, 0]
graph.ndata['features'] = graph.ndata['feat']
graph.ndata['labels'] = labels
in_feats = graph.ndata['features'].shape[1]
num_labels = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
# Find the node IDs in the training, validation, and test set.
train_nid, val_nid, test_nid = splitted_idx['train'], splitted_idx['valid'], splitted_idx['test']
train_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
train_mask[train_nid] = True
val_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
val_mask[val_nid] = True
test_mask = th.zeros((graph.number_of_nodes(),), dtype=th.bool)
test_mask[test_nid] = True
graph.ndata['train_mask'] = train_mask
graph.ndata['val_mask'] = val_mask
graph.ndata['test_mask'] = test_mask
print('finish constructing', name)
return graph, num_labels
def inductive_split(g):
"""Split the graph into training graph, validation graph, and test graph by training
and validation masks. Suitable for inductive models."""
train_g = g.subgraph(g.ndata['train_mask'])
val_g = g.subgraph(g.ndata['train_mask'] | g.ndata['val_mask'])
test_g = g
return train_g, val_g, test_g