Skip to content

Commit

Permalink
Merge branch 'PaddlePaddle:main' into pna
Browse files Browse the repository at this point in the history
  • Loading branch information
dongZheX committed Sep 28, 2022
2 parents 0779eaf + 28fc540 commit 5906457
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 12 deletions.
37 changes: 37 additions & 0 deletions examples/pna/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Principal Neighbourhood Aggregation for Graph Nets (PNA)

[Principal Neighbourhood Aggregation for Graph Nets \(PNA\)](https://arxiv.org/abs/2004.05718) is a graph learning model combining multiple aggregators with degree-scalers.


### Datasets

We perform graph classification experiment to reproduce paper results on [OGB](https://ogb.stanford.edu/).

### Dependencies

- paddlepaddle >= 2.2.0
- pgl >= 2.2.4

### How to run


```
python main.py --config config.yaml # train on ogbg-molhiv
python main.py --config config_pcba.yaml # train on ogbg-molpcba
```


### Important Hyperparameters

- aggregators: a list of aggregators name. ("mean", "sum", "max", "min", "var", "std")
- scalers: a list of scalers name. ("identity", "amplification", "attenuation", "linear", "inverse_linear")
- tower: The number of towers.
- divide_input: hether the input features should be split between towers or not.
- pre_layers: the number of MLP layers behind aggregators.
- post_layers: MLP layers after aggregator.

### Experiment results (ROC-AUC)
| | GIN | PNA(paper result) | PNA(ours)|
|-------------|----------|------------|-----------------|
|HIV | 0.7778 | 0.7905 | 0.7929 |
|PCBA | 0.2266 | 0.2838 | 0.2801 |
3 changes: 0 additions & 3 deletions examples/pna/config_pcba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,3 @@ log_filename: log.txt
save_dir: ./checkpoints
output_dir: ./outputs
files2saved: ["*.yaml", "*.py", "./utils"]

# python main_PCBA.py --type_net="complex" --batch_size=512 --lap_norm="none" --weight_decay=3e-6 --L=4 --hidden_dim=510 --out_dim=510 --residual=True
# --edge_feat=True --readout=sum --graph_norm=True --batch_norm=True --aggregators="mean sum max" --scalers="identity" --config "configs/molecules_graph_classification_DGN_PCBA.json" --lr_schedule_patience=4 --towers=5 --dropout=0.2 --init_lr=0.0005 --min_lr=0.00002 --edge_dim=16 --lr_reduce_factor=0.8
2 changes: 1 addition & 1 deletion examples/pna/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main(config):
num_workers=config.num_workers,
collate_fn=fn)

deg_hog = get_degree_histogram(train_loader)
deg_hog = paddle.to_tensor(get_degree_histogram(train_loader))

model = PNAModel(
config.hidden_size,
Expand Down
2 changes: 1 addition & 1 deletion examples/sag_pool/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ python main.py --use_cuda --dataset_name PROTEINS --lr 0.005 --batch_size 128 --

- data\_path: the root path of your dataset
- dataset\_name: the name of the dataset. ("MUTAG", "IMDBBINARY", "IMDBMULTI", "COLLAB", "PROTEINS", "NCI1", "PTC", "REDDITBINARY", "REDDITMULTI5K")
- fold\_idx: The $fold\_idx^{th}$ fold of dataset splited. Here we use 10 fold cross-validation
- fold\_idx: The $fold\_{idx}^{th}$ fold of dataset splited. Here we use 10 fold cross-validation
- min\_score: parameter for SAGPool which indicates minimal node score. (When min\_score is not None, pool\_ratio is ignored)
- pool\_ratio: parameter for SAGPool which decides how many nodes will be removed.

Expand Down
13 changes: 7 additions & 6 deletions pgl/nn/pna_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class PNAConv(nn.Layer):
scalers: (list): List of scaler function keyword,
choices in ["identity", "amplification",
"attenuation", "linear", "inverse_linear"]
deg (numpy array): Histogram of in-degrees of nodes in the training set for computed avg_deg for scalers
deg (Tensor): Histogram of in-degrees of nodes in the training set for computed avg_deg for scalers
towers (int, optional): Number of towers. Default: 1
pre_layers (int, optional): Number of transformation layers before
aggregation. Default: 1
Expand Down Expand Up @@ -77,12 +77,13 @@ def __init__(self,

deg = deg.astype("float32")
total_no_vertices = deg.sum()
bin_degrees = np.arange(len(deg)).astype("float32")
bin_degrees = paddle.arange(len(deg), dtype="float32")
self.avg_deg = {
'lin': ((bin_degrees * deg).sum() / total_no_vertices),
'log': ((np.log(
(bin_degrees + 1)) * deg).sum() / total_no_vertices),
'exp': ((np.exp(bin_degrees) * deg).sum() / total_no_vertices),
'lin': ((bin_degrees * deg).sum() / total_no_vertices).item(),
'log':
(((bin_degrees + 1).log() * deg).sum() / total_no_vertices).item(),
'exp': (
(bin_degrees.exp() * deg).sum() / total_no_vertices).item(),
}
if use_edge:
self.edge_mlp = paddle.nn.Linear(input_size, self.F_in)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_pna_conv(self):
"identity", "amplification", "attenuation", "linear",
"inverse_linear"
],
deg=np.asarray([0, 1, 1, 1, 2]),
deg=paddle.to_tensor([0, 1, 1, 1, 2]),
towers=2,
pre_layers=1,
post_layers=2,
Expand Down

0 comments on commit 5906457

Please sign in to comment.