-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
159 lines (127 loc) · 5.16 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import argparse
from yacs.config import CfgNode as CN
def set_cfg(cfg):
# ------------------------------------------------------------------------ #
# Basic options
# ------------------------------------------------------------------------ #
# Dataset name
cfg.dataset = 'arxiv-2023'
# Cuda device number, used for machine with multiple gpus
cfg.device = 'cuda:0'
# Whether fix the running seed to remove randomness
cfg.seed = None
# Number of runs with random init
cfg.runs = 4
cfg.gnn = CN()
cfg.lm = CN()
cfg.gt = CN()
# ------------------------------------------------------------------------ #
# GNN Model options
# ------------------------------------------------------------------------ #
cfg.gnn.model = CN()
# GNN model name
cfg.gnn.model.name = 'GCN'
# Number of gnn layers
cfg.gnn.model.num_layers = 4
# Hidden size of the model
cfg.gnn.model.hidden_dim = 128
# ------------------------------------------------------------------------ #
# GNN Training options
# ------------------------------------------------------------------------ #
cfg.gnn.train = CN()
# The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights
cfg.gnn.train.weight_decay = 0.0
# Maximal number of epochs
cfg.gnn.train.epochs = 200
# Node feature type, options: ogb, TA, P, E
cfg.gnn.train.feature_type = 'TA_P_E'
# Number of epochs with no improvement after which training will be stopped
cfg.gnn.train.early_stop = 25
# Base learning rate
cfg.gnn.train.lr = 0.01
# L2 regularization, weight decay
cfg.gnn.train.wd = 0.0
# Dropout rate
cfg.gnn.train.dropout = 0.0
# ------------------------------------------------------------------------ #
# LM Model options
# ------------------------------------------------------------------------ #
cfg.lm.model = CN()
# LM model name
cfg.lm.model.name = 'deberta-base'
cfg.lm.model.feat_shrink = ""
cfg.lm.model.path = '/gpfsnyu/scratch/ys6310/deberta-base'
# ------------------------------------------------------------------------ #
# LM Training options
# ------------------------------------------------------------------------ #
cfg.lm.train = CN()
# Number of samples computed once per batch per device
cfg.lm.train.batch_size = 9
# Number of training steps for which the gradients should be accumulated
cfg.lm.train.grad_acc_steps = 1
# Base learning rate
cfg.lm.train.lr = 2e-5
# Maximal number of epochs
cfg.lm.train.epochs = 4
# The number of warmup steps
cfg.lm.train.warmup_epochs = 0.6
# Number of update steps between two evaluations
cfg.lm.train.eval_patience = 50000
# The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights
cfg.lm.train.weight_decay = 0.0
# The dropout probability for all fully connected layers in the embeddings, encoder, and pooler
cfg.lm.train.dropout = 0.3
# The dropout ratio for the attention probabilities
cfg.lm.train.att_dropout = 0.1
# The dropout ratio for the classifier
cfg.lm.train.cla_dropout = 0.4
# Whether or not to use the gpt responses (i.e., explanation and prediction) as text input
# If not, use the original text attributes (i.e., title and abstract)
cfg.lm.train.use_gpt = False
# ------------------------------------------------------------------------ #
# GraphT Training options
# ------------------------------------------------------------------------ #
cfg.gt.train = CN()
cfg.gt.train.batch_size = 100
cfg.gt.train.epochs = 100
cfg.gt.train.n_layers = 9
cfg.gt.train.dim_hidden = 256
cfg.gt.train.dim_qk = 256
cfg.gt.train.dim_v = 256
cfg.gt.train.dim_ff = 256
cfg.gt.train.n_heads = 16
cfg.gt.train.last_layer_n_heads = 16
cfg.gt.train.input_dropout_rate = 0.2
cfg.gt.train.dropout_rate = 0.2
cfg.gt.train.weight_decay = 0.0
cfg.gt.train.lr = 1e-4
return cfg
# Principle means that if an option is defined in a YACS config object,
# then your program should set that configuration option using cfg.merge_from_list(opts) and not by defining,
# for example, --train-scales as a command line argument that is then used to set cfg.TRAIN.SCALES.
def update_cfg(cfg, args_str=None):
parser = argparse.ArgumentParser()
parser.add_argument('--config', default="",
metavar="FILE", help="Path to config file")
# opts arg needs to match set_cfg
parser.add_argument("opts", default=[], nargs=argparse.REMAINDER,
help="Modify config options using the command-line")
if isinstance(args_str, str):
# parse from a string
args = parser.parse_args(args_str.split())
else:
# parse from command line
args = parser.parse_args()
# Clone the original cfg
cfg = cfg.clone()
# Update from config file
if os.path.isfile(args.config):
cfg.merge_from_file(args.config)
# Update from command line
cfg.merge_from_list(args.opts)
return cfg
"""
Global variable
"""
cfg = set_cfg(CN())