-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathConfigures.py
82 lines (63 loc) · 2.44 KB
/
Configures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import torch
from typing import List
class DataParser():
def __init__(self):
super().__init__()
self.dataset_name = 'ogbg-molhiv' # ['MUTAG', 'DHFR', 'ogbg-molhiv', 'ogbg-molpcba', 'ZINC']
self.dataset_dir = './datasets'
self.random_split: bool = True
self.data_split_ratio: List = [120/756, 120/756, 516/756] # the ratio of training, validation and testing set for random split
self.seed = 1
self.imb = False # To control if dataset imbalance
self.imb_ratio = 0.1 # Imbalance Ratio
self.num_train = 150
self.num_val = 150
class ModelParser():
def __init__(self):
super().__init__()
self.device: int = 0
self.model_name: str = 'pna'
self.checkpoint: str = './checkpoint'
self.readout: 'str' = 'max' # the graph pooling method
self.enable_prot = True # whether to enable prototype training
self.num_prototypes_per_class = 1 # the num_prototypes_per_class
self.deg = None
self.edge_dim = None
self.pe_dim = 20
self.prot_dim = 32
self.single_target = True
self.mlp_out_dim = 0
def process_args(self) -> None:
# self.device = torch.device('cpu')
if torch.cuda.is_available():
self.device = torch.device('cuda', self.device_id)
else:
pass
class TrainParser():
def __init__(self):
super().__init__()
self.learning_rate = 0.001
self.batch_size = 24
self.weight_decay = 0
self.max_epochs = 200
self.save_epoch = 20
self.early_stopping = 50000
self.retain_ratio = 1.0 # retain ratio
self.pruning_epochs = 0 # pruning from epoch >= pruning_epochs
self.imb_solve: str = 'no' # [upsampling, smote, overall_reweight]
self.visualize = 0
self.pre_transform = None
self.num_workers = 0
self.cnt = 1
data_args = DataParser()
model_args = ModelParser()
train_args = TrainParser()
import torch
import random
import numpy as np
random_seed = 1234
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)