-
First, you can install the required environments as illustrated in the DeiT repository or follow the instructions below:
# Create virtual env conda create -n spvit-deit python=3.7 -y conda activate spvit-deit # Install PyTorch 1.7.0+ and torchvision 0.8.1+ and pytorch-image-models 0.3.2: conda install -c pytorch pytorch torchvision pip install timm==0.3.2
-
Next, install some other dependencies that are required by SPViT:
pip install tensorboardX tensorboard
-
Please refer to the DeiT repository to prepare the standard ImageNet dataset, then link the ImageNet dataset under the
data
folder:$ tree data imagenet ├── train │ ├── class1 │ │ ├── img1.jpeg │ │ ├── img2.jpeg │ │ └── ... │ ├── class2 │ │ ├── img3.jpeg │ │ └── ... │ └── ... └── val ├── class1 │ ├── img4.jpeg │ ├── img5.jpeg │ └── ... ├── class2 │ ├── img6.jpeg │ └── ... └── ...
-
We start searching and fine-tuneing both from the pre-trained models.
-
Since we provide training scripts for three DeiT models: DeiT-Ti, DeiT-S and DeiT-B, please download the corresponding three pre-trained models from the DeiT repository as well.
-
Next, move the downloaded pre-trained models into the following file structure:
$ tree model ├── deit_base_patch16_224-b5f2ef4d.pth ├── deit_small_patch16_224-cd65a155.pth ├── deit_tiny_patch16_224-a1311bcf.pth
-
Note that do not change the filenames for the pre-trained models as we hard-coded these filenames when tailoring and loading the pre-trained models. Feel free to modify the hard-coded parts when pruning from other pre-trained models.
To search architectures with SPViT-DeiT-Ti, run:
python -m torch.distributed.launch --nproc_per_node=4 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_ti_l200_t10_search.json
To search architectures with SPViT-DeiT-S, run:
python -m torch.distributed.launch --nproc_per_node=8 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_sm_l30_t32_search.json
To search architectures with SPViT-DeiT-B, run:
python -m torch.distributed.launch --nproc_per_node=8 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_bs_l006_t100_search.json
You can start fine-tuning from either your own searched architectures or from our provided architectures by modifying and assigning the MSA indicators in assigned_indicators
and the FFN indicators in searching_model
.
To fine-tune the architectures searched by SPViT-DeiT-Ti, run:
python -m torch.distributed.launch --nproc_per_node=4 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_ti_l200_t10_ft.json
To fine-tune the architectures with SPViT-DeiT-S, run:
python -m torch.distributed.launch --nproc_per_node=8 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_sm_l30_t32_ft.json
To fine-tune the architectures with SPViT-DeiT-B, run:
python -m torch.distributed.launch --nproc_per_node=8 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_bs_l006_t100_ft.json
We provide several examples for evaluating pre-trained SPViT models.
To evaluate SPViT-DeiT-Ti pre-trained models, run:
python -m torch.distributed.launch --nproc_per_node=1 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_ti_l200_t10_ft.json --resume [PRE-TRAINED MODEL PATH] --eval
To evaluate SPViT-DeiT-S pre-trained models, run:
python -m torch.distributed.launch --nproc_per_node=1 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_sm_l30_t32_ft.json --resume [PRE-TRAINED MODEL PATH] --eval
To evaluate SPViT-DeiT-B pre-trained models, run:
python -m torch.distributed.launch --nproc_per_node=1 --master_port=3146 --use_env main_pruning.py --config config/spvit_deit_bs_l006_t100_ft.json --resume [PRE-TRAINED MODEL PATH] --eval
After fine-tuning, you can optimize your checkpoint to a smaller size with the following code:
python post_training_optimize_checkpoint.py YOUR_CHECKPOINT_PATH
The optimized checkpoint can be evaluated by replacing UnifiedAttention
with UnifiedAttentionParamOpt
and we have provided an example in SPViT_DeiT/config/spvit_deit_bs_l008_t60_ft_param_opt.json
.
- [x] Release code.
- [x] Release pre-trained models.