This repository is the official implementation of RegionViT: Regional-to-Local Attention for Vision Transformers. ArXiv
We provided the codes for Image Classification and Object Detection.
If you use the codes and models from this repo, please cite our work. Thanks!
@inproceedings{
chen2021regionvit,
title={{RegionViT: Regional-to-Local Attention for Vision Transformers}},
author={Chun-Fu (Richard) Chen and Rameswar Panda and Quanfu Fan},
booktitle={ArXiv},
year={2021}
}
To install requirements:
pip install -r requirements.txt
Download and extract ImageNet train and val images from http://image-net.org/.
The directory structure is the standard layout for the torchvision datasets.ImageFolder
, and the training and validation data is expected to be in the train/
folder and val
folder respectively:
/path/to/imagenet/
train/
class1/
img1.jpeg
class2/
img2.jpeg
val/
class1/
img3.jpeg
class/2
img4.jpeg
We provide models trained on ImageNet1K. Models can be found here.
Name | Acc@1 | #FLOPs | #Params | URL |
---|---|---|---|---|
RegionViT-Ti | 80.4 | 2.4 | 13.8M | model |
RegionViT-S | 82.6 | 5.3 | 30.6M | model |
RegionViT-M | 83.1 | 7.4 | 41.2M | model |
RegionViT-B | 83.2 | 13.0 | 72.7M | model |
To train RegionViT-S on ImageNet on a single node with 8 gpus for 300 epochs run:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model regionvit_small_224 --batch-size 256 --data-path /path/to/imagenet
Model names of other models are regionvit_tiny_224
, regionvit_medium_224
and regionvit_base_224
.
Distributed training is available via Slurm and submitit
:
To train RegionViT-S model on ImageNet on 4 nodes with 8 gpus each for 300 epochs:
python run_with_submitit.py --model regionvit_small_224 --data-path /path/to/imagenet --batch-size 256 --warmup-epochs 50
Note that: some slurm configurations might need to be changed based on your cluster.
To evaluate a pretrained model on RegionViT-S:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model regionvit_small_224 --batch-size 256 --data-path /path/to/imagenet --eval --initial_checkpoint /path/to/checkpoint
We performed the object detection based on Detectron2 with some modifications.
The modified version can be found at https://github.com/chunfuchen/detectron2. The major difference is the data augmentation pipepile.
Follows the installation guide on Install.md.
Follows Detectron2 to setup MS COCO dataset. Link
Before training, you will need to convert the pretrained model into Detectron2 format. We provide the script tools/convert_cls_model_to_d2.py
for the conversion.
python3 tools/convert_cls_model_to_d2.py --model /path/to/pretrained/model --ows 7 --nws 7
Then, to train RetinaNet RegionViT-S on MS COCO with 1x schedule:
python main_detection.py --num-gpus 8 --resume --config-file detection/configs/retinanet_regionvit_FPN_1x.yaml MODEL.BACKBONE.REGIONVIT regionvit_small_224 MODEL.WEIGHTS /path/to/pretrained_model OUTPUT_DIR /path/to/log_folder
Model names of other models are regionvit_base_224
, regionvit_small_w14_224
, etc. Supported models can be found here
We provide models trained on MS COCO with MaskRCNN and RetinaNet. Models can be found here.
Name | #Params (M) | #FLOPs (G) | box mAP (1x) | mask mAP (1x) | box mAP (3x) | mask mAP (3x) | url |
---|---|---|---|---|---|---|---|
RegionViT-S | 50.1 | 171.3 | 42.5 | 39.5 | 46.3 | 42.3 | 1x model 3x model |
RegionViT-S+ | 50.9 | 182.9 | 43.5 | 40.4 | 47.3 | 43.4 | 1x model 3x model |
RegionViT-S+ (w/ PEG) | 50.9 | 183.0 | 44.2 | 40.8 | 47.6 | 43.4 | 1x model 3x model |
RegionViT-B | 92.2 | 287.9 | 43.5 | 40.1 | 47.2 | 43.0 | 1x model 3x model |
RegionViT-B+ | 93.2 | 307.1 | 44.5 | 41.0 | 48.1 | 43.5 | 1x model 3x model |
RegionViT-B+ (w/ PEG) | 93.2 | 307.2 | 45.4 | 41.6 | 48.3 | 43.5 | 1x model 3x model |
RegionViT-B+ (w/ PEG) dagger | 93.2 | 464.4 | 46.3 | 42.4 | 49.2 | 44.5 | 1x model 3x model |
Name | #Params (M) | #FLOPs (G) | box mAP (1x) | box mAP (3x) | url |
---|---|---|---|---|---|
RegionViT-S | 40.8 | 192.6 | 42.2 | 45.8 | 1x model 3x model |
RegionViT-S+ | 41.5 | 204.2 | 43.1 | 46.9 | 1x model 3x model |
RegionViT-S+ (w/ PEG) | 41.6 | 204.3 | 43.9 | 46.7 | 1x model 3x model |
RegionViT-B | 83.4 | 308.9 | 43.3 | 46.1 | 1x model 3x model |
RegionViT-B+ | 84.4 | 328.1 | 44.2 | 46.9 | 1x model 3x model |
RegionViT-B+ (w/ PEG) | 84.5 | 328.2 | 44.6 | 46.9 | 1x model 3x model |
RegionViT-B+ (w/ PEG) dagger | 84.5 | 506.4 | 46.1 | 48.2 | 1x model 3x model |