Skip to content

Latest commit

 

History

History
144 lines (121 loc) · 6.56 KB

README.md

File metadata and controls

144 lines (121 loc) · 6.56 KB

Training data-efficient image transformers & distillation through attention, arxiv

PaddlePaddle training/validation code and pretrained models for DeiT.

The official pytorch implementation is here.

This implementation is developed by PaddleViT.

drawing

DeiT Model Overview

Update

  • Update (2022-04-02): Add model weights trained from scratch using PaddleViT.
  • Update (2022-03-16): Code is refactored.
  • Update (2021-09-27): More weights are uploaded.
  • Update (2021-08-11): Code is released and ported weights are uploaded.

Models Zoo

Model Acc@1 Acc@5 #Params FLOPs Image Size Crop_pct Interpolation Link
deit_tiny_distilled_224 74.52 91.90 5.9M 1.1G 224 0.875 bicubic google/baidu
deit_small_distilled_224 81.17 95.41 22.4M 4.3G 224 0.875 bicubic google/baidu
deit_base_distilled_224 83.32 96.49 87.2M 17.0G 224 0.875 bicubic google/baidu
deit_base_distilled_384 85.43 97.33 87.2M 49.9G 384 1.0 bicubic google/baidu
Teacher Model Link
RegNet_Y_160 google/baidu

*The results are evaluated on ImageNet2012 validation set.

Models trained from scratch by PaddleViT

Model Acc@1 Acc@5 #Params FLOPs Image Size Crop_pct Interpolation Link Log
deit_tiny_distilled_224 74.26 91.85 5.9M 1.1G 224 0.875 bicubic google/baidu baidu

Data Preparation

ImageNet2012 dataset is used in the following file structure:

│imagenet/
├──train_list.txt
├──val_list.txt
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......
  • train_list.txt: list of relative paths and labels of training images. You can download it from: google/baidu
  • val_list.txt: list of relative paths and labels of validation images. You can download it from: google/baidu

Usage

To use the model with pretrained weights, download the .pdparam weight file and change related file paths in the following python scripts. The model config files are located in ./configs/.

For example, assume weight file is downloaded in ./deit_base_patch16_224.pdparams, to use the deit_base_patch16_224 model in python:

from config import get_config
from deit import build_vit as build_model
# config files in ./configs/
config = get_config('./configs/deit_base_patch16_224.yaml')
# build model
model = build_model(config)
# load pretrained weights
model_state_dict = paddle.load('./deit_base_patch16_224.pdparams')
model.set_state_dict(model_state_dict)

Evaluation

To evaluate DeiT model performance on ImageNet2012, run the following script using command line:

sh run_eval_multi.sh

or

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
python main_multi_gpu.py \
-cfg='./configs/deit_tiny_patch16_224.yaml' \
-dataset='imagenet2012' \
-batch_size=256 \
-data_path='/dataset/imagenet' \
-eval \
-pretrained='./deit_tiny_patch16_224.pdparams' \
-amp

Note: if you have only 1 GPU, change device number to CUDA_VISIBLE_DEVICES=0 would run the evaluation on single GPU.

Training

To train the DeiT model on ImageNet2012 with distillation, run the following script using command line:

sh run_train_multi_distill.sh

or

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
python main_multi_gpu_distill.py \
-cfg='./configs/deit_tiny_distilled_patch16_224.yaml' \
-dataset='imagenet2012' \
-batch_size=256 \
-data_path='/dataset/imagenet' \
-amp

Note: it is highly recommanded to run the training using multiple GPUs / multi-node GPUs.

Finetuning

To finetune the DeiT model on ImageNet2012, run the following script using command line:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
python main_multi_gpu_distill.py \
-cfg='./configs/deit_base_distilled_patch16_384.yaml' \
-dataset='imagenet2012' \
-batch_size=16 \
-data_path='/dataset/imagenet' \
-pretrained='./deit_base_distilled_patch16_224.pdparams' \
-amp

Note: use -pretrained argument to set the pretrained model path, you may also need to modify the hyperparams defined in config file.

Reference

@inproceedings{touvron2021training,
  title={Training data-efficient image transformers \& distillation through attention},
  author={Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and J{\'e}gou, Herv{\'e}},
  booktitle={International Conference on Machine Learning},
  pages={10347--10357},
  year={2021},
  organization={PMLR}
}