diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c461b497ce1..13e4cc406c4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,3 +43,9 @@ repos: hooks: - id: docformatter args: ["--in-place", "--wrap-descriptions", "79"] + - repo: https://github.com/open-mmlab/pre-commit-hooks + rev: master # Use the ref you want to point at + hooks: + - id: check-algo-readme + - id: check-copyright + args: ["mmdet"] # replace the dir_to_check with your expected directory to check diff --git a/README.md b/README.md index 09231c7603d..8fd8f10234a 100644 --- a/README.md +++ b/README.md @@ -13,10 +13,10 @@ -[📘Documentation](https://mmdetection.readthedocs.io/en/v2.19.1/) | -[🛠️Installation](https://mmdetection.readthedocs.io/en/v2.19.1/get_started.html) | -[👀Model Zoo](https://mmdetection.readthedocs.io/en/v2.19.1/model_zoo.html) | -[🆕Update News](https://mmdetection.readthedocs.io/en/v2.19.1/changelog.html) | +[📘Documentation](https://mmdetection.readthedocs.io/en/v2.20.0/) | +[🛠️Installation](https://mmdetection.readthedocs.io/en/v2.20.0/get_started.html) | +[👀Model Zoo](https://mmdetection.readthedocs.io/en/v2.20.0/model_zoo.html) | +[🆕Update News](https://mmdetection.readthedocs.io/en/v2.20.0/changelog.html) | [🚀Ongoing Projects](https://github.com/open-mmlab/mmdetection/projects) | [🤔Reporting Issues](https://github.com/open-mmlab/mmdetection/issues/new/choose) @@ -60,11 +60,10 @@ This project is released under the [Apache 2.0 license](LICENSE). ## Changelog -**2.19.1** was released in 14/12/2021: +**2.20.0** was released in 27/12/2021: -- Release [YOLOX](configs/yolox/README.md) COCO pretrained models -- Add abstract and sketch of the papers in readmes -- Fix some weight initialization bugs +- Support [TOOD](configs/tood/README.md): Task-aligned One-stage Object Detection (ICCV 2021 Oral) +- Support resuming from the latest checkpoint automatically Please refer to [changelog.md](docs/en/changelog.md) for details and release history. @@ -149,6 +148,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md). - [x] [YOLOX (ArXiv'2021)](configs/yolox/README.md) - [x] [SOLO (ECCV'2020)](configs/solo/README.md) - [x] [QueryInst (ICCV'2021)](configs/queryinst/README.md) +- [x] [TOOD (ICCV'2021)](configs/tood/README.md) Some other methods are also supported in [projects using MMDetection](./docs/en/projects.md). @@ -209,3 +209,5 @@ If you use this toolbox or benchmark in your research, please cite this project. - [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark. - [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark. - [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark. +- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark. +- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab Model Compression Toolbox and Benchmark. diff --git a/README_zh-CN.md b/README_zh-CN.md index 846b4f4a93d..bca9a2a4e65 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -13,10 +13,10 @@ -[📘使用文档](https://mmdetection.readthedocs.io/zh_CN/v2.19.1/) | -[🛠️安装教程](https://mmdetection.readthedocs.io/zh_CN/v2.19.1/get_started.html) | -[👀模型库](https://mmdetection.readthedocs.io/zh_CN/v2.19.1/model_zoo.html) | -[🆕更新日志](https://mmdetection.readthedocs.io/en/v2.19.1/changelog.html) | +[📘使用文档](https://mmdetection.readthedocs.io/zh_CN/v2.20.0/) | +[🛠️安装教程](https://mmdetection.readthedocs.io/zh_CN/v2.20.0/get_started.html) | +[👀模型库](https://mmdetection.readthedocs.io/zh_CN/v2.20.0/model_zoo.html) | +[🆕更新日志](https://mmdetection.readthedocs.io/en/v2.20.0/changelog.html) | [🚀进行中的项目](https://github.com/open-mmlab/mmdetection/projects) | [🤔报告问题](https://github.com/open-mmlab/mmdetection/issues/new/choose) @@ -59,10 +59,9 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope ## 更新日志 -最新的 **2.19.1** 版本已经在 2021.12.14 发布: -- 发布 [YOLOX](configs/yolox/README.md) COCO 预训练模型 -- 在自述文件中添加论文的摘要和草图 -- 修复一些权重初始化错误 +最新的 **2.20.0** 版本已经在 2021.12.27 发布: +- 支持了 ICCV 2021 Oral 方法 [TOOD](configs/tood/README.md): Task-aligned One-stage Object Detection +- 支持了自动从最新的存储参数节点恢复训练 如果想了解更多版本更新细节和历史信息,请阅读[更新日志](docs/changelog.md)。 @@ -146,6 +145,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope - [x] [YOLOX (ArXiv'2021)](configs/yolox/README.md) - [x] [SOLO (ECCV'2020)](configs/solo/README.md) - [x] [QueryInst (ICCV'2021)](configs/queryinst/README.md) +- [x] [TOOD (ICCV'2021)](configs/tood/README.md) 我们在[基于 MMDetection 的项目](./docs/zh_cn/projects.md)中列举了一些其他的支持的算法。 @@ -206,6 +206,8 @@ MMDetection 是一款由来自不同高校和企业的研发人员共同参与 - [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab 光流估计工具箱与测试基准 - [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab 少样本学习工具箱与测试基准 - [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 人体参数化模型工具箱与测试基准 +- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab 自监督学习工具箱与测试基准 +- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab 模型压缩工具箱与测试基准 ## 欢迎加入 OpenMMLab 社区 diff --git a/configs/resnest/metafile.yml b/configs/resnest/metafile.yml index d7f68e5cd8b..3323fad027a 100644 --- a/configs/resnest/metafile.yml +++ b/configs/resnest/metafile.yml @@ -11,7 +11,7 @@ Collections: Paper: URL: https://arxiv.org/abs/2004.08955 Title: 'ResNeSt: Split-Attention Networks' - README: configs/renest/README.md + README: configs/resnest/README.md Code: URL: https://github.com/open-mmlab/mmdetection/blob/v2.7.0/mmdet/models/backbones/resnest.py#L273 Version: v2.7.0 diff --git a/configs/strong_baselines/README.md b/configs/strong_baselines/README.md index c1487ef99a3..5ada104bbe2 100644 --- a/configs/strong_baselines/README.md +++ b/configs/strong_baselines/README.md @@ -1,6 +1,6 @@ # Strong Baselines -We train Mask R-CNN with large-scale jittor and longer schedule as strong baselines. +We train Mask R-CNN with large-scale jitter and longer schedule as strong baselines. The modifications follow those in [Detectron2](https://github.com/facebookresearch/detectron2/tree/master/configs/new_baselines). ## Results and models diff --git a/configs/tood/README.md b/configs/tood/README.md new file mode 100644 index 00000000000..b1522e78565 --- /dev/null +++ b/configs/tood/README.md @@ -0,0 +1,44 @@ +# TOOD: Task-aligned One-stage Object Detection + +## Abstract + + + +One-stage object detection is commonly implemented by optimizing two sub-tasks: object classification and localization, using heads with two parallel branches, which might lead to a certain level of spatial misalignment in predictions between the two tasks. In this work, we propose a Task-aligned One-stage Object Detection (TOOD) that explicitly aligns the two tasks in a learning-based manner. First, we design a novel Task-aligned Head (T-Head) which offers a better balance between learning task-interactive and task-specific features, as well as a greater flexibility to learn the alignment via a task-aligned predictor. Second, we propose Task Alignment Learning (TAL) to explicitly pull closer (or even unify) the optimal anchors for the two tasks during training via a designed sample assignment scheme and a task-aligned loss. Extensive experiments are conducted on MS-COCO, where TOOD achieves a 51.1 AP at single-model single-scale testing. This surpasses the recent one-stage detectors by a large margin, such as ATSS (47.7 AP), GFL (48.2 AP), and PAA (49.0 AP), with fewer parameters and FLOPs. Qualitative results also demonstrate the effectiveness of TOOD for better aligning the tasks of object classification and localization. + + +
+ +
+ + + + +## Citation + + + +```latex +@inproceedings{feng2021tood, + title={TOOD: Task-aligned One-stage Object Detection}, + author={Feng, Chengjian and Zhong, Yujie and Gao, Yu and Scott, Matthew R and Huang, Weilin}, + booktitle={ICCV}, + year={2021} +} +``` + +## Results and Models + +| Backbone | Style | Anchor Type | Lr schd | Multi-scale Training| Mem (GB)| Inf time (fps) | box AP | Config | Download | +|:-----------------:|:-------:|:------------:|:-------:|:-------------------:|:-------:|:--------------:|:------:|:------:|:--------:| +| R-50 | pytorch | Anchor-free | 1x | N | 4.1 | | 42.4 | [config](./tood_r50_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_1x_coco/tood_r50_fpn_1x_coco_20211210_103425-20e20746.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_1x_coco/tood_r50_fpn_1x_coco_20211210_103425.log) | +| R-50 | pytorch | Anchor-based | 1x | N | 4.1 | | 42.4 | [config](./tood_r50_fpn_anchor_based_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_anchor_based_1x_coco/tood_r50_fpn_anchor_based_1x_coco_20211214_100105-b776c134.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_anchor_based_1x_coco/tood_r50_fpn_anchor_based_1x_coco_20211214_100105.log) | +| R-50 | pytorch | Anchor-free | 2x | Y | 4.1 | | 44.5 | [config](./tood_r50_fpn_mstrain_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_mstrain_2x_coco/tood_r50_fpn_mstrain_2x_coco_20211210_144231-3b23174c.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_mstrain_2x_coco/tood_r50_fpn_mstrain_2x_coco_20211210_144231.log) | +| R-101 | pytorch | Anchor-free | 2x | Y | 6.0 | | 46.1 | [config](./tood_r101_fpn_mstrain_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_mstrain_2x_coco/tood_r101_fpn_mstrain_2x_coco_20211210_144232-a18f53c8.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_mstrain_2x_coco/tood_r101_fpn_mstrain_2x_coco_20211210_144232.log) | +| R-101-dcnv2 | pytorch | Anchor-free | 2x | Y | 6.2 | | 49.3 | [config](./tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco_20211210_213728-4a824142.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco_20211210_213728.log) | +| X-101-64x4d | pytorch | Anchor-free | 2x | Y | 10.2 | | 47.6 | [config](./tood_x101_64x4d_fpn_mstrain_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_x101_64x4d_fpn_mstrain_2x_coco/tood_x101_64x4d_fpn_mstrain_2x_coco_20211211_003519-a4f36113.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/tood/tood_x101_64x4d_fpn_mstrain_2x_coco/tood_x101_64x4d_fpn_mstrain_2x_coco_20211211_003519.log) | +| X-101-64x4d-dcnv2 | pytorch | Anchor-free | 2x | Y | | | | [config](./tood_x101_64x4d_fpn_dconv_c4-c5_mstrain_2x_coco.py) | [model]() | [log]() | + +[1] *1x and 2x mean the model is trained for 90K and 180K iterations, respectively.* \ +[2] *All results are obtained with a single model and without any test time data augmentation such as multi-scale, flipping and etc..* \ +[3] *`dcnv2` denotes deformable convolutional networks v2.* \ diff --git a/configs/tood/metafile.yml b/configs/tood/metafile.yml new file mode 100644 index 00000000000..27a0f8dbfc5 --- /dev/null +++ b/configs/tood/metafile.yml @@ -0,0 +1,95 @@ +Collections: + - Name: TOOD + Metadata: + Training Data: COCO + Training Techniques: + - SGD + Training Resources: 8x V100 GPUs + Architecture: + - TOOD + Paper: + URL: https://arxiv.org/abs/2108.07755 + Title: 'TOOD: Task-aligned One-stage Object Detection' + README: configs/tood/README.md + Code: + URL: https://github.com/open-mmlab/mmdetection/blob/v2.20.0/mmdet/models/detectors/tood.py#L7 + Version: v2.20.0 + +Models: + - Name: tood_r101_fpn_mstrain_2x_coco + In Collection: TOOD + Config: configs/tood/tood_r101_fpn_mstrain_2x_coco.py + Metadata: + Training Memory (GB): 6.0 + Epochs: 24 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 46.1 + Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_mstrain_2x_coco/tood_r101_fpn_mstrain_2x_coco_20211210_144232-a18f53c8.pth + + - Name: tood_x101_64x4d_fpn_mstrain_2x_coco + In Collection: TOOD + Config: configs/tood/tood_x101_64x4d_fpn_mstrain_2x_coco.py + Metadata: + Training Memory (GB): 10.2 + Epochs: 24 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 47.6 + Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_x101_64x4d_fpn_mstrain_2x_coco/tood_x101_64x4d_fpn_mstrain_2x_coco_20211211_003519-a4f36113.pth + + - Name: tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco + In Collection: TOOD + Config: configs/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py + Metadata: + Training Memory (GB): 6.2 + Epochs: 24 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 49.3 + Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco_20211210_213728-4a824142.pth + + - Name: tood_r50_fpn_anchor_based_1x_coco + In Collection: TOOD + Config: configs/tood/tood_r50_fpn_anchor_based_1x_coco.py + Metadata: + Training Memory (GB): 4.1 + Epochs: 12 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 42.4 + Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_anchor_based_1x_coco/tood_r50_fpn_anchor_based_1x_coco_20211214_100105-b776c134.pth + + - Name: tood_r50_fpn_1x_coco + In Collection: TOOD + Config: configs/tood/tood_r50_fpn_1x_coco.py + Metadata: + Training Memory (GB): 4.1 + Epochs: 12 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 42.4 + Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_1x_coco/tood_r50_fpn_1x_coco_20211210_103425-20e20746.pth + + - Name: tood_r50_fpn_mstrain_2x_coco + In Collection: TOOD + Config: configs/tood/tood_r50_fpn_mstrain_2x_coco.py + Metadata: + Training Memory (GB): 4.1 + Epochs: 24 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 44.5 + Weights: https://download.openmmlab.com/mmdetection/v2.0/tood/tood_r50_fpn_mstrain_2x_coco/tood_r50_fpn_mstrain_2x_coco_20211210_144231-3b23174c.pth diff --git a/configs/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py b/configs/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py new file mode 100644 index 00000000000..c7f1bbcbaf1 --- /dev/null +++ b/configs/tood/tood_r101_fpn_dconv_c3-c5_mstrain_2x_coco.py @@ -0,0 +1,7 @@ +_base_ = './tood_r101_fpn_mstrain_2x_coco.py' + +model = dict( + backbone=dict( + dcn=dict(type='DCNv2', deformable_groups=1, fallback_on_stride=False), + stage_with_dcn=(False, True, True, True)), + bbox_head=dict(num_dcn=2)) diff --git a/configs/tood/tood_r101_fpn_mstrain_2x_coco.py b/configs/tood/tood_r101_fpn_mstrain_2x_coco.py new file mode 100644 index 00000000000..d9d2c32d8ce --- /dev/null +++ b/configs/tood/tood_r101_fpn_mstrain_2x_coco.py @@ -0,0 +1,7 @@ +_base_ = './tood_r50_fpn_mstrain_2x_coco.py' + +model = dict( + backbone=dict( + depth=101, + init_cfg=dict(type='Pretrained', + checkpoint='torchvision://resnet101'))) diff --git a/configs/tood/tood_r50_fpn_1x_coco.py b/configs/tood/tood_r50_fpn_1x_coco.py new file mode 100644 index 00000000000..35a77a400e1 --- /dev/null +++ b/configs/tood/tood_r50_fpn_1x_coco.py @@ -0,0 +1,74 @@ +_base_ = [ + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +model = dict( + type='TOOD', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='TOODHead', + num_classes=80, + in_channels=256, + stacked_convs=6, + feat_channels=256, + anchor_type='anchor_free', + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + initial_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0)), + train_cfg=dict( + initial_epoch=4, + initial_assigner=dict(type='ATSSAssigner', topk=9), + assigner=dict(type='TaskAlignedAssigner', topk=13), + alpha=1, + beta=6, + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) + +# custom hooks +custom_hooks = [dict(type='SetEpochInfoHook')] diff --git a/configs/tood/tood_r50_fpn_anchor_based_1x_coco.py b/configs/tood/tood_r50_fpn_anchor_based_1x_coco.py new file mode 100644 index 00000000000..c7fbf6aff19 --- /dev/null +++ b/configs/tood/tood_r50_fpn_anchor_based_1x_coco.py @@ -0,0 +1,2 @@ +_base_ = './tood_r50_fpn_1x_coco.py' +model = dict(bbox_head=dict(anchor_type='anchor_based')) diff --git a/configs/tood/tood_r50_fpn_mstrain_2x_coco.py b/configs/tood/tood_r50_fpn_mstrain_2x_coco.py new file mode 100644 index 00000000000..157d13a4a17 --- /dev/null +++ b/configs/tood/tood_r50_fpn_mstrain_2x_coco.py @@ -0,0 +1,22 @@ +_base_ = './tood_r50_fpn_1x_coco.py' +# learning policy +lr_config = dict(step=[16, 22]) +runner = dict(type='EpochBasedRunner', max_epochs=24) +# multi-scale training +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='Resize', + img_scale=[(1333, 480), (1333, 800)], + multiscale_mode='range', + keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +data = dict(train=dict(pipeline=train_pipeline)) diff --git a/configs/tood/tood_x101_64x4d_fpn_dconv_c4-c5_mstrain_2x_coco.py b/configs/tood/tood_x101_64x4d_fpn_dconv_c4-c5_mstrain_2x_coco.py new file mode 100644 index 00000000000..47c92695a92 --- /dev/null +++ b/configs/tood/tood_x101_64x4d_fpn_dconv_c4-c5_mstrain_2x_coco.py @@ -0,0 +1,7 @@ +_base_ = './tood_x101_64x4d_fpn_mstrain_2x_coco.py' +model = dict( + backbone=dict( + dcn=dict(type='DCNv2', deformable_groups=1, fallback_on_stride=False), + stage_with_dcn=(False, False, True, True), + ), + bbox_head=dict(num_dcn=2)) diff --git a/configs/tood/tood_x101_64x4d_fpn_mstrain_2x_coco.py b/configs/tood/tood_x101_64x4d_fpn_mstrain_2x_coco.py new file mode 100644 index 00000000000..842f320e839 --- /dev/null +++ b/configs/tood/tood_x101_64x4d_fpn_mstrain_2x_coco.py @@ -0,0 +1,16 @@ +_base_ = './tood_r50_fpn_mstrain_2x_coco.py' + +model = dict( + backbone=dict( + type='ResNeXt', + depth=101, + groups=64, + base_width=4, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict( + type='Pretrained', checkpoint='open-mmlab://resnext101_64x4d'))) diff --git a/docker/serve/Dockerfile b/docker/serve/Dockerfile index 2a3fc0f5581..dcadde7076c 100644 --- a/docker/serve/Dockerfile +++ b/docker/serve/Dockerfile @@ -4,7 +4,7 @@ ARG CUDNN="7" FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel ARG MMCV="1.3.17" -ARG MMDET="2.19.1" +ARG MMDET="2.20.0" ENV PYTHONUNBUFFERED TRUE diff --git a/docs/en/changelog.md b/docs/en/changelog.md index 7a6d81dd637..27673af5aa3 100644 --- a/docs/en/changelog.md +++ b/docs/en/changelog.md @@ -1,5 +1,32 @@ ## Changelog +### v2.20.0 (27/12/2021) + +#### New Features + +- Support [TOOD](configs/tood/README.md): Task-aligned One-stage Object Detection (ICCV 2021 Oral) (#6746) +- Support resuming from the latest checkpoint automatically (#6727) + +#### Bug Fixes + +- Fix wrong bbox `loss_weight` of the PAA head (#6744) +- Fix the padding value of `gt_semantic_seg` in batch collating (#6837) +- Fix test error of lvis when using `classwise` (#6845) +- Avoid BC-breaking of `get_local_path` (#6719) +- Fix bug in `sync_norm_hook` when the BN layer does not exist (#6852) +- Use pycocotools directly no matter what platform it is (#6838) + +#### Improvements + +- Add unit test for SimOTA with no valid bbox (#6770) +- Use precommit to check readme (#6802) +- Support selecting GPU-ids in non-distributed testing time (#6781) + +#### Contributors + +A total of 12 developers contributed to this release. +Thanks @ZwwWayne, @Czm369, @jshilong, @RangiLyu, @BIGWangYuDong, @hhaAndroid, @jamiechoi1995, @AronLin, @Keiku, @gkagkos, @fcakyon, @www516717402 + ### v2.19.1 (14/12/2021) #### New Features diff --git a/docs/en/conf.py b/docs/en/conf.py index 13b7e81b550..9227f98095b 100644 --- a/docs/en/conf.py +++ b/docs/en/conf.py @@ -87,117 +87,10 @@ def get_version(): 'name': 'GitHub', 'url': 'https://github.com/open-mmlab/mmdetection' }, - { - 'name': - 'Projects', - 'children': [{ - 'name': - 'MMCV', - 'url': - 'https://mmcv.readthedocs.io/en/latest/', - 'description': - 'Foundational library for computer vision' - }, { - 'name': - 'MMDetection', - 'url': - 'https://mmdetection.readthedocs.io/en/latest/', - 'description': - 'Object detection toolbox and benchmark' - }, { - 'name': - 'MMAction2', - 'url': - 'https://mmaction2.readthedocs.io/en/latest/', - 'description': - 'Action understanding toolbox and benchmark' - }, { - 'name': - 'MMClassification', - 'url': - 'https://mmclassification.readthedocs.io/en/latest/', - 'description': - 'Image classification toolbox and benchmark' - }, { - 'name': - 'MMSegmentation', - 'url': - 'https://mmsegmentation.readthedocs.io/en/latest/', - 'description': - 'Semantic segmentation toolbox and benchmark' - }, { - 'name': 'MMDetection3D', - 'url': 'https://mmdetection3d.readthedocs.io/en/latest/', - 'description': 'General 3D object detection platform' - }, { - 'name': 'MMEditing', - 'url': 'https://mmediting.readthedocs.io/en/latest/', - 'description': 'Image and video editing toolbox' - }, { - 'name': - 'MMOCR', - 'url': - 'https://mmocr.readthedocs.io/en/latest/', - 'description': - 'Text detection, recognition and understanding toolbox' - }, { - 'name': 'MMPose', - 'url': 'https://mmpose.readthedocs.io/en/latest/', - 'description': 'Pose estimation toolbox and benchmark' - }, { - 'name': - 'MMTracking', - 'url': - 'https://mmtracking.readthedocs.io/en/latest/', - 'description': - 'Video perception toolbox and benchmark' - }, { - 'name': 'MMGeneration', - 'url': 'https://mmgeneration.readthedocs.io/en/latest/', - 'description': 'Generative model toolbox' - }, { - 'name': 'MMFlow', - 'url': 'https://mmflow.readthedocs.io/en/latest/', - 'description': 'Optical flow toolbox and benchmark' - }, { - 'name': - 'MMFewShot', - 'url': - 'https://mmfewshot.readthedocs.io/en/latest/', - 'description': - 'FewShot learning toolbox and benchmark' - }, { - 'name': - 'MMHuman3D', - 'url': - 'https://mmhuman3d.readthedocs.io/en/latest/', - 'description': - '3D human parametric model toolbox and benchmark.' - }] - }, - { - 'name': - 'OpenMMLab', - 'children': [ - { - 'name': 'Homepage', - 'url': 'https://openmmlab.com/' - }, - { - 'name': 'GitHub', - 'url': 'https://github.com/open-mmlab/' - }, - { - 'name': 'Twitter', - 'url': 'https://twitter.com/OpenMMLab' - }, - { - 'name': 'Zhihu', - 'url': 'https://zhihu.com/people/openmmlab' - }, - ] - }, - ] + ], + # Specify the language of shared menu + 'menu_lang': + 'en' } # Add any paths that contain custom static files (such as style sheets) here, diff --git a/docs/en/get_started.md b/docs/en/get_started.md index 36771ffcfd7..8e28e6d4039 100644 --- a/docs/en/get_started.md +++ b/docs/en/get_started.md @@ -12,6 +12,7 @@ Compatible MMDetection and MMCV versions are shown as below. Please install the | MMDetection version | MMCV version | |:-------------------:|:-------------------:| | master | mmcv-full>=1.3.17, <1.5.0 | +| 2.20.0 | mmcv-full>=1.3.17, <1.5.0 | | 2.19.1 | mmcv-full>=1.3.17, <1.5.0 | | 2.19.0 | mmcv-full>=1.3.17, <1.5.0 | | 2.18.0 | mmcv-full>=1.3.17, <1.4.0 | diff --git a/docs/zh_cn/conf.py b/docs/zh_cn/conf.py index e6b1fc62c25..2dc42daae66 100644 --- a/docs/zh_cn/conf.py +++ b/docs/zh_cn/conf.py @@ -87,90 +87,10 @@ def get_version(): 'name': 'GitHub', 'url': 'https://github.com/open-mmlab/mmdetection' }, - { - 'name': - '算法库', - 'children': [{ - 'name': 'MMCV', - 'url': 'https://mmcv.readthedocs.io/zh_CN/latest/', - 'description': '计算机视觉基础库' - }, { - 'name': 'MMDetection', - 'url': 'https://mmdetection.readthedocs.io/zh_CN/latest/', - 'description': '检测工具箱与测试基准' - }, { - 'name': 'MMAction2', - 'url': 'https://mmaction2.readthedocs.io/zh_CN/latest/', - 'description': '视频理解工具箱与测试基准' - }, { - 'name': 'MMClassification', - 'url': 'https://mmclassification.readthedocs.io/zh_CN/latest/', - 'description': '图像分类工具箱与测试基准' - }, { - 'name': 'MMSegmentation', - 'url': 'https://mmsegmentation.readthedocs.io/zh_CN/latest/', - 'description': '语义分割工具箱与测试基准' - }, { - 'name': 'MMDetection3D', - 'url': 'https://mmdetection3d.readthedocs.io/zh_CN/latest/', - 'description': '通用3D目标检测平台' - }, { - 'name': 'MMEditing', - 'url': 'https://mmediting.readthedocs.io/zh_CN/latest/', - 'description': '图像视频编辑工具箱' - }, { - 'name': 'MMOCR', - 'url': 'https://mmocr.readthedocs.io/zh_CN/latest/', - 'description': '全流程文字检测识别理解工具包' - }, { - 'name': 'MMPose', - 'url': 'https://mmpose.readthedocs.io/zh_CN/latest/', - 'description': '姿态估计工具箱与测试基准' - }, { - 'name': 'MMTracking', - 'url': 'https://mmtracking.readthedocs.io/zh_CN/latest/', - 'description': '一体化视频目标感知平台' - }, { - 'name': 'MMGeneration', - 'url': 'https://mmgeneration.readthedocs.io/zh_CN/latest/', - 'description': '生成模型工具箱' - }, { - 'name': 'MMFlow', - 'url': 'https://mmflow.readthedocs.io/zh_CN/latest/', - 'description': '光流估计工具箱与测试基准' - }, { - 'name': 'MMFewShot', - 'url': 'https://mmfewshot.readthedocs.io/zh_CN/latest/', - 'description': '少样本学习工具箱与测试基准' - }, { - 'name': 'MMHuman3D', - 'url': 'https://mmhuman3d.readthedocs.io/en/latest/', - 'description': 'OpenMMLab 人体参数化模型工具箱与测试基准.' - }] - }, - { - 'name': - 'OpenMMLab', - 'children': [ - { - 'name': '官网', - 'url': 'https://openmmlab.com/' - }, - { - 'name': 'GitHub', - 'url': 'https://github.com/open-mmlab/' - }, - { - 'name': '推特', - 'url': 'https://twitter.com/OpenMMLab' - }, - { - 'name': '知乎', - 'url': 'https://zhihu.com/people/openmmlab' - }, - ] - }, - ] + ], + # Specify the language of shared menu + 'menu_lang': + 'cn', } # Add any paths that contain custom static files (such as style sheets) here, diff --git a/mmdet/__init__.py b/mmdet/__init__.py index e52beb9ddfd..ca5518fc31d 100644 --- a/mmdet/__init__.py +++ b/mmdet/__init__.py @@ -16,7 +16,7 @@ def digit_version(version_str): return digit_version -mmcv_minimum_version = '1.3.8' +mmcv_minimum_version = '1.3.17' mmcv_maximum_version = '1.5.0' mmcv_version = digit_version(mmcv.__version__) diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index d1aca853377..d9c8deebaca 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -13,7 +13,7 @@ from mmdet.core import DistEvalHook, EvalHook from mmdet.datasets import (build_dataloader, build_dataset, replace_ImageToTensor) -from mmdet.utils import get_root_logger +from mmdet.utils import find_latest_checkpoint, get_root_logger def init_random_seed(seed=None, device='cuda'): @@ -196,6 +196,12 @@ def train_detector(model, runner.register_hook( eval_hook(val_dataloader, **eval_cfg), priority='LOW') + resume_from = None + if cfg.resume_from is None and cfg.get('auto_resume'): + resume_from = find_latest_checkpoint(cfg.work_dir) + if resume_from is not None: + cfg.resume_from = resume_from + if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py index 2fef4be8128..a182686491d 100644 --- a/mmdet/core/bbox/assigners/__init__.py +++ b/mmdet/core/bbox/assigners/__init__.py @@ -10,10 +10,12 @@ from .point_assigner import PointAssigner from .region_assigner import RegionAssigner from .sim_ota_assigner import SimOTAAssigner +from .task_aligned_assigner import TaskAlignedAssigner from .uniform_assigner import UniformAssigner __all__ = [ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner', - 'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner' + 'HungarianAssigner', 'RegionAssigner', 'UniformAssigner', 'SimOTAAssigner', + 'TaskAlignedAssigner' ] diff --git a/mmdet/core/bbox/assigners/task_aligned_assigner.py b/mmdet/core/bbox/assigners/task_aligned_assigner.py new file mode 100644 index 00000000000..1872de4a780 --- /dev/null +++ b/mmdet/core/bbox/assigners/task_aligned_assigner.py @@ -0,0 +1,151 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from ..builder import BBOX_ASSIGNERS +from ..iou_calculators import build_iou_calculator +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + +INF = 100000000 + + +@BBOX_ASSIGNERS.register_module() +class TaskAlignedAssigner(BaseAssigner): + """Task aligned assigner used in the paper: + `TOOD: Task-aligned One-stage Object Detection. + `_. + + Assign a corresponding gt bbox or background to each predicted bbox. + Each bbox will be assigned with `0` or a positive integer + indicating the ground truth index. + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + topk (int): number of bbox selected in each level + iou_calculator (dict): Config dict for iou calculator. + Default: dict(type='BboxOverlaps2D') + """ + + def __init__(self, topk, iou_calculator=dict(type='BboxOverlaps2D')): + assert topk >= 1 + self.topk = topk + self.iou_calculator = build_iou_calculator(iou_calculator) + + def assign(self, + pred_scores, + decode_bboxes, + anchors, + gt_bboxes, + gt_bboxes_ignore=None, + gt_labels=None, + alpha=1, + beta=6): + """Assign gt to bboxes. + + The assignment is done in following steps + + 1. compute alignment metric between all bbox (bbox of all pyramid + levels) and gt + 2. select top-k bbox as candidates for each gt + 3. limit the positive sample's center in gt (because the anchor-free + detector only can predict positive distance) + + + Args: + pred_scores (Tensor): predicted class probability, + shape(n, num_classes) + decode_bboxes (Tensor): predicted bounding boxes, shape(n, 4) + anchors (Tensor): pre-defined anchors, shape(n, 4). + gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`, e.g., crowd boxes in COCO. + gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). + + Returns: + :obj:`TaskAlignedAssignResult`: The assign result. + """ + anchors = anchors[:, :4] + num_gt, num_bboxes = gt_bboxes.size(0), anchors.size(0) + # compute alignment metric between all bbox and gt + overlaps = self.iou_calculator(decode_bboxes, gt_bboxes).detach() + bbox_scores = pred_scores[:, gt_labels].detach() + # assign 0 by default + assigned_gt_inds = anchors.new_full((num_bboxes, ), + 0, + dtype=torch.long) + assign_metrics = anchors.new_zeros((num_bboxes, )) + + if num_gt == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + max_overlaps = anchors.new_zeros((num_bboxes, )) + if num_gt == 0: + # No gt boxes, assign everything to background + assigned_gt_inds[:] = 0 + if gt_labels is None: + assigned_labels = None + else: + assigned_labels = anchors.new_full((num_bboxes, ), + -1, + dtype=torch.long) + assign_result = AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + assign_result.assign_metrics = assign_metrics + return assign_result + + # select top-k bboxes as candidates for each gt + alignment_metrics = bbox_scores**alpha * overlaps**beta + topk = min(self.topk, alignment_metrics.size(0)) + _, candidate_idxs = alignment_metrics.topk(topk, dim=0, largest=True) + candidate_metrics = alignment_metrics[candidate_idxs, + torch.arange(num_gt)] + is_pos = candidate_metrics > 0 + + # limit the positive sample's center in gt + anchors_cx = (anchors[:, 0] + anchors[:, 2]) / 2.0 + anchors_cy = (anchors[:, 1] + anchors[:, 3]) / 2.0 + for gt_idx in range(num_gt): + candidate_idxs[:, gt_idx] += gt_idx * num_bboxes + ep_anchors_cx = anchors_cx.view(1, -1).expand( + num_gt, num_bboxes).contiguous().view(-1) + ep_anchors_cy = anchors_cy.view(1, -1).expand( + num_gt, num_bboxes).contiguous().view(-1) + candidate_idxs = candidate_idxs.view(-1) + + # calculate the left, top, right, bottom distance between positive + # bbox center and gt side + l_ = ep_anchors_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0] + t_ = ep_anchors_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1] + r_ = gt_bboxes[:, 2] - ep_anchors_cx[candidate_idxs].view(-1, num_gt) + b_ = gt_bboxes[:, 3] - ep_anchors_cy[candidate_idxs].view(-1, num_gt) + is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01 + is_pos = is_pos & is_in_gts + + # if an anchor box is assigned to multiple gts, + # the one with the highest iou will be selected. + overlaps_inf = torch.full_like(overlaps, + -INF).t().contiguous().view(-1) + index = candidate_idxs.view(-1)[is_pos.view(-1)] + overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index] + overlaps_inf = overlaps_inf.view(num_gt, -1).t() + + max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1) + assigned_gt_inds[ + max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1 + assign_metrics[max_overlaps != -INF] = alignment_metrics[ + max_overlaps != -INF, argmax_overlaps[max_overlaps != -INF]] + + if gt_labels is not None: + assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) + pos_inds = torch.nonzero( + assigned_gt_inds > 0, as_tuple=False).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[ + assigned_gt_inds[pos_inds] - 1] + else: + assigned_labels = None + assign_result = AssignResult( + num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) + assign_result.assign_metrics = assign_metrics + return assign_result diff --git a/mmdet/core/bbox/coder/distance_point_bbox_coder.py b/mmdet/core/bbox/coder/distance_point_bbox_coder.py index 19499e3e270..9f308a8419c 100644 --- a/mmdet/core/bbox/coder/distance_point_bbox_coder.py +++ b/mmdet/core/bbox/coder/distance_point_bbox_coder.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. from ..builder import BBOX_CODERS from ..transforms import bbox2distance, distance2bbox from .base_bbox_coder import BaseBBoxCoder diff --git a/mmdet/core/hook/__init__.py b/mmdet/core/hook/__init__.py index 31d69a0d0c8..4d4dca6b557 100644 --- a/mmdet/core/hook/__init__.py +++ b/mmdet/core/hook/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .checkloss_hook import CheckInvalidLossHook from .ema import ExpMomentumEMAHook, LinearMomentumEMAHook +from .set_epoch_info_hook import SetEpochInfoHook from .sync_norm_hook import SyncNormHook from .sync_random_size_hook import SyncRandomSizeHook from .yolox_lrupdater_hook import YOLOXLrUpdaterHook @@ -9,5 +10,5 @@ __all__ = [ 'SyncRandomSizeHook', 'YOLOXModeSwitchHook', 'SyncNormHook', 'ExpMomentumEMAHook', 'LinearMomentumEMAHook', 'YOLOXLrUpdaterHook', - 'CheckInvalidLossHook' + 'CheckInvalidLossHook', 'SetEpochInfoHook' ] diff --git a/mmdet/core/hook/set_epoch_info_hook.py b/mmdet/core/hook/set_epoch_info_hook.py new file mode 100644 index 00000000000..c2b134ceb69 --- /dev/null +++ b/mmdet/core/hook/set_epoch_info_hook.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.parallel import is_module_wrapper +from mmcv.runner import HOOKS, Hook + + +@HOOKS.register_module() +class SetEpochInfoHook(Hook): + """Set runner's epoch information to the model.""" + + def before_train_epoch(self, runner): + epoch = runner.epoch + model = runner.model + if is_module_wrapper(model): + model = model.module + model.set_epoch(epoch) diff --git a/mmdet/datasets/lvis.py b/mmdet/datasets/lvis.py index d7a2e919f5f..591f5c831f0 100644 --- a/mmdet/datasets/lvis.py +++ b/mmdet/datasets/lvis.py @@ -424,8 +424,10 @@ def evaluate(self, for idx, catId in enumerate(self.cat_ids): # area range index 0: all area ranges # max dets index -1: typically 100 per image - nm = self.coco.load_cats(catId)[0] - precision = precisions[:, :, idx, 0, -1] + # the dimensions of precisions are + # [num_thrs, num_recalls, num_cats, num_area_rngs] + nm = self.coco.load_cats([catId])[0] + precision = precisions[:, :, idx, 0] precision = precision[precision > -1] if precision.size: ap = np.mean(precision) diff --git a/mmdet/datasets/pipelines/formating.py b/mmdet/datasets/pipelines/formating.py index df037610b24..45ca69cfc6f 100644 --- a/mmdet/datasets/pipelines/formating.py +++ b/mmdet/datasets/pipelines/formating.py @@ -191,10 +191,17 @@ class DefaultFormatBundle: Args: img_to_float (bool): Whether to force the image to be converted to float type. Default: True. + pad_val (dict): A dict for padding value in batch collating, + the default value is `dict(img=0, masks=0, seg=255)`. + Without this argument, the padding value of "gt_semantic_seg" + will be set to 0 by default, which should be 255. """ - def __init__(self, img_to_float=True): + def __init__(self, + img_to_float=True, + pad_val=dict(img=0, masks=0, seg=255)): self.img_to_float = img_to_float + self.pad_val = pad_val def __call__(self, results): """Call function to transform and format common fields in results. @@ -220,16 +227,22 @@ def __call__(self, results): if len(img.shape) < 3: img = np.expand_dims(img, -1) img = np.ascontiguousarray(img.transpose(2, 0, 1)) - results['img'] = DC(to_tensor(img), stack=True) + results['img'] = DC( + to_tensor(img), padding_value=self.pad_val['img'], stack=True) for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']: if key not in results: continue results[key] = DC(to_tensor(results[key])) if 'gt_masks' in results: - results['gt_masks'] = DC(results['gt_masks'], cpu_only=True) + results['gt_masks'] = DC( + results['gt_masks'], + padding_value=self.pad_val['masks'], + cpu_only=True) if 'gt_semantic_seg' in results: results['gt_semantic_seg'] = DC( - to_tensor(results['gt_semantic_seg'][None, ...]), stack=True) + to_tensor(results['gt_semantic_seg'][None, ...]), + padding_value=self.pad_val['seg'], + stack=True) return results def _add_default_meta_keys(self, results): diff --git a/mmdet/datasets/samplers/infinite_sampler.py b/mmdet/datasets/samplers/infinite_sampler.py index 6bc32a1b235..421c0de3369 100644 --- a/mmdet/datasets/samplers/infinite_sampler.py +++ b/mmdet/datasets/samplers/infinite_sampler.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import itertools import numpy as np diff --git a/mmdet/models/backbones/pvt.py b/mmdet/models/backbones/pvt.py index 1680dd69d3f..9443273a62e 100644 --- a/mmdet/models/backbones/pvt.py +++ b/mmdet/models/backbones/pvt.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import math import warnings diff --git a/mmdet/models/backbones/swin.py b/mmdet/models/backbones/swin.py index 777b34d09d0..c9f1455ae4b 100644 --- a/mmdet/models/backbones/swin.py +++ b/mmdet/models/backbones/swin.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import warnings from collections import OrderedDict from copy import deepcopy diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index e78444aca59..81d6ec2f74d 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -31,6 +31,7 @@ from .sabl_retina_head import SABLRetinaHead from .solo_head import DecoupledSOLOHead, DecoupledSOLOLightHead, SOLOHead from .ssd_head import SSDHead +from .tood_head import TOODHead from .vfnet_head import VFNetHead from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead from .yolo_head import YOLOV3Head @@ -48,5 +49,5 @@ 'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead', 'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead', 'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead', - 'DecoupledSOLOLightHead', 'LADHead' + 'DecoupledSOLOLightHead', 'LADHead', 'TOODHead' ] diff --git a/mmdet/models/dense_heads/lad_head.py b/mmdet/models/dense_heads/lad_head.py index ce518ad62fe..85273bcb243 100644 --- a/mmdet/models/dense_heads/lad_head.py +++ b/mmdet/models/dense_heads/lad_head.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import torch from mmcv.runner import force_fp32 diff --git a/mmdet/models/dense_heads/paa_head.py b/mmdet/models/dense_heads/paa_head.py index a2bc0af5229..fef2d5237fc 100644 --- a/mmdet/models/dense_heads/paa_head.py +++ b/mmdet/models/dense_heads/paa_head.py @@ -248,7 +248,7 @@ def get_pos_loss(self, anchors, cls_score, bbox_pred, label, label_weight, pos_bbox_pred, pos_bbox_target, pos_bbox_weight, - avg_factor=self.loss_cls.loss_weight, + avg_factor=self.loss_bbox.loss_weight, reduction_override='none') loss_cls = loss_cls.sum(-1) diff --git a/mmdet/models/dense_heads/tood_head.py b/mmdet/models/dense_heads/tood_head.py new file mode 100644 index 00000000000..90bc57e23ab --- /dev/null +++ b/mmdet/models/dense_heads/tood_head.py @@ -0,0 +1,768 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, Scale, bias_init_with_prob, normal_init +from mmcv.ops import deform_conv2d +from mmcv.runner import force_fp32 + +from mmdet.core import (anchor_inside_flags, build_assigner, distance2bbox, + images_to_levels, multi_apply, reduce_mean, unmap) +from mmdet.core.utils import filter_scores_and_topk +from ..builder import HEADS, build_loss +from .atss_head import ATSSHead + + +class TaskDecomposition(nn.Module): + """Task decomposition module in task-aligned predictor of TOOD. + + Args: + feat_channels (int): Number of feature channels in TOOD head. + stacked_convs (int): Number of conv layers in TOOD head. + la_down_rate (int): Downsample rate of layer attention. + conv_cfg (dict): Config dict for convolution layer. + norm_cfg (dict): Config dict for normalization layer. + """ + + def __init__(self, + feat_channels, + stacked_convs, + la_down_rate=8, + conv_cfg=None, + norm_cfg=None): + super(TaskDecomposition, self).__init__() + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.in_channels = self.feat_channels * self.stacked_convs + self.norm_cfg = norm_cfg + self.layer_attention = nn.Sequential( + nn.Conv2d(self.in_channels, self.in_channels // la_down_rate, 1), + nn.ReLU(inplace=True), + nn.Conv2d( + self.in_channels // la_down_rate, + self.stacked_convs, + 1, + padding=0), nn.Sigmoid()) + + self.reduction_conv = ConvModule( + self.in_channels, + self.feat_channels, + 1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=norm_cfg is None) + + def init_weights(self): + for m in self.layer_attention.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + normal_init(self.reduction_conv.conv, std=0.01) + + def forward(self, feat, avg_feat=None): + b, c, h, w = feat.shape + if avg_feat is None: + avg_feat = F.adaptive_avg_pool2d(feat, (1, 1)) + weight = self.layer_attention(avg_feat) + + # here we first compute the product between layer attention weight and + # conv weight, and then compute the convolution between new conv weight + # and feature map, in order to save memory and FLOPs. + conv_weight = weight.reshape( + b, 1, self.stacked_convs, + 1) * self.reduction_conv.conv.weight.reshape( + 1, self.feat_channels, self.stacked_convs, self.feat_channels) + conv_weight = conv_weight.reshape(b, self.feat_channels, + self.in_channels) + feat = feat.reshape(b, self.in_channels, h * w) + feat = torch.bmm(conv_weight, feat).reshape(b, self.feat_channels, h, + w) + if self.norm_cfg is not None: + feat = self.reduction_conv.norm(feat) + feat = self.reduction_conv.activate(feat) + + return feat + + +@HEADS.register_module() +class TOODHead(ATSSHead): + """TOODHead used in `TOOD: Task-aligned One-stage Object Detection. + + `_. + + TOOD uses Task-aligned head (T-head) and is optimized by Task Alignment + Learning (TAL). + + Args: + num_dcn (int): Number of deformable convolution in the head. + Default: 0. + anchor_type (str): If set to `anchor_free`, the head will use centers + to regress bboxes. If set to `anchor_based`, the head will + regress bboxes based on anchors. Default: `anchor_free`. + initial_loss_cls (dict): Config of initial loss. + + Example: + >>> self = TOODHead(11, 7) + >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]] + >>> cls_score, bbox_pred = self.forward(feats) + >>> assert len(cls_score) == len(self.scales) + """ + + def __init__(self, + num_classes, + in_channels, + num_dcn=0, + anchor_type='anchor_free', + initial_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + activated=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + **kwargs): + assert anchor_type in ['anchor_free', 'anchor_based'] + self.num_dcn = num_dcn + self.anchor_type = anchor_type + self.epoch = 0 # which would be update in SetEpochInfoHook! + super(TOODHead, self).__init__(num_classes, in_channels, **kwargs) + + if self.train_cfg: + self.initial_epoch = self.train_cfg.initial_epoch + self.initial_assigner = build_assigner( + self.train_cfg.initial_assigner) + self.initial_loss_cls = build_loss(initial_loss_cls) + self.assigner = self.initial_assigner + self.alignment_assigner = build_assigner(self.train_cfg.assigner) + self.alpha = self.train_cfg.alpha + self.beta = self.train_cfg.beta + + def _init_layers(self): + """Initialize layers of the head.""" + self.relu = nn.ReLU(inplace=True) + self.inter_convs = nn.ModuleList() + for i in range(self.stacked_convs): + if i < self.num_dcn: + conv_cfg = dict(type='DCNv2', deform_groups=4) + else: + conv_cfg = self.conv_cfg + chn = self.in_channels if i == 0 else self.feat_channels + self.inter_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg)) + + self.cls_decomp = TaskDecomposition(self.feat_channels, + self.stacked_convs, + self.stacked_convs * 8, + self.conv_cfg, self.norm_cfg) + self.reg_decomp = TaskDecomposition(self.feat_channels, + self.stacked_convs, + self.stacked_convs * 8, + self.conv_cfg, self.norm_cfg) + + self.tood_cls = nn.Conv2d( + self.feat_channels, + self.num_base_priors * self.cls_out_channels, + 3, + padding=1) + self.tood_reg = nn.Conv2d( + self.feat_channels, self.num_base_priors * 4, 3, padding=1) + + self.cls_prob_module = nn.Sequential( + nn.Conv2d(self.feat_channels * self.stacked_convs, + self.feat_channels // 4, 1), nn.ReLU(inplace=True), + nn.Conv2d(self.feat_channels // 4, 1, 3, padding=1)) + self.reg_offset_module = nn.Sequential( + nn.Conv2d(self.feat_channels * self.stacked_convs, + self.feat_channels // 4, 1), nn.ReLU(inplace=True), + nn.Conv2d(self.feat_channels // 4, 4 * 2, 3, padding=1)) + + self.scales = nn.ModuleList( + [Scale(1.0) for _ in self.prior_generator.strides]) + + def init_weights(self): + """Initialize weights of the head.""" + bias_cls = bias_init_with_prob(0.01) + for m in self.inter_convs: + normal_init(m.conv, std=0.01) + for m in self.cls_prob_module: + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.01) + for m in self.reg_offset_module: + if isinstance(m, nn.Conv2d): + normal_init(m, std=0.001) + normal_init(self.cls_prob_module[-1], std=0.01, bias=bias_cls) + + self.cls_decomp.init_weights() + self.reg_decomp.init_weights() + + normal_init(self.tood_cls, std=0.01, bias=bias_cls) + normal_init(self.tood_reg, std=0.01) + + def forward(self, feats): + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + cls_scores (list[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_anchors * num_classes. + bbox_preds (list[Tensor]): Decoded box for all scale levels, + each is a 4D-tensor, the channels number is + num_anchors * 4. In [tl_x, tl_y, br_x, br_y] format. + """ + cls_scores = [] + bbox_preds = [] + for idx, (x, scale, stride) in enumerate( + zip(feats, self.scales, self.prior_generator.strides)): + b, c, h, w = x.shape + anchor = self.prior_generator.single_level_grid_priors( + (h, w), idx, device=x.device) + anchor = torch.cat([anchor for _ in range(b)]) + # extract task interactive features + inter_feats = [] + for inter_conv in self.inter_convs: + x = inter_conv(x) + inter_feats.append(x) + feat = torch.cat(inter_feats, 1) + + # task decomposition + avg_feat = F.adaptive_avg_pool2d(feat, (1, 1)) + cls_feat = self.cls_decomp(feat, avg_feat) + reg_feat = self.reg_decomp(feat, avg_feat) + + # cls prediction and alignment + cls_logits = self.tood_cls(cls_feat) + cls_prob = self.cls_prob_module(feat) + cls_score = (cls_logits.sigmoid() * cls_prob.sigmoid()).sqrt() + + # reg prediction and alignment + if self.anchor_type == 'anchor_free': + reg_dist = scale(self.tood_reg(reg_feat).exp()).float() + reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4) + reg_bbox = distance2bbox( + self.anchor_center(anchor) / stride[0], + reg_dist).reshape(b, h, w, 4).permute(0, 3, 1, + 2) # (b, c, h, w) + elif self.anchor_type == 'anchor_based': + reg_dist = scale(self.tood_reg(reg_feat)).float() + reg_dist = reg_dist.permute(0, 2, 3, 1).reshape(-1, 4) + reg_bbox = self.bbox_coder.decode(anchor, reg_dist).reshape( + b, h, w, 4).permute(0, 3, 1, 2) / stride[0] + else: + raise NotImplementedError( + f'Unknown anchor type: {self.anchor_type}.' + f'Please use `anchor_free` or `anchor_based`.') + reg_offset = self.reg_offset_module(feat) + bbox_pred = self.deform_sampling(reg_bbox.contiguous(), + reg_offset.contiguous()) + cls_scores.append(cls_score) + bbox_preds.append(bbox_pred) + return tuple(cls_scores), tuple(bbox_preds) + + def deform_sampling(self, feat, offset): + """Sampling the feature x according to offset. + + Args: + feat (Tensor): Feature + offset (Tensor): Spatial offset for for feature sampliing + """ + # it is an equivalent implementation of bilinear interpolation + b, c, h, w = feat.shape + weight = feat.new_ones(c, 1, 1, 1) + y = deform_conv2d(feat, offset, weight, 1, 0, 1, c, c) + return y + + def anchor_center(self, anchors): + """Get anchor centers from anchors. + + Args: + anchors (Tensor): Anchor list with shape (N, 4), "xyxy" format. + + Returns: + Tensor: Anchor centers with shape (N, 2), "xy" format. + """ + anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 + anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 + return torch.stack([anchors_cx, anchors_cy], dim=-1) + + def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights, + bbox_targets, alignment_metrics, stride): + """Compute loss of a single scale level. + + Args: + anchors (Tensor): Box reference for each scale level with shape + (N, num_total_anchors, 4). + cls_score (Tensor): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W). + bbox_pred (Tensor): Decoded bboxes for each scale + level with shape (N, num_anchors * 4, H, W). + labels (Tensor): Labels of each anchors with shape + (N, num_total_anchors). + label_weights (Tensor): Label weights of each anchor with shape + (N, num_total_anchors). + bbox_targets (Tensor): BBox regression targets of each anchor with + shape (N, num_total_anchors, 4). + alignment_metrics (Tensor): Alignment metrics with shape + (N, num_total_anchors). + stride (tuple[int]): Downsample stride of the feature map. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert stride[0] == stride[1], 'h stride is not equal to w stride!' + anchors = anchors.reshape(-1, 4) + cls_score = cls_score.permute(0, 2, 3, 1).reshape( + -1, self.cls_out_channels).contiguous() + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) + bbox_targets = bbox_targets.reshape(-1, 4) + labels = labels.reshape(-1) + alignment_metrics = alignment_metrics.reshape(-1) + label_weights = label_weights.reshape(-1) + targets = labels if self.epoch < self.initial_epoch else ( + labels, alignment_metrics) + cls_loss_func = self.initial_loss_cls \ + if self.epoch < self.initial_epoch else self.loss_cls + + loss_cls = cls_loss_func( + cls_score, targets, label_weights, avg_factor=1.0) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + + if len(pos_inds) > 0: + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_anchors = anchors[pos_inds] + + pos_decode_bbox_pred = pos_bbox_pred + pos_decode_bbox_targets = pos_bbox_targets / stride[0] + + # regression loss + pos_bbox_weight = self.centerness_target( + pos_anchors, pos_bbox_targets + ) if self.epoch < self.initial_epoch else alignment_metrics[ + pos_inds] + + loss_bbox = self.loss_bbox( + pos_decode_bbox_pred, + pos_decode_bbox_targets, + weight=pos_bbox_weight, + avg_factor=1.0) + else: + loss_bbox = bbox_pred.sum() * 0 + pos_bbox_weight = bbox_targets.new_tensor(0.) + + return loss_cls, loss_bbox, alignment_metrics.sum( + ), pos_bbox_weight.sum() + + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) + def loss(self, + cls_scores, + bbox_preds, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Decoded box for each scale + level with shape (N, num_anchors * 4, H, W) in + [tl_x, tl_y, br_x, br_y] format. + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (list[Tensor] | None): specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_imgs = len(img_metas) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, img_metas, device=device) + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + + flatten_cls_scores = torch.cat([ + cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.cls_out_channels) + for cls_score in cls_scores + ], 1) + flatten_bbox_preds = torch.cat([ + bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) * stride[0] + for bbox_pred, stride in zip(bbox_preds, + self.prior_generator.strides) + ], 1) + + cls_reg_targets = self.get_targets( + flatten_cls_scores, + flatten_bbox_preds, + anchor_list, + valid_flag_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + label_channels=label_channels) + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + alignment_metrics_list) = cls_reg_targets + + losses_cls, losses_bbox,\ + cls_avg_factors, bbox_avg_factors = multi_apply( + self.loss_single, + anchor_list, + cls_scores, + bbox_preds, + labels_list, + label_weights_list, + bbox_targets_list, + alignment_metrics_list, + self.prior_generator.strides) + + cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item() + losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls)) + + bbox_avg_factor = reduce_mean( + sum(bbox_avg_factors)).clamp_(min=1).item() + losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) + return dict(loss_cls=losses_cls, loss_bbox=losses_bbox) + + def _get_bboxes_single(self, + cls_score_list, + bbox_pred_list, + score_factor_list, + mlvl_priors, + img_meta, + cfg, + rescale=False, + with_nms=True, + **kwargs): + """Transform outputs of a single image into bbox predictions. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. In all + anchor-based methods, it has shape (num_priors, 4). In + all anchor-free methods, it has shape (num_priors, 2) + when `with_stride=True`, otherwise it still has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (mmcv.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default: False. + with_nms (bool): If True, do nms before return boxes. + Default: True. + + Returns: + tuple[Tensor]: Results of detected bboxes and labels. If with_nms + is False and mlvl_score_factor is None, return mlvl_bboxes and + mlvl_scores, else return mlvl_bboxes, mlvl_scores and + mlvl_score_factor. Usually with_nms is False is used for aug + test. If with_nms is True, then return the following format + + - det_bboxes (Tensor): Predicted bboxes with shape \ + [num_bboxes, 5], where the first 4 columns are bounding \ + box positions (tl_x, tl_y, br_x, br_y) and the 5-th \ + column are scores between 0 and 1. + - det_labels (Tensor): Predicted labels of the corresponding \ + box with shape [num_bboxes]. + """ + + cfg = self.test_cfg if cfg is None else cfg + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_labels = [] + for cls_score, bbox_pred, priors, stride in zip( + cls_score_list, bbox_pred_list, mlvl_priors, + self.prior_generator.strides): + + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) * stride[0] + scores = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + + # After https://github.com/open-mmlab/mmdetection/pull/6268/, + # this operation keeps fewer bboxes under the same `nms_pre`. + # There is no difference in performance for most models. If you + # find a slight drop in performance, you can set a larger + # `nms_pre` than before. + results = filter_scores_and_topk( + scores, cfg.score_thr, nms_pre, + dict(bbox_pred=bbox_pred, priors=priors)) + scores, labels, keep_idxs, filtered_results = results + + bboxes = filtered_results['bbox_pred'] + + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + + return self._bbox_post_process(mlvl_scores, mlvl_labels, mlvl_bboxes, + img_meta['scale_factor'], cfg, rescale, + with_nms, None, **kwargs) + + def get_targets(self, + cls_scores, + bbox_preds, + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + unmap_outputs=True): + """Compute regression and classification targets for anchors in + multiple images. + + Args: + cls_scores (Tensor): Classification predictions of images, + a 3D-Tensor with shape [num_imgs, num_priors, num_classes]. + bbox_preds (Tensor): Decoded bboxes predictions of one image, + a 3D-Tensor with shape [num_imgs, num_priors, 4] in [tl_x, + tl_y, br_x, br_y] format. + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, 4). + valid_flag_list (list[list[Tensor]]): Multi level valid flags of + each image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, ) + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be + ignored. + gt_labels_list (list[Tensor]): Ground truth labels of each box. + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: a tuple containing learning targets. + + - anchors_list (list[list[Tensor]]): Anchors of each level. + - labels_list (list[Tensor]): Labels of each level. + - label_weights_list (list[Tensor]): Label weights of each + level. + - bbox_targets_list (list[Tensor]): BBox targets of each level. + - norm_alignment_metrics_list (list[Tensor]): Normalized + alignment metrics of each level. + """ + num_imgs = len(img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + num_level_anchors_list = [num_level_anchors] * num_imgs + + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = torch.cat(anchor_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + # anchor_list: list(b * [-1, 4]) + + if self.epoch < self.initial_epoch: + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply( + super()._get_target_single, + anchor_list, + valid_flag_list, + num_level_anchors_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + img_metas, + label_channels=label_channels, + unmap_outputs=unmap_outputs) + all_assign_metrics = [ + weight[..., 0] for weight in all_bbox_weights + ] + else: + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_assign_metrics) = multi_apply( + self._get_target_single, + cls_scores, + bbox_preds, + anchor_list, + valid_flag_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + img_metas, + label_channels=label_channels, + unmap_outputs=unmap_outputs) + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + + # split targets to a list w.r.t. multiple levels + anchors_list = images_to_levels(all_anchors, num_level_anchors) + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors) + norm_alignment_metrics_list = images_to_levels(all_assign_metrics, + num_level_anchors) + + return (anchors_list, labels_list, label_weights_list, + bbox_targets_list, norm_alignment_metrics_list) + + def _get_target_single(self, + cls_scores, + bbox_preds, + flat_anchors, + valid_flags, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + img_meta, + label_channels=1, + unmap_outputs=True): + """Compute regression, classification targets for anchors in a single + image. + + Args: + cls_scores (list(Tensor)): Box scores for each image. + bbox_preds (list(Tensor)): Box energies / deltas for each image. + flat_anchors (Tensor): Multi-level anchors of the image, which are + concatenated into a single tensor of shape (num_anchors ,4) + valid_flags (Tensor): Multi level valid flags of the image, + which are concatenated into a single tensor of + shape (num_anchors,). + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + img_meta (dict): Meta info of the image. + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: N is the number of total anchors in the image. + anchors (Tensor): All anchors in the image with shape (N, 4). + labels (Tensor): Labels of all anchors in the image with shape + (N,). + label_weights (Tensor): Label weights of all anchor in the + image with shape (N,). + bbox_targets (Tensor): BBox targets of all anchors in the + image with shape (N, 4). + norm_alignment_metrics (Tensor): Normalized alignment metrics + of all priors in the image with shape (N,). + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg.allowed_border) + if not inside_flags.any(): + return (None, ) * 7 + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + assign_result = self.alignment_assigner.assign( + cls_scores[inside_flags, :], bbox_preds[inside_flags, :], anchors, + gt_bboxes, gt_bboxes_ignore, gt_labels, self.alpha, self.beta) + assign_ious = assign_result.max_overlaps + assign_metrics = assign_result.assign_metrics + + sampling_result = self.sampler.sample(assign_result, anchors, + gt_bboxes) + + num_valid_anchors = anchors.shape[0] + bbox_targets = torch.zeros_like(anchors) + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + norm_alignment_metrics = anchors.new_zeros( + num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + # point-based + pos_bbox_targets = sampling_result.pos_gt_bboxes + bbox_targets[pos_inds, :] = pos_bbox_targets + + if gt_labels is None: + # Only rpn gives gt_labels as None + # Foreground is the first class since v2.5.0 + labels[pos_inds] = 0 + else: + labels[pos_inds] = gt_labels[ + sampling_result.pos_assigned_gt_inds] + if self.train_cfg.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg.pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + class_assigned_gt_inds = torch.unique( + sampling_result.pos_assigned_gt_inds) + for gt_inds in class_assigned_gt_inds: + gt_class_inds = pos_inds[sampling_result.pos_assigned_gt_inds == + gt_inds] + pos_alignment_metrics = assign_metrics[gt_class_inds] + pos_ious = assign_ious[gt_class_inds] + pos_norm_alignment_metrics = pos_alignment_metrics / ( + pos_alignment_metrics.max() + 10e-8) * pos_ious.max() + norm_alignment_metrics[gt_class_inds] = pos_norm_alignment_metrics + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + anchors = unmap(anchors, num_total_anchors, inside_flags) + labels = unmap( + labels, num_total_anchors, inside_flags, fill=self.num_classes) + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) + norm_alignment_metrics = unmap(norm_alignment_metrics, + num_total_anchors, inside_flags) + return (anchors, labels, label_weights, bbox_targets, + norm_alignment_metrics) diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index 637434a60b5..456b8d424fb 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -32,6 +32,7 @@ from .single_stage import SingleStageDetector from .solo import SOLO from .sparse_rcnn import SparseRCNN +from .tood import TOOD from .trident_faster_rcnn import TridentFasterRCNN from .two_stage import TwoStageDetector from .vfnet import VFNet @@ -48,5 +49,5 @@ 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', 'YOLOV3', 'YOLACT', 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', 'SOLO', 'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX', - 'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD' + 'TwoStagePanopticSegmentor', 'PanopticFPN', 'QueryInst', 'LAD', 'TOOD' ] diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index f1b450cd5f1..bf64bce63e8 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -63,7 +63,7 @@ def extract_feats(self, imgs): def forward_train(self, imgs, img_metas, **kwargs): """ Args: - img (list[Tensor]): List of tensors of shape (1, C, H, W). + img (Tensor): of shape (N, C, H, W) encoding input images. Typically these should be mean centered and std scaled. img_metas (list[dict]): List of image info dict where each dict has: 'img_shape', 'scale_factor', 'flip', and may also contain diff --git a/mmdet/models/detectors/lad.py b/mmdet/models/detectors/lad.py index 6d232197e8f..c6cc1e0b2d9 100644 --- a/mmdet/models/detectors/lad.py +++ b/mmdet/models/detectors/lad.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from mmcv.runner import load_checkpoint diff --git a/mmdet/models/detectors/queryinst.py b/mmdet/models/detectors/queryinst.py index 6618c2f7756..5fc216c4734 100644 --- a/mmdet/models/detectors/queryinst.py +++ b/mmdet/models/detectors/queryinst.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. from ..builder import DETECTORS from .sparse_rcnn import SparseRCNN diff --git a/mmdet/models/detectors/solo.py b/mmdet/models/detectors/solo.py index 9f45d314eed..df6f6de0162 100644 --- a/mmdet/models/detectors/solo.py +++ b/mmdet/models/detectors/solo.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. from ..builder import DETECTORS from .single_stage_instance_seg import SingleStageInstanceSegmentor diff --git a/mmdet/models/detectors/tood.py b/mmdet/models/detectors/tood.py new file mode 100644 index 00000000000..7dd18c3c96a --- /dev/null +++ b/mmdet/models/detectors/tood.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from ..builder import DETECTORS +from .single_stage import SingleStageDetector + + +@DETECTORS.register_module() +class TOOD(SingleStageDetector): + r"""Implementation of `TOOD: Task-aligned One-stage Object Detection. + `_.""" + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(TOOD, self).__init__(backbone, neck, bbox_head, train_cfg, + test_cfg, pretrained, init_cfg) + + def set_epoch(self, epoch): + self.bbox_head.epoch = epoch diff --git a/mmdet/models/losses/focal_loss.py b/mmdet/models/losses/focal_loss.py index 92909117ced..6c20fddd56f 100644 --- a/mmdet/models/losses/focal_loss.py +++ b/mmdet/models/losses/focal_loss.py @@ -57,6 +57,59 @@ def py_sigmoid_focal_loss(pred, return loss +def py_focal_loss_with_prob(pred, + target, + weight=None, + gamma=2.0, + alpha=0.25, + reduction='mean', + avg_factor=None): + """PyTorch version of `Focal Loss `_. + Different from `py_sigmoid_focal_loss`, this function accepts probability + as input. + + Args: + pred (torch.Tensor): The prediction probability with shape (N, C), + C is the number of classes. + target (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + num_classes = pred.size(1) + target = F.one_hot(target, num_classes=num_classes + 1) + target = target[:, :num_classes] + + target = target.type_as(pred) + pt = (1 - pred) * target + pred * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * pt.pow(gamma) + loss = F.binary_cross_entropy( + pred, target, reduction='none') * focal_weight + if weight is not None: + if weight.shape != loss.shape: + if weight.size(0) == loss.size(0): + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + assert weight.numel() == loss.numel() + weight = weight.view(loss.size(0), -1) + assert weight.ndim == loss.ndim + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + def sigmoid_focal_loss(pred, target, weight=None, @@ -111,7 +164,8 @@ def __init__(self, gamma=2.0, alpha=0.25, reduction='mean', - loss_weight=1.0): + loss_weight=1.0, + activated=False): """`Focal Loss `_ Args: @@ -125,6 +179,10 @@ def __init__(self, a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". loss_weight (float, optional): Weight of loss. Defaults to 1.0. + activated (bool, optional): Whether the input is activated. + If True, it means the input has been activated and can be + treated as probabilities. Else, it should be treated as logits. + Defaults to False. """ super(FocalLoss, self).__init__() assert use_sigmoid is True, 'Only sigmoid focal loss supported now.' @@ -133,6 +191,7 @@ def __init__(self, self.alpha = alpha self.reduction = reduction self.loss_weight = loss_weight + self.activated = activated def forward(self, pred, @@ -160,13 +219,16 @@ def forward(self, reduction = ( reduction_override if reduction_override else self.reduction) if self.use_sigmoid: - if torch.cuda.is_available() and pred.is_cuda: - calculate_loss_func = sigmoid_focal_loss + if self.activated: + calculate_loss_func = py_focal_loss_with_prob else: - num_classes = pred.size(1) - target = F.one_hot(target, num_classes=num_classes + 1) - target = target[:, :num_classes] - calculate_loss_func = py_sigmoid_focal_loss + if torch.cuda.is_available() and pred.is_cuda: + calculate_loss_func = sigmoid_focal_loss + else: + num_classes = pred.size(1) + target = F.one_hot(target, num_classes=num_classes + 1) + target = target[:, :num_classes] + calculate_loss_func = py_sigmoid_focal_loss loss_cls = self.loss_weight * calculate_loss_func( pred, diff --git a/mmdet/models/losses/gfocal_loss.py b/mmdet/models/losses/gfocal_loss.py index a7a1b765f4f..0e8d26373f8 100644 --- a/mmdet/models/losses/gfocal_loss.py +++ b/mmdet/models/losses/gfocal_loss.py @@ -52,6 +52,52 @@ def quality_focal_loss(pred, target, beta=2.0): return loss +@weighted_loss +def quality_focal_loss_with_prob(pred, target, beta=2.0): + r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning + Qualified and Distributed Bounding Boxes for Dense Object Detection + `_. + Different from `quality_focal_loss`, this function accepts probability + as input. + + Args: + pred (torch.Tensor): Predicted joint representation of classification + and quality (IoU) estimation with shape (N, C), C is the number of + classes. + target (tuple([torch.Tensor])): Target category label with shape (N,) + and target quality label with shape (N,). + beta (float): The beta parameter for calculating the modulating factor. + Defaults to 2.0. + + Returns: + torch.Tensor: Loss tensor with shape (N,). + """ + assert len(target) == 2, """target for QFL must be a tuple of two elements, + including category label and quality label, respectively""" + # label denotes the category id, score denotes the quality score + label, score = target + + # negatives are supervised by 0 quality score + pred_sigmoid = pred + scale_factor = pred_sigmoid + zerolabel = scale_factor.new_zeros(pred.shape) + loss = F.binary_cross_entropy( + pred, zerolabel, reduction='none') * scale_factor.pow(beta) + + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + bg_class_ind = pred.size(1) + pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1) + pos_label = label[pos].long() + # positives are supervised by bbox quality (IoU) score + scale_factor = score[pos] - pred_sigmoid[pos, pos_label] + loss[pos, pos_label] = F.binary_cross_entropy( + pred[pos, pos_label], score[pos], + reduction='none') * scale_factor.abs().pow(beta) + + loss = loss.sum(dim=1, keepdim=False) + return loss + + @mmcv.jit(derivate=True, coderize=True) @weighted_loss def distribution_focal_loss(pred, label): @@ -91,19 +137,25 @@ class QualityFocalLoss(nn.Module): Defaults to 2.0. reduction (str): Options are "none", "mean" and "sum". loss_weight (float): Loss weight of current loss. + activated (bool, optional): Whether the input is activated. + If True, it means the input has been activated and can be + treated as probabilities. Else, it should be treated as logits. + Defaults to False. """ def __init__(self, use_sigmoid=True, beta=2.0, reduction='mean', - loss_weight=1.0): + loss_weight=1.0, + activated=False): super(QualityFocalLoss, self).__init__() assert use_sigmoid is True, 'Only sigmoid in QFL supported now.' self.use_sigmoid = use_sigmoid self.beta = beta self.reduction = reduction self.loss_weight = loss_weight + self.activated = activated def forward(self, pred, @@ -131,7 +183,11 @@ def forward(self, reduction = ( reduction_override if reduction_override else self.reduction) if self.use_sigmoid: - loss_cls = self.loss_weight * quality_focal_loss( + if self.activated: + calculate_loss_func = quality_focal_loss_with_prob + else: + calculate_loss_func = quality_focal_loss + loss_cls = self.loss_weight * calculate_loss_func( pred, target, weight, diff --git a/mmdet/models/roi_heads/mask_heads/dynamic_mask_head.py b/mmdet/models/roi_heads/mask_heads/dynamic_mask_head.py index c670f35d263..5bbe7eea49c 100644 --- a/mmdet/models/roi_heads/mask_heads/dynamic_mask_head.py +++ b/mmdet/models/roi_heads/mask_heads/dynamic_mask_head.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from mmcv.runner import auto_fp16, force_fp32 diff --git a/mmdet/models/utils/brick_wrappers.py b/mmdet/models/utils/brick_wrappers.py index b95a099633c..fa0279ab60d 100644 --- a/mmdet/models/utils/brick_wrappers.py +++ b/mmdet/models/utils/brick_wrappers.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn import torch.nn.functional as F diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py index 3f15580521a..64f3173d53f 100644 --- a/mmdet/utils/__init__.py +++ b/mmdet/utils/__init__.py @@ -1,5 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .collect_env import collect_env from .logger import get_root_logger +from .misc import find_latest_checkpoint -__all__ = ['get_root_logger', 'collect_env'] +__all__ = [ + 'get_root_logger', + 'collect_env', + 'find_latest_checkpoint', +] diff --git a/mmdet/utils/misc.py b/mmdet/utils/misc.py new file mode 100644 index 00000000000..f5c425300e4 --- /dev/null +++ b/mmdet/utils/misc.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import glob +import os.path as osp +import warnings + + +def find_latest_checkpoint(path, suffix='pth'): + """Find the latest checkpoint from the working directory. + + Args: + path(str): The path to find checkpoints. + suffix(str): File extension. + Defaults to pth. + + Returns: + latest_path(str | None): File path of the latest checkpoint. + References: + .. [1] https://github.com/microsoft/SoftTeacher + /blob/main/ssod/utils/patch.py + """ + if not osp.exists(path): + warnings.warn('The path of checkpoints does not exist.') + return None + if osp.exists(osp.join(path, f'latest.{suffix}')): + return osp.join(path, f'latest.{suffix}') + + checkpoints = glob.glob(osp.join(path, f'*.{suffix}')) + if len(checkpoints) == 0: + warnings.warn('There are no checkpoints in the path.') + return None + latest = -1 + latest_path = None + for checkpoint in checkpoints: + count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0]) + if count > latest: + latest = count + latest_path = checkpoint + return latest_path diff --git a/mmdet/version.py b/mmdet/version.py index 4643a0c468f..e4da1588b1c 100644 --- a/mmdet/version.py +++ b/mmdet/version.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -__version__ = '2.19.1' +__version__ = '2.20.0' short_version = __version__ diff --git a/model-index.yml b/model-index.yml index 23131003e91..900b9385242 100644 --- a/model-index.yml +++ b/model-index.yml @@ -1,6 +1,7 @@ Import: - configs/atss/metafile.yml - configs/autoassign/metafile.yml + - configs/carafe/metafile.yml - configs/cascade_rcnn/metafile.yml - configs/centernet/metafile.yml - configs/centripetalnet/metafile.yml @@ -15,7 +16,6 @@ Import: - configs/faster_rcnn/metafile.yml - configs/fcos/metafile.yml - configs/foveabox/metafile.yml - - configs/fp16/metafile.yml - configs/fpg/metafile.yml - configs/free_anchor/metafile.yml - configs/fsaf/metafile.yml @@ -38,9 +38,11 @@ Import: - configs/nas_fpn/metafile.yml - configs/paa/metafile.yml - configs/pafpn/metafile.yml + - configs/panoptic_fpn/metafile.yml - configs/pvt/metafile.yml - configs/pisa/metafile.yml - configs/point_rend/metafile.yml + - configs/queryinst/metafile.yml - configs/regnet/metafile.yml - configs/reppoints/metafile.yml - configs/res2net/metafile.yml @@ -49,14 +51,15 @@ Import: - configs/sabl/metafile.yml - configs/scnet/metafile.yml - configs/scratch/metafile.yml + - configs/seesaw_loss/metafile.yml - configs/sparse_rcnn/metafile.yml - configs/solo/metafile.yml - configs/ssd/metafile.yml + - configs/swin/metafile.yml - configs/tridentnet/metafile.yml + - configs/tood/metafile.yml - configs/vfnet/metafile.yml - configs/yolact/metafile.yml - configs/yolo/metafile.yml - configs/yolof/metafile.yml - configs/yolox/metafile.yml - - configs/swin/metafile.yml - - configs/strong_baselines/metafile.yml diff --git a/setup.cfg b/setup.cfg index b0171d35b2f..18adf687165 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,4 +15,4 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true [codespell] skip = *.ipynb quiet-level = 3 -ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids +ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood diff --git a/tests/test_models/test_dense_heads/test_tood_head.py b/tests/test_models/test_dense_heads/test_tood_head.py new file mode 100644 index 00000000000..2174cc0e4db --- /dev/null +++ b/tests/test_models/test_dense_heads/test_tood_head.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import torch + +from mmdet.models.dense_heads import TOODHead + + +def test_paa_head_loss(): + """Tests paa head loss when truth is empty and non-empty.""" + + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'scale_factor': 1, + 'pad_shape': (s, s, 3) + }] + train_cfg = mmcv.Config( + dict( + initial_epoch=4, + initial_assigner=dict(type='ATSSAssigner', topk=9), + assigner=dict(type='TaskAlignedAssigner', topk=13), + alpha=1, + beta=6, + allowed_border=-1, + pos_weight=-1, + debug=False)) + test_cfg = mmcv.Config( + dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + # since Focal Loss is not supported on CPU + self = TOODHead( + num_classes=80, + in_channels=1, + stacked_convs=6, + feat_channels=256, + anchor_type='anchor_free', + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + initial_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + activated=True, # use probability instead of logit as input + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + train_cfg=train_cfg, + test_cfg=test_cfg) + self.init_weights() + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [8, 16, 32, 64, 128] + ] + cls_scores, bbox_preds = self(feat) + + # test initial assigner and losses + self.epoch = 0 + # Test that empty ground truth encourages the network to predict background + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + gt_bboxes_ignore = None + empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_cls_loss = empty_gt_losses['loss_cls'] + empty_box_loss = empty_gt_losses['loss_bbox'] + assert sum(empty_cls_loss).item() > 0, 'cls loss should be non-zero' + assert sum(empty_box_loss).item() == 0, ( + 'there should be no box loss when there are no true boxes') + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + onegt_cls_loss = one_gt_losses['loss_cls'] + onegt_box_loss = one_gt_losses['loss_bbox'] + assert sum(onegt_cls_loss).item() > 0, 'cls loss should be non-zero' + assert sum(onegt_box_loss).item() > 0, 'box loss should be non-zero' + + # test task alignment assigner and losses + self.epoch = 10 + # Test that empty ground truth encourages the network to predict background + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + gt_bboxes_ignore = None + empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_cls_loss = empty_gt_losses['loss_cls'] + empty_box_loss = empty_gt_losses['loss_bbox'] + assert sum(empty_cls_loss).item() > 0, 'cls loss should be non-zero' + assert sum(empty_box_loss).item() == 0, ( + 'there should be no box loss when there are no true boxes') + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + onegt_cls_loss = one_gt_losses['loss_cls'] + onegt_box_loss = one_gt_losses['loss_bbox'] + assert sum(onegt_cls_loss).item() > 0, 'cls loss should be non-zero' + assert sum(onegt_box_loss).item() > 0, 'box loss should be non-zero' diff --git a/tests/test_models/test_dense_heads/test_yolox_head.py b/tests/test_models/test_dense_heads/test_yolox_head.py index cb63527e97a..f82c8a0b860 100644 --- a/tests/test_models/test_dense_heads/test_yolox_head.py +++ b/tests/test_models/test_dense_heads/test_yolox_head.py @@ -70,3 +70,19 @@ def test_yolox_head_loss(): assert onegt_box_loss.item() > 0, 'box loss should be non-zero' assert onegt_obj_loss.item() > 0, 'obj loss should be non-zero' assert onegt_l1_loss.item() > 0, 'l1 loss should be non-zero' + + # Test groud truth out of bound + gt_bboxes = [torch.Tensor([[s * 4, s * 4, s * 4 + 10, s * 4 + 10]])] + gt_labels = [torch.LongTensor([2])] + empty_gt_losses = self.loss(cls_scores, bbox_preds, objectnesses, + gt_bboxes, gt_labels, img_metas) + # When gt_bboxes out of bound, the assign results should be empty, + # so the cls and bbox loss should be zero. + empty_cls_loss = empty_gt_losses['loss_cls'].sum() + empty_box_loss = empty_gt_losses['loss_bbox'].sum() + empty_obj_loss = empty_gt_losses['loss_obj'].sum() + assert empty_cls_loss.item() == 0, ( + 'there should be no cls loss when gt_bboxes out of bound') + assert empty_box_loss.item() == 0, ( + 'there should be no box loss when gt_bboxes out of bound') + assert empty_obj_loss.item() > 0, 'objectness loss should be non-zero' diff --git a/tests/test_models/test_forward.py b/tests/test_models/test_forward.py index c0611be9996..9e51c17176a 100644 --- a/tests/test_models/test_forward.py +++ b/tests/test_models/test_forward.py @@ -698,10 +698,4 @@ def test_yolox_random_size(): gt_bboxes=gt_bboxes, gt_labels=gt_labels, return_loss=True) - detector.forward( - imgs, - img_metas, - gt_bboxes=gt_bboxes, - gt_labels=gt_labels, - return_loss=True) - assert detector._input_size == (64, 64) + assert detector._input_size == (64, 96) diff --git a/tests/test_utils/test_assigner.py b/tests/test_utils/test_assigner.py index 3a42d35dcc8..ca82aeda127 100644 --- a/tests/test_utils/test_assigner.py +++ b/tests/test_utils/test_assigner.py @@ -5,12 +5,13 @@ pytest tests/test_utils/test_assigner.py xdoctest tests/test_utils/test_assigner.py zero """ +import pytest import torch from mmdet.core.bbox.assigners import (ApproxMaxIoUAssigner, CenterRegionAssigner, HungarianAssigner, MaxIoUAssigner, PointAssigner, - UniformAssigner) + TaskAlignedAssigner, UniformAssigner) def test_max_iou_assigner(): @@ -496,3 +497,45 @@ def test_uniform_assigner_with_empty_boxes(): # Test without gt_labels assign_result = self.assign(pred_bbox, anchor, gt_bboxes, gt_labels=None) assert len(assign_result.gt_inds) == 0 + + +def test_task_aligned_assigner(): + with pytest.raises(AssertionError): + TaskAlignedAssigner(topk=0) + + self = TaskAlignedAssigner(topk=13) + pred_score = torch.FloatTensor([[0.1, 0.2], [0.2, 0.3], [0.3, 0.4], + [0.4, 0.5]]) + pred_bbox = torch.FloatTensor([ + [1, 1, 12, 8], + [4, 4, 20, 20], + [1, 5, 15, 15], + [30, 5, 32, 42], + ]) + anchor = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]) + gt_labels = torch.LongTensor([0, 1]) + assign_result = self.assign( + pred_score, + pred_bbox, + anchor, + gt_bboxes=gt_bboxes, + gt_labels=gt_labels) + assert len(assign_result.gt_inds) == 4 + assert len(assign_result.labels) == 4 + + # test empty gt + gt_bboxes = torch.empty(0, 4) + gt_labels = torch.empty(0, 2) + assign_result = self.assign( + pred_score, pred_bbox, anchor, gt_bboxes=gt_bboxes) + expected_gt_inds = torch.LongTensor([0, 0, 0, 0]) + assert torch.all(assign_result.gt_inds == expected_gt_inds) diff --git a/tests/test_utils/test_hook.py b/tests/test_utils/test_hook.py index afd176788f0..43fab670a05 100644 --- a/tests/test_utils/test_hook.py +++ b/tests/test_utils/test_hook.py @@ -323,3 +323,32 @@ def train_step(self, x, optimizer, **kwargs): with pytest.raises(AssertionError): runner.run([loader], [('train', 1)]) shutil.rmtree(runner.work_dir) + + +def test_set_epoch_info_hook(): + """Test SetEpochInfoHook.""" + + class DemoModel(nn.Module): + + def __init__(self): + super().__init__() + self.epoch = 0 + self.linear = nn.Linear(2, 1) + + def forward(self, x): + return self.linear(x) + + def train_step(self, x, optimizer, **kwargs): + return dict(loss=self(x)) + + def set_epoch(self, epoch): + self.epoch = epoch + + loader = DataLoader(torch.ones((5, 2))) + runner = _build_demo_runner(max_epochs=3) + + demo_model = DemoModel() + runner.model = demo_model + runner.register_hook_from_cfg(dict(type='SetEpochInfoHook')) + runner.run([loader], [('train', 1)]) + assert demo_model.epoch == 2 diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index 5c1d2b1ed4a..de22ad6cb32 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import tempfile + import numpy as np import pytest import torch @@ -7,6 +9,7 @@ from mmdet.core.mask.structures import BitmapMasks, PolygonMasks from mmdet.core.utils import (center_of_mass, filter_scores_and_topk, flip_tensor, mask2ndarray, select_single_mlvl) +from mmdet.utils import find_latest_checkpoint def dummy_raw_polygon_masks(size): @@ -160,3 +163,41 @@ def test_filter_scores_and_topk(): assert keep_idxs.allclose(torch.tensor([1, 2, 1, 3])) assert results['bbox_pred'].allclose( torch.tensor([[0.4, 0.7], [0.1, 0.1], [0.4, 0.7], [0.5, 0.1]])) + + +def test_find_latest_checkpoint(): + with tempfile.TemporaryDirectory() as tmpdir: + path = tmpdir + latest = find_latest_checkpoint(path) + # There are no checkpoints in the path. + assert latest is None + + path = tmpdir + '/none' + latest = find_latest_checkpoint(path) + # The path does not exist. + assert latest is None + + with tempfile.TemporaryDirectory() as tmpdir: + with open(tmpdir + '/latest.pth', 'w') as f: + f.write('latest') + path = tmpdir + latest = find_latest_checkpoint(path) + assert latest == tmpdir + '/latest.pth' + + with tempfile.TemporaryDirectory() as tmpdir: + with open(tmpdir + '/iter_4000.pth', 'w') as f: + f.write('iter_4000') + with open(tmpdir + '/iter_8000.pth', 'w') as f: + f.write('iter_8000') + path = tmpdir + latest = find_latest_checkpoint(path) + assert latest == tmpdir + '/iter_8000.pth' + + with tempfile.TemporaryDirectory() as tmpdir: + with open(tmpdir + '/epoch_1.pth', 'w') as f: + f.write('epoch_1') + with open(tmpdir + '/epoch_2.pth', 'w') as f: + f.write('epoch_2') + path = tmpdir + latest = find_latest_checkpoint(path) + assert latest == tmpdir + '/epoch_2.pth' diff --git a/tools/test.py b/tools/test.py index 2287aeb3060..da9821de117 100644 --- a/tools/test.py +++ b/tools/test.py @@ -33,6 +33,12 @@ def parse_args(): action='store_true', help='Whether to fuse conv and bn, this will slightly increase' 'the inference speed') + parser.add_argument( + '--gpu-ids', + type=int, + nargs='+', + help='ids of gpus to use ' + '(only applicable to non-distributed testing)') parser.add_argument( '--format-only', action='store_true', @@ -155,9 +161,20 @@ def main(): for ds_cfg in cfg.data.test: ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) + if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids + else: + cfg.gpu_ids = range(1) + # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False + if len(cfg.gpu_ids) > 1: + warnings.warn( + f'We treat {cfg.gpu_ids} as gpu-ids, and reset to ' + f'{cfg.gpu_ids[0:1]} as gpu-ids to avoid potential error in ' + 'non-distribute testing time.') + cfg.gpu_ids = cfg.gpu_ids[0:1] else: distributed = True init_dist(args.launcher, **cfg.dist_params) @@ -195,7 +212,7 @@ def main(): model.CLASSES = dataset.CLASSES if not distributed: - model = MMDataParallel(model, device_ids=[0]) + model = MMDataParallel(model, device_ids=cfg.gpu_ids) outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, args.show_score_thr) else: diff --git a/tools/train.py b/tools/train.py index 95db9fbbc01..8be81775b2f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -25,6 +25,10 @@ def parse_args(): parser.add_argument('--work-dir', help='the dir to save logs and models') parser.add_argument( '--resume-from', help='the checkpoint file to resume from') + parser.add_argument( + '--auto-resume', + action='store_true', + help='resume from the latest checkpoint automatically') parser.add_argument( '--no-validate', action='store_true', @@ -104,6 +108,7 @@ def main(): osp.splitext(osp.basename(args.config))[0]) if args.resume_from is not None: cfg.resume_from = args.resume_from + cfg.auto_resume = args.auto_resume if args.gpu_ids is not None: cfg.gpu_ids = args.gpu_ids else: @@ -112,6 +117,12 @@ def main(): # init distributed env first, since logger depends on the dist info. if args.launcher == 'none': distributed = False + if len(cfg.gpu_ids) > 1: + warnings.warn( + f'We treat {cfg.gpu_ids} as gpu-ids, and reset to ' + f'{cfg.gpu_ids[0:1]} as gpu-ids to avoid potential error in ' + 'non-distribute training time.') + cfg.gpu_ids = cfg.gpu_ids[0:1] else: distributed = True init_dist(args.launcher, **cfg.dist_params)