Skip to content

Commit

Permalink
[FEATURE] Support YOLOv6 v3.0 face detection (#812)
Browse files Browse the repository at this point in the history
* support the usage of WIDERFace dataset

* add YOLOv6FaceHead

* add YOLOv6 face detection configs

* add a checkpoint convertion script

* add a face visualizer

* fix a bug of YOLOv6CSPBep initialization
  • Loading branch information
Qingrenn authored Aug 23, 2023
1 parent 8c4d9dc commit 6b52f81
Show file tree
Hide file tree
Showing 13 changed files with 1,050 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
_base_ = './yolov6_v3_m_syncbn_fast_8xb32-300e_widerface.py'

# ======================= Possible modified parameters =======================
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 1
# The scaling factor that controls the width of the network structure
widen_factor = 1

# ============================== Unmodified in most cases ===================
model = dict(
backbone=dict(
use_cspsppf=False,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
block_cfg=dict(
type='ConvWrapper',
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)),
act_cfg=dict(type='SiLU', inplace=True)),
neck=dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
block_cfg=dict(
type='ConvWrapper',
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)),
block_act_cfg=dict(type='SiLU', inplace=True)),
bbox_head=dict(head_module=dict(reg_max=16, widen_factor=widen_factor)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
_base_ = './yolov6_v3_s_syncbn_fast_8xb32-300e_widerface.py'

# ======================= Possible modified parameters =======================
# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.67
# The scaling factor that controls the width of the network structure
widen_factor = 0.75

# -----train val related-----
affine_scale = 0.9 # YOLOv5RandomAffine scaling ratio

# ============================== Unmodified in most cases ===================
model = dict(
backbone=dict(
use_cspsppf=False,
deepen_factor=deepen_factor,
widen_factor=widen_factor),
neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor),
bbox_head=dict(head_module=dict(reg_max=16, widen_factor=widen_factor)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
_base_ = ['../../_base_/default_runtime.py', '../../_base_/det_p5_tta.py']

# ======================= Frequently modified parameters =====================
# -----data related-----
data_root = 'data/WIDERFace/' # Root path of data

num_classes = 1 # Number of classes for classification
# Batch size of a single GPU during training
train_batch_size_per_gpu = 32
# Worker to pre-fetch data for each single GPU during training
train_num_workers = 8
# persistent_workers must be False if num_workers is 0
persistent_workers = True

# -----train val related-----
# Base learning rate for optim_wrapper
base_lr = 0.01
max_epochs = 300 # Maximum training epochs
num_last_epochs = 15 # Last epoch number to switch training pipeline

# ======================= Possible modified parameters =======================
# -----data related-----
img_scale = (640, 640) # width, height
# Dataset type, this will be used to define the dataset
dataset_type = 'YOLOv6WIDERFaceDataset'
# Batch size of a single GPU during validation
val_batch_size_per_gpu = 1
# Worker to pre-fetch data for each single GPU during validation
val_num_workers = 2

# Config of batch shapes. Only on val.
# It means not used if batch_shapes_cfg is None.
batch_shapes_cfg = dict(
type='BatchShapePolicy',
batch_size=val_batch_size_per_gpu,
img_size=img_scale[0],
size_divisor=32,
extra_pad_ratio=0.5)

# -----model related-----
# The scaling factor that controls the depth of the network structure
deepen_factor = 0.33
# The scaling factor that controls the width of the network structure
widen_factor = 0.25

# -----train val related-----
affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio
lr_factor = 0.01 # Learning rate scaling factor
weight_decay = 0.0005
# Save model checkpoint and validation intervals
save_epoch_intervals = 10
# The maximum checkpoints to keep.
max_keep_ckpts = 3
# Single-scale training is recommended to
# be turned on, which can speed up training.
env_cfg = dict(cudnn_benchmark=True)

# ============================== Unmodified in most cases ===================
model = dict(
type='YOLODetector',
data_preprocessor=dict(
type='YOLOv5DetDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
bgr_to_rgb=True),
backbone=dict(
type='YOLOv6EfficientRep',
out_indices=[1, 2, 3, 4],
use_cspsppf=True,
deepen_factor=deepen_factor,
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='ReLU', inplace=True)),
neck=dict(
type='YOLOv6RepBiPAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[128, 256, 512, 1024],
out_channels=[128, 256, 512],
num_csp_blocks=12,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='ReLU', inplace=True),
),
bbox_head=dict(
type='YOLOv6FaceHead',
head_module=dict(
type='YOLOv6FaceHeadModule',
num_classes=num_classes,
in_channels=[128, 256, 512],
stemout_channels=[128, 256, 512],
widen_factor=widen_factor,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True),
featmap_strides=[8, 16, 32]),
loss_bbox=dict(
type='IoULoss',
iou_mode='siou',
bbox_format='xyxy',
reduction='mean',
loss_weight=2.5,
return_iou=False)),
train_cfg=dict(
initial_epoch=4,
initial_assigner=dict(
type='BatchATSSAssigner',
num_classes=num_classes,
topk=9,
iou_calculator=dict(type='mmdet.BboxOverlaps2D')),
assigner=dict(
type='BatchTaskAlignedAssigner',
num_classes=num_classes,
topk=13,
alpha=1,
beta=6),
),
test_cfg=dict(
multi_label=True,
nms_pre=30000,
score_thr=0.4,
nms=dict(type='nms', iou_threshold=0.45),
max_per_img=1000))

# The training pipeline of YOLOv6 is basically the same as YOLOv5.
# The difference is that Mosaic and RandomAffine will be closed in the last 15 epochs. # noqa
pre_transform = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='LoadAnnotations', with_bbox=True)
]

train_pipeline = [
*pre_transform,
dict(
type='Mosaic',
img_scale=img_scale,
pad_val=114.0,
pre_transform=pre_transform),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_translate_ratio=0.1,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
# img_scale is (width, height)
border=(-img_scale[0] // 2, -img_scale[1] // 2),
border_val=(114, 114, 114),
max_shear_degree=0.0),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]

train_pipeline_stage2 = [
*pre_transform,
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=True,
pad_val=dict(img=114)),
dict(
type='YOLOv5RandomAffine',
max_rotate_degree=0.0,
max_translate_ratio=0.1,
scaling_ratio_range=(1 - affine_scale, 1 + affine_scale),
max_shear_degree=0.0,
),
dict(type='YOLOv5HSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip',
'flip_direction'))
]

train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_sampler=dict(type='AspectRatioBatchSampler'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='train.txt',
data_prefix=dict(img='WIDER_train'),
filter_cfg=dict(filter_empty_gt=True, bbox_min_size=17, min_size=32),
pipeline=train_pipeline))

test_pipeline = [
dict(type='LoadImageFromFile', backend_args=_base_.backend_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]

val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='val.txt',
data_prefix=dict(img='WIDER_val'),
test_mode=True,
pipeline=test_pipeline))

test_dataloader = val_dataloader

# Optimizer and learning rate scheduler of YOLOv6 are basically the same as YOLOv5. # noqa
# The difference is that the scheduler_type of YOLOv6 is cosine.
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='SGD',
lr=base_lr,
momentum=0.937,
weight_decay=weight_decay,
nesterov=True,
batch_size_per_gpu=train_batch_size_per_gpu),
constructor='YOLOv5OptimizerConstructor')

default_hooks = dict(
param_scheduler=dict(
type='YOLOv5ParamSchedulerHook',
scheduler_type='cosine',
lr_factor=lr_factor,
max_epochs=max_epochs),
checkpoint=dict(
type='CheckpointHook',
interval=save_epoch_intervals,
max_keep_ckpts=max_keep_ckpts,
save_best='auto'))

custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0001,
update_buffers=True,
strict_load=False,
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - num_last_epochs,
switch_pipeline=train_pipeline_stage2)
]

val_evaluator = dict(
# TODO: support WiderFace-Evaluation for easy, medium, hard cases
type='mmdet.VOCMetric',
metric='mAP',
eval_mode='11points')
test_evaluator = val_evaluator

train_cfg = dict(
type='EpochBasedTrainLoop',
max_epochs=max_epochs,
val_interval=save_epoch_intervals,
dynamic_intervals=[(max_epochs - num_last_epochs, 1)])
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='FaceVisualizer', vis_backends=vis_backends, name='visualizer')
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
_base_ = ['./yolov6_v3_n_syncbn_fast_8xb32-300e_widerface.py']

deepen_factor = 0.70
# The scaling factor that controls the width of the network structure
widen_factor = 0.50

model = dict(
backbone=dict(
type='YOLOv6CSPBep',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
block_cfg=dict(type='RepVGGBlock'),
hidden_ratio=0.5,
act_cfg=dict(type='ReLU', inplace=True)),
neck=dict(
type='YOLOv6CSPRepBiPAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
block_cfg=dict(type='RepVGGBlock'),
block_act_cfg=dict(type='ReLU', inplace=True),
hidden_ratio=0.5),
bbox_head=dict(
type='YOLOv6FaceHead',
head_module=dict(stemout_channels=256, widen_factor=widen_factor),
loss_bbox=dict(type='IoULoss', iou_mode='giou')))
3 changes: 2 additions & 1 deletion mmyolo/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from .yolov5_crowdhuman import YOLOv5CrowdHumanDataset
from .yolov5_dota import YOLOv5DOTADataset
from .yolov5_voc import YOLOv5VOCDataset
from .yolov6_widerface import YOLOv6WIDERFaceDataset

__all__ = [
'YOLOv5CocoDataset', 'YOLOv5VOCDataset', 'BatchShapePolicy',
'yolov5_collate', 'YOLOv5CrowdHumanDataset', 'YOLOv5DOTADataset',
'PoseCocoDataset'
'PoseCocoDataset', 'YOLOv6WIDERFaceDataset'
]
11 changes: 11 additions & 0 deletions mmyolo/datasets/yolov6_widerface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.datasets import WIDERFaceDataset

from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset
from ..registry import DATASETS


@DATASETS.register_module()
class YOLOv6WIDERFaceDataset(BatchShapePolicyDataset, WIDERFaceDataset):
"""Dataset for YOLOv6 WIDERFace Dataset."""
pass
Loading

0 comments on commit 6b52f81

Please sign in to comment.