-
Notifications
You must be signed in to change notification settings - Fork 540
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEATURE] Support YOLOv6 v3.0 face detection (#812)
* support the usage of WIDERFace dataset * add YOLOv6FaceHead * add YOLOv6 face detection configs * add a checkpoint convertion script * add a face visualizer * fix a bug of YOLOv6CSPBep initialization
- Loading branch information
Showing
13 changed files
with
1,050 additions
and
7 deletions.
There are no files selected for viewing
27 changes: 27 additions & 0 deletions
27
configs/yolov6/yolov6-face/yolov6_v3_l_syncbn_fast_8xb32-300e_widerface.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
_base_ = './yolov6_v3_m_syncbn_fast_8xb32-300e_widerface.py' | ||
|
||
# ======================= Possible modified parameters ======================= | ||
# -----model related----- | ||
# The scaling factor that controls the depth of the network structure | ||
deepen_factor = 1 | ||
# The scaling factor that controls the width of the network structure | ||
widen_factor = 1 | ||
|
||
# ============================== Unmodified in most cases =================== | ||
model = dict( | ||
backbone=dict( | ||
use_cspsppf=False, | ||
deepen_factor=deepen_factor, | ||
widen_factor=widen_factor, | ||
block_cfg=dict( | ||
type='ConvWrapper', | ||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)), | ||
act_cfg=dict(type='SiLU', inplace=True)), | ||
neck=dict( | ||
deepen_factor=deepen_factor, | ||
widen_factor=widen_factor, | ||
block_cfg=dict( | ||
type='ConvWrapper', | ||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)), | ||
block_act_cfg=dict(type='SiLU', inplace=True)), | ||
bbox_head=dict(head_module=dict(reg_max=16, widen_factor=widen_factor))) |
20 changes: 20 additions & 0 deletions
20
configs/yolov6/yolov6-face/yolov6_v3_m_syncbn_fast_8xb32-300e_widerface.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
_base_ = './yolov6_v3_s_syncbn_fast_8xb32-300e_widerface.py' | ||
|
||
# ======================= Possible modified parameters ======================= | ||
# -----model related----- | ||
# The scaling factor that controls the depth of the network structure | ||
deepen_factor = 0.67 | ||
# The scaling factor that controls the width of the network structure | ||
widen_factor = 0.75 | ||
|
||
# -----train val related----- | ||
affine_scale = 0.9 # YOLOv5RandomAffine scaling ratio | ||
|
||
# ============================== Unmodified in most cases =================== | ||
model = dict( | ||
backbone=dict( | ||
use_cspsppf=False, | ||
deepen_factor=deepen_factor, | ||
widen_factor=widen_factor), | ||
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor), | ||
bbox_head=dict(head_module=dict(reg_max=16, widen_factor=widen_factor))) |
279 changes: 279 additions & 0 deletions
279
configs/yolov6/yolov6-face/yolov6_v3_n_syncbn_fast_8xb32-300e_widerface.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,279 @@ | ||
_base_ = ['../../_base_/default_runtime.py', '../../_base_/det_p5_tta.py'] | ||
|
||
# ======================= Frequently modified parameters ===================== | ||
# -----data related----- | ||
data_root = 'data/WIDERFace/' # Root path of data | ||
|
||
num_classes = 1 # Number of classes for classification | ||
# Batch size of a single GPU during training | ||
train_batch_size_per_gpu = 32 | ||
# Worker to pre-fetch data for each single GPU during training | ||
train_num_workers = 8 | ||
# persistent_workers must be False if num_workers is 0 | ||
persistent_workers = True | ||
|
||
# -----train val related----- | ||
# Base learning rate for optim_wrapper | ||
base_lr = 0.01 | ||
max_epochs = 300 # Maximum training epochs | ||
num_last_epochs = 15 # Last epoch number to switch training pipeline | ||
|
||
# ======================= Possible modified parameters ======================= | ||
# -----data related----- | ||
img_scale = (640, 640) # width, height | ||
# Dataset type, this will be used to define the dataset | ||
dataset_type = 'YOLOv6WIDERFaceDataset' | ||
# Batch size of a single GPU during validation | ||
val_batch_size_per_gpu = 1 | ||
# Worker to pre-fetch data for each single GPU during validation | ||
val_num_workers = 2 | ||
|
||
# Config of batch shapes. Only on val. | ||
# It means not used if batch_shapes_cfg is None. | ||
batch_shapes_cfg = dict( | ||
type='BatchShapePolicy', | ||
batch_size=val_batch_size_per_gpu, | ||
img_size=img_scale[0], | ||
size_divisor=32, | ||
extra_pad_ratio=0.5) | ||
|
||
# -----model related----- | ||
# The scaling factor that controls the depth of the network structure | ||
deepen_factor = 0.33 | ||
# The scaling factor that controls the width of the network structure | ||
widen_factor = 0.25 | ||
|
||
# -----train val related----- | ||
affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio | ||
lr_factor = 0.01 # Learning rate scaling factor | ||
weight_decay = 0.0005 | ||
# Save model checkpoint and validation intervals | ||
save_epoch_intervals = 10 | ||
# The maximum checkpoints to keep. | ||
max_keep_ckpts = 3 | ||
# Single-scale training is recommended to | ||
# be turned on, which can speed up training. | ||
env_cfg = dict(cudnn_benchmark=True) | ||
|
||
# ============================== Unmodified in most cases =================== | ||
model = dict( | ||
type='YOLODetector', | ||
data_preprocessor=dict( | ||
type='YOLOv5DetDataPreprocessor', | ||
mean=[0., 0., 0.], | ||
std=[255., 255., 255.], | ||
bgr_to_rgb=True), | ||
backbone=dict( | ||
type='YOLOv6EfficientRep', | ||
out_indices=[1, 2, 3, 4], | ||
use_cspsppf=True, | ||
deepen_factor=deepen_factor, | ||
widen_factor=widen_factor, | ||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), | ||
act_cfg=dict(type='ReLU', inplace=True)), | ||
neck=dict( | ||
type='YOLOv6RepBiPAFPN', | ||
deepen_factor=deepen_factor, | ||
widen_factor=widen_factor, | ||
in_channels=[128, 256, 512, 1024], | ||
out_channels=[128, 256, 512], | ||
num_csp_blocks=12, | ||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), | ||
act_cfg=dict(type='ReLU', inplace=True), | ||
), | ||
bbox_head=dict( | ||
type='YOLOv6FaceHead', | ||
head_module=dict( | ||
type='YOLOv6FaceHeadModule', | ||
num_classes=num_classes, | ||
in_channels=[128, 256, 512], | ||
stemout_channels=[128, 256, 512], | ||
widen_factor=widen_factor, | ||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), | ||
act_cfg=dict(type='SiLU', inplace=True), | ||
featmap_strides=[8, 16, 32]), | ||
loss_bbox=dict( | ||
type='IoULoss', | ||
iou_mode='siou', | ||
bbox_format='xyxy', | ||
reduction='mean', | ||
loss_weight=2.5, | ||
return_iou=False)), | ||
train_cfg=dict( | ||
initial_epoch=4, | ||
initial_assigner=dict( | ||
type='BatchATSSAssigner', | ||
num_classes=num_classes, | ||
topk=9, | ||
iou_calculator=dict(type='mmdet.BboxOverlaps2D')), | ||
assigner=dict( | ||
type='BatchTaskAlignedAssigner', | ||
num_classes=num_classes, | ||
topk=13, | ||
alpha=1, | ||
beta=6), | ||
), | ||
test_cfg=dict( | ||
multi_label=True, | ||
nms_pre=30000, | ||
score_thr=0.4, | ||
nms=dict(type='nms', iou_threshold=0.45), | ||
max_per_img=1000)) | ||
|
||
# The training pipeline of YOLOv6 is basically the same as YOLOv5. | ||
# The difference is that Mosaic and RandomAffine will be closed in the last 15 epochs. # noqa | ||
pre_transform = [ | ||
dict(type='LoadImageFromFile', backend_args=_base_.backend_args), | ||
dict(type='LoadAnnotations', with_bbox=True) | ||
] | ||
|
||
train_pipeline = [ | ||
*pre_transform, | ||
dict( | ||
type='Mosaic', | ||
img_scale=img_scale, | ||
pad_val=114.0, | ||
pre_transform=pre_transform), | ||
dict( | ||
type='YOLOv5RandomAffine', | ||
max_rotate_degree=0.0, | ||
max_translate_ratio=0.1, | ||
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale), | ||
# img_scale is (width, height) | ||
border=(-img_scale[0] // 2, -img_scale[1] // 2), | ||
border_val=(114, 114, 114), | ||
max_shear_degree=0.0), | ||
dict(type='YOLOv5HSVRandomAug'), | ||
dict(type='mmdet.RandomFlip', prob=0.5), | ||
dict( | ||
type='mmdet.PackDetInputs', | ||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip', | ||
'flip_direction')) | ||
] | ||
|
||
train_pipeline_stage2 = [ | ||
*pre_transform, | ||
dict(type='YOLOv5KeepRatioResize', scale=img_scale), | ||
dict( | ||
type='LetterResize', | ||
scale=img_scale, | ||
allow_scale_up=True, | ||
pad_val=dict(img=114)), | ||
dict( | ||
type='YOLOv5RandomAffine', | ||
max_rotate_degree=0.0, | ||
max_translate_ratio=0.1, | ||
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale), | ||
max_shear_degree=0.0, | ||
), | ||
dict(type='YOLOv5HSVRandomAug'), | ||
dict(type='mmdet.RandomFlip', prob=0.5), | ||
dict( | ||
type='mmdet.PackDetInputs', | ||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip', | ||
'flip_direction')) | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=2, | ||
num_workers=2, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
batch_sampler=dict(type='AspectRatioBatchSampler'), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='train.txt', | ||
data_prefix=dict(img='WIDER_train'), | ||
filter_cfg=dict(filter_empty_gt=True, bbox_min_size=17, min_size=32), | ||
pipeline=train_pipeline)) | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile', backend_args=_base_.backend_args), | ||
dict(type='YOLOv5KeepRatioResize', scale=img_scale), | ||
dict( | ||
type='LetterResize', | ||
scale=img_scale, | ||
allow_scale_up=False, | ||
pad_val=dict(img=114)), | ||
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'), | ||
dict( | ||
type='mmdet.PackDetInputs', | ||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', | ||
'scale_factor', 'pad_param')) | ||
] | ||
|
||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=2, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='val.txt', | ||
data_prefix=dict(img='WIDER_val'), | ||
test_mode=True, | ||
pipeline=test_pipeline)) | ||
|
||
test_dataloader = val_dataloader | ||
|
||
# Optimizer and learning rate scheduler of YOLOv6 are basically the same as YOLOv5. # noqa | ||
# The difference is that the scheduler_type of YOLOv6 is cosine. | ||
optim_wrapper = dict( | ||
type='OptimWrapper', | ||
optimizer=dict( | ||
type='SGD', | ||
lr=base_lr, | ||
momentum=0.937, | ||
weight_decay=weight_decay, | ||
nesterov=True, | ||
batch_size_per_gpu=train_batch_size_per_gpu), | ||
constructor='YOLOv5OptimizerConstructor') | ||
|
||
default_hooks = dict( | ||
param_scheduler=dict( | ||
type='YOLOv5ParamSchedulerHook', | ||
scheduler_type='cosine', | ||
lr_factor=lr_factor, | ||
max_epochs=max_epochs), | ||
checkpoint=dict( | ||
type='CheckpointHook', | ||
interval=save_epoch_intervals, | ||
max_keep_ckpts=max_keep_ckpts, | ||
save_best='auto')) | ||
|
||
custom_hooks = [ | ||
dict( | ||
type='EMAHook', | ||
ema_type='ExpMomentumEMA', | ||
momentum=0.0001, | ||
update_buffers=True, | ||
strict_load=False, | ||
priority=49), | ||
dict( | ||
type='mmdet.PipelineSwitchHook', | ||
switch_epoch=max_epochs - num_last_epochs, | ||
switch_pipeline=train_pipeline_stage2) | ||
] | ||
|
||
val_evaluator = dict( | ||
# TODO: support WiderFace-Evaluation for easy, medium, hard cases | ||
type='mmdet.VOCMetric', | ||
metric='mAP', | ||
eval_mode='11points') | ||
test_evaluator = val_evaluator | ||
|
||
train_cfg = dict( | ||
type='EpochBasedTrainLoop', | ||
max_epochs=max_epochs, | ||
val_interval=save_epoch_intervals, | ||
dynamic_intervals=[(max_epochs - num_last_epochs, 1)]) | ||
val_cfg = dict(type='ValLoop') | ||
test_cfg = dict(type='TestLoop') | ||
|
||
vis_backends = [dict(type='LocalVisBackend')] | ||
visualizer = dict( | ||
type='FaceVisualizer', vis_backends=vis_backends, name='visualizer') |
25 changes: 25 additions & 0 deletions
25
configs/yolov6/yolov6-face/yolov6_v3_s_syncbn_fast_8xb32-300e_widerface.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
_base_ = ['./yolov6_v3_n_syncbn_fast_8xb32-300e_widerface.py'] | ||
|
||
deepen_factor = 0.70 | ||
# The scaling factor that controls the width of the network structure | ||
widen_factor = 0.50 | ||
|
||
model = dict( | ||
backbone=dict( | ||
type='YOLOv6CSPBep', | ||
deepen_factor=deepen_factor, | ||
widen_factor=widen_factor, | ||
block_cfg=dict(type='RepVGGBlock'), | ||
hidden_ratio=0.5, | ||
act_cfg=dict(type='ReLU', inplace=True)), | ||
neck=dict( | ||
type='YOLOv6CSPRepBiPAFPN', | ||
deepen_factor=deepen_factor, | ||
widen_factor=widen_factor, | ||
block_cfg=dict(type='RepVGGBlock'), | ||
block_act_cfg=dict(type='ReLU', inplace=True), | ||
hidden_ratio=0.5), | ||
bbox_head=dict( | ||
type='YOLOv6FaceHead', | ||
head_module=dict(stemout_channels=256, widen_factor=widen_factor), | ||
loss_bbox=dict(type='IoULoss', iou_mode='giou'))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmdet.datasets import WIDERFaceDataset | ||
|
||
from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset | ||
from ..registry import DATASETS | ||
|
||
|
||
@DATASETS.register_module() | ||
class YOLOv6WIDERFaceDataset(BatchShapePolicyDataset, WIDERFaceDataset): | ||
"""Dataset for YOLOv6 WIDERFace Dataset.""" | ||
pass |
Oops, something went wrong.