-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_CAV_CBM.py
executable file
·107 lines (99 loc) · 3.5 KB
/
train_CAV_CBM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# -*- encoding: utf-8 -*-
"""
@Author : liuyang
@github : https://github.com/ly1998117/MMCBM
@Contact : [email protected]
"""
import os
import argparse
import torch.optim.lr_scheduler
from utils.EarlyStop import EarlyStopping
from monai.utils import set_determinism
from utils.metrics import *
from loss import Loss
from utils.decorator import decorator_args
from params import id_to_labels
from trainer.train_helper_mmcbm import TrainHelperMMCBM
from models.MMCBM.concepts_bank import ConceptBank
from models import get_backbone
def get_model_opti(args):
backbone = get_backbone(args)
bank_dir = os.path.join(args.output_dir, args.dir_name)
# initialize the concept bank
args.concept_bank = ConceptBank(device=args.device,
clip_name=args.clip_name,
location=args.cbm_location,
backbone=backbone,
n_samples=args.pos_samples,
neg_samples=args.neg_samples,
svm_C=args.svm_C,
bank_dir=bank_dir,
report_shot=args.report_shot,
concept_shot=args.concept_shot,
cav_split=args.cav_split,
language='zh'
)
from models.MMCBM.CBMs import M2LinearCBM
# initialize the Concept Bottleneck Model: FA_ICGA and US
model = M2LinearCBM(
idx_to_class=id_to_labels,
concept_bank=args.concept_bank,
n_classes=args.out_channel,
fusion=args.fusion,
activation=args.activation,
analysis_top_k=args.analysis_top_k,
analysis_threshold=args.analysis_threshold,
act_on_weight=args.act_on_weight,
init_method=args.init_method,
bias=args.bias,
)
opt = torch.optim.Adam(params=model.parameters(), lr=args.lr)
return model, opt
@decorator_args
def get_args(args) -> argparse.Namespace:
# enabling cudnn determinism appears to speed up training by a lot
torch.backends.cudnn.deterministic = not args.cudnn_nondet
args.down_sample = False
##################### debug #####################
# args.device = 0
# args.k = 0
args.wandb = False
# args.clip_name = 'clip_ViT-L/14'
args.cbm_model = 'b0'
args.modality = 'MM'
# args.name = 'MMCBM_2'
args.fusion = 'max'
args.cbm_location = 'report_strict'
# args.mark = f'{args.cbm_location}'
# args.infer = True
# args.resume = True
# args.test_only = True
# args.idx = 130
args.analysis_top_k = 15
# args.test = False
args.activation = 'sigmoid'
args.act_on_weight = True
args.num_worker = 2
# args.backbone = 'Efficientb0_SCLS_TestOnly/fold_0'
##################### debug #####################
# if 'clip' in args.clip_name:
# args.dir_name = 'CLip'
# else:
# args.dir_name = f'CAV'
# args.dir_name = args.clip_name.upper()
args.metrics = [
Accuracy(),
Precision(),
Recall(),
F1(),
]
args.mode = 'max'
if __name__ == "__main__":
from utils.logger import PrintColor
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
args = get_args()
set_determinism(args.seed)
model, opti = get_model_opti(args)
args.loss = Loss(loss_type=args.loss, model=model)
# start training
TrainHelperMMCBM(args, model, opti).start_train()