-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathtrain.py
101 lines (90 loc) · 2.22 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import paddle.fluid as fluid
from models.scan import SCAN
from utils.runner import Runnner
from dataset.faceforensics import FaceForensics
model_cfg = dict(
backbone=dict(
depth=18,
out_indices=(0, 1, 2, 3),
frozen_stages=-1),
neck=dict(
norm_cfg=dict(type='IN')),
head=dict(
depth=18,
out_indices=(3,),
norm_cfg=dict(type='BN'),
dropout=0.5),
train_cfg=dict(
w_cls=5.0,
w_tri=1.0,
w_reg=5.0,
with_mask=False),
test_cfg=dict(
thr=0.5),
pretrained='./pretrained/resnet18-torch',
)
checkpoint_cfg = dict(
work_dir='./work_dir/ff_add_val',
load_from='./work_dir/ff_c23/Best_model',
save_interval=10000,
eval_interval=200,
log_interval=10,
eval_type='acc'
)
optimizer_cfg = dict(
lr=0.0005,
type='Adam',
warmup_iter=1000,
decay_epoch=[5, 8, 12],
decay=0.3,
regularization=0.0005,
)
extra_aug = dict(
photo_metric_distortion=dict(
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.8, 1.2),
hue_delta=16),
random_erasing=dict(
probability=0.5,
area=(0.01, 0.03),
mean=(80, 80, 80)),
random_cutout=dict(
probability=0.5,
max_edge=20),
ramdom_rotate=dict(
probability=0.5,
angle=30),
ramdom_crop=dict(
probability=0.5,
w_h=(0.12, 0.12))
)
data_root = 'Path/FaceForensics/data/'
train_dataset = FaceForensics(
img_prefix=data_root,
ann_file=data_root + 'train_add_train.txt',
mask_file=None,
img_scale=(224, 224),
img_norm_cfg=dict(mean=(100, 100, 100), std=(80, 80, 80)),
extra_aug=extra_aug,
crop_face=0.1,
)
val_dataset = FaceForensics(
img_prefix=data_root,
ann_file=data_root + 'train_val_train.txt',
img_scale=(224, 224),
img_norm_cfg=dict(mean=(100, 100, 100), std=(80, 80, 80)),
extra_aug=dict(),
test_mode=True,
crop_face=0.1,
)
with fluid.dygraph.guard():
model = SCAN(**model_cfg)
runner = Runnner(
model,
train_dataset,
val_dataset=val_dataset,
batch_size=96,
checkpoint_config=checkpoint_cfg,
optimizer_config=optimizer_cfg)
runner.train(max_epochs=15)