forked from open-mmlab/mmpretrain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support VAN. (open-mmlab#739)
* add van * fix config * add metafile * add test * model convert script * fix review * fix lint * fix the configs and improve docs * rm debug lines * add VAN into api Co-authored-by: Yu Zhaohui <[email protected]>
- Loading branch information
Showing
22 changed files
with
1,126 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,6 +122,7 @@ venv.bak/ | |
*.log.json | ||
/work_dirs | ||
/mmcls/.mim | ||
.DS_Store | ||
|
||
# Pytorch | ||
*.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict(type='VAN', arch='base', drop_path_rate=0.1), | ||
neck=dict(type='GlobalAveragePooling'), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=512, | ||
init_cfg=None, # suppress the default init_cfg of LinearClsHead. | ||
loss=dict( | ||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), | ||
cal_acc=False)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict(type='VAN', arch='large', drop_path_rate=0.2), | ||
neck=dict(type='GlobalAveragePooling'), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=512, | ||
init_cfg=None, # suppress the default init_cfg of LinearClsHead. | ||
loss=dict( | ||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), | ||
cal_acc=False)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict(type='VAN', arch='small', drop_path_rate=0.1), | ||
neck=dict(type='GlobalAveragePooling'), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=512, | ||
init_cfg=None, # suppress the default init_cfg of LinearClsHead. | ||
loss=dict( | ||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), | ||
cal_acc=False), | ||
init_cfg=[ | ||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.), | ||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.) | ||
], | ||
train_cfg=dict(augments=[ | ||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), | ||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) | ||
])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict(type='VAN', arch='tiny', drop_path_rate=0.1), | ||
neck=dict(type='GlobalAveragePooling'), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=256, | ||
init_cfg=None, # suppress the default init_cfg of LinearClsHead. | ||
loss=dict( | ||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), | ||
cal_acc=False), | ||
init_cfg=[ | ||
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.), | ||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.) | ||
], | ||
train_cfg=dict(augments=[ | ||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5), | ||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5) | ||
])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Visual Attention Network | ||
|
||
> [Visual Attention Network](https://arxiv.org/pdf/2202.09741v2.pdf) | ||
<!-- [ALGORITHM] --> | ||
## Abstract | ||
|
||
While originally designed for natural language processing (NLP) tasks, the self-attention mechanism has recently taken various computer vision areas by storm. However, the 2D nature of images brings three challenges for applying self-attention in computer vision. (1) Treating images as 1D sequences neglects their 2D structures. (2) The quadratic complexity is too expensive for high-resolution images. (3) It only captures spatial adaptability but ignores channel adaptability. In this paper, we propose a novel large kernel attention (LKA) module to enable self-adaptive and long-range correlations in self-attention while avoiding the above issues. We further introduce a novel neural network based on LKA, namely Visual Attention Network (VAN). While extremely simple and efficient, VAN outperforms the state-of-the-art vision transformers and convolutional neural networks with a large margin in extensive experiments, including image classification, object detection, semantic segmentation, instance segmentation, etc. | ||
|
||
<div align=center> | ||
<img src="https://user-images.githubusercontent.com/24734142/157409411-2f622ba7-553c-4702-91be-eba03f9ea04f.png" width="80%"/> | ||
</div> | ||
|
||
|
||
## Results and models | ||
|
||
### ImageNet-1k | ||
|
||
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | | ||
|:---------:|:------------:|:-----------:|:---------:|:---------:|:---------:|:---------:|:------:|:--------:| | ||
| VAN-T\* | From scratch | 224x224 | 4.11 | 0.88 | 75.41 | 93.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-tiny_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-tiny_8xb128_in1k_20220427-8ac0feec.pth) | | ||
| VAN-S\* | From scratch | 224x224 | 13.86 | 2.52 | 81.01 | 95.63 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-small_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-small_8xb128_in1k_20220427-bd6a9edd.pth) | | ||
| VAN-B\* | From scratch | 224x224 | 26.58 | 5.03 | 82.80 | 96.21 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-base_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-base_8xb128_in1k_20220427-5275471d.pth) | | ||
| VAN-L\* | From scratch | 224x224 | 44.77 | 8.99 | 83.86 | 96.73 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/van/van-large_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/van/van-large_8xb128_in1k_20220427-56159105.pth) | | ||
|
||
*Models with \* are converted from [the official repo](https://github.com/Visual-Attention-Network/VAN-Classification). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results. | ||
|
||
## Citation | ||
|
||
``` | ||
@article{guo2022visual, | ||
title={Visual Attention Network}, | ||
author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min}, | ||
journal={arXiv preprint arXiv:2202.09741}, | ||
year={2022} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
Collections: | ||
- Name: Visual-Attention-Network | ||
Metadata: | ||
Training Data: ImageNet-1k | ||
Training Techniques: | ||
- AdamW | ||
- Weight Decay | ||
Architecture: | ||
- Visual Attention Network | ||
Paper: | ||
URL: https://arxiv.org/pdf/2202.09741v2.pdf | ||
Title: "Visual Attention Network" | ||
README: configs/van/README.md | ||
Code: | ||
URL: https://github.com/open-mmlab/mmclassification/blob/v0.23.0/mmcls/models/backbones/van.py | ||
Version: v0.23.0 | ||
|
||
Models: | ||
- Name: van-tiny_8xb128_in1k | ||
Metadata: | ||
FLOPs: 4110000 # 4.11M | ||
Parameters: 880000000 # 0.88G | ||
In Collection: Visual-Attention-Network | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 75.41 | ||
Top 5 Accuracy: 93.02 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmclassification/v0/van/van-tiny_8xb128_in1k_20220427-8ac0feec.pth | ||
Config: configs/van/van-tiny_8xb128_in1k.py | ||
- Name: van-small_8xb128_in1k | ||
Metadata: | ||
FLOPs: 13860000 # 13.86M | ||
Parameters: 2520000000 # 2.52G | ||
In Collection: Visual-Attention-Network | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 81.01 | ||
Top 5 Accuracy: 95.63 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmclassification/v0/van/van-small_8xb128_in1k_20220427-bd6a9edd.pth | ||
Config: configs/van/van-small_8xb128_in1k.py | ||
- Name: van-base_8xb128_in1k | ||
Metadata: | ||
FLOPs: 26580000 # 26.58M | ||
Parameters: 5030000000 # 5.03G | ||
In Collection: Visual-Attention-Network | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 82.80 | ||
Top 5 Accuracy: 96.21 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmclassification/v0/van/van-base_8xb128_in1k_20220427-5275471d.pth | ||
Config: configs/van/van-base_8xb128_in1k.py | ||
- Name: van-large_8xb128_in1k | ||
Metadata: | ||
FLOPs: 44770000 # 44.77 M | ||
Parameters: 8990000000 # 8.99G | ||
In Collection: Visual-Attention-Network | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 83.86 | ||
Top 5 Accuracy: 96.73 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmclassification/v0/van/van-large_8xb128_in1k_20220427-56159105.pth | ||
Config: configs/van/van-large_8xb128_in1k.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
_base_ = [ | ||
'../_base_/models/van/van_base.py', | ||
'../_base_/datasets/imagenet_bs64_swin_224.py', | ||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py', | ||
'../_base_/default_runtime.py' | ||
] | ||
|
||
# Note that the mean and variance used here are different from other configs | ||
img_norm_cfg = dict( | ||
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='RandomResizedCrop', | ||
size=224, | ||
backend='pillow', | ||
interpolation='bicubic'), | ||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), | ||
dict( | ||
type='RandAugment', | ||
policies={{_base_.rand_increasing_policies}}, | ||
num_policies=2, | ||
total_level=10, | ||
magnitude_level=9, | ||
magnitude_std=0.5, | ||
hparams=dict( | ||
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]], | ||
interpolation='bicubic')), | ||
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), | ||
dict( | ||
type='RandomErasing', | ||
erase_prob=0.25, | ||
mode='rand', | ||
min_area_ratio=0.02, | ||
max_area_ratio=1 / 3, | ||
fill_color=img_norm_cfg['mean'][::-1], | ||
fill_std=img_norm_cfg['std'][::-1]), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='ToTensor', keys=['gt_label']), | ||
dict(type='Collect', keys=['img', 'gt_label']) | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='Resize', | ||
size=(248, -1), | ||
backend='pillow', | ||
interpolation='bicubic'), | ||
dict(type='CenterCrop', crop_size=224), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='Collect', keys=['img']) | ||
] | ||
|
||
data = dict( | ||
samples_per_gpu=128, | ||
train=dict(pipeline=train_pipeline), | ||
val=dict(pipeline=test_pipeline), | ||
test=dict(pipeline=test_pipeline)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
_base_ = [ | ||
'../_base_/models/van/van_large.py', | ||
'../_base_/datasets/imagenet_bs64_swin_224.py', | ||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py', | ||
'../_base_/default_runtime.py' | ||
] | ||
|
||
# Note that the mean and variance used here are different from other configs | ||
img_norm_cfg = dict( | ||
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='RandomResizedCrop', | ||
size=224, | ||
backend='pillow', | ||
interpolation='bicubic'), | ||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), | ||
dict( | ||
type='RandAugment', | ||
policies={{_base_.rand_increasing_policies}}, | ||
num_policies=2, | ||
total_level=10, | ||
magnitude_level=9, | ||
magnitude_std=0.5, | ||
hparams=dict( | ||
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]], | ||
interpolation='bicubic')), | ||
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), | ||
dict( | ||
type='RandomErasing', | ||
erase_prob=0.25, | ||
mode='rand', | ||
min_area_ratio=0.02, | ||
max_area_ratio=1 / 3, | ||
fill_color=img_norm_cfg['mean'][::-1], | ||
fill_std=img_norm_cfg['std'][::-1]), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='ToTensor', keys=['gt_label']), | ||
dict(type='Collect', keys=['img', 'gt_label']) | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='Resize', | ||
size=(248, -1), | ||
backend='pillow', | ||
interpolation='bicubic'), | ||
dict(type='CenterCrop', crop_size=224), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='Collect', keys=['img']) | ||
] | ||
|
||
data = dict( | ||
samples_per_gpu=128, | ||
train=dict(pipeline=train_pipeline), | ||
val=dict(pipeline=test_pipeline), | ||
test=dict(pipeline=test_pipeline)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
_base_ = [ | ||
'../_base_/models/van/van_small.py', | ||
'../_base_/datasets/imagenet_bs64_swin_224.py', | ||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py', | ||
'../_base_/default_runtime.py' | ||
] | ||
|
||
# Note that the mean and variance used here are different from other configs | ||
img_norm_cfg = dict( | ||
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='RandomResizedCrop', | ||
size=224, | ||
backend='pillow', | ||
interpolation='bicubic'), | ||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), | ||
dict( | ||
type='RandAugment', | ||
policies={{_base_.rand_increasing_policies}}, | ||
num_policies=2, | ||
total_level=10, | ||
magnitude_level=9, | ||
magnitude_std=0.5, | ||
hparams=dict( | ||
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]], | ||
interpolation='bicubic')), | ||
dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), | ||
dict( | ||
type='RandomErasing', | ||
erase_prob=0.25, | ||
mode='rand', | ||
min_area_ratio=0.02, | ||
max_area_ratio=1 / 3, | ||
fill_color=img_norm_cfg['mean'][::-1], | ||
fill_std=img_norm_cfg['std'][::-1]), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='ToTensor', keys=['gt_label']), | ||
dict(type='Collect', keys=['img', 'gt_label']) | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='Resize', | ||
size=(248, -1), | ||
backend='pillow', | ||
interpolation='bicubic'), | ||
dict(type='CenterCrop', crop_size=224), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='Collect', keys=['img']) | ||
] | ||
|
||
data = dict( | ||
samples_per_gpu=128, | ||
train=dict(pipeline=train_pipeline), | ||
val=dict(pipeline=test_pipeline), | ||
test=dict(pipeline=test_pipeline)) |
Oops, something went wrong.