From 6b52f81c07999bce2c6b2485c6e969ea8491be4d Mon Sep 17 00:00:00 2001 From: Qingren Date: Wed, 23 Aug 2023 15:17:25 +0800 Subject: [PATCH] [FEATURE] Support YOLOv6 v3.0 face detection (#812) * support the usage of WIDERFace dataset * add YOLOv6FaceHead * add YOLOv6 face detection configs * add a checkpoint convertion script * add a face visualizer * fix a bug of YOLOv6CSPBep initialization --- ...6_v3_l_syncbn_fast_8xb32-300e_widerface.py | 27 ++ ...6_v3_m_syncbn_fast_8xb32-300e_widerface.py | 20 + ...6_v3_n_syncbn_fast_8xb32-300e_widerface.py | 279 +++++++++++ ...6_v3_s_syncbn_fast_8xb32-300e_widerface.py | 25 + mmyolo/datasets/__init__.py | 3 +- mmyolo/datasets/yolov6_widerface.py | 11 + mmyolo/models/backbones/efficient_rep.py | 8 +- mmyolo/models/dense_heads/__init__.py | 10 +- mmyolo/models/dense_heads/yolov6_face_head.py | 434 ++++++++++++++++++ mmyolo/utils/__init__.py | 3 +- mmyolo/utils/face_visualizer.py | 45 ++ .../test_dense_heads/test_yolov6-face_head.py | 65 +++ .../yolov6-face_v3_to_mmyolo.py | 127 +++++ 13 files changed, 1050 insertions(+), 7 deletions(-) create mode 100644 configs/yolov6/yolov6-face/yolov6_v3_l_syncbn_fast_8xb32-300e_widerface.py create mode 100644 configs/yolov6/yolov6-face/yolov6_v3_m_syncbn_fast_8xb32-300e_widerface.py create mode 100644 configs/yolov6/yolov6-face/yolov6_v3_n_syncbn_fast_8xb32-300e_widerface.py create mode 100644 configs/yolov6/yolov6-face/yolov6_v3_s_syncbn_fast_8xb32-300e_widerface.py create mode 100644 mmyolo/datasets/yolov6_widerface.py create mode 100644 mmyolo/models/dense_heads/yolov6_face_head.py create mode 100644 mmyolo/utils/face_visualizer.py create mode 100644 tests/test_models/test_dense_heads/test_yolov6-face_head.py create mode 100644 tools/model_converters/yolov6-face_v3_to_mmyolo.py diff --git a/configs/yolov6/yolov6-face/yolov6_v3_l_syncbn_fast_8xb32-300e_widerface.py b/configs/yolov6/yolov6-face/yolov6_v3_l_syncbn_fast_8xb32-300e_widerface.py new file mode 100644 index 000000000..6e69c0a3e --- /dev/null +++ b/configs/yolov6/yolov6-face/yolov6_v3_l_syncbn_fast_8xb32-300e_widerface.py @@ -0,0 +1,27 @@ +_base_ = './yolov6_v3_m_syncbn_fast_8xb32-300e_widerface.py' + +# ======================= Possible modified parameters ======================= +# -----model related----- +# The scaling factor that controls the depth of the network structure +deepen_factor = 1 +# The scaling factor that controls the width of the network structure +widen_factor = 1 + +# ============================== Unmodified in most cases =================== +model = dict( + backbone=dict( + use_cspsppf=False, + deepen_factor=deepen_factor, + widen_factor=widen_factor, + block_cfg=dict( + type='ConvWrapper', + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)), + act_cfg=dict(type='SiLU', inplace=True)), + neck=dict( + deepen_factor=deepen_factor, + widen_factor=widen_factor, + block_cfg=dict( + type='ConvWrapper', + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001)), + block_act_cfg=dict(type='SiLU', inplace=True)), + bbox_head=dict(head_module=dict(reg_max=16, widen_factor=widen_factor))) diff --git a/configs/yolov6/yolov6-face/yolov6_v3_m_syncbn_fast_8xb32-300e_widerface.py b/configs/yolov6/yolov6-face/yolov6_v3_m_syncbn_fast_8xb32-300e_widerface.py new file mode 100644 index 000000000..cd7260e57 --- /dev/null +++ b/configs/yolov6/yolov6-face/yolov6_v3_m_syncbn_fast_8xb32-300e_widerface.py @@ -0,0 +1,20 @@ +_base_ = './yolov6_v3_s_syncbn_fast_8xb32-300e_widerface.py' + +# ======================= Possible modified parameters ======================= +# -----model related----- +# The scaling factor that controls the depth of the network structure +deepen_factor = 0.67 +# The scaling factor that controls the width of the network structure +widen_factor = 0.75 + +# -----train val related----- +affine_scale = 0.9 # YOLOv5RandomAffine scaling ratio + +# ============================== Unmodified in most cases =================== +model = dict( + backbone=dict( + use_cspsppf=False, + deepen_factor=deepen_factor, + widen_factor=widen_factor), + neck=dict(deepen_factor=deepen_factor, widen_factor=widen_factor), + bbox_head=dict(head_module=dict(reg_max=16, widen_factor=widen_factor))) diff --git a/configs/yolov6/yolov6-face/yolov6_v3_n_syncbn_fast_8xb32-300e_widerface.py b/configs/yolov6/yolov6-face/yolov6_v3_n_syncbn_fast_8xb32-300e_widerface.py new file mode 100644 index 000000000..6e3ed18a2 --- /dev/null +++ b/configs/yolov6/yolov6-face/yolov6_v3_n_syncbn_fast_8xb32-300e_widerface.py @@ -0,0 +1,279 @@ +_base_ = ['../../_base_/default_runtime.py', '../../_base_/det_p5_tta.py'] + +# ======================= Frequently modified parameters ===================== +# -----data related----- +data_root = 'data/WIDERFace/' # Root path of data + +num_classes = 1 # Number of classes for classification +# Batch size of a single GPU during training +train_batch_size_per_gpu = 32 +# Worker to pre-fetch data for each single GPU during training +train_num_workers = 8 +# persistent_workers must be False if num_workers is 0 +persistent_workers = True + +# -----train val related----- +# Base learning rate for optim_wrapper +base_lr = 0.01 +max_epochs = 300 # Maximum training epochs +num_last_epochs = 15 # Last epoch number to switch training pipeline + +# ======================= Possible modified parameters ======================= +# -----data related----- +img_scale = (640, 640) # width, height +# Dataset type, this will be used to define the dataset +dataset_type = 'YOLOv6WIDERFaceDataset' +# Batch size of a single GPU during validation +val_batch_size_per_gpu = 1 +# Worker to pre-fetch data for each single GPU during validation +val_num_workers = 2 + +# Config of batch shapes. Only on val. +# It means not used if batch_shapes_cfg is None. +batch_shapes_cfg = dict( + type='BatchShapePolicy', + batch_size=val_batch_size_per_gpu, + img_size=img_scale[0], + size_divisor=32, + extra_pad_ratio=0.5) + +# -----model related----- +# The scaling factor that controls the depth of the network structure +deepen_factor = 0.33 +# The scaling factor that controls the width of the network structure +widen_factor = 0.25 + +# -----train val related----- +affine_scale = 0.5 # YOLOv5RandomAffine scaling ratio +lr_factor = 0.01 # Learning rate scaling factor +weight_decay = 0.0005 +# Save model checkpoint and validation intervals +save_epoch_intervals = 10 +# The maximum checkpoints to keep. +max_keep_ckpts = 3 +# Single-scale training is recommended to +# be turned on, which can speed up training. +env_cfg = dict(cudnn_benchmark=True) + +# ============================== Unmodified in most cases =================== +model = dict( + type='YOLODetector', + data_preprocessor=dict( + type='YOLOv5DetDataPreprocessor', + mean=[0., 0., 0.], + std=[255., 255., 255.], + bgr_to_rgb=True), + backbone=dict( + type='YOLOv6EfficientRep', + out_indices=[1, 2, 3, 4], + use_cspsppf=True, + deepen_factor=deepen_factor, + widen_factor=widen_factor, + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='ReLU', inplace=True)), + neck=dict( + type='YOLOv6RepBiPAFPN', + deepen_factor=deepen_factor, + widen_factor=widen_factor, + in_channels=[128, 256, 512, 1024], + out_channels=[128, 256, 512], + num_csp_blocks=12, + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='ReLU', inplace=True), + ), + bbox_head=dict( + type='YOLOv6FaceHead', + head_module=dict( + type='YOLOv6FaceHeadModule', + num_classes=num_classes, + in_channels=[128, 256, 512], + stemout_channels=[128, 256, 512], + widen_factor=widen_factor, + norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + act_cfg=dict(type='SiLU', inplace=True), + featmap_strides=[8, 16, 32]), + loss_bbox=dict( + type='IoULoss', + iou_mode='siou', + bbox_format='xyxy', + reduction='mean', + loss_weight=2.5, + return_iou=False)), + train_cfg=dict( + initial_epoch=4, + initial_assigner=dict( + type='BatchATSSAssigner', + num_classes=num_classes, + topk=9, + iou_calculator=dict(type='mmdet.BboxOverlaps2D')), + assigner=dict( + type='BatchTaskAlignedAssigner', + num_classes=num_classes, + topk=13, + alpha=1, + beta=6), + ), + test_cfg=dict( + multi_label=True, + nms_pre=30000, + score_thr=0.4, + nms=dict(type='nms', iou_threshold=0.45), + max_per_img=1000)) + +# The training pipeline of YOLOv6 is basically the same as YOLOv5. +# The difference is that Mosaic and RandomAffine will be closed in the last 15 epochs. # noqa +pre_transform = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='LoadAnnotations', with_bbox=True) +] + +train_pipeline = [ + *pre_transform, + dict( + type='Mosaic', + img_scale=img_scale, + pad_val=114.0, + pre_transform=pre_transform), + dict( + type='YOLOv5RandomAffine', + max_rotate_degree=0.0, + max_translate_ratio=0.1, + scaling_ratio_range=(1 - affine_scale, 1 + affine_scale), + # img_scale is (width, height) + border=(-img_scale[0] // 2, -img_scale[1] // 2), + border_val=(114, 114, 114), + max_shear_degree=0.0), + dict(type='YOLOv5HSVRandomAug'), + dict(type='mmdet.RandomFlip', prob=0.5), + dict( + type='mmdet.PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip', + 'flip_direction')) +] + +train_pipeline_stage2 = [ + *pre_transform, + dict(type='YOLOv5KeepRatioResize', scale=img_scale), + dict( + type='LetterResize', + scale=img_scale, + allow_scale_up=True, + pad_val=dict(img=114)), + dict( + type='YOLOv5RandomAffine', + max_rotate_degree=0.0, + max_translate_ratio=0.1, + scaling_ratio_range=(1 - affine_scale, 1 + affine_scale), + max_shear_degree=0.0, + ), + dict(type='YOLOv5HSVRandomAug'), + dict(type='mmdet.RandomFlip', prob=0.5), + dict( + type='mmdet.PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'flip', + 'flip_direction')) +] + +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=dict(type='AspectRatioBatchSampler'), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='train.txt', + data_prefix=dict(img='WIDER_train'), + filter_cfg=dict(filter_empty_gt=True, bbox_min_size=17, min_size=32), + pipeline=train_pipeline)) + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='YOLOv5KeepRatioResize', scale=img_scale), + dict( + type='LetterResize', + scale=img_scale, + allow_scale_up=False, + pad_val=dict(img=114)), + dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'), + dict( + type='mmdet.PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'pad_param')) +] + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='val.txt', + data_prefix=dict(img='WIDER_val'), + test_mode=True, + pipeline=test_pipeline)) + +test_dataloader = val_dataloader + +# Optimizer and learning rate scheduler of YOLOv6 are basically the same as YOLOv5. # noqa +# The difference is that the scheduler_type of YOLOv6 is cosine. +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='SGD', + lr=base_lr, + momentum=0.937, + weight_decay=weight_decay, + nesterov=True, + batch_size_per_gpu=train_batch_size_per_gpu), + constructor='YOLOv5OptimizerConstructor') + +default_hooks = dict( + param_scheduler=dict( + type='YOLOv5ParamSchedulerHook', + scheduler_type='cosine', + lr_factor=lr_factor, + max_epochs=max_epochs), + checkpoint=dict( + type='CheckpointHook', + interval=save_epoch_intervals, + max_keep_ckpts=max_keep_ckpts, + save_best='auto')) + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0001, + update_buffers=True, + strict_load=False, + priority=49), + dict( + type='mmdet.PipelineSwitchHook', + switch_epoch=max_epochs - num_last_epochs, + switch_pipeline=train_pipeline_stage2) +] + +val_evaluator = dict( + # TODO: support WiderFace-Evaluation for easy, medium, hard cases + type='mmdet.VOCMetric', + metric='mAP', + eval_mode='11points') +test_evaluator = val_evaluator + +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_interval=save_epoch_intervals, + dynamic_intervals=[(max_epochs - num_last_epochs, 1)]) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='FaceVisualizer', vis_backends=vis_backends, name='visualizer') diff --git a/configs/yolov6/yolov6-face/yolov6_v3_s_syncbn_fast_8xb32-300e_widerface.py b/configs/yolov6/yolov6-face/yolov6_v3_s_syncbn_fast_8xb32-300e_widerface.py new file mode 100644 index 000000000..daf2abfaa --- /dev/null +++ b/configs/yolov6/yolov6-face/yolov6_v3_s_syncbn_fast_8xb32-300e_widerface.py @@ -0,0 +1,25 @@ +_base_ = ['./yolov6_v3_n_syncbn_fast_8xb32-300e_widerface.py'] + +deepen_factor = 0.70 +# The scaling factor that controls the width of the network structure +widen_factor = 0.50 + +model = dict( + backbone=dict( + type='YOLOv6CSPBep', + deepen_factor=deepen_factor, + widen_factor=widen_factor, + block_cfg=dict(type='RepVGGBlock'), + hidden_ratio=0.5, + act_cfg=dict(type='ReLU', inplace=True)), + neck=dict( + type='YOLOv6CSPRepBiPAFPN', + deepen_factor=deepen_factor, + widen_factor=widen_factor, + block_cfg=dict(type='RepVGGBlock'), + block_act_cfg=dict(type='ReLU', inplace=True), + hidden_ratio=0.5), + bbox_head=dict( + type='YOLOv6FaceHead', + head_module=dict(stemout_channels=256, widen_factor=widen_factor), + loss_bbox=dict(type='IoULoss', iou_mode='giou'))) diff --git a/mmyolo/datasets/__init__.py b/mmyolo/datasets/__init__.py index 9db439045..25ecb9a0d 100644 --- a/mmyolo/datasets/__init__.py +++ b/mmyolo/datasets/__init__.py @@ -6,9 +6,10 @@ from .yolov5_crowdhuman import YOLOv5CrowdHumanDataset from .yolov5_dota import YOLOv5DOTADataset from .yolov5_voc import YOLOv5VOCDataset +from .yolov6_widerface import YOLOv6WIDERFaceDataset __all__ = [ 'YOLOv5CocoDataset', 'YOLOv5VOCDataset', 'BatchShapePolicy', 'yolov5_collate', 'YOLOv5CrowdHumanDataset', 'YOLOv5DOTADataset', - 'PoseCocoDataset' + 'PoseCocoDataset', 'YOLOv6WIDERFaceDataset' ] diff --git a/mmyolo/datasets/yolov6_widerface.py b/mmyolo/datasets/yolov6_widerface.py new file mode 100644 index 000000000..706bdca77 --- /dev/null +++ b/mmyolo/datasets/yolov6_widerface.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.datasets import WIDERFaceDataset + +from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset +from ..registry import DATASETS + + +@DATASETS.register_module() +class YOLOv6WIDERFaceDataset(BatchShapePolicyDataset, WIDERFaceDataset): + """Dataset for YOLOv6 WIDERFace Dataset.""" + pass diff --git a/mmyolo/models/backbones/efficient_rep.py b/mmyolo/models/backbones/efficient_rep.py index 32e455f06..ff8fdff11 100644 --- a/mmyolo/models/backbones/efficient_rep.py +++ b/mmyolo/models/backbones/efficient_rep.py @@ -32,6 +32,9 @@ class YOLOv6EfficientRep(BaseBackbone): Defaults to (2, 3, 4). frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. + use_cspsppf (bool): Whether to use CSPSPPFBottleneck. It is only valid + when `use_spp`=True, i.e. it may be used in the last stage of the + backbone. Defaults to False. norm_cfg (dict): Dictionary to construct and config norm layer. Defaults to dict(type='BN', requires_grad=True). act_cfg (dict): Config dict for activation layer. @@ -188,6 +191,9 @@ class YOLOv6CSPBep(YOLOv6EfficientRep): Defaults to (2, 3, 4). frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Defaults to -1. + use_cspsppf (bool): Whether to use CSPSPPFBottleneck. It is only valid + when `use_spp`=True, i.e. it may be used in the last stage of the + backbone. Defaults to False. norm_cfg (dict): Dictionary to construct and config norm layer. Defaults to dict(type='BN', requires_grad=True). act_cfg (dict): Config dict for activation layer. @@ -239,7 +245,6 @@ def __init__(self, block_cfg: ConfigType = dict(type='ConvWrapper'), init_cfg: OptMultiConfig = None): self.hidden_ratio = hidden_ratio - self.use_cspsppf = use_cspsppf super().__init__( arch=arch, deepen_factor=deepen_factor, @@ -248,6 +253,7 @@ def __init__(self, out_indices=out_indices, plugins=plugins, frozen_stages=frozen_stages, + use_cspsppf=use_cspsppf, norm_cfg=norm_cfg, act_cfg=act_cfg, norm_eval=norm_eval, diff --git a/mmyolo/models/dense_heads/__init__.py b/mmyolo/models/dense_heads/__init__.py index 90587c3fb..d60d2a9bb 100644 --- a/mmyolo/models/dense_heads/__init__.py +++ b/mmyolo/models/dense_heads/__init__.py @@ -6,6 +6,7 @@ RTMDetRotatedSepBNHeadModule) from .yolov5_head import YOLOv5Head, YOLOv5HeadModule from .yolov5_ins_head import YOLOv5InsHead, YOLOv5InsHeadModule +from .yolov6_face_head import YOLOv6FaceHead, YOLOv6FaceHeadModule from .yolov6_head import YOLOv6Head, YOLOv6HeadModule from .yolov7_head import YOLOv7Head, YOLOv7HeadModule, YOLOv7p6HeadModule from .yolov8_head import YOLOv8Head, YOLOv8HeadModule @@ -13,10 +14,11 @@ from .yolox_pose_head import YOLOXPoseHead, YOLOXPoseHeadModule __all__ = [ - 'YOLOv5Head', 'YOLOv6Head', 'YOLOXHead', 'YOLOv5HeadModule', - 'YOLOv6HeadModule', 'YOLOXHeadModule', 'RTMDetHead', - 'RTMDetSepBNHeadModule', 'YOLOv7Head', 'PPYOLOEHead', 'PPYOLOEHeadModule', - 'YOLOv7HeadModule', 'YOLOv7p6HeadModule', 'YOLOv8Head', 'YOLOv8HeadModule', + 'YOLOv5Head', 'YOLOv6Head', 'YOLOv6FaceHead', 'YOLOXHead', + 'YOLOv5HeadModule', 'YOLOv6HeadModule', 'YOLOv6FaceHeadModule', + 'YOLOXHeadModule', 'RTMDetHead', 'RTMDetSepBNHeadModule', 'YOLOv7Head', + 'PPYOLOEHead', 'PPYOLOEHeadModule', 'YOLOv7HeadModule', + 'YOLOv7p6HeadModule', 'YOLOv8Head', 'YOLOv8HeadModule', 'RTMDetRotatedHead', 'RTMDetRotatedSepBNHeadModule', 'RTMDetInsSepBNHead', 'RTMDetInsSepBNHeadModule', 'YOLOv5InsHead', 'YOLOv5InsHeadModule', 'YOLOXPoseHead', 'YOLOXPoseHeadModule' diff --git a/mmyolo/models/dense_heads/yolov6_face_head.py b/mmyolo/models/dense_heads/yolov6_face_head.py new file mode 100644 index 000000000..ef8a1c52a --- /dev/null +++ b/mmyolo/models/dense_heads/yolov6_face_head.py @@ -0,0 +1,434 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmdet.models.utils import filter_scores_and_topk +from mmdet.utils import (ConfigType, OptConfigType, OptInstanceList, + OptMultiConfig) +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData +from torch import Tensor + +from mmyolo.registry import MODELS +from .yolov6_head import YOLOv6Head, YOLOv6HeadModule + + +@MODELS.register_module() +class YOLOv6FaceHeadModule(YOLOv6HeadModule): + """YOLOv6FaceHead head module used in `YOLOv6. + + `_. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (Union[int, Sequence]): Number of channels in the input + feature map. + stemout_channels (Union[int, Sequence]): Number of channels of the + feature map output by stem module. + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Defaults to 1.0. + num_base_priors: (int): The number of priors (points) at a point + on the feature grid. + featmap_strides (Sequence[int]): Downsample factor of each feature map. + Defaults to [8, 16, 32]. + None, otherwise False. Defaults to "auto". + norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization + layer. Defaults to dict(type='BN', momentum=0.03, eps=0.001). + act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. + Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + num_classes: int, + in_channels: Union[int, Sequence], + stemout_channels: Union[int, Sequence] = None, + widen_factor: float = 1.0, + num_base_priors: int = 1, + reg_max=0, + featmap_strides: Sequence[int] = (8, 16, 32), + norm_cfg: ConfigType = dict( + type='BN', momentum=0.03, eps=0.001), + act_cfg: ConfigType = dict(type='SiLU', inplace=True), + init_cfg: OptMultiConfig = None): + + if stemout_channels is None: + self.stemout_channels = self.in_channels + + if isinstance(stemout_channels, int): + num_levels = len(featmap_strides) + self.stemout_channels = [int(stemout_channels * widen_factor) + ] * num_levels + else: + self.stemout_channels = [ + int(i * widen_factor) for i in stemout_channels + ] + + super().__init__( + num_classes=num_classes, + in_channels=in_channels, + widen_factor=widen_factor, + num_base_priors=num_base_priors, + reg_max=reg_max, + featmap_strides=featmap_strides, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + init_cfg=init_cfg) + + def _init_layers(self): + """initialize conv layers in YOLOv6 head.""" + # Init decouple head + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + self.cls_preds = nn.ModuleList() + self.reg_preds = nn.ModuleList() + self.stems = nn.ModuleList() + + if self.reg_max > 1: + proj = torch.arange( + self.reg_max + self.num_base_priors, dtype=torch.float) + self.register_buffer('proj', proj, persistent=False) + + for i in range(self.num_levels): + self.stems.append( + ConvModule( + in_channels=self.in_channels[i], + out_channels=self.stemout_channels[i], + kernel_size=1, + stride=1, + padding=1 // 2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.cls_convs.append( + ConvModule( + in_channels=self.stemout_channels[i], + out_channels=self.stemout_channels[i], + kernel_size=3, + stride=1, + padding=3 // 2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.reg_convs.append( + ConvModule( + in_channels=self.stemout_channels[i], + out_channels=self.stemout_channels[i], + kernel_size=3, + stride=1, + padding=3 // 2, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.cls_preds.append( + nn.Conv2d( + in_channels=self.stemout_channels[i], + out_channels=self.num_base_priors * self.num_classes, + kernel_size=1)) + self.reg_preds.append( + nn.Conv2d( + in_channels=self.stemout_channels[i], + out_channels=(self.num_base_priors + self.reg_max) * 4 + + 10, + kernel_size=1)) + + def forward_single(self, x: Tensor, stem: nn.Module, cls_conv: nn.Module, + cls_pred: nn.Module, reg_conv: nn.Module, + reg_pred: nn.Module) -> Tuple[Tensor, Tensor]: + """Forward feature of a single scale level.""" + b, _, h, w = x.shape + y = stem(x) + cls_x = y + reg_x = y + cls_feat = cls_conv(cls_x) + reg_feat = reg_conv(reg_x) + + cls_score = cls_pred(cls_feat) + bbox_dist_preds = reg_pred(reg_feat) + keypoint_preds = bbox_dist_preds[:, -10:, :, :] + bbox_dist_preds = bbox_dist_preds[:, :-10, :, :] + + if self.reg_max > 1: + bbox_dist_preds = bbox_dist_preds.reshape( + [-1, 4, self.reg_max + self.num_base_priors, + h * w]).permute(0, 3, 1, 2) + + # TODO: The get_flops script cannot handle the situation of + # matmul, and needs to be fixed later + # bbox_preds = bbox_dist_preds.softmax(3).matmul(self.proj) + bbox_preds = bbox_dist_preds.softmax(3).matmul( + self.proj.view([-1, 1])).squeeze(-1) + bbox_preds = bbox_preds.transpose(1, 2).reshape(b, -1, h, w) + else: + bbox_preds = bbox_dist_preds + + if self.training: + return cls_score, bbox_preds, bbox_dist_preds, keypoint_preds + else: + return cls_score, bbox_preds, keypoint_preds + + +@MODELS.register_module() +class YOLOv6FaceHead(YOLOv6Head): + """YOLOv6FaceHead head used in `YOLOv6. + + `_. + + Args: + head_module(ConfigType): Base module used for YOLOv6Head + prior_generator(dict): Points generator feature maps + in 2D points-based detectors. + loss_cls (:obj:`ConfigDict` or dict): Config of classification loss. + loss_bbox (:obj:`ConfigDict` or dict): Config of localization loss. + train_cfg (:obj:`ConfigDict` or dict, optional): Training config of + anchor head. Defaults to None. + test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of + anchor head. Defaults to None. + init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or + list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__(self, + head_module: ConfigType, + prior_generator: ConfigType = dict( + type='mmdet.MlvlPointGenerator', + offset=0.5, + strides=[8, 16, 32]), + bbox_coder: ConfigType = dict(type='DistancePointBBoxCoder'), + loss_cls: ConfigType = dict( + type='mmdet.VarifocalLoss', + use_sigmoid=True, + alpha=0.75, + gamma=2.0, + iou_weighted=True, + reduction='sum', + loss_weight=1.0), + loss_bbox: ConfigType = dict( + type='IoULoss', + iou_mode='giou', + bbox_format='xyxy', + reduction='mean', + loss_weight=2.5, + return_iou=False), + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super().__init__( + head_module=head_module, + prior_generator=prior_generator, + bbox_coder=bbox_coder, + loss_cls=loss_cls, + loss_bbox=loss_bbox, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg) + + def predict_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + keypoint_preds: List[Tensor], + objectnesses: Optional[List[Tensor]] = None, + batch_img_metas: Optional[List[dict]] = None, + cfg: Optional[ConfigDict] = None, + rescale: bool = True, + with_nms: bool = True) -> List[InstanceData]: + """Transform a batch of output features extracted by the head into + bbox results. + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + keypoint_preds (list[Tensor]): Face keypoints for bboxs + in all scale levels, each is a 4D-tensor, has shape + (batch_size, 10, H, W) + objectnesses (list[Tensor], Optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, 1, H, W). + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(cls_scores) == len(bbox_preds) \ + and len(keypoint_preds) == len(bbox_preds) + if objectnesses is None: + with_objectnesses = False + else: + with_objectnesses = True + assert len(cls_scores) == len(objectnesses) + + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + + multi_label = cfg.multi_label + multi_label &= self.num_classes > 1 + cfg.multi_label = multi_label + + num_imgs = len(batch_img_metas) + featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores] + + # If the shape does not change, use the previous mlvl_priors + if featmap_sizes != self.featmap_sizes: + self.mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, + dtype=cls_scores[0].dtype, + device=cls_scores[0].device) + self.featmap_sizes = featmap_sizes + flatten_priors = torch.cat(self.mlvl_priors) + + mlvl_strides = [ + flatten_priors.new_full( + (featmap_size.numel() * self.num_base_priors, ), stride) for + featmap_size, stride in zip(featmap_sizes, self.featmap_strides) + ] + flatten_stride = torch.cat(mlvl_strides) + + # flatten cls_scores, bbox_preds, keypoint_preds and objectness + flatten_cls_scores = [ + cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.num_classes) + for cls_score in cls_scores + ] + flatten_bbox_preds = [ + bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) + for bbox_pred in bbox_preds + ] + flatten_keypoint_preds = [ + keypoint_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 10) + for keypoint_pred in keypoint_preds + ] + + flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid() + flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) + flatten_decoded_bboxes = self.bbox_coder.decode( + flatten_priors[None], flatten_bbox_preds, flatten_stride) + flatten_keypoint_preds = torch.cat(flatten_keypoint_preds, dim=1) + flatten_keypoint_preds = flatten_keypoint_preds * \ + flatten_stride[None, :, None] + \ + flatten_priors.repeat(1, 5) + + if with_objectnesses: + flatten_objectness = [ + objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1) + for objectness in objectnesses + ] + flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid() + else: + flatten_objectness = [None for _ in range(num_imgs)] + + results_list = [] + for (bboxes, scores, keypoints, objectness, + img_meta) in zip(flatten_decoded_bboxes, flatten_cls_scores, + flatten_keypoint_preds, flatten_objectness, + batch_img_metas): + ori_shape = img_meta['ori_shape'] + scale_factor = img_meta['scale_factor'] + if 'pad_param' in img_meta: + pad_param = img_meta['pad_param'] + else: + pad_param = None + + score_thr = cfg.get('score_thr', -1) + # yolox_style does not require the following operations + if objectness is not None and score_thr > 0 and not cfg.get( + 'yolox_style', False): + conf_inds = objectness > score_thr + bboxes = bboxes[conf_inds, :] + scores = scores[conf_inds, :] + objectness = objectness[conf_inds] + + if objectness is not None: + # conf = obj_conf * cls_conf + scores *= objectness[:, None] + + if scores.shape[0] == 0: + empty_results = InstanceData() + empty_results.bboxes = bboxes + empty_results.scores = scores[:, 0] + empty_results.labels = scores[:, 0].int() + results_list.append(empty_results) + continue + + nms_pre = cfg.get('nms_pre', 100000) + if cfg.multi_label is False: + scores, labels = scores.max(1, keepdim=True) + scores, _, keep_idxs, results = filter_scores_and_topk( + scores, + score_thr, + nms_pre, + results=dict(labels=labels[:, 0])) + labels = results['labels'] + else: + scores, labels, keep_idxs, _ = filter_scores_and_topk( + scores, score_thr, nms_pre) + + results = InstanceData( + scores=scores, + labels=labels, + bboxes=bboxes[keep_idxs], + keypoints=keypoints[keep_idxs]) + + if rescale: + if pad_param is not None: + results.bboxes -= results.bboxes.new_tensor([ + pad_param[2], pad_param[0], pad_param[2], pad_param[0] + ]) + results.keypoints -= results.keypoints.new_tensor( + [pad_param[2], pad_param[0]]).repeat(5) + results.bboxes /= results.bboxes.new_tensor( + scale_factor).repeat((1, 2)) + results.keypoints /= results.keypoints.new_tensor( + scale_factor).repeat((1, 5)) + + if cfg.get('yolox_style', False): + # do not need max_per_img + cfg.max_per_img = len(results) + + results = self._bbox_post_process( + results=results, + cfg=cfg, + rescale=False, + with_nms=with_nms, + img_meta=img_meta) + results.bboxes[:, 0::2].clamp_(0, ori_shape[1]) + results.bboxes[:, 1::2].clamp_(0, ori_shape[0]) + + results_list.append(results) + return results_list + + def loss_by_feat( + self, + cls_scores: Sequence[Tensor], + bbox_preds: Sequence[Tensor], + bbox_dist_preds: Sequence[Tensor], + keypoint_preds: List[Tensor], + batch_gt_instances: Sequence[InstanceData], + batch_img_metas: Sequence[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + # TODO: calculate keypoint preds + super().loss_by_feat(cls_scores, bbox_preds, bbox_dist_preds, + batch_gt_instances, batch_img_metas, + batch_gt_instances_ignore) diff --git a/mmyolo/utils/__init__.py b/mmyolo/utils/__init__.py index f4e968494..1afdd8c2a 100644 --- a/mmyolo/utils/__init__.py +++ b/mmyolo/utils/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .collect_env import collect_env +from .face_visualizer import FaceVisualizer from .misc import is_metainfo_lower, switch_to_deploy from .setup_env import register_all_modules __all__ = [ 'register_all_modules', 'collect_env', 'switch_to_deploy', - 'is_metainfo_lower' + 'is_metainfo_lower', 'FaceVisualizer' ] diff --git a/mmyolo/utils/face_visualizer.py b/mmyolo/utils/face_visualizer.py new file mode 100644 index 000000000..27cdcf0bf --- /dev/null +++ b/mmyolo/utils/face_visualizer.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +from mmdet.visualization import DetLocalVisualizer +from mmengine.structures import InstanceData + +from mmyolo.registry import VISUALIZERS + + +@VISUALIZERS.register_module() +class FaceVisualizer(DetLocalVisualizer): + + def __init__(self, + name: str = 'visualizer', + image: Optional[np.ndarray] = None, + vis_backends: Optional[Dict] = None, + save_dir: Optional[str] = None, + bbox_color: Optional[Union[str, Tuple[int]]] = None, + text_color: Optional[Union[str, + Tuple[int]]] = (200, 200, 200), + mask_color: Optional[Union[str, Tuple[int]]] = None, + keypoint_color: Optional[Union[str, + Tuple[int]]] = ('blue', + 'green', 'red', + 'cyan', + 'yellow'), + line_width: Union[int, float] = 3, + alpha: float = 0.8) -> None: + super().__init__(name, image, vis_backends, save_dir, bbox_color, + text_color, mask_color, line_width, alpha) + self.keypoint_color = keypoint_color + + def _draw_instances(self, image: np.ndarray, instances: List[InstanceData], + classes: Optional[List[str]], + palette: Optional[List[tuple]]) -> np.ndarray: + super()._draw_instances(image, instances, classes, palette) + if 'keypoints' in instances: + keypoints = instances.keypoints + for i in range(5): + self.draw_points( + positions=keypoints[:, i * 2:(i + 1) * 2], + colors=self.keypoint_color[i], + sizes=5) + return self.get_image() diff --git a/tests/test_models/test_dense_heads/test_yolov6-face_head.py b/tests/test_models/test_dense_heads/test_yolov6-face_head.py new file mode 100644 index 000000000..7cbe0c5d7 --- /dev/null +++ b/tests/test_models/test_dense_heads/test_yolov6-face_head.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +from mmengine.config import Config + +from mmyolo.models.dense_heads import YOLOv6FaceHead +from mmyolo.utils import register_all_modules + +register_all_modules() + + +class TestYOLOv6FaceHead(TestCase): + + def setUp(self): + self.head_module = dict( + type='YOLOv6FaceHeadModule', + num_classes=2, + in_channels=[32, 64, 128], + stemout_channels=64, + featmap_strides=[8, 16, 32]) + + def test_predict_by_feat(self): + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'ori_shape': (s, s, 3), + 'scale_factor': (1.0, 1.0), + }] + test_cfg = Config( + dict( + multi_label=True, + max_per_img=300, + score_thr=0.01, + nms=dict(type='nms', iou_threshold=0.65))) + + head = YOLOv6FaceHead(head_module=self.head_module, test_cfg=test_cfg) + head.eval() + + feat = [] + for i in range(len(self.head_module['in_channels'])): + in_channel = self.head_module['in_channels'][i] + feat_size = self.head_module['featmap_strides'][i] + feat.append( + torch.rand(1, in_channel, s // feat_size, s // feat_size)) + + cls_scores, bbox_preds, keypoint_preds = head.forward(feat) + head.predict_by_feat( + cls_scores, + bbox_preds, + keypoint_preds, + None, + img_metas, + cfg=test_cfg, + rescale=True, + with_nms=True) + head.predict_by_feat( + cls_scores, + bbox_preds, + keypoint_preds, + None, + img_metas, + cfg=test_cfg, + rescale=False, + with_nms=False) diff --git a/tools/model_converters/yolov6-face_v3_to_mmyolo.py b/tools/model_converters/yolov6-face_v3_to_mmyolo.py new file mode 100644 index 000000000..3c8a98856 --- /dev/null +++ b/tools/model_converters/yolov6-face_v3_to_mmyolo.py @@ -0,0 +1,127 @@ +import argparse +from collections import OrderedDict + +import torch + + +def convert(src, dst): + import sys + sys.path.append('yolov6') + try: + ckpt = torch.load(src, map_location=torch.device('cpu')) + except ModuleNotFoundError: + raise RuntimeError( + 'This script must be placed under the meituan/YOLOv6 repo,' + ' because loading the official pretrained model need' + ' some python files to build model.') + # The saved model is the model before reparameterization + model = ckpt['ema' if ckpt.get('ema') else 'model'].float() + new_state_dict = OrderedDict() + + for k, v in model.state_dict().items(): + name = k + if 'detect' in k: + if 'proj' in k: + continue + name = k.replace('detect', 'bbox_head.head_module') + if k.find('anchors') >= 0 or k.find('anchor_grid') >= 0: + continue + + if 'ERBlock_2' in k: + name = k.replace('ERBlock_2', 'stage1.0') + if '.cv' in k: + name = name.replace('.cv', '.conv') + if '.m.' in k: + name = name.replace('.m.', '.block.') + elif 'ERBlock_3' in k: + name = k.replace('ERBlock_3', 'stage2.0') + if '.cv' in k: + name = name.replace('.cv', '.conv') + if '.m.' in k: + name = name.replace('.m.', '.block.') + elif 'ERBlock_4' in k: + name = k.replace('ERBlock_4', 'stage3.0') + if '.cv' in k: + name = name.replace('.cv', '.conv') + if '.m.' in k: + name = name.replace('.m.', '.block.') + elif 'ERBlock_5' in k: + name = k.replace('ERBlock_5', 'stage4.0') + if '.cv' in k: + name = name.replace('.cv', '.conv') + if '.m.' in k: + name = name.replace('.m.', '.block.') + if 'stage4.0.2' in name: + name = name.replace('stage4.0.2', 'stage4.1') + name = name.replace('cv', 'conv') + elif 'reduce_layer0' in k: + name = k.replace('reduce_layer0', 'reduce_layers.2') + elif 'Rep_p4' in k: + name = k.replace('Rep_p4', 'top_down_layers.0.0') + if '.cv' in k: + name = name.replace('.cv', '.conv') + if '.m.' in k: + name = name.replace('.m.', '.block.') + elif 'reduce_layer1' in k: + name = k.replace('reduce_layer1', 'top_down_layers.0.1') + if '.cv' in k: + name = name.replace('.cv', '.conv') + if '.m.' in k: + name = name.replace('.m.', '.block.') + elif 'Rep_p3' in k: + name = k.replace('Rep_p3', 'top_down_layers.1') + if '.cv' in k: + name = name.replace('.cv', '.conv') + if '.m.' in k: + name = name.replace('.m.', '.block.') + elif 'Bifusion0' in k: + name = k.replace('Bifusion0', 'upsample_layers.0') + if '.cv' in k: + name = name.replace('.cv', '.conv') + if '.m.' in k: + name = name.replace('.m.', '.block.') + if '.upsample_transpose.' in k: + name = name.replace('.upsample_transpose.', '.') + elif 'Bifusion1' in k: + name = k.replace('Bifusion1', 'upsample_layers.1') + if '.cv' in k: + name = name.replace('.cv', '.conv') + if '.m.' in k: + name = name.replace('.m.', '.block.') + if '.upsample_transpose.' in k: + name = name.replace('.upsample_transpose.', '.') + elif 'Rep_n3' in k: + name = k.replace('Rep_n3', 'bottom_up_layers.0') + if '.cv' in k: + name = name.replace('.cv', '.conv') + if '.m.' in k: + name = name.replace('.m.', '.block.') + elif 'Rep_n4' in k: + name = k.replace('Rep_n4', 'bottom_up_layers.1') + if '.cv' in k: + name = name.replace('.cv', '.conv') + if '.m.' in k: + name = name.replace('.m.', '.block.') + elif 'downsample2' in k: + name = k.replace('downsample2', 'downsample_layers.0') + elif 'downsample1' in k: + name = k.replace('downsample1', 'downsample_layers.1') + + new_state_dict[name] = v + + data = {'state_dict': new_state_dict} + torch.save(data, dst) + + +# Note: This script must be placed under the yolov6 repo to run. +def main(): + parser = argparse.ArgumentParser(description='Convert model keys') + parser.add_argument( + '--src', default='yolov6s.pt', help='src yolov6 model path') + parser.add_argument('--dst', default='mmyolov6.pt', help='save path') + args = parser.parse_args() + convert(args.src, args.dst) + + +if __name__ == '__main__': + main()