From df6edd7f5a53cd6fa3a5c979fddf542fa926c671 Mon Sep 17 00:00:00 2001 From: takuoko Date: Thu, 28 Apr 2022 10:35:17 +0900 Subject: [PATCH] [Feature] Support VAN. (#739) * add van * fix config * add metafile * add test * model convert script * fix review * fix lint * fix the configs and improve docs * rm debug lines * add VAN into api Co-authored-by: Yu Zhaohui <1105212286@qq.com> --- .gitignore | 1 + README.md | 1 + README_zh-CN.md | 1 + configs/_base_/models/van/van_base.py | 13 + configs/_base_/models/van/van_large.py | 13 + configs/_base_/models/van/van_small.py | 21 + configs/_base_/models/van/van_tiny.py | 21 + configs/van/README.md | 37 ++ configs/van/metafile.yml | 70 +++ configs/van/van-base_8xb128_in1k.py | 61 +++ configs/van/van-large_8xb128_in1k.py | 61 +++ configs/van/van-small_8xb128_in1k.py | 61 +++ configs/van/van-tiny_8xb128_in1k.py | 61 +++ docs/en/api/models.rst | 1 + docs/en/model_zoo.md | 4 + docs/zh_CN/changelog.md | 2 +- docs/zh_CN/model_zoo.md | 2 +- mmcls/models/backbones/__init__.py | 47 +- mmcls/models/backbones/van.py | 434 +++++++++++++++++++ model-index.yml | 1 + tests/test_models/test_backbones/test_van.py | 188 ++++++++ tools/convert_models/van2mmcls.py | 65 +++ 22 files changed, 1126 insertions(+), 40 deletions(-) create mode 100644 configs/_base_/models/van/van_base.py create mode 100644 configs/_base_/models/van/van_large.py create mode 100644 configs/_base_/models/van/van_small.py create mode 100644 configs/_base_/models/van/van_tiny.py create mode 100644 configs/van/README.md create mode 100644 configs/van/metafile.yml create mode 100644 configs/van/van-base_8xb128_in1k.py create mode 100644 configs/van/van-large_8xb128_in1k.py create mode 100644 configs/van/van-small_8xb128_in1k.py create mode 100644 configs/van/van-tiny_8xb128_in1k.py create mode 100644 mmcls/models/backbones/van.py create mode 100644 tests/test_models/test_backbones/test_van.py create mode 100644 tools/convert_models/van2mmcls.py diff --git a/.gitignore b/.gitignore index ea4657e6458..190d49e743c 100644 --- a/.gitignore +++ b/.gitignore @@ -122,6 +122,7 @@ venv.bak/ *.log.json /work_dirs /mmcls/.mim +.DS_Store # Pytorch *.pth diff --git a/README.md b/README.md index 0e1ca76df0d..a0ff08fd49e 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea - [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet) - [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext) - [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet) +- [x] [VAN](https://github.com/open-mmlab/mmclassification/tree/master/configs/van) - [x] [ConvMixer](https://github.com/open-mmlab/mmclassification/tree/master/configs/convmixer) - [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/cspnet) - [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/poolformer) diff --git a/README_zh-CN.md b/README_zh-CN.md index f80d799228c..0354d57861c 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -136,6 +136,7 @@ pip3 install -e . - [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet) - [x] [ConvNeXt](https://github.com/open-mmlab/mmclassification/tree/master/configs/convnext) - [x] [HRNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hrnet) +- [x] [VAN](https://github.com/open-mmlab/mmclassification/tree/master/configs/van) - [x] [ConvMixer](https://github.com/open-mmlab/mmclassification/tree/master/configs/convmixer) - [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/cspnet) - [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/poolformer) diff --git a/configs/_base_/models/van/van_base.py b/configs/_base_/models/van/van_base.py new file mode 100644 index 00000000000..006459255f8 --- /dev/null +++ b/configs/_base_/models/van/van_base.py @@ -0,0 +1,13 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='VAN', arch='base', drop_path_rate=0.1), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=512, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + cal_acc=False)) diff --git a/configs/_base_/models/van/van_large.py b/configs/_base_/models/van/van_large.py new file mode 100644 index 00000000000..4ebafabdaaf --- /dev/null +++ b/configs/_base_/models/van/van_large.py @@ -0,0 +1,13 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='VAN', arch='large', drop_path_rate=0.2), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=512, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + cal_acc=False)) diff --git a/configs/_base_/models/van/van_small.py b/configs/_base_/models/van/van_small.py new file mode 100644 index 00000000000..320e90afdc8 --- /dev/null +++ b/configs/_base_/models/van/van_small.py @@ -0,0 +1,21 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='VAN', arch='small', drop_path_rate=0.1), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=512, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + cal_acc=False), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict(augments=[ + dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), + dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) + ])) diff --git a/configs/_base_/models/van/van_tiny.py b/configs/_base_/models/van/van_tiny.py new file mode 100644 index 00000000000..42791ac3beb --- /dev/null +++ b/configs/_base_/models/van/van_tiny.py @@ -0,0 +1,21 @@ +# model settings +model = dict( + type='ImageClassifier', + backbone=dict(type='VAN', arch='tiny', drop_path_rate=0.1), + neck=dict(type='GlobalAveragePooling'), + head=dict( + type='LinearClsHead', + num_classes=1000, + in_channels=256, + init_cfg=None, # suppress the default init_cfg of LinearClsHead. + loss=dict( + type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), + cal_acc=False), + init_cfg=[ + dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.), + dict(type='Constant', layer='LayerNorm', val=1., bias=0.) + ], + train_cfg=dict(augments=[ + dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), + dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) + ])) diff --git a/configs/van/README.md b/configs/van/README.md new file mode 100644 index 00000000000..b356621c4b9 --- /dev/null +++ b/configs/van/README.md @@ -0,0 +1,37 @@ +# Visual Attention Network + +> [Visual Attention Network](https://arxiv.org/pdf/2202.09741v2.pdf) + + +## Abstract + +While originally designed for natural language processing (NLP) tasks, the self-attention mechanism has recently taken various computer vision areas by storm. However, the 2D nature of images brings three challenges for applying self-attention in computer vision. (1) Treating images as 1D sequences neglects their 2D structures. (2) The quadratic complexity is too expensive for high-resolution images. (3) It only captures spatial adaptability but ignores channel adaptability. In this paper, we propose a novel large kernel attention (LKA) module to enable self-adaptive and long-range correlations in self-attention while avoiding the above issues. We further introduce a novel neural network based on LKA, namely Visual Attention Network (VAN). While extremely simple and efficient, VAN outperforms the state-of-the-art vision transformers and convolutional neural networks with a large margin in extensive experiments, including image classification, object detection, semantic segmentation, instance segmentation, etc. + +
+ +
+ + +## Results and models + +### ImageNet-1k + +| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | +|:---------:|:------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:------:|:--------:| +| VAN-T\* | From scratch | 224x224 | 4.11 | 0.88 | 75.41 | 93.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-tiny_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-tiny_8xb128_in1k_20220427-8ac0feec.pth) | +| VAN-S\* | From scratch | 224x224 | 13.86 | 2.52 | 81.01 | 95.63 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-small_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-small_8xb128_in1k_20220427-bd6a9edd.pth) | +| VAN-B\* | From scratch | 224x224 | 26.58 | 5.03 | 82.80 | 96.21 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-base_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-base_8xb128_in1k_20220427-5275471d.pth) | +| VAN-L\* | From scratch | 224x224 | 44.77 | 8.99 | 83.86 | 96.73 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-large_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-large_8xb128_in1k_20220427-56159105.pth) | + +*Models with \* are converted from [the official repo](https://github.com/Visual-Attention-Network/VAN-Classification). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results. + +## Citation + +``` +@article{guo2022visual, + title={Visual Attention Network}, + author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min}, + journal={arXiv preprint arXiv:2202.09741}, + year={2022} +} +``` diff --git a/configs/van/metafile.yml b/configs/van/metafile.yml new file mode 100644 index 00000000000..26b40558cbe --- /dev/null +++ b/configs/van/metafile.yml @@ -0,0 +1,70 @@ +Collections: + - Name: Visual-Attention-Network + Metadata: + Training Data: ImageNet-1k + Training Techniques: + - AdamW + - Weight Decay + Architecture: + - Visual Attention Network + Paper: + URL: https://arxiv.org/pdf/2202.09741v2.pdf + Title: "Visual Attention Network" + README: configs/van/README.md + Code: + URL: https://github.com/open-mmlab/mmclassification/blob/v0.23.0/mmcls/models/backbones/van.py + Version: v0.23.0 + +Models: + - Name: van-tiny_8xb128_in1k + Metadata: + FLOPs: 4110000 # 4.11M + Parameters: 880000000 # 0.88G + In Collection: Visual-Attention-Network + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 75.41 + Top 5 Accuracy: 93.02 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/van/van-tiny_8xb128_in1k_20220427-8ac0feec.pth + Config: configs/van/van-tiny_8xb128_in1k.py + - Name: van-small_8xb128_in1k + Metadata: + FLOPs: 13860000 # 13.86M + Parameters: 2520000000 # 2.52G + In Collection: Visual-Attention-Network + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 81.01 + Top 5 Accuracy: 95.63 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/van/van-small_8xb128_in1k_20220427-bd6a9edd.pth + Config: configs/van/van-small_8xb128_in1k.py + - Name: van-base_8xb128_in1k + Metadata: + FLOPs: 26580000 # 26.58M + Parameters: 5030000000 # 5.03G + In Collection: Visual-Attention-Network + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 82.80 + Top 5 Accuracy: 96.21 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/van/van-base_8xb128_in1k_20220427-5275471d.pth + Config: configs/van/van-base_8xb128_in1k.py + - Name: van-large_8xb128_in1k + Metadata: + FLOPs: 44770000 # 44.77 M + Parameters: 8990000000 # 8.99G + In Collection: Visual-Attention-Network + Results: + - Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 83.86 + Top 5 Accuracy: 96.73 + Task: Image Classification + Weights: https://download.openmmlab.com/mmclassification/v0/van/van-large_8xb128_in1k_20220427-56159105.pth + Config: configs/van/van-large_8xb128_in1k.py diff --git a/configs/van/van-base_8xb128_in1k.py b/configs/van/van-base_8xb128_in1k.py new file mode 100644 index 00000000000..704f111bf51 --- /dev/null +++ b/configs/van/van-base_8xb128_in1k.py @@ -0,0 +1,61 @@ +_base_ = [ + '../_base_/models/van/van_base.py', + '../_base_/datasets/imagenet_bs64_swin_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py' +] + +# Note that the mean and variance used here are different from other configs +img_norm_cfg = dict( + mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + size=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies={{_base_.rand_increasing_policies}}, + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]], + interpolation='bicubic')), + dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=img_norm_cfg['mean'][::-1], + fill_std=img_norm_cfg['std'][::-1]), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=['gt_label']), + dict(type='Collect', keys=['img', 'gt_label']) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + size=(248, -1), + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) +] + +data = dict( + samples_per_gpu=128, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/configs/van/van-large_8xb128_in1k.py b/configs/van/van-large_8xb128_in1k.py new file mode 100644 index 00000000000..b55aff165ef --- /dev/null +++ b/configs/van/van-large_8xb128_in1k.py @@ -0,0 +1,61 @@ +_base_ = [ + '../_base_/models/van/van_large.py', + '../_base_/datasets/imagenet_bs64_swin_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py' +] + +# Note that the mean and variance used here are different from other configs +img_norm_cfg = dict( + mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + size=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies={{_base_.rand_increasing_policies}}, + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]], + interpolation='bicubic')), + dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=img_norm_cfg['mean'][::-1], + fill_std=img_norm_cfg['std'][::-1]), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=['gt_label']), + dict(type='Collect', keys=['img', 'gt_label']) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + size=(248, -1), + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) +] + +data = dict( + samples_per_gpu=128, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/configs/van/van-small_8xb128_in1k.py b/configs/van/van-small_8xb128_in1k.py new file mode 100644 index 00000000000..3b83e25ab8c --- /dev/null +++ b/configs/van/van-small_8xb128_in1k.py @@ -0,0 +1,61 @@ +_base_ = [ + '../_base_/models/van/van_small.py', + '../_base_/datasets/imagenet_bs64_swin_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py' +] + +# Note that the mean and variance used here are different from other configs +img_norm_cfg = dict( + mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + size=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies={{_base_.rand_increasing_policies}}, + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]], + interpolation='bicubic')), + dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=img_norm_cfg['mean'][::-1], + fill_std=img_norm_cfg['std'][::-1]), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=['gt_label']), + dict(type='Collect', keys=['img', 'gt_label']) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + size=(248, -1), + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) +] + +data = dict( + samples_per_gpu=128, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/configs/van/van-tiny_8xb128_in1k.py b/configs/van/van-tiny_8xb128_in1k.py new file mode 100644 index 00000000000..1e001c1c329 --- /dev/null +++ b/configs/van/van-tiny_8xb128_in1k.py @@ -0,0 +1,61 @@ +_base_ = [ + '../_base_/models/van/van_tiny.py', + '../_base_/datasets/imagenet_bs64_swin_224.py', + '../_base_/schedules/imagenet_bs1024_adamw_swin.py', + '../_base_/default_runtime.py' +] + +# Note that the mean and variance used here are different from other configs +img_norm_cfg = dict( + mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='RandomResizedCrop', + size=224, + backend='pillow', + interpolation='bicubic'), + dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), + dict( + type='RandAugment', + policies={{_base_.rand_increasing_policies}}, + num_policies=2, + total_level=10, + magnitude_level=9, + magnitude_std=0.5, + hparams=dict( + pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]], + interpolation='bicubic')), + dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), + dict( + type='RandomErasing', + erase_prob=0.25, + mode='rand', + min_area_ratio=0.02, + max_area_ratio=1 / 3, + fill_color=img_norm_cfg['mean'][::-1], + fill_std=img_norm_cfg['std'][::-1]), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='ToTensor', keys=['gt_label']), + dict(type='Collect', keys=['img', 'gt_label']) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='Resize', + size=(248, -1), + backend='pillow', + interpolation='bicubic'), + dict(type='CenterCrop', crop_size=224), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) +] + +data = dict( + samples_per_gpu=128, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst index 7b9530246e4..687b8009340 100644 --- a/docs/en/api/models.rst +++ b/docs/en/api/models.rst @@ -83,6 +83,7 @@ Backbones T2T_ViT TIMMBackbone TNT + VAN VGG VisionTransformer diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md index 707ad094d4b..6451b70add0 100644 --- a/docs/en/model_zoo.md +++ b/docs/en/model_zoo.md @@ -133,6 +133,10 @@ The ResNet family models below are trained by standard data augmentations, i.e., | CSPDarkNet50\* | 27.64 | 5.04 | 80.05 | 95.07 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/cspnet/cspdarknet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/cspnet/cspdarknet50_3rdparty_8xb32_in1k_20220329-bd275287.pth) | | CSPResNet50\* | 21.62 | 3.48 | 79.55 | 94.68 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/cspnet/cspresnet50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/cspnet/cspresnet50_3rdparty_8xb32_in1k_20220329-dd6dddfb.pth) | | CSPResNeXt50\* | 20.57 | 3.11 | 79.96 | 94.96 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/cspnet/cspresnext50_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/cspnet/cspresnext50_3rdparty_8xb32_in1k_20220329-2cc84d21.pth) | +| VAN-T\* | 4.11 | 0.88 | 75.41 | 93.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-tiny_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-tiny_8xb128_in1k_20220427-8ac0feec.pth) | +| VAN-S\* | 13.86 | 2.52 | 81.01 | 95.63 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-small_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-small_8xb128_in1k_20220427-bd6a9edd.pth) | +| VAN-B\* | 26.58 | 5.03 | 82.80 | 96.21 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-base_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-base_8xb128_in1k_20220427-5275471d.pth) | +| VAN-L\* | 44.77 | 8.99 | 83.86 | 96.73 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-large_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-large_8xb128_in1k_20220427-56159105.pth) | *Models with \* are converted from other repos, others are trained by ourselves.* diff --git a/docs/zh_CN/changelog.md b/docs/zh_CN/changelog.md index 6b731cd0d50..4444256e27e 120000 --- a/docs/zh_CN/changelog.md +++ b/docs/zh_CN/changelog.md @@ -1 +1 @@ -../en/changelog.md \ No newline at end of file +../en/changelog.md diff --git a/docs/zh_CN/model_zoo.md b/docs/zh_CN/model_zoo.md index 013a9acc839..df35077440d 120000 --- a/docs/zh_CN/model_zoo.md +++ b/docs/zh_CN/model_zoo.md @@ -1 +1 @@ -../en/model_zoo.md \ No newline at end of file +../en/model_zoo.md diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index 2d72b662d97..13f565f4132 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -29,46 +29,17 @@ from .timm_backbone import TIMMBackbone from .tnt import TNT from .twins import PCPVT, SVT +from .van import VAN from .vgg import VGG from .vision_transformer import VisionTransformer __all__ = [ - 'LeNet5', - 'AlexNet', - 'VGG', - 'RegNet', - 'ResNet', - 'ResNeXt', - 'ResNetV1d', - 'ResNeSt', - 'ResNet_CIFAR', - 'SEResNet', - 'SEResNeXt', - 'ShuffleNetV1', - 'ShuffleNetV2', - 'MobileNetV2', - 'MobileNetV3', - 'VisionTransformer', - 'SwinTransformer', - 'TNT', - 'TIMMBackbone', - 'T2T_ViT', - 'Res2Net', - 'RepVGG', - 'Conformer', - 'MlpMixer', - 'DistilledVisionTransformer', - 'PCPVT', - 'SVT', - 'EfficientNet', - 'ConvNeXt', - 'HRNet', - 'ResNetV1c', - 'ConvMixer', - 'CSPDarkNet', - 'CSPResNet', - 'CSPResNeXt', - 'CSPNet', - 'RepMLPNet', - 'PoolFormer', + 'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', + 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', + 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', + 'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG', + 'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT', + 'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c', 'ConvMixer', + 'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet', 'RepMLPNet', + 'PoolFormer', 'VAN' ] diff --git a/mmcls/models/backbones/van.py b/mmcls/models/backbones/van.py new file mode 100644 index 00000000000..4022cc0d20b --- /dev/null +++ b/mmcls/models/backbones/van.py @@ -0,0 +1,434 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmcv.cnn.bricks.transformer import PatchEmbed +from mmcv.runner import BaseModule, ModuleList +from mmcv.utils.parrots_wrapper import _BatchNorm + +from ..builder import BACKBONES +from .base_backbone import BaseBackbone + + +class MixFFN(BaseModule): + """An implementation of MixFFN of VAN. Refer to + mmdetection/mmdet/models/backbones/pvt.py. + + The differences between MixFFN & FFN: + 1. Use 1X1 Conv to replace Linear layer. + 2. Introduce 3X3 Depth-wise Conv to encode positional information. + + Args: + embed_dims (int): The feature dimension. Same as + `MultiheadAttention`. + feedforward_channels (int): The hidden dimension of FFNs. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + ffn_drop (float, optional): Probability of an element to be + zeroed in FFN. Default 0.0. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + feedforward_channels, + act_cfg=dict(type='GELU'), + ffn_drop=0., + init_cfg=None): + super(MixFFN, self).__init__(init_cfg=init_cfg) + + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.act_cfg = act_cfg + + self.fc1 = Conv2d( + in_channels=embed_dims, + out_channels=feedforward_channels, + kernel_size=1) + self.dwconv = Conv2d( + in_channels=feedforward_channels, + out_channels=feedforward_channels, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=feedforward_channels) + self.act = build_activation_layer(act_cfg) + self.fc2 = Conv2d( + in_channels=feedforward_channels, + out_channels=embed_dims, + kernel_size=1) + self.drop = nn.Dropout(ffn_drop) + + def forward(self, x): + x = self.fc1(x) + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class LKA(BaseModule): + """Large Kernel Attention(LKA) of VAN. + + .. code:: text + DW_conv (depth-wise convolution) + | + | + DW_D_conv (depth-wise dilation convolution) + | + | + Transition Convolution (1×1 convolution) + + Args: + embed_dims (int): Number of input channels. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, embed_dims, init_cfg=None): + super(LKA, self).__init__(init_cfg=init_cfg) + + # a spatial local convolution (depth-wise convolution) + self.DW_conv = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=5, + padding=2, + groups=embed_dims) + + # a spatial long-range convolution (depth-wise dilation convolution) + self.DW_D_conv = Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=7, + stride=1, + padding=9, + groups=embed_dims, + dilation=3) + + self.conv1 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + + def forward(self, x): + u = x.clone() + attn = self.DW_conv(x) + attn = self.DW_D_conv(attn) + attn = self.conv1(attn) + + return u * attn + + +class SpatialAttention(BaseModule): + """Basic attention module in VANBloack. + + Args: + embed_dims (int): Number of input channels. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, embed_dims, act_cfg=dict(type='GELU'), init_cfg=None): + super(SpatialAttention, self).__init__(init_cfg=init_cfg) + + self.proj_1 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + self.activation = build_activation_layer(act_cfg) + self.spatial_gating_unit = LKA(embed_dims) + self.proj_2 = Conv2d( + in_channels=embed_dims, out_channels=embed_dims, kernel_size=1) + + def forward(self, x): + shorcut = x.clone() + x = self.proj_1(x) + x = self.activation(x) + x = self.spatial_gating_unit(x) + x = self.proj_2(x) + x = x + shorcut + return x + + +class VANBlock(BaseModule): + """A block of VAN. + + Args: + embed_dims (int): Number of input channels. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + act_cfg (dict, optional): The activation config for FFNs. + Default: dict(type='GELU'). + layer_scale_init_value (float): Init value for Layer Scale. + Defaults to 1e-2. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__(self, + embed_dims, + ffn_ratio=4., + drop_rate=0., + drop_path_rate=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='BN', eps=1e-5), + layer_scale_init_value=1e-2, + init_cfg=None): + super(VANBlock, self).__init__(init_cfg=init_cfg) + self.out_channels = embed_dims + + self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1] + self.attn = SpatialAttention(embed_dims, act_cfg=act_cfg) + self.drop_path = DropPath( + drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1] + mlp_hidden_dim = int(embed_dims * ffn_ratio) + self.mlp = MixFFN( + embed_dims=embed_dims, + feedforward_channels=mlp_hidden_dim, + act_cfg=act_cfg, + ffn_drop=drop_rate) + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((embed_dims)), + requires_grad=True) if layer_scale_init_value > 0 else None + + def forward(self, x): + identity = x + x = self.norm1(x) + x = self.attn(x) + if self.layer_scale_1 is not None: + x = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * x + x = identity + self.drop_path(x) + + identity = x + x = self.norm2(x) + x = self.mlp(x) + if self.layer_scale_2 is not None: + x = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * x + x = identity + self.drop_path(x) + + return x + + +class VANPatchEmbed(PatchEmbed): + """Image to Patch Embedding of VAN. + + The differences between VANPatchEmbed & PatchEmbed: + 1. Use BN. + 2. Do not use 'flatten' and 'transpose'. + """ + + def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs): + super(VANPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs) + + def forward(self, x): + """ + Args: + x (Tensor): Has shape (B, C, H, W). In most case, C is 3. + Returns: + tuple: Contains merged results and its spatial shape. + - x (Tensor): Has shape (B, out_h * out_w, embed_dims) + - out_size (tuple[int]): Spatial shape of x, arrange as + (out_h, out_w). + """ + + if self.adaptive_padding: + x = self.adaptive_padding(x) + + x = self.projection(x) + out_size = (x.shape[2], x.shape[3]) + if self.norm is not None: + x = self.norm(x) + return x, out_size + + +@BACKBONES.register_module() +class VAN(BaseBackbone): + """Visual Attention Network. + + A PyTorch implement of : `Visual Attention Network + `_ + + Inspiration from + https://github.com/Visual-Attention-Network/VAN-Classification + + Args: + arch (str | dict): Visual Attention Network architecture. + If use string, choose from 'tiny', 'small', 'base' and 'large'. + If use dict, it should have below keys: + + - **embed_dims** (List[int]): The dimensions of embedding. + - **depths** (List[int]): The number of blocks in each stage. + - **ffn_ratios** (List[int]): The number of expansion ratio of + feedforward network hidden layer channels. + + Defaults to 'tiny'. + patch_sizes (List[int | tuple]): The patch size in patch embeddings. + Defaults to [7, 3, 3, 3]. + in_channels (int): The num of input channels. Defaults to 3. + drop_rate (float): Dropout rate after embedding. Defaults to 0. + drop_path_rate (float): Stochastic depth rate. Defaults to 0.1. + out_indices (Sequence[int]): Output from which stages. + Default: ``(3, )``. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. Defaults to -1. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Defaults to False. + norm_cfg (dict): Config dict for normalization layer for all output + features. Defaults to ``dict(type='LN')`` + block_cfgs (Sequence[dict] | dict): The extra config of each block. + Defaults to empty dicts. + init_cfg (dict, optional): The Config for initialization. + Defaults to None. + + Examples: + >>> from mmcls.models import VAN + >>> import torch + >>> cfg = dict(arch='tiny') + >>> model = VAN(**cfg) + >>> inputs = torch.rand(1, 3, 224, 224) + >>> outputs = model(inputs) + >>> for out in outputs: + >>> print(out.size()) + (1, 256, 7, 7) + """ + arch_zoo = { + **dict.fromkeys(['t', 'tiny'], + {'embed_dims': [32, 64, 160, 256], + 'depths': [3, 3, 5, 2], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['s', 'small'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [2, 2, 4, 2], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['b', 'base'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 3, 12, 3], + 'ffn_ratios': [8, 8, 4, 4]}), + **dict.fromkeys(['l', 'large'], + {'embed_dims': [64, 128, 320, 512], + 'depths': [3, 5, 27, 3], + 'ffn_ratios': [8, 8, 4, 4]}), + } # yapf: disable + + def __init__(self, + arch='tiny', + patch_sizes=[7, 3, 3, 3], + in_channels=3, + drop_rate=0., + drop_path_rate=0., + out_indices=(3, ), + frozen_stages=-1, + norm_eval=False, + norm_cfg=dict(type='LN'), + block_cfgs=dict(), + init_cfg=None): + super(VAN, self).__init__(init_cfg=init_cfg) + + if isinstance(arch, str): + arch = arch.lower() + assert arch in set(self.arch_zoo), \ + f'Arch {arch} is not in default archs {set(self.arch_zoo)}' + self.arch_settings = self.arch_zoo[arch] + else: + essential_keys = {'embed_dims', 'depths', 'ffn_ratios'} + assert isinstance(arch, dict) and set(arch) == essential_keys, \ + f'Custom arch needs a dict with keys {essential_keys}' + self.arch_settings = arch + + self.embed_dims = self.arch_settings['embed_dims'] + self.depths = self.arch_settings['depths'] + self.ffn_ratios = self.arch_settings['ffn_ratios'] + self.num_stages = len(self.depths) + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + + total_depth = sum(self.depths) + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, total_depth) + ] # stochastic depth decay rule + + cur_block_idx = 0 + for i, depth in enumerate(self.depths): + patch_embed = VANPatchEmbed( + in_channels=in_channels if i == 0 else self.embed_dims[i - 1], + input_size=None, + embed_dims=self.embed_dims[i], + kernel_size=patch_sizes[i], + stride=patch_sizes[i] // 2 + 1, + padding=(patch_sizes[i] // 2, patch_sizes[i] // 2), + norm_cfg=dict(type='BN')) + + blocks = ModuleList([ + VANBlock( + embed_dims=self.embed_dims[i], + ffn_ratio=self.ffn_ratios[i], + drop_rate=drop_rate, + drop_path_rate=dpr[cur_block_idx + j], + **block_cfgs) for j in range(depth) + ]) + cur_block_idx += depth + norm = build_norm_layer(norm_cfg, self.embed_dims[i])[1] + + self.add_module(f'patch_embed{i + 1}', patch_embed) + self.add_module(f'blocks{i + 1}', blocks) + self.add_module(f'norm{i + 1}', norm) + + def train(self, mode=True): + super(VAN, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + def _freeze_stages(self): + for i in range(0, self.frozen_stages + 1): + # freeze patch embed + m = getattr(self, f'patch_embed{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze blocks + m = getattr(self, f'blocks{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + # freeze norm + m = getattr(self, f'norm{i + 1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + outs = [] + for i in range(self.num_stages): + patch_embed = getattr(self, f'patch_embed{i + 1}') + blocks = getattr(self, f'blocks{i + 1}') + norm = getattr(self, f'norm{i + 1}') + x, hw_shape = patch_embed(x) + for block in blocks: + x = block(x) + x = x.flatten(2).transpose(1, 2) + x = norm(x) + x = x.reshape(-1, *hw_shape, + block.out_channels).permute(0, 3, 1, 2).contiguous() + if i in self.out_indices: + outs.append(x) + + return tuple(outs) diff --git a/model-index.yml b/model-index.yml index 5d5c767a4d9..81932fd6ac5 100644 --- a/model-index.yml +++ b/model-index.yml @@ -22,6 +22,7 @@ Import: - configs/hrnet/metafile.yml - configs/repmlp/metafile.yml - configs/wrn/metafile.yml + - configs/van/metafile.yml - configs/cspnet/metafile.yml - configs/convmixer/metafile.yml - configs/poolformer/metafile.yml diff --git a/tests/test_models/test_backbones/test_van.py b/tests/test_models/test_backbones/test_van.py new file mode 100644 index 00000000000..136ce973737 --- /dev/null +++ b/tests/test_models/test_backbones/test_van.py @@ -0,0 +1,188 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from copy import deepcopy +from itertools import chain +from unittest import TestCase + +import torch +from mmcv.utils.parrots_wrapper import _BatchNorm +from torch import nn + +from mmcls.models.backbones import VAN + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +class TestVAN(TestCase): + + def setUp(self): + self.cfg = dict(arch='t', drop_path_rate=0.1) + + def test_arch(self): + # Test invalid default arch + with self.assertRaisesRegex(AssertionError, 'not in default archs'): + cfg = deepcopy(self.cfg) + cfg['arch'] = 'unknown' + VAN(**cfg) + + # Test invalid custom arch + with self.assertRaisesRegex(AssertionError, 'Custom arch needs'): + cfg = deepcopy(self.cfg) + cfg['arch'] = { + 'embed_dims': [32, 64, 160, 256], + 'ffn_ratios': [8, 8, 4, 4], + } + VAN(**cfg) + + # Test custom arch + cfg = deepcopy(self.cfg) + embed_dims = [32, 64, 160, 256] + depths = [3, 3, 5, 2] + ffn_ratios = [8, 8, 4, 4] + cfg['arch'] = { + 'embed_dims': embed_dims, + 'depths': depths, + 'ffn_ratios': ffn_ratios + } + model = VAN(**cfg) + + for i in range(len(depths)): + stage = getattr(model, f'blocks{i + 1}') + self.assertEqual(stage[-1].out_channels, embed_dims[i]) + self.assertEqual(len(stage), depths[i]) + + def test_init_weights(self): + # test weight init cfg + cfg = deepcopy(self.cfg) + cfg['init_cfg'] = [ + dict( + type='Kaiming', + layer='Conv2d', + mode='fan_in', + nonlinearity='linear') + ] + model = VAN(**cfg) + ori_weight = model.patch_embed1.projection.weight.clone().detach() + + model.init_weights() + initialized_weight = model.patch_embed1.projection.weight + self.assertFalse(torch.allclose(ori_weight, initialized_weight)) + + def test_forward(self): + imgs = torch.randn(3, 3, 224, 224) + + cfg = deepcopy(self.cfg) + model = VAN(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + feat = outs[-1] + self.assertEqual(feat.shape, (3, 256, 7, 7)) + + # test with patch_sizes + cfg = deepcopy(self.cfg) + cfg['patch_sizes'] = [7, 5, 5, 5] + model = VAN(**cfg) + outs = model(torch.randn(3, 3, 224, 224)) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + feat = outs[-1] + self.assertEqual(feat.shape, (3, 256, 3, 3)) + + # test multiple output indices + cfg = deepcopy(self.cfg) + cfg['out_indices'] = (0, 1, 2, 3) + model = VAN(**cfg) + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 4) + for emb_size, stride, out in zip([32, 64, 160, 256], [1, 2, 4, 8], + outs): + self.assertEqual(out.shape, + (3, emb_size, 56 // stride, 56 // stride)) + + # test with dynamic input shape + imgs1 = torch.randn(3, 3, 224, 224) + imgs2 = torch.randn(3, 3, 256, 256) + imgs3 = torch.randn(3, 3, 256, 309) + cfg = deepcopy(self.cfg) + model = VAN(**cfg) + for imgs in [imgs1, imgs2, imgs3]: + outs = model(imgs) + self.assertIsInstance(outs, tuple) + self.assertEqual(len(outs), 1) + feat = outs[-1] + expect_feat_shape = (math.ceil(imgs.shape[2] / 32), + math.ceil(imgs.shape[3] / 32)) + self.assertEqual(feat.shape, (3, 256, *expect_feat_shape)) + + def test_structure(self): + # test drop_path_rate decay + cfg = deepcopy(self.cfg) + cfg['drop_path_rate'] = 0.2 + model = VAN(**cfg) + depths = model.arch_settings['depths'] + stages = [model.blocks1, model.blocks2, model.blocks3, model.blocks4] + blocks = chain(*[stage for stage in stages]) + total_depth = sum(depths) + dpr = [ + x.item() + for x in torch.linspace(0, cfg['drop_path_rate'], total_depth) + ] + for i, (block, expect_prob) in enumerate(zip(blocks, dpr)): + if expect_prob == 0: + assert isinstance(block.drop_path, nn.Identity) + else: + self.assertAlmostEqual(block.drop_path.drop_prob, expect_prob) + + # test VAN with norm_eval=True + cfg = deepcopy(self.cfg) + cfg['norm_eval'] = True + cfg['norm_cfg'] = dict(type='BN') + model = VAN(**cfg) + model.init_weights() + model.train() + self.assertTrue(check_norm_state(model.modules(), False)) + + # test VAN with first stage frozen. + cfg = deepcopy(self.cfg) + frozen_stages = 0 + cfg['frozen_stages'] = frozen_stages + cfg['out_indices'] = (0, 1, 2, 3) + model = VAN(**cfg) + model.init_weights() + model.train() + + # the patch_embed and first stage should not require grad. + self.assertFalse(model.patch_embed1.training) + for param in model.patch_embed1.parameters(): + self.assertFalse(param.requires_grad) + for i in range(frozen_stages + 1): + patch = getattr(model, f'patch_embed{i+1}') + for param in patch.parameters(): + self.assertFalse(param.requires_grad) + blocks = getattr(model, f'blocks{i + 1}') + for param in blocks.parameters(): + self.assertFalse(param.requires_grad) + norm = getattr(model, f'norm{i + 1}') + for param in norm.parameters(): + self.assertFalse(param.requires_grad) + + # the second stage should require grad. + for i in range(frozen_stages + 1, 4): + patch = getattr(model, f'patch_embed{i + 1}') + for param in patch.parameters(): + self.assertTrue(param.requires_grad) + blocks = getattr(model, f'blocks{i+1}') + for param in blocks.parameters(): + self.assertTrue(param.requires_grad) + norm = getattr(model, f'norm{i + 1}') + for param in norm.parameters(): + self.assertTrue(param.requires_grad) diff --git a/tools/convert_models/van2mmcls.py b/tools/convert_models/van2mmcls.py new file mode 100644 index 00000000000..5ea7d9ca75d --- /dev/null +++ b/tools/convert_models/van2mmcls.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os.path as osp +from collections import OrderedDict + +import mmcv +import torch +from mmcv.runner import CheckpointLoader + + +def convert_van(ckpt): + + new_ckpt = OrderedDict() + + for k, v in list(ckpt.items()): + new_v = v + if k.startswith('head'): + new_k = k.replace('head.', 'head.fc.') + new_ckpt[new_k] = new_v + continue + elif k.startswith('patch_embed'): + if 'proj.' in k: + new_k = k.replace('proj.', 'projection.') + else: + new_k = k + elif k.startswith('block'): + new_k = k.replace('block', 'blocks') + if 'attn.spatial_gating_unit' in new_k: + new_k = new_k.replace('conv0', 'DW_conv') + new_k = new_k.replace('conv_spatial', 'DW_D_conv') + if 'dwconv.dwconv' in new_k: + new_k = new_k.replace('dwconv.dwconv', 'dwconv') + else: + new_k = k + + if not new_k.startswith('head'): + new_k = 'backbone.' + new_k + new_ckpt[new_k] = new_v + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys in pretrained van models to mmcls style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + weight = convert_van(state_dict) + mmcv.mkdir_or_exist(osp.dirname(args.dst)) + torch.save(weight, args.dst) + + print('Done!!') + + +if __name__ == '__main__': + main()