-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_train.py
118 lines (100 loc) · 3.04 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import wandb
from dataset import get_data, get_data_manager, update_data
from IF import UserAIF
from model import get_model
from model.trainer import CAT, train
from utils.config import load_config
def main(cfg):
torch.manual_seed(cfg.random_seed)
## get data
data_dict, datasets, dataloaders = get_data(cfg)
if cfg.user_aif:
print("user aif")
## train on unbiased set
wandb.init(
project=cfg.wandb.project_name,
name=cfg.wandb.exp_name + "_unbiased_training",
)
model = get_model(
cfg,
init_ckpt=None,
)
init_ckpt = train(
cfg,
model,
saved_model_name="unbiased_training",
init_item_param_training=True,
train_dataloader=dataloaders["unbiased_train"],
val_train_dataloader=dataloaders["unbiased_val_train"],
val_eval_dataloader=dataloaders["unbiased_val_eval"],
)
wandb.finish()
## train biased user's params
wandb.init(
project=cfg.wandb.project_name,
name=cfg.wandb.exp_name + "_biased_training",
)
model = get_model(
cfg,
init_ckpt=init_ckpt,
)
init_ckpt = train(
cfg,
model,
saved_model_name="biased_training",
init_item_param_training=False,
train_dataloader=dataloaders["biased_train"],
val_train_dataloader=None,
val_eval_dataloader=dataloaders["biased_eval"],
)
wandb.finish()
# de-biasing using UserAIF
model = get_model(
cfg,
init_ckpt=init_ckpt,
)
selected_biased_df = UserAIF(
model, data_dict["biased"], data_dict["unbiased"], data_dict["items"]
)
data_dict, datasets, dataloaders = update_data(
cfg, selected_biased_df, data_dict, datasets, dataloaders
)
## Training
if cfg.wandb.use_wandb:
wandb.init(
project=cfg.wandb.project_name,
name=cfg.wandb.exp_name + "_final_training",
)
model = get_model(
cfg,
init_ckpt=None,
)
init_ckpt = train(
cfg,
model,
saved_model_name="final_training",
init_item_param_training=True,
train_dataloader=dataloaders["unbiased_train+biased"],
val_train_dataloader=dataloaders["unbiased_val_train"],
val_eval_dataloader=dataloaders["unbiased_val_eval"],
)
## CAT
data_manager = get_data_manager(
cfg,
cat_users=data_dict["test_users"],
feature_df=data_dict["test_feature"],
pool_df=data_dict["test_pool"],
eval_df=data_dict["test_eval"],
)
CAT(cfg, model, data_manager)
if __name__ == "__main__":
# Model
model_type = "irt-2pl"
# Dataset
dataset_type = "enem"
# Config
config_type = "./configs/train.yaml"
configs = load_config(config_type, model_type, dataset_type)
# Run
main(cfg=configs)