Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
MendelXu committed Sep 9, 2021
1 parent 7fa0dbd commit ba4b1ab
Show file tree
Hide file tree
Showing 61 changed files with 5,458 additions and 10 deletions.
38 changes: 38 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#vs code
.history/
.vscode
.idea
.history
.DS_Store
#python
__pycache__/
*/__pycache__
*.egg-info
build
#lib
tests
thirdparty
thirdparty/

#develop
wandb
data
data/
*.pkl
*.pkl.json
*.log.json
work_dirs/
figures
cp.py

# Pytorch
*.pth
*.py~
*.sh~
launch.py

#nvidia
*.qdrep
*.sqlite

.pytest*
2 changes: 2 additions & 0 deletions .isort.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[settings]
known_third_party = PIL,cv2,mmcv,mmdet,numpy,prettytable,setuptools,torch
31 changes: 31 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
repos:
- repo: https://github.com/ambv/black
rev: 21.5b1
hooks:
- id: black
- repo: https://github.com/asottile/seed-isort-config
rev: v2.2.0
hooks:
- id: seed-isort-config
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: trailing-whitespace
- id: check-yaml
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: check-merge-conflict
- id: fix-encoding-pragma
args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://github.com/jumanjihouse/pre-commit-hooks
rev: 2.1.5
hooks:
- id: markdownlint
args: ["-r", "~MD002,~MD013,~MD024,~MD029,~MD033,~MD034,~MD036", "-t", "allow_different_nesting"]
- repo: https://github.com/myint/docformatter
rev: v1.4
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
11 changes: 11 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
pre:
python -m pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
mkdir -p thirdparty
git clone https://github.com/open-mmlab/mmdetection.git thirdparty/mmdetection
cd thirdparty/mmdetection && python -m pip install -e .
install:
make pre
python -m pip install -e .
clean:
rm -rf thirdparty
rm -r ssod.egg-info
141 changes: 131 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,145 @@

By [Mengde Xu*](https://scholar.google.com/citations?user=C04zJHEAAAAJ&hl=zh-CN), [Zheng Zhang*](https://github.com/stupidZZ), [Han Hu](https://github.com/ancientmooner), [Jianfeng Wang](https://github.com/amsword), [Lijuan Wang](https://www.microsoft.com/en-us/research/people/lijuanw/), [Fangyun Wei](https://scholar.google.com.tw/citations?user=-ncz2s8AAAAJ&hl=zh-TW), [Xiang Bai](http://cloud.eic.hust.edu.cn:8071/~xbai/), [Zicheng Liu](https://www.microsoft.com/en-us/research/people/zliu/).

![](./resources/pipeline.png)
This repo is the official implementation of ["End-to-End Semi-Supervised Object Detection with Soft Teacher"](https://arxiv.org/abs/2106.09018).

**Code and models will be released soon.**

## Introduction

This paper presents an end-to-end semi-supervised object detection approach, in contrast to previous more complex multi-stage methods. The end-to-end training gradually improves pseudo label qualities during the curriculum, and the more and more accurate pseudo labels in turn benefit object detection training. We also propose two simple yet effective techniques within this framework: a soft teacher mechanism where the classification loss of each unlabeled bounding box is weighed by the classification score produced by the teacher network; a box jittering approach to select reliable pseudo boxes for the learning of box regression. On COCO benchmark, the proposed approach outperforms previous methods by a large margin under various labelling ratios, i.e. 1\%, 5\% and 10\%. Moreover, our approach proves to perform also well when the amount of labeled data is relatively large. For example, it can improve a 40.9 mAP baseline detector trained using the full COCO training set by +3.6 mAP, reaching 44.5 mAP, by leveraging the 123K unlabeled images of COCO. On the state-of-the-art Swin Transformer based object detector (58.9 mAP on test-dev), it can still significantly improve the detection accuracy by +1.5 mAP, reaching 60.4 mAP, and improve the instance segmentation accuracy by +1.2 mAP, reaching 52.4 mAP. Further incorporating with the Object365 pre-trained model, the detection accuracy reaches 61.3 mAP and the instance segmentation accuracy reaches 53.0 mAP, pushing the new state-of-the-art.

In this repository, we provide model implementation (with Pytorch) as well as data preparation, training and evaluation
scripts on MS-COCO.

## Citation

```
```bib
@article{xu2021end,
title={End-to-End Semi-Supervised Object Detection with Soft Teacher},
author={Xu, Mengde and Zhang, Zheng and Hu, Han and Wang, Jianfeng and Wang, Lijuan and Wei, Fangyun and Bai, Xiang and Liu, Zicheng},
journal={arXiv preprint arXiv:2106.09018},
year={2021}
}
```

## Main Results

### Partial Labeled Data

We followed STAC[1] to evalutate on 5 different data splits for each settings, and report the average performance of 5 splits. The results are shown in the following:

#### 1% labeled data
| Method | mAP| Model Weights |Config Files|
| ---- | -------| ----- |----|
| Baseline| 10.0 |-|[Config](configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py)|
| Ours (thr=5e-2) | 21.62 |[Drive](https://drive.google.com/drive/folders/1QA8sAw49DJiMHF-Cr7q0j7KgKjlJyklV?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)|
| Ours (thr=1e-3)|22.64| [Drive](https://drive.google.com/drive/folders/1QA8sAw49DJiMHF-Cr7q0j7KgKjlJyklV?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)|

#### 5% labeled data
| Method | mAP| Model Weights |Config Files|
| ---- | -------| ----- |----|
| Baseline| 20.92 |-|[Config](configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py)|
| Ours (thr=5e-2) | 30.42 |[Drive](https://drive.google.com/drive/folders/1FBWj5SB888m0LU_XYUOK9QEgiubSbU-8?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)|
| Ours (thr=1e-3)|31.7| [Drive](https://drive.google.com/drive/folders/1FBWj5SB888m0LU_XYUOK9QEgiubSbU-8?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)|

#### 10% labeled data
| Method | mAP| Model Weights |Config Files|
| ---- | -------| ----- |----|
| Baseline| 26.94 |-|[Config](configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py)|
| Ours (thr=5e-2) | 33.78 |[Drive](https://drive.google.com/drive/folders/1WyAVpfnWxEgvxCLUesxzNB81fM_de9DI?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)|
| Ours (thr=1e-3)|34.7| [Drive](https://drive.google.com/drive/folders/1WyAVpfnWxEgvxCLUesxzNB81fM_de9DI?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)|

### Full Labeled Data

#### Faster R-CNN (ResNet-50)
| Model | mAP| Model Weights |Config Files|
| ------ |--- | ----- |----|
| Baseline | 40.9 | - | [Config](configs/baseline/faster_rcnn_r50_caffe_fpn_coco_full_720k.py) |
| Ours (thr=5e-2) | 44.05 |[Drive](https://drive.google.com/file/d/1QSwAcU1dpmqVkJiXufW_QaQu-puOeblG/view?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py)|
| Ours (thr=1e-3) | 44.6 |[Drive](https://drive.google.com/file/d/1QSwAcU1dpmqVkJiXufW_QaQu-puOeblG/view?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py)|
| Ours* (thr=5e-2) | 44.5 | - | [Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_1440k.py) |
| Ours* (thr=1e-3) | 44.9 | - | [Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_1440k.py) |

#### Faster R-CNN (ResNet-101)
| Model | mAP| Model Weights |Config Files|
| ------ |--- | ----- |----|
| Baseline | 43.8 | - | [Config](configs/baseline/faster_rcnn_r101_caffe_fpn_coco_full_720k.py) |
| Ours* (thr=5e-2) | 46.8 | - |[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r101_caffe_fpn_coco_full_1080k.py) |
| Ours* (thr=1e-3) | 47.3 | - | [Config](configs/soft_teacher/soft_teacher_faster_rcnn_r101_caffe_fpn_coco_full_1080k.py) |


### Notes
- Ours* means we use longger training schedule.
- `thr` indicates `model.test_cfg.rcnn.score_thr` in config files. This inference trick was first introduced by Instant-Teaching[2].
- All models are trained on 8*V100 GPUs

## Usage

### Requirements
- `Ubuntu 16.04`
- `Anaconda3` with `python=3.6`
- `Pytorch=1.9.0`
- `mmdetection=2.16.0+fe46ffe`
- `mmcv=1.3.9`
- `wandb=0.10.31`

#### Notes
- We use [wandb](https://wandb.ai/) for visualization, if you don't want to use it, just comment line `276-289` in `configs/soft_teacher/base.py`.

### Installation
```
make install
```

### Data Preparation
- Download the COCO dataset
- Execute the following command to generate data set splits:
```shell script
# YOUR_DATA should be a directory contains coco dataset.
# For eg.:
# YOUR_DATA/
# coco/
# train2017/
# val2017/
# unlabeled2017/
# annotations/
ln -s ${YOUR_DATA} data
bash tools/dataset/prepare_coco_data.sh conduct

```

### Training
- To train model on the **partial labeled data** setting:
```shell script
# JOB_TYPE: 'baseline' or 'semi', decide which kind of job to run
# PERCENT_LABELED_DATA: 1, 5, 10. The ratio of labeled coco data in whole training dataset.
# GPU_NUM: number of gpus to run the job
for FOLD in 1 2 3 4 5;
do
bash tools/dist_train_partially.sh <JOB_TYPE> ${FOLD} <PERCENT_LABELED_DATA> <GPU_NUM>
done
```
For example, we could run the following scripts to train our model on 10% labeled data with 8 GPUs:

```shell script
for FOLD in 1 2 3 4 5;
do
bash tools/dist_train_partially.sh semi ${FOLD} 10 8
done
```

- To train model on the **full labeled data** setting:
```shell script
bash tools/dist_train.sh <CONFIG_FILE_PATH> <NUM_GPUS>
```
For example, to train ours `R50` model with 8 GPUs:
```shell script
bash tools/dist_train.sh configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py 8
```



### Inference
```
bash tools/dist_test.sh <CONFIG_FILE_PATH> <CHECKPOINT_PATH> <NUM_GPUS> --eval bbox --cfg-options model.test_cfg.rcnn.score_thr=<THR>
```

[1] [A Simple Semi-Supervised Learning Framework for Object Detection](https://arxiv.org/pdf/2005.04757.pdf)


[2] [Instant-Teaching: An End-to-End Semi-Supervised
Object Detection Framework](https://arxiv.org/pdf/2103.11402.pdf)

123 changes: 123 additions & 0 deletions configs/baseline/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
mmdet_base = "../../thirdparty/mmdetection/configs/_base_"
_base_ = [
f"{mmdet_base}/models/faster_rcnn_r50_fpn.py",
f"{mmdet_base}/datasets/coco_detection.py",
f"{mmdet_base}/schedules/schedule_1x.py",
f"{mmdet_base}/default_runtime.py",
]

model = dict(
backbone=dict(
norm_cfg=dict(requires_grad=False),
norm_eval=True,
style="caffe",
init_cfg=dict(
type="Pretrained", checkpoint="open-mmlab://detectron2/resnet50_caffe"
),
)
)

img_norm_cfg = dict(mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)

train_pipeline = [
dict(type="LoadImageFromFile"),
dict(type="LoadAnnotations", with_bbox=True),
dict(
type="Sequential",
transforms=[
dict(
type="RandResize",
img_scale=[(1333, 400), (1333, 1200)],
multiscale_mode="range",
keep_ratio=True,
),
dict(type="RandFlip", flip_ratio=0.5),
dict(
type="OneOf",
transforms=[
dict(type=k)
for k in [
"Identity",
"AutoContrast",
"RandEqualize",
"RandSolarize",
"RandColor",
"RandContrast",
"RandBrightness",
"RandSharpness",
"RandPosterize",
]
],
),
],
),
dict(type="Pad", size_divisor=32),
dict(type="Normalize", **img_norm_cfg),
dict(type="ExtraAttrs", tag="sup"),
dict(type="DefaultFormatBundle"),
dict(
type="Collect",
keys=["img", "gt_bboxes", "gt_labels"],
meta_keys=(
"filename",
"ori_shape",
"img_shape",
"img_norm_cfg",
"pad_shape",
"scale_factor",
"tag",
),
),
]

test_pipeline = [
dict(type="LoadImageFromFile"),
dict(
type="MultiScaleFlipAug",
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type="Resize", keep_ratio=True),
dict(type="RandomFlip"),
dict(type="Normalize", **img_norm_cfg),
dict(type="Pad", size_divisor=32),
dict(type="ImageToTensor", keys=["img"]),
dict(type="Collect", keys=["img"]),
],
),
]

data = dict(
samples_per_gpu=1,
workers_per_gpu=1,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline),
)

optimizer = dict(type="SGD", lr=0.01, momentum=0.9, weight_decay=0.0001)
lr_config = dict(step=[120000, 160000])
runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000)
checkpoint_config = dict(by_epoch=False, interval=4000, max_keep_ckpts=10)
evaluation = dict(interval=4000)

fp16 = dict(loss_scale="dynamic")

log_config = dict(
interval=50,
hooks=[
dict(type="TextLoggerHook"),
dict(
type="WandbLoggerHook",
init_kwargs=dict(
project="pre_release",
name="${cfg_name}",
config=dict(
work_dirs="${work_dir}",
total_step="${runner.max_iters}",
),
),
by_epoch=False,
),
],
)
20 changes: 20 additions & 0 deletions configs/baseline/faster_rcnn_r101_caffe_fpn_coco_full_720k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
_base_ = "base.py"
model = dict(
backbone=dict(
depth=101,
init_cfg=dict(checkpoint="open-mmlab://detectron2/resnet101_caffe"),
)
)

data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
ann_file="data/coco/annotations/instances_train2017.json",
img_prefix="data/coco/train2017/",
),
)

optimizer = dict(lr=0.02)
lr_config = dict(step=[120000 * 4, 160000 * 4])
runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000 * 4)
Loading

0 comments on commit ba4b1ab

Please sign in to comment.