This repository is an official implementation of "NOAH: Learning Pairwise Object Category Attentions for Image Classification".
The macro-structure of DNNs with a Non-glObal Attentive Head (NOAH). Unlike popular heads based on the global feature encoding, NOAH relies on Pairwise Object Category Attentions (POCAs) learnt at local to global scales via a neat association of feature split (two levels), interaction and aggregation operations, taking the feature maps from the last layer of a CNN, ViT or MLP backbone as the input.Following this repository,
- Download the ImageNet dataset from http://www.image-net.org/.
- Then, move and extract the training and validation images to labeled subfolders, using the following script.
- Python >= 3.7.0
- PyTorch >= 1.8.1
- torchvision >= 0.9.1
In the experiments, we construct our networks by replacing the existing head of each selected DNN architecture by a NOAH.
Please refer to README.md in the folders of vit and mlp on how to train/evaluate ViT and MLP backbones with NOAH.
Here, we show the results and models for CNN backbones with NOAH trained on ImageNet.
Backbones | Params | Top-1 Acc(%) | Google Drive | ||
---|---|---|---|---|---|
ResNet18 | 11.69M | - | - | 70.25 | model |
+ NOAH | 11.70M | 4 | 1/2 | 71.81 | model |
ResNet50 | 25.56M | - | - | 76.23 | model |
+ NOAH | 25.56M | 4 | 1/8 | 77.25 | model |
ResNet101 | 44.55M | - | - | 77.41 | model |
+ NOAH | 44.56M | 4 | 1/8 | 78.22 | model |
ResNet152 | 60.19M | - | - | 78.16 | model |
+ NOAH | 60.20M | 4 | 1/8 | 78.57 | model |
Backbones | Params | Top-1 Acc(%) | Google Drive | ||
---|---|---|---|---|---|
MobileNetV2 (1.0×) | 3.50M | - | - | 72.02 | model |
+ NOAH | 3.52M | 8 | 1/4 | 73.35 | model |
MobileNetV2 (0.75×) | 2.64M | - | - | 69.65 | model |
+ NOAH | 2.65M | 8 | 1/4 | 71.44 | model |
MobileNetV2 (0.5×) | 1.97M | - | - | 64.30 | model |
+ NOAH | 1.98M | 8 | 1/4 | 67.44 | model |
MobileNetV2 (0.35×) | 1.68M | - | - | 59.62 | model |
+ NOAH | 1.69M | 8 | 1/4 | 63.40 | model |
MobileNetV3-Small | 2.94M | - | - | 67.11 | model |
+ NOAH | 2.95M | 8 | 1/4 | 68.92 | model |
ShuffleNetV2 (1.0×) | 2.28M | - | - | 69.43 | model |
+ NOAH | 2.29M | 8 | 1/4 | 70.72 | model |
To train ResNet18 with NOAH
python -m torch.distributed.launch --nproc_per_node=8 main.py --arch resnet18 --epochs 100 --lr 0.1 --wd 1e-4 \
--lr-decay schedule --schedule 30 60 90 --use_noah --head_num 4 --key_ratio 0.5 --dropout 0.1 \
--data ./datasets/ILSVRC2012 --checkpoint ./checkpoints/noah_resnet18
To train ResNet18 with standard head (GAP + FC):
python -m torch.distributed.launch --nproc_per_node=8 main.py --arch resnet18 --epochs 100 --lr 0.1 --wd 1e-4 \
--lr-decay schedule --schedule 30 60 90 --dropout 0.1 --data ./datasets/ILSVRC2012 --checkpoint ./checkpoints/resnet18
To train MobileNetV2 (1.0×) with NOAH
python -m torch.distributed.launch --nproc_per_node=8 main.py --arch mobilenetv2_100 --epochs 150 --lr 0.05 --wd 4e-5 \
--lr-decay cos --use_noah --head_num 8 --key_ratio 0.25 --dropout 0.2 \
--data ./datasets/ILSVRC2012 --checkpoint ./checkpoints/noah_mobilenetv2_100
To train MobileNetV3-Small with NOAH
python -m torch.distributed.launch --nproc_per_node=8 main.py --arch mobilenetv3_small --epochs 150 --lr 0.05 --wd 4e-5 \
--lr-decay cos --use_noah --head_num 8 --key_ratio 0.25 --dropout 0.2 --nowd-bn \
--data ./datasets/ILSVRC2012 --checkpoint ./checkpoints/noah_mobilenetv3_small
To train ShuffleNetV2 (1.0×) with NOAH
python -m torch.distributed.launch --nproc_per_node=8 main.py --arch shufflenetv2_100 --epochs 240 --lr 0.5 --wd 4e-5 \
--train-batch 1024 --lr-decay linear --use_noah --head_num 8 --key_ratio 0.25 --dropout 0 --nowd-bn \
--data ./datasets/ILSVRC2012 --checkpoint ./checkpoints/noah_shufflenetv2_100
To evaluate a pre-trained model:
python -m torch.distributed.launch --nproc_per_node={ngpus} main.py \
--arch {model name} --data {path to dataset} --use_noah --head_num {number of heads} \
--key_ratio {key ratio of POCA} --evaluate --resume {path to model}
NOAH is released under the Apache license. We encourage use for both research and commercial purposes, as long as proper attribution is given.
This repository is built based on pytorch-image-models, deit, pvt repositories. We thank the authors for releasing their amazing codes.