diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1e909f0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,38 @@ +#vs code +.history/ +.vscode +.idea +.history +.DS_Store +#python +__pycache__/ +*/__pycache__ +*.egg-info +build +#lib +tests +thirdparty +thirdparty/ + +#develop +wandb +data +data/ +*.pkl +*.pkl.json +*.log.json +work_dirs/ +figures +cp.py + +# Pytorch +*.pth +*.py~ +*.sh~ +launch.py + +#nvidia +*.qdrep +*.sqlite + +.pytest* diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..2ba5713 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,2 @@ +[settings] +known_third_party = PIL,cv2,mmcv,mmdet,numpy,prettytable,setuptools,torch diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..66396d9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: + - repo: https://github.com/ambv/black + rev: 21.5b1 + hooks: + - id: black + - repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: trailing-whitespace + - id: check-yaml + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: check-merge-conflict + - id: fix-encoding-pragma + args: ["--remove"] + - id: mixed-line-ending + args: ["--fix=lf"] + - repo: https://github.com/jumanjihouse/pre-commit-hooks + rev: 2.1.5 + hooks: + - id: markdownlint + args: ["-r", "~MD002,~MD013,~MD024,~MD029,~MD033,~MD034,~MD036", "-t", "allow_different_nesting"] + - repo: https://github.com/myint/docformatter + rev: v1.4 + hooks: + - id: docformatter + args: ["--in-place", "--wrap-descriptions", "79"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0b9b656 --- /dev/null +++ b/Makefile @@ -0,0 +1,11 @@ +pre: + python -m pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html + mkdir -p thirdparty + git clone https://github.com/open-mmlab/mmdetection.git thirdparty/mmdetection + cd thirdparty/mmdetection && python -m pip install -e . +install: + make pre + python -m pip install -e . +clean: + rm -rf thirdparty + rm -r ssod.egg-info diff --git a/README.md b/README.md index 5ebb12f..5c4512e 100644 --- a/README.md +++ b/README.md @@ -2,20 +2,12 @@ By [Mengde Xu*](https://scholar.google.com/citations?user=C04zJHEAAAAJ&hl=zh-CN), [Zheng Zhang*](https://github.com/stupidZZ), [Han Hu](https://github.com/ancientmooner), [Jianfeng Wang](https://github.com/amsword), [Lijuan Wang](https://www.microsoft.com/en-us/research/people/lijuanw/), [Fangyun Wei](https://scholar.google.com.tw/citations?user=-ncz2s8AAAAJ&hl=zh-TW), [Xiang Bai](http://cloud.eic.hust.edu.cn:8071/~xbai/), [Zicheng Liu](https://www.microsoft.com/en-us/research/people/zliu/). +![](./resources/pipeline.png) This repo is the official implementation of ["End-to-End Semi-Supervised Object Detection with Soft Teacher"](https://arxiv.org/abs/2106.09018). -**Code and models will be released soon.** - -## Introduction - -This paper presents an end-to-end semi-supervised object detection approach, in contrast to previous more complex multi-stage methods. The end-to-end training gradually improves pseudo label qualities during the curriculum, and the more and more accurate pseudo labels in turn benefit object detection training. We also propose two simple yet effective techniques within this framework: a soft teacher mechanism where the classification loss of each unlabeled bounding box is weighed by the classification score produced by the teacher network; a box jittering approach to select reliable pseudo boxes for the learning of box regression. On COCO benchmark, the proposed approach outperforms previous methods by a large margin under various labelling ratios, i.e. 1\%, 5\% and 10\%. Moreover, our approach proves to perform also well when the amount of labeled data is relatively large. For example, it can improve a 40.9 mAP baseline detector trained using the full COCO training set by +3.6 mAP, reaching 44.5 mAP, by leveraging the 123K unlabeled images of COCO. On the state-of-the-art Swin Transformer based object detector (58.9 mAP on test-dev), it can still significantly improve the detection accuracy by +1.5 mAP, reaching 60.4 mAP, and improve the instance segmentation accuracy by +1.2 mAP, reaching 52.4 mAP. Further incorporating with the Object365 pre-trained model, the detection accuracy reaches 61.3 mAP and the instance segmentation accuracy reaches 53.0 mAP, pushing the new state-of-the-art. - -In this repository, we provide model implementation (with Pytorch) as well as data preparation, training and evaluation -scripts on MS-COCO. - ## Citation -``` +```bib @article{xu2021end, title={End-to-End Semi-Supervised Object Detection with Soft Teacher}, author={Xu, Mengde and Zhang, Zheng and Hu, Han and Wang, Jianfeng and Wang, Lijuan and Wei, Fangyun and Bai, Xiang and Liu, Zicheng}, @@ -23,3 +15,132 @@ scripts on MS-COCO. year={2021} } ``` + +## Main Results + +### Partial Labeled Data + +We followed STAC[1] to evalutate on 5 different data splits for each settings, and report the average performance of 5 splits. The results are shown in the following: + +#### 1% labeled data +| Method | mAP| Model Weights |Config Files| +| ---- | -------| ----- |----| +| Baseline| 10.0 |-|[Config](configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py)| +| Ours (thr=5e-2) | 21.62 |[Drive](https://drive.google.com/drive/folders/1QA8sAw49DJiMHF-Cr7q0j7KgKjlJyklV?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)| +| Ours (thr=1e-3)|22.64| [Drive](https://drive.google.com/drive/folders/1QA8sAw49DJiMHF-Cr7q0j7KgKjlJyklV?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)| + +#### 5% labeled data +| Method | mAP| Model Weights |Config Files| +| ---- | -------| ----- |----| +| Baseline| 20.92 |-|[Config](configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py)| +| Ours (thr=5e-2) | 30.42 |[Drive](https://drive.google.com/drive/folders/1FBWj5SB888m0LU_XYUOK9QEgiubSbU-8?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)| +| Ours (thr=1e-3)|31.7| [Drive](https://drive.google.com/drive/folders/1FBWj5SB888m0LU_XYUOK9QEgiubSbU-8?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)| + +#### 10% labeled data +| Method | mAP| Model Weights |Config Files| +| ---- | -------| ----- |----| +| Baseline| 26.94 |-|[Config](configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py)| +| Ours (thr=5e-2) | 33.78 |[Drive](https://drive.google.com/drive/folders/1WyAVpfnWxEgvxCLUesxzNB81fM_de9DI?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)| +| Ours (thr=1e-3)|34.7| [Drive](https://drive.google.com/drive/folders/1WyAVpfnWxEgvxCLUesxzNB81fM_de9DI?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py)| + +### Full Labeled Data + +#### Faster R-CNN (ResNet-50) +| Model | mAP| Model Weights |Config Files| +| ------ |--- | ----- |----| +| Baseline | 40.9 | - | [Config](configs/baseline/faster_rcnn_r50_caffe_fpn_coco_full_720k.py) | +| Ours (thr=5e-2) | 44.05 |[Drive](https://drive.google.com/file/d/1QSwAcU1dpmqVkJiXufW_QaQu-puOeblG/view?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py)| +| Ours (thr=1e-3) | 44.6 |[Drive](https://drive.google.com/file/d/1QSwAcU1dpmqVkJiXufW_QaQu-puOeblG/view?usp=sharing)|[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py)| +| Ours* (thr=5e-2) | 44.5 | - | [Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_1440k.py) | +| Ours* (thr=1e-3) | 44.9 | - | [Config](configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_1440k.py) | + +#### Faster R-CNN (ResNet-101) +| Model | mAP| Model Weights |Config Files| +| ------ |--- | ----- |----| +| Baseline | 43.8 | - | [Config](configs/baseline/faster_rcnn_r101_caffe_fpn_coco_full_720k.py) | +| Ours* (thr=5e-2) | 46.8 | - |[Config](configs/soft_teacher/soft_teacher_faster_rcnn_r101_caffe_fpn_coco_full_1080k.py) | +| Ours* (thr=1e-3) | 47.3 | - | [Config](configs/soft_teacher/soft_teacher_faster_rcnn_r101_caffe_fpn_coco_full_1080k.py) | + + +### Notes +- Ours* means we use longger training schedule. +- `thr` indicates `model.test_cfg.rcnn.score_thr` in config files. This inference trick was first introduced by Instant-Teaching[2]. +- All models are trained on 8*V100 GPUs + +## Usage + +### Requirements +- `Ubuntu 16.04` +- `Anaconda3` with `python=3.6` +- `Pytorch=1.9.0` +- `mmdetection=2.16.0+fe46ffe` +- `mmcv=1.3.9` +- `wandb=0.10.31` + +#### Notes +- We use [wandb](https://wandb.ai/) for visualization, if you don't want to use it, just comment line `276-289` in `configs/soft_teacher/base.py`. + +### Installation +``` +make install +``` + +### Data Preparation +- Download the COCO dataset +- Execute the following command to generate data set splits: +```shell script +# YOUR_DATA should be a directory contains coco dataset. +# For eg.: +# YOUR_DATA/ +# coco/ +# train2017/ +# val2017/ +# unlabeled2017/ +# annotations/ +ln -s ${YOUR_DATA} data +bash tools/dataset/prepare_coco_data.sh conduct + +``` + +### Training +- To train model on the **partial labeled data** setting: +```shell script +# JOB_TYPE: 'baseline' or 'semi', decide which kind of job to run +# PERCENT_LABELED_DATA: 1, 5, 10. The ratio of labeled coco data in whole training dataset. +# GPU_NUM: number of gpus to run the job +for FOLD in 1 2 3 4 5; +do + bash tools/dist_train_partially.sh ${FOLD} +done +``` +For example, we could run the following scripts to train our model on 10% labeled data with 8 GPUs: + +```shell script +for FOLD in 1 2 3 4 5; +do + bash tools/dist_train_partially.sh semi ${FOLD} 10 8 +done +``` + +- To train model on the **full labeled data** setting: +```shell script +bash tools/dist_train.sh +``` +For example, to train ours `R50` model with 8 GPUs: +```shell script +bash tools/dist_train.sh configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py 8 +``` + + + +### Inference +``` +bash tools/dist_test.sh --eval bbox --cfg-options model.test_cfg.rcnn.score_thr= +``` + +[1] [A Simple Semi-Supervised Learning Framework for Object Detection](https://arxiv.org/pdf/2005.04757.pdf) + + +[2] [Instant-Teaching: An End-to-End Semi-Supervised +Object Detection Framework](https://arxiv.org/pdf/2103.11402.pdf) + diff --git a/configs/baseline/base.py b/configs/baseline/base.py new file mode 100644 index 0000000..fed6e60 --- /dev/null +++ b/configs/baseline/base.py @@ -0,0 +1,123 @@ +mmdet_base = "../../thirdparty/mmdetection/configs/_base_" +_base_ = [ + f"{mmdet_base}/models/faster_rcnn_r50_fpn.py", + f"{mmdet_base}/datasets/coco_detection.py", + f"{mmdet_base}/schedules/schedule_1x.py", + f"{mmdet_base}/default_runtime.py", +] + +model = dict( + backbone=dict( + norm_cfg=dict(requires_grad=False), + norm_eval=True, + style="caffe", + init_cfg=dict( + type="Pretrained", checkpoint="open-mmlab://detectron2/resnet50_caffe" + ), + ) +) + +img_norm_cfg = dict(mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) + +train_pipeline = [ + dict(type="LoadImageFromFile"), + dict(type="LoadAnnotations", with_bbox=True), + dict( + type="Sequential", + transforms=[ + dict( + type="RandResize", + img_scale=[(1333, 400), (1333, 1200)], + multiscale_mode="range", + keep_ratio=True, + ), + dict(type="RandFlip", flip_ratio=0.5), + dict( + type="OneOf", + transforms=[ + dict(type=k) + for k in [ + "Identity", + "AutoContrast", + "RandEqualize", + "RandSolarize", + "RandColor", + "RandContrast", + "RandBrightness", + "RandSharpness", + "RandPosterize", + ] + ], + ), + ], + ), + dict(type="Pad", size_divisor=32), + dict(type="Normalize", **img_norm_cfg), + dict(type="ExtraAttrs", tag="sup"), + dict(type="DefaultFormatBundle"), + dict( + type="Collect", + keys=["img", "gt_bboxes", "gt_labels"], + meta_keys=( + "filename", + "ori_shape", + "img_shape", + "img_norm_cfg", + "pad_shape", + "scale_factor", + "tag", + ), + ), +] + +test_pipeline = [ + dict(type="LoadImageFromFile"), + dict( + type="MultiScaleFlipAug", + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type="Resize", keep_ratio=True), + dict(type="RandomFlip"), + dict(type="Normalize", **img_norm_cfg), + dict(type="Pad", size_divisor=32), + dict(type="ImageToTensor", keys=["img"]), + dict(type="Collect", keys=["img"]), + ], + ), +] + +data = dict( + samples_per_gpu=1, + workers_per_gpu=1, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline), +) + +optimizer = dict(type="SGD", lr=0.01, momentum=0.9, weight_decay=0.0001) +lr_config = dict(step=[120000, 160000]) +runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000) +checkpoint_config = dict(by_epoch=False, interval=4000, max_keep_ckpts=10) +evaluation = dict(interval=4000) + +fp16 = dict(loss_scale="dynamic") + +log_config = dict( + interval=50, + hooks=[ + dict(type="TextLoggerHook"), + dict( + type="WandbLoggerHook", + init_kwargs=dict( + project="pre_release", + name="${cfg_name}", + config=dict( + work_dirs="${work_dir}", + total_step="${runner.max_iters}", + ), + ), + by_epoch=False, + ), + ], +) diff --git a/configs/baseline/faster_rcnn_r101_caffe_fpn_coco_full_720k.py b/configs/baseline/faster_rcnn_r101_caffe_fpn_coco_full_720k.py new file mode 100644 index 0000000..e7be399 --- /dev/null +++ b/configs/baseline/faster_rcnn_r101_caffe_fpn_coco_full_720k.py @@ -0,0 +1,20 @@ +_base_ = "base.py" +model = dict( + backbone=dict( + depth=101, + init_cfg=dict(checkpoint="open-mmlab://detectron2/resnet101_caffe"), + ) +) + +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + ann_file="data/coco/annotations/instances_train2017.json", + img_prefix="data/coco/train2017/", + ), +) + +optimizer = dict(lr=0.02) +lr_config = dict(step=[120000 * 4, 160000 * 4]) +runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000 * 4) diff --git a/configs/baseline/faster_rcnn_r50_caffe_fpn_coco_full_720k.py b/configs/baseline/faster_rcnn_r50_caffe_fpn_coco_full_720k.py new file mode 100644 index 0000000..1f20148 --- /dev/null +++ b/configs/baseline/faster_rcnn_r50_caffe_fpn_coco_full_720k.py @@ -0,0 +1,14 @@ +_base_ = "base.py" + +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + ann_file="data/coco/annotations/instances_train2017.json", + img_prefix="data/coco/train2017/", + ), +) + +optimizer = dict(lr=0.02) +lr_config = dict(step=[120000 * 4, 160000 * 4]) +runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000 * 4) diff --git a/configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py b/configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py new file mode 100644 index 0000000..333719b --- /dev/null +++ b/configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py @@ -0,0 +1,32 @@ +_base_ = "base.py" +fold = 1 +percent = 1 +data = dict( + samples_per_gpu=1, + workers_per_gpu=1, + train=dict( + ann_file="data/coco/annotations/semi_supervised/instances_train2017.${fold}@${percent}.json", + img_prefix="data/coco/train2017/", + ), +) +work_dir = "work_dirs/${cfg_name}/${percent}/${fold}" +log_config = dict( + interval=50, + hooks=[ + dict(type="TextLoggerHook"), + dict( + type="WandbLoggerHook", + init_kwargs=dict( + project="pre_release", + name="${cfg_name}", + config=dict( + fold="${fold}", + percent="${percent}", + work_dirs="${work_dir}", + total_step="${runner.max_iters}", + ), + ), + by_epoch=False, + ), + ], +) diff --git a/configs/soft_teacher/base.py b/configs/soft_teacher/base.py new file mode 100644 index 0000000..b90c452 --- /dev/null +++ b/configs/soft_teacher/base.py @@ -0,0 +1,287 @@ +mmdet_base = "../../thirdparty/mmdetection/configs/_base_" +_base_ = [ + f"{mmdet_base}/models/faster_rcnn_r50_fpn.py", + f"{mmdet_base}/datasets/coco_detection.py", + f"{mmdet_base}/schedules/schedule_1x.py", + f"{mmdet_base}/default_runtime.py", +] + +model = dict( + backbone=dict( + norm_cfg=dict(requires_grad=False), + norm_eval=True, + style="caffe", + init_cfg=dict( + type="Pretrained", checkpoint="open-mmlab://detectron2/resnet50_caffe" + ), + ) +) + +img_norm_cfg = dict(mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) + +train_pipeline = [ + dict(type="LoadImageFromFile"), + dict(type="LoadAnnotations", with_bbox=True), + dict( + type="Sequential", + transforms=[ + dict( + type="RandResize", + img_scale=[(1333, 400), (1333, 1200)], + multiscale_mode="range", + keep_ratio=True, + ), + dict(type="RandFlip", flip_ratio=0.5), + dict( + type="OneOf", + transforms=[ + dict(type=k) + for k in [ + "Identity", + "AutoContrast", + "RandEqualize", + "RandSolarize", + "RandColor", + "RandContrast", + "RandBrightness", + "RandSharpness", + "RandPosterize", + ] + ], + ), + ], + record=True, + ), + dict(type="Pad", size_divisor=32), + dict(type="Normalize", **img_norm_cfg), + dict(type="ExtraAttrs", tag="sup"), + dict(type="DefaultFormatBundle"), + dict( + type="Collect", + keys=["img", "gt_bboxes", "gt_labels"], + meta_keys=( + "filename", + "ori_shape", + "img_shape", + "img_norm_cfg", + "pad_shape", + "scale_factor", + "tag", + ), + ), +] + +strong_pipeline = [ + dict( + type="Sequential", + transforms=[ + dict( + type="RandResize", + img_scale=[(1333, 400), (1333, 1200)], + multiscale_mode="range", + keep_ratio=True, + ), + dict(type="RandFlip", flip_ratio=0.5), + dict( + type="ShuffledSequential", + transforms=[ + dict( + type="OneOf", + transforms=[ + dict(type=k) + for k in [ + "Identity", + "AutoContrast", + "RandEqualize", + "RandSolarize", + "RandColor", + "RandContrast", + "RandBrightness", + "RandSharpness", + "RandPosterize", + ] + ], + ), + dict( + type="OneOf", + transforms=[ + dict(type="RandTranslate", x=(-0.1, 0.1)), + dict(type="RandTranslate", y=(-0.1, 0.1)), + dict(type="RandRotate", angle=(-30, 30)), + [ + dict(type="RandShear", x=(-30, 30)), + dict(type="RandShear", y=(-30, 30)), + ], + ], + ), + ], + ), + dict( + type="RandErase", + n_iterations=(1, 5), + size=[0, 0.2], + squared=True, + ), + ], + record=True, + ), + dict(type="Pad", size_divisor=32), + dict(type="Normalize", **img_norm_cfg), + dict(type="ExtraAttrs", tag="unsup_student"), + dict(type="DefaultFormatBundle"), + dict( + type="Collect", + keys=["img", "gt_bboxes", "gt_labels"], + meta_keys=( + "filename", + "ori_shape", + "img_shape", + "img_norm_cfg", + "pad_shape", + "scale_factor", + "tag", + "transform_matrix", + ), + ), +] +weak_pipeline = [ + dict( + type="Sequential", + transforms=[ + dict( + type="RandResize", + img_scale=[(1333, 400), (1333, 1200)], + multiscale_mode="range", + keep_ratio=True, + ), + dict(type="RandFlip", flip_ratio=0.5), + ], + record=True, + ), + dict(type="Pad", size_divisor=32), + dict(type="Normalize", **img_norm_cfg), + dict(type="ExtraAttrs", tag="unsup_teacher"), + dict(type="DefaultFormatBundle"), + dict( + type="Collect", + keys=["img", "gt_bboxes", "gt_labels"], + meta_keys=( + "filename", + "ori_shape", + "img_shape", + "img_norm_cfg", + "pad_shape", + "scale_factor", + "tag", + "transform_matrix", + ), + ), +] +unsup_pipeline = [ + dict(type="LoadImageFromFile"), + # dict(type="LoadAnnotations", with_bbox=True), + # generate fake labels for data format compability + dict(type="PseudoSamples", with_bbox=True), + dict( + type="MultiBranch", unsup_teacher=strong_pipeline, unsup_student=weak_pipeline + ), +] + +test_pipeline = [ + dict(type="LoadImageFromFile"), + dict( + type="MultiScaleFlipAug", + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type="Resize", keep_ratio=True), + dict(type="RandomFlip"), + dict(type="Normalize", **img_norm_cfg), + dict(type="Pad", size_divisor=32), + dict(type="ImageToTensor", keys=["img"]), + dict(type="Collect", keys=["img"]), + ], + ), +] +data = dict( + samples_per_gpu=None, + workers_per_gpu=None, + train=dict( + _delete_=True, + type="SemiDataset", + sup=dict( + type="CocoDataset", + ann_file=None, + img_prefix=None, + pipeline=train_pipeline, + ), + unsup=dict( + type="CocoDataset", + ann_file=None, + img_prefix=None, + pipeline=unsup_pipeline, + filter_empty_gt=False, + ), + ), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline), + sampler=dict( + train=dict( + type="SemiBalanceSampler", + sample_ratio=[1, 4], + by_prob=True, + # at_least_one=True, + epoch_length=7330, + ) + ), +) + +semi_wrapper = dict( + type="SoftTeacher", + model="${model}", + train_cfg=dict( + use_teacher_proposal=False, + pseudo_label_initial_score_thr=0.5, + rpn_pseudo_threshold=0.9, + cls_pseudo_threshold=0.9, + reg_pseudo_threshold=0.01, + jitter_times=10, + jitter_scale=0.06, + min_pseduo_box_size=0, + unsup_weight=4.0, + ), + test_cfg=dict(inference_on="student"), +) + +custom_hooks = [ + dict(type="NumClassCheckHook"), + dict(type="WeightSummary"), + dict(type="MeanTeacher", momentum=0.999, interval=1, warm_up=0), + dict(type="Weighter", steps=[-5000], vals=[4, 0], name="unsup_weight"), +] +evaluation = dict(type="SubModulesDistEvalHook", interval=4000) +optimizer = dict(type="SGD", lr=0.01, momentum=0.9, weight_decay=0.0001) +lr_config = dict(step=[120000, 160000]) +runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000) +checkpoint_config = dict(by_epoch=False, interval=4000, max_keep_ckpts=20) + +fp16 = dict(loss_scale="dynamic") + +log_config = dict( + interval=50, + hooks=[ + dict(type="TextLoggerHook"), + dict( + type="WandbLoggerHook", + init_kwargs=dict( + project="pre_release", + name="${cfg_name}", + config=dict( + work_dirs="${work_dir}", + total_step="${runner.max_iters}", + ), + ), + by_epoch=False, + ), + ], +) diff --git a/configs/soft_teacher/soft_teacher_faster_rcnn_r101_caffe_fpn_coco_full_1080k.py b/configs/soft_teacher/soft_teacher_faster_rcnn_r101_caffe_fpn_coco_full_1080k.py new file mode 100644 index 0000000..373b5d7 --- /dev/null +++ b/configs/soft_teacher/soft_teacher_faster_rcnn_r101_caffe_fpn_coco_full_1080k.py @@ -0,0 +1,36 @@ +_base_ = "base.py" +model = dict( + backbone=dict( + depth=101, + init_cfg=dict(checkpoint="open-mmlab://detectron2/resnet101_caffe"), + ) +) + +data = dict( + samples_per_gpu=8, + workers_per_gpu=8, + train=dict( + sup=dict( + ann_file="data/coco/annotations/instances_train2017.json", + img_prefix="data/coco/train2017/", + ), + unsup=dict( + ann_file="data/coco/annotations/instances_unlabeled2017.json", + img_prefix="data/coco/unlabeled2017/", + ), + ), + sampler=dict( + train=dict( + sample_ratio=[1, 1], + ) + ), +) + +semi_wrapper = dict( + train_cfg=dict( + unsup_weight=2.0, + ) +) + +lr_config = dict(step=[120000 * 6, 160000 * 6]) +runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000 * 6) diff --git a/configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py b/configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py new file mode 100644 index 0000000..2f7ead2 --- /dev/null +++ b/configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py @@ -0,0 +1,48 @@ +_base_ = "base.py" + +data = dict( + samples_per_gpu=5, + workers_per_gpu=5, + train=dict( + sup=dict( + type="CocoDataset", + ann_file="data/coco/annotations/semi_supervised/instances_train2017.${fold}@${percent}.json", + img_prefix="data/coco/train2017/", + ), + unsup=dict( + type="CocoDataset", + ann_file="data/coco/annotations/semi_supervised/instances_train2017.${fold}@${percent}-unlabeled.json", + img_prefix="data/coco/train2017/", + ), + ), + sampler=dict( + train=dict( + sample_ratio=[1, 4], + ) + ), +) + +fold = 1 +percent = 1 + +work_dir = "work_dirs/${cfg_name}/${percent}/${fold}" +log_config = dict( + interval=50, + hooks=[ + dict(type="TextLoggerHook"), + dict( + type="WandbLoggerHook", + init_kwargs=dict( + project="pre_release", + name="${cfg_name}", + config=dict( + fold="${fold}", + percent="${percent}", + work_dirs="${work_dir}", + total_step="${runner.max_iters}", + ), + ), + by_epoch=False, + ), + ], +) diff --git a/configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_1440k.py b/configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_1440k.py new file mode 100644 index 0000000..056459d --- /dev/null +++ b/configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_1440k.py @@ -0,0 +1,30 @@ +_base_ = "base.py" + +data = dict( + samples_per_gpu=8, + workers_per_gpu=8, + train=dict( + sup=dict( + ann_file="data/coco/annotations/instances_train2017.json", + img_prefix="data/coco/train2017/", + ), + unsup=dict( + ann_file="data/coco/annotations/instances_unlabeled2017.json", + img_prefix="data/coco/unlabeled2017/", + ), + ), + sampler=dict( + train=dict( + sample_ratio=[1, 1], + ) + ), +) + +semi_wrapper = dict( + train_cfg=dict( + unsup_weight=2.0, + ) +) + +lr_config = dict(step=[120000 * 8, 160000 * 8]) +runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000 * 8) diff --git a/configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py b/configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py new file mode 100644 index 0000000..eb959a1 --- /dev/null +++ b/configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py @@ -0,0 +1,36 @@ +_base_="base.py" + +data = dict( + samples_per_gpu=8, + workers_per_gpu=8, + train=dict( + + sup=dict( + + ann_file="data/coco/annotations/instances_train2017.json", + img_prefix="data/coco/train2017/", + + ), + unsup=dict( + + ann_file="data/coco/annotations/instances_unlabeled2017.json", + img_prefix="data/coco/unlabeled2017/", + + ), + ), + sampler=dict( + train=dict( + sample_ratio=[1, 1], + ) + ), +) + +semi_wrapper = dict( + train_cfg=dict( + unsup_weight=2.0, + ) +) + +lr_config = dict(step=[120000 * 4, 160000 * 4]) +runner = dict(_delete_=True, type="IterBasedRunner", max_iters=180000 * 4) + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f3d6626 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +torch +torchvision +mmcv-full +wandb +prettytable diff --git a/resources/pipeline.png b/resources/pipeline.png new file mode 100644 index 0000000..74f1466 Binary files /dev/null and b/resources/pipeline.png differ diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..cd88f38 --- /dev/null +++ b/setup.py @@ -0,0 +1,99 @@ +import re + +from setuptools import find_packages, setup + + +def get_version(): + version_file = "ssod/version.py" + with open(version_file, "r") as f: + exec(compile(f.read(), version_file, "exec")) + return locals()["__version__"] + + +def parse_requirements(fname="requirements.txt", with_version=True): + """Parse the package dependencies listed in a requirements file but strips + specific versioning information. + + Args: + fname (str): path to requirements file + with_version (bool, default=False): if True include version specs + Returns: + List[str]: list of requirements items + CommandLine: + python -c "import setup; print(setup.parse_requirements())" + """ + import sys + from os.path import exists + + require_fpath = fname + + def parse_line(line): + """Parse information from a line in a requirements text file.""" + if line.startswith("-r "): + # Allow specifying requirements in other files + target = line.split(" ")[1] + for info in parse_require_file(target): + yield info + else: + info = {"line": line} + if line.startswith("-e "): + info["package"] = line.split("#egg=")[1] + else: + # Remove versioning from the package + pat = "(" + "|".join([">=", "==", ">"]) + ")" + parts = re.split(pat, line, maxsplit=1) + parts = [p.strip() for p in parts] + + info["package"] = parts[0] + if len(parts) > 1: + op, rest = parts[1:] + if ";" in rest: + # Handle platform specific dependencies + # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies + version, platform_deps = map(str.strip, rest.split(";")) + info["platform_deps"] = platform_deps + else: + version = rest # NOQA + info["version"] = (op, version) + yield info + + def parse_require_file(fpath): + with open(fpath, "r") as f: + for line in f.readlines(): + line = line.strip() + if line and not line.startswith("#"): + for info in parse_line(line): + yield info + + def gen_packages_items(): + if exists(require_fpath): + for info in parse_require_file(require_fpath): + parts = [info["package"]] + if with_version and "version" in info: + parts.extend(info["version"]) + if not sys.version.startswith("3.4"): + # apparently package_deps are broken in 3.4 + platform_deps = info.get("platform_deps") + if platform_deps is not None: + parts.append(";" + platform_deps) + item = "".join(parts) + yield item + + packages = list(gen_packages_items()) + return packages + + +if __name__ == "__main__": + install_requires = parse_requirements() + setup( + name="ssod", + version=get_version(), + description="Semi-Supervised Object Detection Benchmark", + author="someone", + author_email="someone", + packages=find_packages(exclude=("configs", "tools", "demo")), + install_requires=install_requires, + include_package_data=True, + ext_modules=[], + zip_safe=False, + ) diff --git a/ssod/__init__.py b/ssod/__init__.py new file mode 100644 index 0000000..aed4fa3 --- /dev/null +++ b/ssod/__init__.py @@ -0,0 +1 @@ +from .models import * diff --git a/ssod/apis/__init__.py b/ssod/apis/__init__.py new file mode 100644 index 0000000..f87a80d --- /dev/null +++ b/ssod/apis/__init__.py @@ -0,0 +1,3 @@ +from .train import get_root_logger, set_random_seed, train_detector + +__all__ = ["get_root_logger", "set_random_seed", "train_detector"] diff --git a/ssod/apis/train.py b/ssod/apis/train.py new file mode 100644 index 0000000..1896fc2 --- /dev/null +++ b/ssod/apis/train.py @@ -0,0 +1,205 @@ +import random +import warnings + +import numpy as np +import torch +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import ( + HOOKS, + DistSamplerSeedHook, + EpochBasedRunner, + Fp16OptimizerHook, + OptimizerHook, + build_optimizer, + build_runner, +) +from mmcv.runner.hooks import HOOKS +from mmcv.utils import build_from_cfg +from mmdet.core import DistEvalHook, EvalHook +from mmdet.datasets import build_dataset, replace_ImageToTensor + +from ssod.datasets import build_dataloader +from ssod.utils import find_latest_checkpoint, get_root_logger, patch_runner + + +def set_random_seed(seed, deterministic=False): + """Set random seed. + + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def train_detector( + model, dataset, cfg, distributed=False, validate=False, timestamp=None, meta=None +): + logger = get_root_logger(log_level=cfg.log_level) + + # prepare data loaders + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] + if "imgs_per_gpu" in cfg.data: + logger.warning( + '"imgs_per_gpu" is deprecated in MMDet V2.0. ' + 'Please use "samples_per_gpu" instead' + ) + if "samples_per_gpu" in cfg.data: + logger.warning( + f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' + f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' + f"={cfg.data.imgs_per_gpu} is used in this experiments" + ) + else: + logger.warning( + 'Automatically set "samples_per_gpu"="imgs_per_gpu"=' + f"{cfg.data.imgs_per_gpu} in this experiments" + ) + cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu + + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + # cfg.gpus will be ignored if distributed + len(cfg.gpu_ids), + dist=distributed, + seed=cfg.seed, + sampler_cfg=cfg.data.get("sampler", {}).get("train", {}), + ) + for ds in dataset + ] + + # put model on gpus + if distributed: + find_unused_parameters = cfg.get("find_unused_parameters", False) + # Sets the `find_unused_parameters` parameter in + # torch.nn.parallel.DistributedDataParallel + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters, + ) + else: + model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) + + # build runner + optimizer = build_optimizer(model, cfg.optimizer) + + if "runner" not in cfg: + cfg.runner = {"type": "EpochBasedRunner", "max_epochs": cfg.total_epochs} + warnings.warn( + "config is now expected to have a `runner` section, " + "please set `runner` in your config.", + UserWarning, + ) + else: + if "total_epochs" in cfg: + assert cfg.total_epochs == cfg.runner.max_epochs + + runner = build_runner( + cfg.runner, + default_args=dict( + model=model, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta, + ), + ) + + # an ugly workaround to make .log and .log.json filenames the same + runner.timestamp = timestamp + + # fp16 setting + fp16_cfg = cfg.get("fp16", None) + if fp16_cfg is not None: + optimizer_config = Fp16OptimizerHook( + **cfg.optimizer_config, **fp16_cfg, distributed=distributed + ) + elif distributed and "type" not in cfg.optimizer_config: + optimizer_config = OptimizerHook(**cfg.optimizer_config) + else: + optimizer_config = cfg.optimizer_config + + # register hooks + runner.register_training_hooks( + cfg.lr_config, + optimizer_config, + cfg.checkpoint_config, + cfg.log_config, + cfg.get("momentum_config", None), + ) + if distributed: + if isinstance(runner, EpochBasedRunner): + runner.register_hook(DistSamplerSeedHook()) + + # register eval hooks + if validate: + # Support batch_size > 1 in validation + val_samples_per_gpu = cfg.data.val.pop("samples_per_gpu", 1) + if val_samples_per_gpu > 1: + # Replace 'ImageToTensor' to 'DefaultFormatBundle' + cfg.data.val.pipeline = replace_ImageToTensor(cfg.data.val.pipeline) + val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) + val_dataloader = build_dataloader( + val_dataset, + samples_per_gpu=val_samples_per_gpu, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False, + ) + eval_cfg = cfg.get("evaluation", {}) + eval_cfg["by_epoch"] = eval_cfg.get( + "by_epoch", cfg.runner["type"] != "IterBasedRunner" + ) + if "type" not in eval_cfg: + eval_hook = DistEvalHook if distributed else EvalHook + eval_hook = eval_hook(val_dataloader, **eval_cfg) + + else: + eval_hook = build_from_cfg( + eval_cfg, HOOKS, default_args=dict(dataloader=val_dataloader) + ) + + runner.register_hook(eval_hook) + + # user-defined hooks + if cfg.get("custom_hooks", None): + custom_hooks = cfg.custom_hooks + assert isinstance( + custom_hooks, list + ), f"custom_hooks expect list type, but got {type(custom_hooks)}" + for hook_cfg in cfg.custom_hooks: + assert isinstance(hook_cfg, dict), ( + "Each item in custom_hooks expects dict type, but got " + f"{type(hook_cfg)}" + ) + hook_cfg = hook_cfg.copy() + priority = hook_cfg.pop("priority", "NORMAL") + hook = build_from_cfg(hook_cfg, HOOKS) + runner.register_hook(hook, priority=priority) + + runner = patch_runner(runner) + resume_from = None + if cfg.get("auto_resume", True): + resume_from = find_latest_checkpoint(cfg.work_dir) + if resume_from is not None: + cfg.resume_from = resume_from + + if cfg.resume_from: + runner.resume(cfg.resume_from) + elif cfg.load_from: + runner.load_checkpoint(cfg.load_from) + runner.run(data_loaders, cfg.workflow) diff --git a/ssod/core/__init__.py b/ssod/core/__init__.py new file mode 100644 index 0000000..f7a82a1 --- /dev/null +++ b/ssod/core/__init__.py @@ -0,0 +1 @@ +from .masks import TrimapMasks diff --git a/ssod/core/masks/__init__.py b/ssod/core/masks/__init__.py new file mode 100644 index 0000000..c410fc6 --- /dev/null +++ b/ssod/core/masks/__init__.py @@ -0,0 +1 @@ +from .structures import TrimapMasks diff --git a/ssod/core/masks/structures.py b/ssod/core/masks/structures.py new file mode 100644 index 0000000..045d04a --- /dev/null +++ b/ssod/core/masks/structures.py @@ -0,0 +1,60 @@ +""" +Designed for pseudo masks. +In a `TrimapMasks`, it allow some part of the mask is ignored when computing loss. +""" +import numpy as np +import torch +from mmcv.ops.roi_align import roi_align +from mmdet.core import BitmapMasks + + +class TrimapMasks(BitmapMasks): + def __init__(self, masks, height, width, ignore_value=255): + """ + Args: + ignore_value: flag to ignore in loss computation. + See `mmdet.core.BitmapMasks` for more information + """ + super().__init__(masks, height, width) + self.ignore_value = ignore_value + + def crop_and_resize( + self, bboxes, out_shape, inds, device="cpu", interpolation="bilinear" + ): + """See :func:`BaseInstanceMasks.crop_and_resize`.""" + if len(self.masks) == 0: + empty_masks = np.empty((0, *out_shape), dtype=np.uint8) + return BitmapMasks(empty_masks, *out_shape) + + # convert bboxes to tensor + if isinstance(bboxes, np.ndarray): + bboxes = torch.from_numpy(bboxes).to(device=device) + if isinstance(inds, np.ndarray): + inds = torch.from_numpy(inds).to(device=device) + + num_bbox = bboxes.shape[0] + fake_inds = torch.arange(num_bbox, device=device).to(dtype=bboxes.dtype)[ + :, None + ] + rois = torch.cat([fake_inds, bboxes], dim=1) # Nx5 + rois = rois.to(device=device) + if num_bbox > 0: + gt_masks_th = ( + torch.from_numpy(self.masks) + .to(device) + .index_select(0, inds) + .to(dtype=rois.dtype) + ) + targets = roi_align( + gt_masks_th[:, None, :, :], rois, out_shape, 1.0, 0, "avg", True + ).squeeze(1) + # for a mask: + # value<0.5 -> background, + # 0.5<=value<=1 -> foreground + # value>1 -> ignored area + resized_masks = (targets >= 0.5).float() + resized_masks[targets > 1] = self.ignore_value + resized_masks = resized_masks.cpu().numpy() + else: + resized_masks = [] + return BitmapMasks(resized_masks, *out_shape) diff --git a/ssod/datasets/__init__.py b/ssod/datasets/__init__.py new file mode 100644 index 0000000..726976c --- /dev/null +++ b/ssod/datasets/__init__.py @@ -0,0 +1,15 @@ +from mmdet.datasets import build_dataset + +from .builder import build_dataloader +from .dataset_wrappers import SemiDataset +from .pipelines import * +from .pseudo_coco import PseudoCocoDataset +from .samplers import DistributedGroupSemiBalanceSampler + +__all__ = [ + "PseudoCocoDataset", + "build_dataloader", + "build_dataset", + "SemiDataset", + "DistributedGroupSemiBalanceSampler", +] diff --git a/ssod/datasets/builder.py b/ssod/datasets/builder.py new file mode 100644 index 0000000..91a57ab --- /dev/null +++ b/ssod/datasets/builder.py @@ -0,0 +1,172 @@ +from collections.abc import Mapping, Sequence +from functools import partial + +import torch +from mmcv.parallel import DataContainer +from mmcv.runner import get_dist_info +from mmcv.utils import Registry, build_from_cfg +from mmdet.datasets.builder import worker_init_fn +from mmdet.datasets.samplers import ( + DistributedGroupSampler, + DistributedSampler, + GroupSampler, +) +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import default_collate + +SAMPLERS = Registry("sampler") + +SAMPLERS.register_module(module=DistributedGroupSampler) +SAMPLERS.register_module(module=DistributedSampler) +SAMPLERS.register_module(module=GroupSampler) + + +def build_sampler(cfg, dist=False, group=False, default_args=None): + if cfg and ("type" in cfg): + sampler_type = cfg.get("type") + else: + sampler_type = default_args.get("type") + if group: + sampler_type = "Group" + sampler_type + if dist: + sampler_type = "Distributed" + sampler_type + + if cfg: + cfg.update(type=sampler_type) + else: + cfg = dict(type=sampler_type) + + return build_from_cfg(cfg, SAMPLERS, default_args) + + +def build_dataloader( + dataset, + samples_per_gpu, + workers_per_gpu, + num_gpus=1, + dist=True, + shuffle=True, + seed=None, + sampler_cfg=None, + **kwargs, +): + rank, world_size = get_dist_info() + default_sampler_cfg = dict(type="Sampler", dataset=dataset) + if shuffle: + default_sampler_cfg.update(samples_per_gpu=samples_per_gpu) + else: + default_sampler_cfg.update(shuffle=False) + if dist: + default_sampler_cfg.update(num_replicas=world_size, rank=rank, seed=seed) + sampler = build_sampler(sampler_cfg, dist, shuffle, default_sampler_cfg) + + batch_size = samples_per_gpu + num_workers = workers_per_gpu + else: + sampler = build_sampler(sampler_cfg, default_sampler_cfg) if shuffle else None + batch_size = num_gpus * samples_per_gpu + num_workers = num_gpus * workers_per_gpu + + init_fn = ( + partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) + if seed is not None + else None + ) + + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=partial(collate, samples_per_gpu=samples_per_gpu, flatten=True), + pin_memory=False, + worker_init_fn=init_fn, + **kwargs, + ) + return data_loader + + +def collate(batch, samples_per_gpu=1, flatten=False): + """Puts each data field into a tensor/DataContainer with outer dimension + batch size. + + Extend default_collate to add support for + :type:`~mmcv.parallel.DataContainer`. There are 3 cases. + + 1. cpu_only = True, e.g., meta data + 2. cpu_only = False, stack = True, e.g., images tensors + 3. cpu_only = False, stack = False, e.g., gt bboxes + """ + if not isinstance(batch, Sequence): + raise TypeError(f"{batch.dtype} is not supported.") + + if isinstance(batch[0], DataContainer): + stacked = [] + if batch[0].cpu_only: + for i in range(0, len(batch), samples_per_gpu): + stacked.append( + [sample.data for sample in batch[i : i + samples_per_gpu]] + ) + return DataContainer( + stacked, batch[0].stack, batch[0].padding_value, cpu_only=True + ) + elif batch[0].stack: + for i in range(0, len(batch), samples_per_gpu): + assert isinstance(batch[i].data, torch.Tensor) + + if batch[i].pad_dims is not None: + ndim = batch[i].dim() + assert ndim > batch[i].pad_dims + max_shape = [0 for _ in range(batch[i].pad_dims)] + for dim in range(1, batch[i].pad_dims + 1): + max_shape[dim - 1] = batch[i].size(-dim) + for sample in batch[i : i + samples_per_gpu]: + for dim in range(0, ndim - batch[i].pad_dims): + assert batch[i].size(dim) == sample.size(dim) + for dim in range(1, batch[i].pad_dims + 1): + max_shape[dim - 1] = max( + max_shape[dim - 1], sample.size(-dim) + ) + padded_samples = [] + for sample in batch[i : i + samples_per_gpu]: + pad = [0 for _ in range(batch[i].pad_dims * 2)] + for dim in range(1, batch[i].pad_dims + 1): + pad[2 * dim - 1] = max_shape[dim - 1] - sample.size(-dim) + padded_samples.append( + F.pad(sample.data, pad, value=sample.padding_value) + ) + stacked.append(default_collate(padded_samples)) + elif batch[i].pad_dims is None: + stacked.append( + default_collate( + [sample.data for sample in batch[i : i + samples_per_gpu]] + ) + ) + else: + raise ValueError("pad_dims should be either None or integers (1-3)") + + else: + for i in range(0, len(batch), samples_per_gpu): + stacked.append( + [sample.data for sample in batch[i : i + samples_per_gpu]] + ) + return DataContainer(stacked, batch[0].stack, batch[0].padding_value) + elif any([isinstance(b, Sequence) for b in batch]): + if flatten: + flattened = [] + for b in batch: + if isinstance(b, Sequence): + flattened.extend(b) + else: + flattened.extend([b]) + return collate(flattened, len(flattened)) + else: + transposed = zip(*batch) + return [collate(samples, samples_per_gpu) for samples in transposed] + elif isinstance(batch[0], Mapping): + return { + key: collate([d[key] for d in batch], samples_per_gpu) for key in batch[0] + } + else: + return default_collate(batch) diff --git a/ssod/datasets/dataset_wrappers.py b/ssod/datasets/dataset_wrappers.py new file mode 100644 index 0000000..6fd3bcc --- /dev/null +++ b/ssod/datasets/dataset_wrappers.py @@ -0,0 +1,17 @@ +from mmdet.datasets import DATASETS, ConcatDataset, build_dataset + + +@DATASETS.register_module() +class SemiDataset(ConcatDataset): + """Wrapper for semisupervised od.""" + + def __init__(self, sup: dict, unsup: dict, **kwargs): + super().__init__([build_dataset(sup), build_dataset(unsup)], **kwargs) + + @property + def sup(self): + return self.datasets[0] + + @property + def unsup(self): + return self.datasets[1] diff --git a/ssod/datasets/pipelines/__init__.py b/ssod/datasets/pipelines/__init__.py new file mode 100644 index 0000000..674ec18 --- /dev/null +++ b/ssod/datasets/pipelines/__init__.py @@ -0,0 +1,2 @@ +from .formating import * +from .rand_aug import * diff --git a/ssod/datasets/pipelines/formating.py b/ssod/datasets/pipelines/formating.py new file mode 100644 index 0000000..4f1bd3c --- /dev/null +++ b/ssod/datasets/pipelines/formating.py @@ -0,0 +1,78 @@ +import numpy as np +from mmdet.datasets import PIPELINES +from mmdet.datasets.pipelines.formating import Collect + +from ssod.core import TrimapMasks + + +@PIPELINES.register_module() +class ExtraAttrs(object): + def __init__(self, **attrs): + self.attrs = attrs + + def __call__(self, results): + for k, v in self.attrs.items(): + assert k not in results + results[k] = v + return results + + +@PIPELINES.register_module() +class ExtraCollect(Collect): + def __init__(self, *args, extra_meta_keys=[], **kwargs): + super().__init__(*args, **kwargs) + self.meta_keys = self.meta_keys + tuple(extra_meta_keys) + + +@PIPELINES.register_module() +class PseudoSamples(object): + def __init__( + self, with_bbox=False, with_mask=False, with_seg=False, fill_value=255 + ): + """ + Replacing gt labels in original data with fake labels or adding extra fake labels for unlabeled data. + This is to remove the effect of labeled data and keep its elements aligned with other sample. + Args: + with_bbox: + with_mask: + with_seg: + fill_value: + """ + self.with_bbox = with_bbox + self.with_mask = with_mask + self.with_seg = with_seg + self.fill_value = fill_value + + def __call__(self, results): + if self.with_bbox: + results["gt_bboxes"] = np.zeros((0, 4)) + results["gt_labels"] = np.zeros((0,)) + if "bbox_fields" not in results: + results["bbox_fields"] = [] + if "gt_bboxes" not in results["bbox_fields"]: + results["bbox_fields"].append("gt_bboxes") + if self.with_mask: + num_inst = len(results["gt_bboxes"]) + h, w = results["img"].shape[:2] + results["gt_masks"] = TrimapMasks( + [ + self.fill_value * np.ones((h, w), dtype=np.uint8) + for _ in range(num_inst) + ], + h, + w, + ) + + if "mask_fields" not in results: + results["mask_fields"] = [] + if "gt_masks" not in results["mask_fields"]: + results["mask_fields"].append("gt_masks") + if self.with_seg: + results["gt_semantic_seg"] = self.fill_value * np.ones( + results["img"].shape[:2], dtype=np.uint8 + ) + if "seg_fields" not in results: + results["seg_fields"] = [] + if "gt_semantic_seg" not in results["seg_fields"]: + results["seg_fields"].append("gt_semantic_seg") + return results diff --git a/ssod/datasets/pipelines/geo_utils.py b/ssod/datasets/pipelines/geo_utils.py new file mode 100644 index 0000000..c495027 --- /dev/null +++ b/ssod/datasets/pipelines/geo_utils.py @@ -0,0 +1,94 @@ +""" +Recored the geometric transformation information used in the augmentation in a transformation matrix. +""" +import numpy as np + + +class GeometricTransformationBase(object): + @classmethod + def inverse(cls, results): + # compute the inverse + return results["transform_matrix"].I # 3x3 + + @classmethod + def apply(self, results, operator, **kwargs): + trans_matrix = getattr(self, f"_get_{operator}_matrix")(**kwargs) + if "transform_matrix" not in results: + results["transform_matrix"] = trans_matrix + else: + base_transformation = results["transform_matrix"] + results["transform_matrix"] = np.dot(trans_matrix, base_transformation) + + @classmethod + def apply_cv2_matrix(self, results, cv2_matrix): + if cv2_matrix.shape[0] == 2: + mat = np.concatenate( + [cv2_matrix, np.array([0, 0, 1]).reshape((1, 3))], axis=0 + ) + else: + mat = cv2_matrix + base_transformation = results["transform_matrix"] + results["transform_matrix"] = np.dot(mat, base_transformation) + return results + + @classmethod + def _get_rotate_matrix(cls, degree=None, cv2_rotation_matrix=None, inverse=False): + # TODO: this is rotated by zero point + if degree is None and cv2_rotation_matrix is None: + raise ValueError( + "At least one of degree or rotation matrix should be provided" + ) + if degree: + if inverse: + degree = -degree + rad = degree * np.pi / 180 + sin_a = np.sin(rad) + cos_a = np.cos(rad) + return np.array([[cos_a, sin_a, 0], [-sin_a, cos_a, 0], [0, 0, 1]]) # 2x3 + else: + mat = np.concatenate( + [cv2_rotation_matrix, np.array([0, 0, 1]).reshape((1, 3))], axis=0 + ) + if inverse: + mat = mat * np.array([[1, -1, -1], [-1, 1, -1], [1, 1, 1]]) + return mat + + @classmethod + def _get_shift_matrix(cls, dx=0, dy=0, inverse=False): + if inverse: + dx = -dx + dy = -dy + return np.array([[1, 0, dx], [0, 1, dy], [0, 0, 1]]) + + @classmethod + def _get_shear_matrix( + cls, degree=None, magnitude=None, direction="horizontal", inverse=False + ): + if magnitude is None: + assert degree is not None + rad = degree * np.pi / 180 + magnitude = np.tan(rad) + + if inverse: + magnitude = -magnitude + if direction == "horizontal": + shear_matrix = np.float32([[1, magnitude, 0], [0, 1, 0], [0, 0, 1]]) + else: + shear_matrix = np.float32([[1, 0, 0], [magnitude, 1, 0], [0, 0, 1]]) + return shear_matrix + + @classmethod + def _get_flip_matrix(cls, shape, direction="horizontal", inverse=False): + h, w = shape + if direction == "horizontal": + flip_matrix = np.float32([[-1, 0, w], [0, 1, 0], [0, 0, 1]]) + else: + flip_matrix = np.float32([[1, 0, 0], [0, h - 1, 0], [0, 0, 1]]) + return flip_matrix + + @classmethod + def _get_scale_matrix(cls, sx, sy, inverse=False): + if inverse: + sx = 1 / sx + sy = 1 / sy + return np.float32([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) diff --git a/ssod/datasets/pipelines/rand_aug.py b/ssod/datasets/pipelines/rand_aug.py new file mode 100644 index 0000000..b0bdec3 --- /dev/null +++ b/ssod/datasets/pipelines/rand_aug.py @@ -0,0 +1,965 @@ +""" +Modified from https://github.com/google-research/ssl_detection/blob/master/detection/utils/augmentation.py. +""" +import copy + +import cv2 +import mmcv +import numpy as np +from PIL import Image, ImageEnhance, ImageOps +from mmcv.image.colorspace import bgr2rgb, rgb2bgr +from mmdet.core.mask import BitmapMasks, PolygonMasks +from mmdet.datasets import PIPELINES +from mmdet.datasets.pipelines import Compose as BaseCompose +from mmdet.datasets.pipelines import transforms + +from .geo_utils import GeometricTransformationBase as GTrans + +PARAMETER_MAX = 10 + + +def int_parameter(level, maxval, max_level=None): + if max_level is None: + max_level = PARAMETER_MAX + return int(level * maxval / max_level) + + +def float_parameter(level, maxval, max_level=None): + if max_level is None: + max_level = PARAMETER_MAX + return float(level) * maxval / max_level + + +class RandAug(object): + """refer to https://github.com/google-research/ssl_detection/blob/00d52272f + 61b56eade8d5ace18213cba6c74f6d8/detection/utils/augmentation.py#L240.""" + + def __init__( + self, + prob: float = 1.0, + magnitude: int = 10, + random_magnitude: bool = True, + record: bool = False, + magnitude_limit: int = 10, + ): + assert 0 <= prob <= 1, f"probability should be in (0,1) but get {prob}" + assert ( + magnitude <= PARAMETER_MAX + ), f"magnitude should be small than max value {PARAMETER_MAX} but get {magnitude}" + + self.prob = prob + self.magnitude = magnitude + self.magnitude_limit = magnitude_limit + self.random_magnitude = random_magnitude + self.record = record + self.buffer = None + + def __call__(self, results): + if np.random.random() < self.prob: + magnitude = self.magnitude + if self.random_magnitude: + magnitude = np.random.randint(1, magnitude) + if self.record: + if "aug_info" not in results: + results["aug_info"] = [] + results["aug_info"].append(self.get_aug_info(magnitude=magnitude)) + results = self.apply(results, magnitude) + # clear buffer + return results + + def apply(self, results, magnitude: int = None): + raise NotImplementedError() + + def __repr__(self): + return f"{self.__class__.__name__}(prob={self.prob},magnitude={self.magnitude},max_magnitude={self.magnitude_limit},random_magnitude={self.random_magnitude})" + + def get_aug_info(self, **kwargs): + aug_info = dict(type=self.__class__.__name__) + aug_info.update( + dict( + prob=1.0, + random_magnitude=False, + record=False, + magnitude=self.magnitude, + ) + ) + aug_info.update(kwargs) + return aug_info + + def enable_record(self, mode: bool = True): + self.record = mode + + +@PIPELINES.register_module() +class Identity(RandAug): + def apply(self, results, magnitude: int = None): + return results + + +@PIPELINES.register_module() +class AutoContrast(RandAug): + def apply(self, results, magnitude=None): + for key in results.get("img_fields", ["img"]): + img = bgr2rgb(results[key]) + results[key] = rgb2bgr( + np.asarray(ImageOps.autocontrast(Image.fromarray(img)), dtype=img.dtype) + ) + return results + + +@PIPELINES.register_module() +class RandEqualize(RandAug): + def apply(self, results, magnitude=None): + for key in results.get("img_fields", ["img"]): + img = bgr2rgb(results[key]) + results[key] = rgb2bgr( + np.asarray(ImageOps.equalize(Image.fromarray(img)), dtype=img.dtype) + ) + return results + + +@PIPELINES.register_module() +class RandSolarize(RandAug): + def apply(self, results, magnitude=None): + for key in results.get("img_fields", ["img"]): + img = results[key] + results[key] = mmcv.solarize( + img, min(int_parameter(magnitude, 256, self.magnitude_limit), 255) + ) + return results + + +def _enhancer_impl(enhancer): + """Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of + PIL.""" + + def impl(pil_img, level, max_level=None): + v = float_parameter(level, 1.8, max_level) + 0.1 # going to 0 just destroys it + return enhancer(pil_img).enhance(v) + + return impl + + +class RandEnhance(RandAug): + op = None + + def apply(self, results, magnitude=None): + for key in results.get("img_fields", ["img"]): + img = bgr2rgb(results[key]) + + results[key] = rgb2bgr( + np.asarray( + _enhancer_impl(self.op)( + Image.fromarray(img), magnitude, self.magnitude_limit + ), + dtype=img.dtype, + ) + ) + return results + + +@PIPELINES.register_module() +class RandColor(RandEnhance): + op = ImageEnhance.Color + + +@PIPELINES.register_module() +class RandContrast(RandEnhance): + op = ImageEnhance.Contrast + + +@PIPELINES.register_module() +class RandBrightness(RandEnhance): + op = ImageEnhance.Brightness + + +@PIPELINES.register_module() +class RandSharpness(RandEnhance): + op = ImageEnhance.Sharpness + + +@PIPELINES.register_module() +class RandPosterize(RandAug): + def apply(self, results, magnitude=None): + for key in results.get("img_fields", ["img"]): + img = bgr2rgb(results[key]) + magnitude = int_parameter(magnitude, 4, self.magnitude_limit) + results[key] = rgb2bgr( + np.asarray( + ImageOps.posterize(Image.fromarray(img), 4 - magnitude), + dtype=img.dtype, + ) + ) + return results + + +@PIPELINES.register_module() +class Sequential(BaseCompose): + def __init__(self, transforms, record: bool = False): + super().__init__(transforms) + self.record = record + self.enable_record(record) + + def enable_record(self, mode: bool = True): + # enable children to record + self.record = mode + for transform in self.transforms: + transform.enable_record(mode) + + +@PIPELINES.register_module() +class OneOf(Sequential): + def __init__(self, transforms, record: bool = False): + self.transforms = [] + for trans in transforms: + if isinstance(trans, list): + self.transforms.append(Sequential(trans)) + else: + assert isinstance(trans, dict) + self.transforms.append(Sequential([trans])) + self.enable_record(record) + + def __call__(self, results): + transform = np.random.choice(self.transforms) + return transform(results) + + +@PIPELINES.register_module() +class ShuffledSequential(Sequential): + def __call__(self, data): + order = np.random.permutation(len(self.transforms)) + for idx in order: + t = self.transforms[idx] + data = t(data) + if data is None: + return None + return data + + +""" +Geometric Augmentation. Modified from thirdparty/mmdetection/mmdet/datasets/pipelines/auto_augment.py +""" + + +def bbox2fields(): + """The key correspondence from bboxes to labels, masks and + segmentations.""" + bbox2label = {"gt_bboxes": "gt_labels", "gt_bboxes_ignore": "gt_labels_ignore"} + bbox2mask = {"gt_bboxes": "gt_masks", "gt_bboxes_ignore": "gt_masks_ignore"} + bbox2seg = { + "gt_bboxes": "gt_semantic_seg", + } + return bbox2label, bbox2mask, bbox2seg + + +class GeometricAugmentation(object): + def __init__( + self, + img_fill_val=125, + seg_ignore_label=255, + min_size=0, + prob: float = 1.0, + random_magnitude: bool = True, + record: bool = False, + ): + if isinstance(img_fill_val, (float, int)): + img_fill_val = tuple([float(img_fill_val)] * 3) + elif isinstance(img_fill_val, tuple): + assert len(img_fill_val) == 3, "img_fill_val as tuple must have 3 elements." + img_fill_val = tuple([float(val) for val in img_fill_val]) + assert np.all( + [0 <= val <= 255 for val in img_fill_val] + ), "all elements of img_fill_val should between range [0,255]." + self.img_fill_val = img_fill_val + self.seg_ignore_label = seg_ignore_label + self.min_size = min_size + self.prob = prob + self.random_magnitude = random_magnitude + self.record = record + + def __call__(self, results): + if np.random.random() < self.prob: + magnitude: dict = self.get_magnitude(results) + if self.record: + if "aug_info" not in results: + results["aug_info"] = [] + results["aug_info"].append(self.get_aug_info(**magnitude)) + results = self.apply(results, **magnitude) + self._filter_invalid(results, min_size=self.min_size) + return results + + def get_magnitude(self, results) -> dict: + raise NotImplementedError() + + def apply(self, results, **kwargs): + raise NotImplementedError() + + def enable_record(self, mode: bool = True): + self.record = mode + + def get_aug_info(self, **kwargs): + aug_info = dict(type=self.__class__.__name__) + aug_info.update( + dict( + # make op deterministic + prob=1.0, + random_magnitude=False, + record=False, + img_fill_val=self.img_fill_val, + seg_ignore_label=self.seg_ignore_label, + min_size=self.min_size, + ) + ) + aug_info.update(kwargs) + return aug_info + + def _filter_invalid(self, results, min_size=0): + """Filter bboxes and masks too small or translated out of image.""" + if min_size is None: + return results + bbox2label, bbox2mask, _ = bbox2fields() + for key in results.get("bbox_fields", []): + bbox_w = results[key][:, 2] - results[key][:, 0] + bbox_h = results[key][:, 3] - results[key][:, 1] + valid_inds = (bbox_w > min_size) & (bbox_h > min_size) + valid_inds = np.nonzero(valid_inds)[0] + results[key] = results[key][valid_inds] + # label fields. e.g. gt_labels and gt_labels_ignore + label_key = bbox2label.get(key) + if label_key in results: + results[label_key] = results[label_key][valid_inds] + # mask fields, e.g. gt_masks and gt_masks_ignore + mask_key = bbox2mask.get(key) + if mask_key in results: + results[mask_key] = results[mask_key][valid_inds] + return results + + def __repr__(self): + return f"""{self.__class__.__name__}( + img_fill_val={self.img_fill_val}, + seg_ignore_label={self.seg_ignore_label}, + min_size={self.magnitude}, + prob: float = {self.prob}, + random_magnitude: bool = {self.random_magnitude}, + )""" + + +@PIPELINES.register_module() +class RandTranslate(GeometricAugmentation): + def __init__(self, x=None, y=None, **kwargs): + super().__init__(**kwargs) + self.x = x + self.y = y + if self.x is None and self.y is None: + self.prob = 0.0 + + def get_magnitude(self, results): + magnitude = {} + if self.random_magnitude: + if isinstance(self.x, (list, tuple)): + assert len(self.x) == 2 + x = np.random.random() * (self.x[1] - self.x[0]) + self.x[0] + magnitude["x"] = x + if isinstance(self.y, (list, tuple)): + assert len(self.y) == 2 + y = np.random.random() * (self.y[1] - self.y[0]) + self.y[0] + magnitude["y"] = y + else: + if self.x is not None: + assert isinstance(self.x, (int, float)) + magnitude["x"] = self.x + if self.y is not None: + assert isinstance(self.y, (int, float)) + magnitude["y"] = self.y + return magnitude + + def apply(self, results, x=None, y=None): + # ratio to pixel + h, w, c = results["img_shape"] + if x is not None: + x = w * x + if y is not None: + y = h * y + if x is not None: + # translate horizontally + self._translate(results, x) + if y is not None: + # translate veritically + self._translate(results, y, direction="vertical") + return results + + def _translate(self, results, offset, direction="horizontal"): + if self.record: + GTrans.apply( + results, + "shift", + dx=offset if direction == "horizontal" else 0, + dy=offset if direction == "vertical" else 0, + ) + self._translate_img(results, offset, direction=direction) + self._translate_bboxes(results, offset, direction=direction) + # fill_val defaultly 0 for BitmapMasks and None for PolygonMasks. + self._translate_masks(results, offset, direction=direction) + self._translate_seg( + results, offset, fill_val=self.seg_ignore_label, direction=direction + ) + + def _translate_img(self, results, offset, direction="horizontal"): + for key in results.get("img_fields", ["img"]): + img = results[key].copy() + results[key] = mmcv.imtranslate( + img, offset, direction, self.img_fill_val + ).astype(img.dtype) + + def _translate_bboxes(self, results, offset, direction="horizontal"): + """Shift bboxes horizontally or vertically, according to offset.""" + h, w, c = results["img_shape"] + for key in results.get("bbox_fields", []): + min_x, min_y, max_x, max_y = np.split( + results[key], results[key].shape[-1], axis=-1 + ) + if direction == "horizontal": + min_x = np.maximum(0, min_x + offset) + max_x = np.minimum(w, max_x + offset) + elif direction == "vertical": + min_y = np.maximum(0, min_y + offset) + max_y = np.minimum(h, max_y + offset) + + # the boxs translated outside of image will be filtered along with + # the corresponding masks, by invoking ``_filter_invalid``. + results[key] = np.concatenate([min_x, min_y, max_x, max_y], axis=-1) + + def _translate_masks(self, results, offset, direction="horizontal", fill_val=0): + """Translate masks horizontally or vertically.""" + h, w, c = results["img_shape"] + for key in results.get("mask_fields", []): + masks = results[key] + results[key] = masks.translate((h, w), offset, direction, fill_val) + + def _translate_seg(self, results, offset, direction="horizontal", fill_val=255): + """Translate segmentation maps horizontally or vertically.""" + for key in results.get("seg_fields", []): + seg = results[key].copy() + results[key] = mmcv.imtranslate(seg, offset, direction, fill_val).astype( + seg.dtype + ) + + def __repr__(self): + repr_str = super().__repr__() + return ("\n").join( + repr_str.split("\n")[:-1] + + [f"x={self.x}", f"y={self.y}"] + + repr_str.split("\n")[-1:] + ) + + +@PIPELINES.register_module() +class RandRotate(GeometricAugmentation): + def __init__(self, angle=None, center=None, scale=1, **kwargs): + super().__init__(**kwargs) + self.angle = angle + self.center = center + self.scale = scale + if self.angle is None: + self.prob = 0.0 + + def get_magnitude(self, results): + magnitude = {} + if self.random_magnitude: + if isinstance(self.angle, (list, tuple)): + assert len(self.angle) == 2 + angle = ( + np.random.random() * (self.angle[1] - self.angle[0]) + self.angle[0] + ) + magnitude["angle"] = angle + else: + if self.angle is not None: + assert isinstance(self.angle, (int, float)) + magnitude["angle"] = self.angle + + return magnitude + + def apply(self, results, angle: float = None): + h, w = results["img"].shape[:2] + center = self.center + if center is None: + center = ((w - 1) * 0.5, (h - 1) * 0.5) + self._rotate_img(results, angle, center, self.scale) + rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale) + if self.record: + GTrans.apply(results, "rotate", cv2_rotation_matrix=rotate_matrix) + self._rotate_bboxes(results, rotate_matrix) + self._rotate_masks(results, angle, center, self.scale, fill_val=0) + self._rotate_seg( + results, angle, center, self.scale, fill_val=self.seg_ignore_label + ) + return results + + def _rotate_img(self, results, angle, center=None, scale=1.0): + """Rotate the image. + + Args: + results (dict): Result dict from loading pipeline. + angle (float): Rotation angle in degrees, positive values + mean clockwise rotation. Same in ``mmcv.imrotate``. + center (tuple[float], optional): Center point (w, h) of the + rotation. Same in ``mmcv.imrotate``. + scale (int | float): Isotropic scale factor. Same in + ``mmcv.imrotate``. + """ + for key in results.get("img_fields", ["img"]): + img = results[key].copy() + img_rotated = mmcv.imrotate( + img, angle, center, scale, border_value=self.img_fill_val + ) + results[key] = img_rotated.astype(img.dtype) + + def _rotate_bboxes(self, results, rotate_matrix): + """Rotate the bboxes.""" + h, w, c = results["img_shape"] + for key in results.get("bbox_fields", []): + min_x, min_y, max_x, max_y = np.split( + results[key], results[key].shape[-1], axis=-1 + ) + coordinates = np.stack( + [[min_x, min_y], [max_x, min_y], [min_x, max_y], [max_x, max_y]] + ) # [4, 2, nb_bbox, 1] + # pad 1 to convert from format [x, y] to homogeneous + # coordinates format [x, y, 1] + coordinates = np.concatenate( + ( + coordinates, + np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype), + ), + axis=1, + ) # [4, 3, nb_bbox, 1] + coordinates = coordinates.transpose((2, 0, 1, 3)) # [nb_bbox, 4, 3, 1] + rotated_coords = np.matmul(rotate_matrix, coordinates) # [nb_bbox, 4, 2, 1] + rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2] + min_x, min_y = ( + np.min(rotated_coords[:, :, 0], axis=1), + np.min(rotated_coords[:, :, 1], axis=1), + ) + max_x, max_y = ( + np.max(rotated_coords[:, :, 0], axis=1), + np.max(rotated_coords[:, :, 1], axis=1), + ) + min_x, min_y = ( + np.clip(min_x, a_min=0, a_max=w), + np.clip(min_y, a_min=0, a_max=h), + ) + max_x, max_y = ( + np.clip(max_x, a_min=min_x, a_max=w), + np.clip(max_y, a_min=min_y, a_max=h), + ) + results[key] = np.stack([min_x, min_y, max_x, max_y], axis=-1).astype( + results[key].dtype + ) + + def _rotate_masks(self, results, angle, center=None, scale=1.0, fill_val=0): + """Rotate the masks.""" + h, w, c = results["img_shape"] + for key in results.get("mask_fields", []): + masks = results[key] + results[key] = masks.rotate((h, w), angle, center, scale, fill_val) + + def _rotate_seg(self, results, angle, center=None, scale=1.0, fill_val=255): + """Rotate the segmentation map.""" + for key in results.get("seg_fields", []): + seg = results[key].copy() + results[key] = mmcv.imrotate( + seg, angle, center, scale, border_value=fill_val + ).astype(seg.dtype) + + def __repr__(self): + repr_str = super().__repr__() + return ("\n").join( + repr_str.split("\n")[:-1] + + [f"angle={self.angle}", f"center={self.center}", f"scale={self.scale}"] + + repr_str.split("\n")[-1:] + ) + + +@PIPELINES.register_module() +class RandShear(GeometricAugmentation): + def __init__(self, x=None, y=None, interpolation="bilinear", **kwargs): + super().__init__(**kwargs) + self.x = x + self.y = y + self.interpolation = interpolation + if self.x is None and self.y is None: + self.prob = 0.0 + + def get_magnitude(self, results): + magnitude = {} + if self.random_magnitude: + if isinstance(self.x, (list, tuple)): + assert len(self.x) == 2 + x = np.random.random() * (self.x[1] - self.x[0]) + self.x[0] + magnitude["x"] = x + if isinstance(self.y, (list, tuple)): + assert len(self.y) == 2 + y = np.random.random() * (self.y[1] - self.y[0]) + self.y[0] + magnitude["y"] = y + else: + if self.x is not None: + assert isinstance(self.x, (int, float)) + magnitude["x"] = self.x + if self.y is not None: + assert isinstance(self.y, (int, float)) + magnitude["y"] = self.y + return magnitude + + def apply(self, results, x=None, y=None): + if x is not None: + # translate horizontally + self._shear(results, np.tanh(-x * np.pi / 180)) + if y is not None: + # translate veritically + self._shear(results, np.tanh(y * np.pi / 180), direction="vertical") + return results + + def _shear(self, results, magnitude, direction="horizontal"): + if self.record: + GTrans.apply(results, "shear", magnitude=magnitude, direction=direction) + self._shear_img(results, magnitude, direction, interpolation=self.interpolation) + self._shear_bboxes(results, magnitude, direction=direction) + # fill_val defaultly 0 for BitmapMasks and None for PolygonMasks. + self._shear_masks( + results, magnitude, direction=direction, interpolation=self.interpolation + ) + self._shear_seg( + results, + magnitude, + direction=direction, + interpolation=self.interpolation, + fill_val=self.seg_ignore_label, + ) + + def _shear_img( + self, results, magnitude, direction="horizontal", interpolation="bilinear" + ): + """Shear the image. + + Args: + results (dict): Result dict from loading pipeline. + magnitude (int | float): The magnitude used for shear. + direction (str): The direction for shear, either "horizontal" + or "vertical". + interpolation (str): Same as in :func:`mmcv.imshear`. + """ + for key in results.get("img_fields", ["img"]): + img = results[key] + img_sheared = mmcv.imshear( + img, + magnitude, + direction, + border_value=self.img_fill_val, + interpolation=interpolation, + ) + results[key] = img_sheared.astype(img.dtype) + + def _shear_bboxes(self, results, magnitude, direction="horizontal"): + """Shear the bboxes.""" + h, w, c = results["img_shape"] + if direction == "horizontal": + shear_matrix = np.stack([[1, magnitude], [0, 1]]).astype( + np.float32 + ) # [2, 2] + else: + shear_matrix = np.stack([[1, 0], [magnitude, 1]]).astype(np.float32) + for key in results.get("bbox_fields", []): + min_x, min_y, max_x, max_y = np.split( + results[key], results[key].shape[-1], axis=-1 + ) + coordinates = np.stack( + [[min_x, min_y], [max_x, min_y], [min_x, max_y], [max_x, max_y]] + ) # [4, 2, nb_box, 1] + coordinates = ( + coordinates[..., 0].transpose((2, 1, 0)).astype(np.float32) + ) # [nb_box, 2, 4] + new_coords = np.matmul( + shear_matrix[None, :, :], coordinates + ) # [nb_box, 2, 4] + min_x = np.min(new_coords[:, 0, :], axis=-1) + min_y = np.min(new_coords[:, 1, :], axis=-1) + max_x = np.max(new_coords[:, 0, :], axis=-1) + max_y = np.max(new_coords[:, 1, :], axis=-1) + min_x = np.clip(min_x, a_min=0, a_max=w) + min_y = np.clip(min_y, a_min=0, a_max=h) + max_x = np.clip(max_x, a_min=min_x, a_max=w) + max_y = np.clip(max_y, a_min=min_y, a_max=h) + results[key] = np.stack([min_x, min_y, max_x, max_y], axis=-1).astype( + results[key].dtype + ) + + def _shear_masks( + self, + results, + magnitude, + direction="horizontal", + fill_val=0, + interpolation="bilinear", + ): + """Shear the masks.""" + h, w, c = results["img_shape"] + for key in results.get("mask_fields", []): + masks = results[key] + results[key] = masks.shear( + (h, w), + magnitude, + direction, + border_value=fill_val, + interpolation=interpolation, + ) + + def _shear_seg( + self, + results, + magnitude, + direction="horizontal", + fill_val=255, + interpolation="bilinear", + ): + """Shear the segmentation maps.""" + for key in results.get("seg_fields", []): + seg = results[key] + results[key] = mmcv.imshear( + seg, + magnitude, + direction, + border_value=fill_val, + interpolation=interpolation, + ).astype(seg.dtype) + + def __repr__(self): + repr_str = super().__repr__() + return ("\n").join( + repr_str.split("\n")[:-1] + + [f"x_magnitude={self.x}", f"y_magnitude={self.y}"] + + repr_str.split("\n")[-1:] + ) + + +@PIPELINES.register_module() +class RandErase(GeometricAugmentation): + def __init__( + self, + n_iterations=None, + size=None, + squared: bool = True, + patches=None, + **kwargs, + ): + kwargs.update(min_size=None) + super().__init__(**kwargs) + self.n_iterations = n_iterations + self.size = size + self.squared = squared + self.patches = patches + + def get_magnitude(self, results): + magnitude = {} + if self.random_magnitude: + n_iterations = self._get_erase_cycle() + patches = [] + h, w, c = results["img_shape"] + for i in range(n_iterations): + # random sample patch size in the image + ph, pw = self._get_patch_size(h, w) + # random sample patch left top in the image + px, py = np.random.randint(0, w - pw), np.random.randint(0, h - ph) + patches.append([px, py, px + pw, py + ph]) + magnitude["patches"] = patches + else: + assert self.patches is not None + magnitude["patches"] = self.patches + + return magnitude + + def _get_erase_cycle(self): + if isinstance(self.n_iterations, int): + n_iterations = self.n_iterations + else: + assert ( + isinstance(self.n_iterations, (tuple, list)) + and len(self.n_iterations) == 2 + ) + n_iterations = np.random.randint(*self.n_iterations) + return n_iterations + + def _get_patch_size(self, h, w): + if isinstance(self.size, float): + assert 0 < self.size < 1 + return int(self.size * h), int(self.size * w) + else: + assert isinstance(self.size, (tuple, list)) + assert len(self.size) == 2 + assert 0 <= self.size[0] < 1 and 0 <= self.size[1] < 1 + w_ratio = np.random.random() * (self.size[1] - self.size[0]) + self.size[0] + h_ratio = w_ratio + + if not self.squared: + h_ratio = ( + np.random.random() * (self.size[1] - self.size[0]) + self.size[0] + ) + return int(h_ratio * h), int(w_ratio * w) + + def apply(self, results, patches: list): + for patch in patches: + self._erase_image(results, patch, fill_val=self.img_fill_val) + self._erase_mask(results, patch) + self._erase_seg(results, patch, fill_val=self.seg_ignore_label) + return results + + def _erase_image(self, results, patch, fill_val=128): + for key in results.get("img_fields", ["img"]): + tmp = results[key].copy() + x1, y1, x2, y2 = patch + tmp[y1:y2, x1:x2, :] = fill_val + results[key] = tmp + + def _erase_mask(self, results, patch, fill_val=0): + for key in results.get("mask_fields", []): + masks = results[key] + if isinstance(masks, PolygonMasks): + # convert mask to bitmask + masks = masks.to_bitmap() + x1, y1, x2, y2 = patch + tmp = masks.masks.copy() + tmp[:, y1:y2, x1:x2] = fill_val + masks = BitmapMasks(tmp, masks.height, masks.width) + results[key] = masks + + def _erase_seg(self, results, patch, fill_val=0): + for key in results.get("seg_fields", []): + seg = results[key].copy() + x1, y1, x2, y2 = patch + seg[y1:y2, x1:x2] = fill_val + results[key] = seg + + +@PIPELINES.register_module() +class RecomputeBox(object): + def __init__(self, record=False): + self.record = record + + def __call__(self, results): + if self.record: + if "aug_info" not in results: + results["aug_info"] = [] + results["aug_info"].append(dict(type="RecomputeBox")) + _, bbox2mask, _ = bbox2fields() + for key in results.get("bbox_fields", []): + mask_key = bbox2mask.get(key) + if mask_key in results: + masks = results[mask_key] + results[key] = self._recompute_bbox(masks) + return results + + def enable_record(self, mode: bool = True): + self.record = mode + + def _recompute_bbox(self, masks): + boxes = np.zeros(masks.masks.shape[0], 4, dtype=np.float32) + x_any = np.any(masks.masks, axis=1) + y_any = np.any(masks.masks, axis=2) + for idx in range(masks.masks.shape[0]): + x = np.where(x_any[idx, :])[0] + y = np.where(y_any[idx, :])[0] + if len(x) > 0 and len(y) > 0: + boxes[idx, :] = np.array( + [x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=np.float32 + ) + return boxes + + +# TODO: Implement Augmentation Inside Box + + +@PIPELINES.register_module() +class RandResize(transforms.Resize): + def __init__(self, record=False, **kwargs): + super().__init__(**kwargs) + self.record = record + + def __call__(self, results): + results = super().__call__(results) + if self.record: + scale_factor = results["scale_factor"] + GTrans.apply(results, "scale", sx=scale_factor[0], sy=scale_factor[1]) + + if "aug_info" not in results: + results["aug_info"] = [] + new_h, new_w = results["img"].shape[:2] + results["aug_info"].append( + dict( + type=self.__class__.__name__, + record=False, + img_scale=(new_w, new_h), + keep_ratio=False, + bbox_clip_border=self.bbox_clip_border, + backend=self.backend, + ) + ) + return results + + def enable_record(self, mode: bool = True): + self.record = mode + + +@PIPELINES.register_module() +class RandFlip(transforms.RandomFlip): + def __init__(self, record=False, **kwargs): + super().__init__(**kwargs) + self.record = record + + def __call__(self, results): + results = super().__call__(results) + if self.record: + if "aug_info" not in results: + results["aug_info"] = [] + if results["flip"]: + GTrans.apply( + results, + "flip", + direction=results["flip_direction"], + shape=results["img_shape"][:2], + ) + results["aug_info"].append( + dict( + type=self.__class__.__name__, + record=False, + flip_ratio=1.0, + direction=results["flip_direction"], + ) + ) + else: + results["aug_info"].append( + dict( + type=self.__class__.__name__, + record=False, + flip_ratio=0.0, + direction="vertical", + ) + ) + return results + + def enable_record(self, mode: bool = True): + self.record = mode + + +@PIPELINES.register_module() +class MultiBranch(object): + def __init__(self, **transform_group): + self.transform_group = {k: BaseCompose(v) for k, v in transform_group.items()} + + def __call__(self, results): + multi_results = [] + for k, v in self.transform_group.items(): + res = v(copy.deepcopy(results)) + if res is None: + return None + # res["img_metas"]["tag"] = k + multi_results.append(res) + return multi_results diff --git a/ssod/datasets/pseudo_coco.py b/ssod/datasets/pseudo_coco.py new file mode 100644 index 0000000..6db1666 --- /dev/null +++ b/ssod/datasets/pseudo_coco.py @@ -0,0 +1,86 @@ +import copy +import json + +from mmdet.datasets import DATASETS, CocoDataset +from mmdet.datasets.api_wrappers import COCO + + +@DATASETS.register_module() +class PseudoCocoDataset(CocoDataset): + def __init__( + self, + ann_file, + pseudo_ann_file, + pipeline, + confidence_threshold=0.9, + classes=None, + data_root=None, + img_prefix="", + seg_prefix=None, + proposal_file=None, + test_mode=False, + filter_empty_gt=True, + ): + self.confidence_threshold = confidence_threshold + self.pseudo_ann_file = pseudo_ann_file + + super().__init__( + ann_file, + pipeline, + classes, + data_root, + img_prefix, + seg_prefix, + proposal_file, + test_mode=test_mode, + filter_empty_gt=filter_empty_gt, + ) + + def load_pesudo_targets(self, pseudo_ann_file): + with open(pseudo_ann_file) as f: + pesudo_anns = json.load(f) + print(f"loading {len(pesudo_anns)} results") + + def _add_attr(dict_terms, **kwargs): + new_dict = copy.copy(dict_terms) + new_dict.update(**kwargs) + return new_dict + + def _compute_area(bbox): + _, _, w, h = bbox + return w * h + + pesudo_anns = [ + _add_attr(ann, id=i, area=_compute_area(ann["bbox"])) + for i, ann in enumerate(pesudo_anns) + if ann["score"] > self.confidence_threshold + ] + print( + f"With {len(pesudo_anns)} results over threshold {self.confidence_threshold}" + ) + + return pesudo_anns + + def load_annotations(self, ann_file): + """Load annotation from COCO style annotation file. + + Args: + ann_file (str): Path of annotation file. + Returns: + list[dict]: Annotation info from COCO api. + """ + pesudo_anns = self.load_pesudo_targets(self.pseudo_ann_file) + self.coco = COCO(ann_file) + self.coco.dataset["annotations"] = pesudo_anns + self.coco.createIndex() + + self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) + self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} + self.img_ids = self.coco.get_img_ids() + data_infos = [] + for i in self.img_ids: + info = self.coco.load_imgs([i])[0] + info["filename"] = info["file_name"] + data_infos.append(info) + + return data_infos diff --git a/ssod/datasets/samplers/__init__.py b/ssod/datasets/samplers/__init__.py new file mode 100644 index 0000000..5bdc225 --- /dev/null +++ b/ssod/datasets/samplers/__init__.py @@ -0,0 +1,4 @@ +from .semi_sampler import DistributedGroupSemiBalanceSampler +__all__ = [ + "DistributedGroupSemiBalanceSampler", +] diff --git a/ssod/datasets/samplers/semi_sampler.py b/ssod/datasets/samplers/semi_sampler.py new file mode 100644 index 0000000..0726aff --- /dev/null +++ b/ssod/datasets/samplers/semi_sampler.py @@ -0,0 +1,196 @@ +from __future__ import division + +import numpy as np +import torch +from mmcv.runner import get_dist_info +from torch.utils.data import Sampler, WeightedRandomSampler + +from ..builder import SAMPLERS + + +@SAMPLERS.register_module() +class DistributedGroupSemiBalanceSampler(Sampler): + def __init__( + self, + dataset, + by_prob=False, + epoch_length=7330, + sample_ratio=None, + samples_per_gpu=1, + num_replicas=None, + rank=None, + **kwargs + ): + _rank, _num_replicas = get_dist_info() + if num_replicas is None: + num_replicas = _num_replicas + if rank is None: + rank = _rank + + self.dataset = dataset + self.samples_per_gpu = samples_per_gpu + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.by_prob = by_prob + + assert hasattr(self.dataset, "flag") + self.flag = self.dataset.flag + self.group_sizes = np.bincount(self.flag) + self.num_samples = 0 + self.cumulative_sizes = dataset.cumulative_sizes + # decide the frequency to sample each kind of datasets + if not isinstance(sample_ratio, list): + sample_ratio = [sample_ratio] * len(self.cumulative_sizes) + self.sample_ratio = sample_ratio + self.sample_ratio = [ + int(sr / min(self.sample_ratio)) for sr in self.sample_ratio + ] + self.size_of_dataset = [] + cumulative_sizes = [0] + self.cumulative_sizes + + for i, _ in enumerate(self.group_sizes): + size_of_dataset = 0 + cur_group_inds = np.where(self.flag == i)[0] + for j in range(len(self.cumulative_sizes)): + cur_group_cur_dataset = np.where( + np.logical_and( + cur_group_inds > cumulative_sizes[j], + cur_group_inds < cumulative_sizes[j + 1], + ) + )[0] + size_per_dataset = len(cur_group_cur_dataset) + size_of_dataset = max( + size_of_dataset, np.ceil(size_per_dataset / self.sample_ratio[j]) + ) + + self.size_of_dataset.append( + int(np.ceil(size_of_dataset / self.samples_per_gpu / self.num_replicas)) + * self.samples_per_gpu + ) + for j in range(len(self.cumulative_sizes)): + self.num_samples += self.size_of_dataset[-1] * self.sample_ratio[j] + + self.total_size = self.num_samples * self.num_replicas + group_factor = [g / sum(self.group_sizes) for g in self.group_sizes] + self.epoch_length = [int(np.round(gf * epoch_length)) for gf in group_factor] + self.epoch_length[-1] = epoch_length - sum(self.epoch_length[:-1]) + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = [] + cumulative_sizes = [0] + self.cumulative_sizes + for i, size in enumerate(self.group_sizes): + if size > 0: + indice = np.where(self.flag == i)[0] + assert len(indice) == size + indice_per_dataset = [] + + for j in range(len(self.cumulative_sizes)): + indice_per_dataset.append( + indice[ + np.where( + np.logical_and( + indice >= cumulative_sizes[j], + indice < cumulative_sizes[j + 1], + ) + )[0] + ] + ) + + shuffled_indice_per_dataset = [ + s[list(torch.randperm(int(s.shape[0]), generator=g).numpy())] + for s in indice_per_dataset + ] + # split into + total_indice = [] + batch_idx = 0 + # pdb.set_trace() + while batch_idx < self.epoch_length[i] * self.num_replicas: + ratio = [x / sum(self.sample_ratio) for x in self.sample_ratio] + if self.by_prob: + indicator = list( + WeightedRandomSampler( + ratio, + self.samples_per_gpu, + replacement=True, + generator=g, + ) + ) + unique, counts = np.unique(indicator, return_counts=True) + ratio = [0] * len(shuffled_indice_per_dataset) + for u, c in zip(unique, counts): + ratio[u] = c + assert len(ratio) == 2, "Only two set is suppoted" + if ratio[0] == 0: + ratio[0] = 1 + ratio[1] -= 1 + elif ratio[1] == 0: + ratio[1] = 1 + ratio[0] -= 1 + + ratio = [r / sum(ratio) for r in ratio] + + # num of each dataset + ratio = [int(r * self.samples_per_gpu) for r in ratio] + + ratio[-1] = self.samples_per_gpu - sum(ratio[:-1]) + selected = [] + # print(ratio) + for j in range(len(shuffled_indice_per_dataset)): + if len(shuffled_indice_per_dataset[j]) < ratio[j]: + shuffled_indice_per_dataset[j] = np.concatenate( + ( + shuffled_indice_per_dataset[j], + indice_per_dataset[j][ + list( + torch.randperm( + int(indice_per_dataset[j].shape[0]), + generator=g, + ).numpy() + ) + ], + ) + ) + + selected.append(shuffled_indice_per_dataset[j][: ratio[j]]) + shuffled_indice_per_dataset[j] = shuffled_indice_per_dataset[j][ + ratio[j] : + ] + selected = np.concatenate(selected) + total_indice.append(selected) + batch_idx += 1 + # print(self.size_of_dataset) + indice = np.concatenate(total_indice) + indices.append(indice) + indices = np.concatenate(indices) # k + indices = [ + indices[j] + for i in list( + torch.randperm( + len(indices) // self.samples_per_gpu, + generator=g, + ) + ) + for j in range( + i * self.samples_per_gpu, + (i + 1) * self.samples_per_gpu, + ) + ] + + offset = len(self) * self.rank + indices = indices[offset : offset + len(self)] + assert len(indices) == len(self) + return iter(indices) + + def __len__(self): + return sum(self.epoch_length) * self.samples_per_gpu + + def set_epoch(self, epoch): + self.epoch = epoch + # duplicated, implement it by weight instead of sampling + # def update_sample_ratio(self): + # if self.dynamic_step is not None: + # self.sample_ratio = [d(self.epoch) for d in self.dynamic] diff --git a/ssod/models/__init__.py b/ssod/models/__init__.py new file mode 100644 index 0000000..8e811cf --- /dev/null +++ b/ssod/models/__init__.py @@ -0,0 +1 @@ +from .soft_teacher import SoftTeacher \ No newline at end of file diff --git a/ssod/models/multi_stream_detector.py b/ssod/models/multi_stream_detector.py new file mode 100644 index 0000000..ae153b3 --- /dev/null +++ b/ssod/models/multi_stream_detector.py @@ -0,0 +1,84 @@ +from typing import Dict +from mmdet.models import BaseDetector, TwoStageDetector + + +class MultiSteamDetector(BaseDetector): + def __init__( + self, model: Dict[str, TwoStageDetector], train_cfg=None, test_cfg=None + ): + super(MultiSteamDetector, self).__init__() + self.submodules = list(model.keys()) + for k, v in model.items(): + setattr(self, k, v) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.inference_on = self.test_cfg.get("inference_on", self.submodules[0]) + + def model(self, **kwargs) -> TwoStageDetector: + if "submodule" in kwargs: + assert ( + kwargs["submodule"] in self.submodules + ), "Detector does not contain submodule {}".format(kwargs["submodule"]) + model: TwoStageDetector = getattr(self, kwargs["submodule"]) + else: + model: TwoStageDetector = getattr(self, self.inference_on) + return model + def freeze(self,model_ref:str): + assert model_ref in self.submodules + model = getattr(self, model_ref) + model.eval() + for param in model.parameters(): + param.requires_grad=False + def forward_test(self, imgs, img_metas, **kwargs): + + return self.model(**kwargs).forward_test(imgs, img_metas, **kwargs) + + async def aforward_test(self, *, img, img_metas, **kwargs): + return self.model(**kwargs).aforward_test(img, img_metas, **kwargs) + + def extract_feat(self, imgs): + return self.model().extract_feat(imgs) + + async def aforward_test(self, *, img, img_metas, **kwargs): + return self.model(**kwargs).aforward_test(img, img_metas, **kwargs) + + def aug_test(self, imgs, img_metas, **kwargs): + return self.model(**kwargs).aug_test(imgs, img_metas, **kwargs) + + def simple_test(self, img, img_metas, **kwargs): + return self.model(**kwargs).simple_test(img, img_metas, **kwargs) + + async def async_simple_test(self, img, img_metas, **kwargs): + return self.model(**kwargs).async_simple_test(img, img_metas, **kwargs) + + def show_result( + self, + img, + result, + score_thr=0.3, + bbox_color=(72, 101, 241), + text_color=(72, 101, 241), + mask_color=None, + thickness=2, + font_size=13, + win_name="", + show=False, + wait_time=0, + out_file=None, + ): + return self.model().show_result( + self, + img, + result, + score_thr, + bbox_color, + text_color, + mask_color, + thickness, + font_size, + win_name, + show, + wait_time, + out_file, + ) diff --git a/ssod/models/soft_teacher.py b/ssod/models/soft_teacher.py new file mode 100644 index 0000000..380f919 --- /dev/null +++ b/ssod/models/soft_teacher.py @@ -0,0 +1,511 @@ +import torch +from mmcv.runner.fp16_utils import force_fp32 +from mmdet.core import bbox2roi, multi_apply +from mmdet.models import DETECTORS, build_detector + +from ssod.utils.structure_utils import dict_split, weighted_loss +from ssod.utils import log_image_with_boxes, log_every_n + +from .multi_stream_detector import MultiSteamDetector +from .utils import Transform2D, filter_invalid + + +@DETECTORS.register_module() +class SoftTeacher(MultiSteamDetector): + def __init__(self, model: dict, train_cfg=None, test_cfg=None): + super(SoftTeacher, self).__init__( + dict(teacher=build_detector(model), student=build_detector(model)), + train_cfg=train_cfg, + test_cfg=test_cfg, + ) + if train_cfg is not None: + self.freeze("teacher") + self.unsup_weight = self.train_cfg.unsup_weight + + def forward_train(self, img, img_metas, **kwargs): + super().forward_train(img, img_metas, **kwargs) + kwargs.update({"img": img}) + kwargs.update({"img_metas": img_metas}) + kwargs.update({"tag": [meta["tag"] for meta in img_metas]}) + data_groups = dict_split(kwargs, "tag") + for _, v in data_groups.items(): + v.pop("tag") + + loss = {} + #! Warnings: By splitting losses for supervised data and unsupervised data with different names, + #! it means that at least one sample for each group should be provided on each gpu. + #! In some situation, we can only put one image per gpu, we have to return the sum of loss + #! and log the loss with logger instead. Or it will try to sync tensors don't exist. + if "sup" in data_groups: + gt_bboxes = data_groups["sup"]["gt_bboxes"] + log_every_n( + {"sup_gt_num": sum([len(bbox) for bbox in gt_bboxes]) / len(gt_bboxes)} + ) + sup_loss = self.student.forward_train(**data_groups["sup"]) + sup_loss = {"sup_" + k: v for k, v in sup_loss.items()} + loss.update(**sup_loss) + if "unsup_student" in data_groups: + unsup_loss = weighted_loss( + self.foward_unsup_train( + data_groups["unsup_teacher"], data_groups["unsup_student"] + ), + weight=self.unsup_weight, + ) + unsup_loss = {"unsup_" + k: v for k, v in unsup_loss.items()} + loss.update(**unsup_loss) + + return loss + + def foward_unsup_train(self, teacher_data, student_data): + # sort the teacher and student input to avoid some bugs + tnames = [meta["filename"] for meta in teacher_data["img_metas"]] + snames = [meta["filename"] for meta in student_data["img_metas"]] + tidx = [tnames.index(name) for name in snames] + with torch.no_grad(): + teacher_info = self.extract_teacher_info( + teacher_data["img"][ + torch.Tensor(tidx).to(teacher_data["img"].device).long() + ], + [teacher_data["img_metas"][idx] for idx in tidx], + [teacher_data["proposals"][idx] for idx in tidx] + if ("proposals" in teacher_data) + and (teacher_data["proposals"] is not None) + else None, + ) + student_info = self.extract_student_info(**student_data) + + return self.compute_pseudo_label_loss(student_info, teacher_info) + + def compute_pseudo_label_loss(self, student_info, teacher_info): + M = self._get_trans_mat( + teacher_info["transform_matrix"], student_info["transform_matrix"] + ) + + pseudo_bboxes = self._transform_bbox( + teacher_info["det_bboxes"], + M, + [meta["img_shape"] for meta in student_info["img_metas"]], + ) + pseudo_labels = teacher_info["det_labels"] + loss = {} + rpn_loss, proposal_list = self.rpn_loss( + student_info["rpn_out"], + pseudo_bboxes, + student_info["img_metas"], + student_info=student_info, + ) + loss.update(rpn_loss) + if proposal_list is not None: + student_info["proposals"] = proposal_list + if self.train_cfg.use_teacher_proposal: + proposals = self._transform_bbox( + teacher_info["proposals"], + M, + [meta["img_shape"] for meta in student_info["img_metas"]], + ) + else: + proposals = student_info["proposals"] + + loss.update( + self.unsup_rcnn_cls_loss( + student_info["backbone_feature"], + student_info["img_metas"], + proposals, + pseudo_bboxes, + pseudo_labels, + teacher_info["transform_matrix"], + student_info["transform_matrix"], + teacher_info["img_metas"], + teacher_info["backbone_feature"], + student_info=student_info, + ) + ) + loss.update( + self.unsup_rcnn_reg_loss( + student_info["backbone_feature"], + student_info["img_metas"], + proposals, + pseudo_bboxes, + pseudo_labels, + student_info=student_info, + ) + ) + return loss + + def rpn_loss( + self, + rpn_out, + pseudo_bboxes, + img_metas, + gt_bboxes_ignore=None, + student_info=None, + **kwargs, + ): + if self.student.with_rpn: + gt_bboxes = [] + for bbox in pseudo_bboxes: + bbox, _, _ = filter_invalid( + bbox[:, :4], + score=bbox[ + :, 4 + ], # TODO: replace with foreground score, here is classification score, + thr=self.train_cfg.rpn_pseudo_threshold, + min_size=self.train_cfg.min_pseduo_box_size, + ) + gt_bboxes.append(bbox) + log_every_n( + {"rpn_gt_num": sum([len(bbox) for bbox in gt_bboxes]) / len(gt_bboxes)} + ) + loss_inputs = rpn_out + [[bbox.float() for bbox in gt_bboxes], img_metas] + losses = self.student.rpn_head.loss( + *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore + ) + proposal_cfg = self.student.train_cfg.get( + "rpn_proposal", self.student.test_cfg.rpn + ) + proposal_list = self.student.rpn_head.get_bboxes( + *rpn_out, img_metas, cfg=proposal_cfg + ) + log_image_with_boxes( + "rpn", + student_info["img"][0], + pseudo_bboxes[0][:, :4], + bbox_tag="rpn_pseudo_label", + scores=pseudo_bboxes[0][:, 4], + interval=500, + img_norm_cfg=student_info["img_metas"][0]["img_norm_cfg"], + ) + return losses, proposal_list + else: + return {}, None + + def unsup_rcnn_cls_loss( + self, + feat, + img_metas, + proposal_list, + pseudo_bboxes, + pseudo_labels, + teacher_transMat, + student_transMat, + teacher_img_metas, + teacher_feat, + student_info=None, + **kwargs, + ): + gt_bboxes, gt_labels, _ = multi_apply( + filter_invalid, + [bbox[:, :4] for bbox in pseudo_bboxes], + pseudo_labels, + [bbox[:, 4] for bbox in pseudo_bboxes], + thr=self.train_cfg.cls_pseudo_threshold, + ) + log_every_n( + {"rcnn_cls_gt_num": sum([len(bbox) for bbox in gt_bboxes]) / len(gt_bboxes)} + ) + sampling_results = self.get_sampling_result( + img_metas, + proposal_list, + gt_bboxes, + gt_labels, + ) + selected_bboxes = [res.bboxes[:, :4] for res in sampling_results] + rois = bbox2roi(selected_bboxes) + bbox_results = self.student.roi_head._bbox_forward(feat, rois) + bbox_targets = self.student.roi_head.bbox_head.get_targets( + sampling_results, gt_bboxes, gt_labels, self.student.train_cfg.rcnn + ) + M = self._get_trans_mat(student_transMat, teacher_transMat) + aligned_proposals = self._transform_bbox( + selected_bboxes, + M, + [meta["img_shape"] for meta in teacher_img_metas], + ) + with torch.no_grad(): + _, _scores = self.teacher.roi_head.simple_test_bboxes( + teacher_feat, + teacher_img_metas, + aligned_proposals, + None, + rescale=False, + ) + bg_score = torch.cat([_score[:, -1] for _score in _scores]) + assigned_label, _, _, _ = bbox_targets + neg_inds = assigned_label == self.student.roi_head.bbox_head.num_classes + bbox_targets[1][neg_inds] = bg_score[neg_inds].detach() + loss = self.student.roi_head.bbox_head.loss( + bbox_results["cls_score"], + bbox_results["bbox_pred"], + rois, + *bbox_targets, + reduction_override="none", + ) + loss["loss_cls"] = loss["loss_cls"].sum() / max(bbox_targets[1].sum(), 1.0) + loss["loss_bbox"] = loss["loss_bbox"].sum() / max( + bbox_targets[1].size()[0], 1.0 + ) + if len(gt_bboxes[0]) > 0: + log_image_with_boxes( + "rcnn_cls", + student_info["img"][0], + gt_bboxes[0], + bbox_tag="pseudo_label", + labels=gt_labels[0], + class_names=self.CLASSES, + interval=500, + img_norm_cfg=student_info["img_metas"][0]["img_norm_cfg"], + ) + return loss + + def unsup_rcnn_reg_loss( + self, + feat, + img_metas, + proposal_list, + pseudo_bboxes, + pseudo_labels, + student_info=None, + **kwargs, + ): + gt_bboxes, gt_labels, _ = multi_apply( + filter_invalid, + [bbox[:, :4] for bbox in pseudo_bboxes], + pseudo_labels, + [-bbox[:, 5:].mean(dim=-1) for bbox in pseudo_bboxes], + thr=-self.train_cfg.reg_pseudo_threshold, + ) + log_every_n( + {"rcnn_reg_gt_num": sum([len(bbox) for bbox in gt_bboxes]) / len(gt_bboxes)} + ) + loss_bbox = self.student.roi_head.forward_train( + feat, img_metas, proposal_list, gt_bboxes, gt_labels, **kwargs + )["loss_bbox"] + if len(gt_bboxes[0]) > 0: + log_image_with_boxes( + "rcnn_reg", + student_info["img"][0], + gt_bboxes[0], + bbox_tag="pseudo_label", + labels=gt_labels[0], + class_names=self.CLASSES, + interval=500, + img_norm_cfg=student_info["img_metas"][0]["img_norm_cfg"], + ) + return {"loss_bbox": loss_bbox} + + def get_sampling_result( + self, + img_metas, + proposal_list, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None, + **kwargs, + ): + num_imgs = len(img_metas) + if gt_bboxes_ignore is None: + gt_bboxes_ignore = [None for _ in range(num_imgs)] + sampling_results = [] + for i in range(num_imgs): + assign_result = self.student.roi_head.bbox_assigner.assign( + proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i], gt_labels[i] + ) + sampling_result = self.student.roi_head.bbox_sampler.sample( + assign_result, + proposal_list[i], + gt_bboxes[i], + gt_labels[i], + ) + sampling_results.append(sampling_result) + return sampling_results + + @force_fp32(apply_to=["bboxes", "trans_mat"]) + def _transform_bbox(self, bboxes, trans_mat, max_shape): + bboxes = Transform2D.transform_bboxes(bboxes, trans_mat, max_shape) + return bboxes + + @force_fp32(apply_to=["a", "b"]) + def _get_trans_mat(self, a, b): + return [bt @ at.inverse() for bt, at in zip(b, a)] + + def extract_student_info(self, img, img_metas, proposals=None, **kwargs): + student_info = {} + student_info["img"] = img + feat = self.student.extract_feat(img) + student_info["backbone_feature"] = feat + if self.student.with_rpn: + rpn_out = self.student.rpn_head(feat) + student_info["rpn_out"] = list(rpn_out) + student_info["img_metas"] = img_metas + student_info["proposals"] = proposals + student_info["transform_matrix"] = [ + torch.from_numpy(meta["transform_matrix"]).float().to(feat[0][0].device) + for meta in img_metas + ] + return student_info + + def extract_teacher_info(self, img, img_metas, proposals=None, **kwargs): + teacher_info = {} + feat = self.teacher.extract_feat(img) + teacher_info["backbone_feature"] = feat + if proposals is None: + proposal_cfg = self.teacher.train_cfg.get( + "rpn_proposal", self.teacher.test_cfg.rpn + ) + rpn_out = list(self.teacher.rpn_head(feat)) + proposal_list = self.teacher.rpn_head.get_bboxes( + *rpn_out, img_metas, cfg=proposal_cfg + ) + else: + proposal_list = proposals + teacher_info["proposals"] = proposal_list + + proposal_list, proposal_label_list = self.teacher.roi_head.simple_test_bboxes( + feat, img_metas, proposal_list, self.teacher.test_cfg.rcnn, rescale=False + ) + + proposal_list = [p.to(feat[0].device) for p in proposal_list] + proposal_list = [ + p if p.shape[0] > 0 else p.new_zeros(0, 5) for p in proposal_list + ] + proposal_label_list = [p.to(feat[0].device) for p in proposal_label_list] + # filter invalid box roughly + if isinstance(self.train_cfg.pseudo_label_initial_score_thr, float): + thr = self.train_cfg.pseudo_label_initial_score_thr + else: + # TODO: use dynamic threshold + raise NotImplementedError("Dynamic Threshold is not implemented yet.") + proposal_list, proposal_label_list, _ = list( + zip( + *[ + filter_invalid( + proposal, + proposal_label, + proposal[:, -1], + thr=thr, + min_size=self.train_cfg.min_pseduo_box_size, + ) + for proposal, proposal_label in zip( + proposal_list, proposal_label_list + ) + ] + ) + ) + det_bboxes = proposal_list + reg_unc = self.compute_uncertainty_with_aug( + feat, img_metas, proposal_list, proposal_label_list + ) + det_bboxes = [ + torch.cat([bbox, unc], dim=-1) for bbox, unc in zip(det_bboxes, reg_unc) + ] + det_labels = proposal_label_list + teacher_info["det_bboxes"] = det_bboxes + teacher_info["det_labels"] = det_labels + teacher_info["transform_matrix"] = [ + torch.from_numpy(meta["transform_matrix"]).float().to(feat[0][0].device) + for meta in img_metas + ] + teacher_info["img_metas"] = img_metas + return teacher_info + + def compute_uncertainty_with_aug( + self, feat, img_metas, proposal_list, proposal_label_list + ): + auged_proposal_list = self.aug_box( + proposal_list, self.train_cfg.jitter_times, self.train_cfg.jitter_scale + ) + # flatten + auged_proposal_list = [ + auged.reshape(-1, auged.shape[-1]) for auged in auged_proposal_list + ] + + bboxes, _ = self.teacher.roi_head.simple_test_bboxes( + feat, + img_metas, + auged_proposal_list, + None, + rescale=False, + ) + reg_channel = max([bbox.shape[-1] for bbox in bboxes]) // 4 + bboxes = [ + bbox.reshape(self.train_cfg.jitter_times, -1, bbox.shape[-1]) + if bbox.numel() > 0 + else bbox.new_zeros(self.train_cfg.jitter_times, 0, 4 * reg_channel).float() + for bbox in bboxes + ] + + box_unc = [bbox.std(dim=0) for bbox in bboxes] + bboxes = [bbox.mean(dim=0) for bbox in bboxes] + # scores = [score.mean(dim=0) for score in scores] + if reg_channel != 1: + bboxes = [ + bbox.reshape(bbox.shape[0], reg_channel, 4)[ + torch.arange(bbox.shape[0]), label + ] + for bbox, label in zip(bboxes, proposal_label_list) + ] + box_unc = [ + unc.reshape(unc.shape[0], reg_channel, 4)[ + torch.arange(unc.shape[0]), label + ] + for unc, label in zip(box_unc, proposal_label_list) + ] + + box_shape = [(bbox[:, 2:4] - bbox[:, :2]).clamp(min=1.0) for bbox in bboxes] + # relative unc + box_unc = [ + unc / wh[:, None, :].expand(-1, 2, 2).reshape(-1, 4) + if wh.numel() > 0 + else unc + for unc, wh in zip(box_unc, box_shape) + ] + return box_unc + + @staticmethod + def aug_box(boxes, times=1, frac=0.06): + def _aug_single(box): + # random translate + # TODO: random flip or something + box_scale = box[:, 2:4] - box[:, :2] + box_scale = ( + box_scale.clamp(min=1)[:, None, :].expand(-1, 2, 2).reshape(-1, 4) + ) + aug_scale = box_scale * frac # [n,4] + + offset = ( + torch.randn(times, box.shape[0], 4, device=box.device) + * aug_scale[None, ...] + ) + new_box = box.clone()[None, ...].expand(times, box.shape[0], -1) + return torch.cat( + [new_box[:, :, :4].clone() + offset, new_box[:, :, 4:]], dim=-1 + ) + + return [_aug_single(box) for box in boxes] + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + if not any(["student" in key or "teacher" in key for key in state_dict.keys()]): + keys = list(state_dict.keys()) + state_dict.update({"teacher." + k: state_dict[k] for k in keys}) + state_dict.update({"student." + k: state_dict[k] for k in keys}) + for k in keys: + state_dict.pop(k) + + return super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) diff --git a/ssod/models/utils/__init__.py b/ssod/models/utils/__init__.py new file mode 100644 index 0000000..5b1fb55 --- /dev/null +++ b/ssod/models/utils/__init__.py @@ -0,0 +1 @@ +from .bbox_utils import Transform2D, filter_invalid diff --git a/ssod/models/utils/bbox_utils.py b/ssod/models/utils/bbox_utils.py new file mode 100644 index 0000000..6614679 --- /dev/null +++ b/ssod/models/utils/bbox_utils.py @@ -0,0 +1,255 @@ +import warnings +from collections.abc import Sequence +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from mmdet.core.mask.structures import BitmapMasks +from torch.nn import functional as F + + +def bbox2points(box): + min_x, min_y, max_x, max_y = torch.split(box[:, :4], [1, 1, 1, 1], dim=1) + + return torch.cat( + [min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y], dim=1 + ).reshape( + -1, 2 + ) # n*4,2 + + +def points2bbox(point, max_w, max_h): + point = point.reshape(-1, 4, 2) + if point.size()[0] > 0: + min_xy = point.min(dim=1)[0] + max_xy = point.max(dim=1)[0] + xmin = min_xy[:, 0].clamp(min=0, max=max_w) + ymin = min_xy[:, 1].clamp(min=0, max=max_h) + xmax = max_xy[:, 0].clamp(min=0, max=max_w) + ymax = max_xy[:, 1].clamp(min=0, max=max_h) + min_xy = torch.stack([xmin, ymin], dim=1) + max_xy = torch.stack([xmax, ymax], dim=1) + return torch.cat([min_xy, max_xy], dim=1) # n,4 + else: + return point.new_zeros(0, 4) + + +def check_is_tensor(obj): + """Checks whether the supplied object is a tensor.""" + if not isinstance(obj, torch.Tensor): + raise TypeError("Input type is not a torch.Tensor. Got {}".format(type(obj))) + + +def normal_transform_pixel( + height: int, + width: int, + eps: float = 1e-14, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + tr_mat = torch.tensor( + [[1.0, 0.0, -1.0], [0.0, 1.0, -1.0], [0.0, 0.0, 1.0]], + device=device, + dtype=dtype, + ) # 3x3 + + # prevent divide by zero bugs + width_denom: float = eps if width == 1 else width - 1.0 + height_denom: float = eps if height == 1 else height - 1.0 + + tr_mat[0, 0] = tr_mat[0, 0] * 2.0 / width_denom + tr_mat[1, 1] = tr_mat[1, 1] * 2.0 / height_denom + + return tr_mat.unsqueeze(0) # 1x3x3 + + +def normalize_homography( + dst_pix_trans_src_pix: torch.Tensor, + dsize_src: Tuple[int, int], + dsize_dst: Tuple[int, int], +) -> torch.Tensor: + check_is_tensor(dst_pix_trans_src_pix) + + if not ( + len(dst_pix_trans_src_pix.shape) == 3 + or dst_pix_trans_src_pix.shape[-2:] == (3, 3) + ): + raise ValueError( + "Input dst_pix_trans_src_pix must be a Bx3x3 tensor. Got {}".format( + dst_pix_trans_src_pix.shape + ) + ) + + # source and destination sizes + src_h, src_w = dsize_src + dst_h, dst_w = dsize_dst + + # compute the transformation pixel/norm for src/dst + src_norm_trans_src_pix: torch.Tensor = normal_transform_pixel(src_h, src_w).to( + dst_pix_trans_src_pix + ) + src_pix_trans_src_norm = torch.inverse(src_norm_trans_src_pix.float()).to( + src_norm_trans_src_pix.dtype + ) + dst_norm_trans_dst_pix: torch.Tensor = normal_transform_pixel(dst_h, dst_w).to( + dst_pix_trans_src_pix + ) + + # compute chain transformations + dst_norm_trans_src_norm: torch.Tensor = dst_norm_trans_dst_pix @ ( + dst_pix_trans_src_pix @ src_pix_trans_src_norm + ) + return dst_norm_trans_src_norm + + +def warp_affine( + src: torch.Tensor, + M: torch.Tensor, + dsize: Tuple[int, int], + mode: str = "bilinear", + padding_mode: str = "zeros", + align_corners: Optional[bool] = None, +) -> torch.Tensor: + if not isinstance(src, torch.Tensor): + raise TypeError( + "Input src type is not a torch.Tensor. Got {}".format(type(src)) + ) + + if not isinstance(M, torch.Tensor): + raise TypeError("Input M type is not a torch.Tensor. Got {}".format(type(M))) + + if not len(src.shape) == 4: + raise ValueError("Input src must be a BxCxHxW tensor. Got {}".format(src.shape)) + + if not (len(M.shape) == 3 or M.shape[-2:] == (2, 3)): + raise ValueError("Input M must be a Bx2x3 tensor. Got {}".format(M.shape)) + + # TODO: remove the statement below in kornia v0.6 + if align_corners is None: + message: str = ( + "The align_corners default value has been changed. By default now is set True " + "in order to match cv2.warpAffine." + ) + warnings.warn(message) + # set default value for align corners + align_corners = True + + B, C, H, W = src.size() + + # we generate a 3x3 transformation matrix from 2x3 affine + + dst_norm_trans_src_norm: torch.Tensor = normalize_homography(M, (H, W), dsize) + + src_norm_trans_dst_norm = torch.inverse(dst_norm_trans_src_norm.float()) + + grid = F.affine_grid( + src_norm_trans_dst_norm[:, :2, :], + [B, C, dsize[0], dsize[1]], + align_corners=align_corners, + ) + + return F.grid_sample( + src.float(), + grid, + align_corners=align_corners, + mode=mode, + padding_mode=padding_mode, + ).to(src.dtype) + + +class Transform2D: + @staticmethod + def transform_bboxes(bbox, M, out_shape): + if isinstance(bbox, Sequence): + assert len(bbox) == len(M) + return [ + Transform2D.transform_bboxes(b, m, o) + for b, m, o in zip(bbox, M, out_shape) + ] + else: + if bbox.shape[0] == 0: + return bbox + score = None + if bbox.shape[1] > 4: + score = bbox[:, 4:] + points = bbox2points(bbox[:, :4]) + points = torch.cat( + [points, points.new_ones(points.shape[0], 1)], dim=1 + ) # n,3 + points = torch.matmul(M, points.t()).t() + points = points[:, :2] / points[:, 2:3] + bbox = points2bbox(points, out_shape[1], out_shape[0]) + if score is not None: + return torch.cat([bbox, score], dim=1) + return bbox + + @staticmethod + def transform_masks( + mask: Union[BitmapMasks, List[BitmapMasks]], + M: Union[torch.Tensor, List[torch.Tensor]], + out_shape: Union[list, List[list]], + ): + if isinstance(mask, Sequence): + assert len(mask) == len(M) + return [ + Transform2D.transform_masks(b, m, o) + for b, m, o in zip(mask, M, out_shape) + ] + else: + if mask.masks.shape[0] == 0: + return BitmapMasks(np.zeros((0, *out_shape)), *out_shape) + mask_tensor = ( + torch.from_numpy(mask.masks[:, None, ...]).to(M.device).to(M.dtype) + ) + return BitmapMasks( + warp_affine( + mask_tensor, + M[None, ...].expand(mask.masks.shape[0], -1, -1), + out_shape, + ) + .squeeze(1) + .cpu() + .numpy(), + out_shape[0], + out_shape[1], + ) + + @staticmethod + def transform_image(img, M, out_shape): + if isinstance(img, Sequence): + assert len(img) == len(M) + return [ + Transform2D.transform_image(b, m, shape) + for b, m, shape in zip(img, M, out_shape) + ] + else: + if img.dim() == 2: + img = img[None, None, ...] + elif img.dim() == 3: + img = img[None, ...] + + return ( + warp_affine(img.float(), M[None, ...], out_shape, mode="nearest") + .squeeze() + .to(img.dtype) + ) + + +def filter_invalid(bbox, label=None, score=None, mask=None, thr=0.0, min_size=0): + if (score is not None) and (thr > 0): + valid = score > thr + bbox = bbox[valid] + if label is not None: + label = label[valid] + if mask is not None: + mask = BitmapMasks(mask.masks[valid.cpu().numpy()], mask.height, mask.width) + if min_size is not None: + bw = bbox[:, 2] - bbox[:, 0] + bh = bbox[:, 3] - bbox[:, 1] + valid = (bw > min_size) & (bh > min_size) + bbox = bbox[valid] + if label is not None: + label = label[valid] + if mask is not None: + mask = BitmapMasks(mask.masks[valid.cpu().numpy()], mask.height, mask.width) + return bbox, label, mask diff --git a/ssod/utils/__init__.py b/ssod/utils/__init__.py new file mode 100644 index 0000000..3b5f2e4 --- /dev/null +++ b/ssod/utils/__init__.py @@ -0,0 +1,19 @@ +from .exts import NamedOptimizerConstructor +from .hooks import Weighter, MeanTeacher, WeightSummary, SubModulesDistEvalHook +from .logger import get_root_logger, log_every_n, log_image_with_boxes +from .patch import patch_config, patch_runner, find_latest_checkpoint + + +__all__ = [ + "get_root_logger", + "log_every_n", + "log_image_with_boxes", + "patch_config", + "patch_runner", + "find_latest_checkpoint", + "Weighter", + "MeanTeacher", + "WeightSummary", + "SubModulesDistEvalHook", + "NamedOptimizerConstructor", +] diff --git a/ssod/utils/exts/__init__.py b/ssod/utils/exts/__init__.py new file mode 100644 index 0000000..306fb0f --- /dev/null +++ b/ssod/utils/exts/__init__.py @@ -0,0 +1 @@ +from .optimizer_constructor import NamedOptimizerConstructor diff --git a/ssod/utils/exts/optimizer_constructor.py b/ssod/utils/exts/optimizer_constructor.py new file mode 100644 index 0000000..999e16c --- /dev/null +++ b/ssod/utils/exts/optimizer_constructor.py @@ -0,0 +1,113 @@ +import warnings + +import torch +from torch.nn import GroupNorm, LayerNorm + +from mmcv.utils import _BatchNorm, _InstanceNorm, build_from_cfg +from mmcv.utils.ext_loader import check_ops_exist +from mmcv.runner.optimizer.builder import OPTIMIZER_BUILDERS, OPTIMIZERS +from mmcv.runner.optimizer import DefaultOptimizerConstructor + + +@OPTIMIZER_BUILDERS.register_module() +class NamedOptimizerConstructor(DefaultOptimizerConstructor): + """Main difference to default constructor: + + 1) Add name to parame groups + """ + + def add_params(self, params, module, prefix="", is_dcn_module=None): + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + prefix (str): The prefix of the module + is_dcn_module (int|float|None): If the current module is a + submodule of DCN, `is_dcn_module` will be passed to + control conv_offset layer's learning rate. Defaults to None. + """ + # get param-wise options + custom_keys = self.paramwise_cfg.get("custom_keys", {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + + bias_lr_mult = self.paramwise_cfg.get("bias_lr_mult", 1.0) + bias_decay_mult = self.paramwise_cfg.get("bias_decay_mult", 1.0) + norm_decay_mult = self.paramwise_cfg.get("norm_decay_mult", 1.0) + dwconv_decay_mult = self.paramwise_cfg.get("dwconv_decay_mult", 1.0) + bypass_duplicate = self.paramwise_cfg.get("bypass_duplicate", False) + dcn_offset_lr_mult = self.paramwise_cfg.get("dcn_offset_lr_mult", 1.0) + + # special rules for norm layers and depth-wise conv layers + is_norm = isinstance(module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) + is_dwconv = ( + isinstance(module, torch.nn.Conv2d) and module.in_channels == module.groups + ) + + for name, param in module.named_parameters(recurse=False): + param_group = {"params": [param], "name": f"{prefix}.{name}"} + if not param.requires_grad: + params.append(param_group) + continue + if bypass_duplicate and self._is_in(param_group, params): + warnings.warn( + f"{prefix} is duplicate. It is skipped since " + f"bypass_duplicate={bypass_duplicate}" + ) + continue + # if the parameter match one of the custom keys, ignore other rules + is_custom = False + for key in sorted_keys: + if key in f"{prefix}.{name}": + is_custom = True + lr_mult = custom_keys[key].get("lr_mult", 1.0) + param_group["lr"] = self.base_lr * lr_mult + if self.base_wd is not None: + decay_mult = custom_keys[key].get("decay_mult", 1.0) + param_group["weight_decay"] = self.base_wd * decay_mult + break + + if not is_custom: + # bias_lr_mult affects all bias parameters + # except for norm.bias dcn.conv_offset.bias + if name == "bias" and not (is_norm or is_dcn_module): + param_group["lr"] = self.base_lr * bias_lr_mult + + if ( + prefix.find("conv_offset") != -1 + and is_dcn_module + and isinstance(module, torch.nn.Conv2d) + ): + # deal with both dcn_offset's bias & weight + param_group["lr"] = self.base_lr * dcn_offset_lr_mult + + # apply weight decay policies + if self.base_wd is not None: + # norm decay + if is_norm: + param_group["weight_decay"] = self.base_wd * norm_decay_mult + # depth-wise conv + elif is_dwconv: + param_group["weight_decay"] = self.base_wd * dwconv_decay_mult + # bias lr and decay + elif name == "bias" and not is_dcn_module: + # TODO: current bias_decay_mult will have affect on DCN + param_group["weight_decay"] = self.base_wd * bias_decay_mult + params.append(param_group) + + if check_ops_exist(): + from mmcv.ops import DeformConv2d, ModulatedDeformConv2d + + is_dcn_module = isinstance(module, (DeformConv2d, ModulatedDeformConv2d)) + else: + is_dcn_module = False + for child_name, child_mod in module.named_children(): + child_prefix = f"{prefix}.{child_name}" if prefix else child_name + self.add_params( + params, child_mod, prefix=child_prefix, is_dcn_module=is_dcn_module + ) diff --git a/ssod/utils/hooks/__init__.py b/ssod/utils/hooks/__init__.py new file mode 100644 index 0000000..880e29f --- /dev/null +++ b/ssod/utils/hooks/__init__.py @@ -0,0 +1,12 @@ +from .weight_adjust import Weighter +from .mean_teacher import MeanTeacher +from .weights_summary import WeightSummary +from .submodules_evaluation import SubModulesDistEvalHook # ,SubModulesEvalHook + + +__all__ = [ + "Weighter", + "MeanTeacher", + "SubModulesDistEvalHook", + "WeightSummary", +] diff --git a/ssod/utils/hooks/mean_teacher.py b/ssod/utils/hooks/mean_teacher.py new file mode 100644 index 0000000..be71dab --- /dev/null +++ b/ssod/utils/hooks/mean_teacher.py @@ -0,0 +1,64 @@ +from mmcv.parallel import is_module_wrapper +from mmcv.runner.hooks import HOOKS, Hook +from bisect import bisect_right +from ..logger import log_every_n + + +@HOOKS.register_module() +class MeanTeacher(Hook): + def __init__( + self, + momentum=0.999, + interval=1, + warm_up=100, + decay_intervals=None, + decay_factor=0.1, + ): + assert momentum >= 0 and momentum <= 1 + self.momentum = momentum + assert isinstance(interval, int) and interval > 0 + self.warm_up = warm_up + self.interval = interval + assert isinstance(decay_intervals, list) or decay_intervals is None + self.decay_intervals = decay_intervals + self.decay_factor = decay_factor + + def before_run(self, runner): + model = runner.model + if is_module_wrapper(model): + model = model.module + assert hasattr(model, "teacher") + assert hasattr(model, "student") + # only do it at initial stage + if runner.iter == 0: + log_every_n("Clone all parameters of student to teacher...") + self.momentum_update(model, 0) + + def before_train_iter(self, runner): + """Update ema parameter every self.interval iterations.""" + curr_step = runner.iter + if curr_step % self.interval != 0: + return + model = runner.model + if is_module_wrapper(model): + model = model.module + # We warm up the momentum considering the instability at beginning + momentum = min( + self.momentum, 1 - (1 + self.warm_up) / (curr_step + 1 + self.warm_up) + ) + runner.log_buffer.output["ema_momentum"] = momentum + self.momentum_update(model, momentum) + + def after_train_iter(self, runner): + curr_step = runner.iter + if self.decay_intervals is None: + return + self.momentum = 1 - (1 - self.momentum) * self.decay_factor ** bisect_right( + self.decay_intervals, curr_step + ) + + def momentum_update(self, model, momentum): + for (src_name, src_parm), (tgt_name, tgt_parm) in zip( + model.student.named_parameters(), model.teacher.named_parameters() + ): + tgt_parm.data.mul_(momentum).add_(src_parm.data, alpha=1 - momentum) diff --git a/ssod/utils/hooks/submodules_evaluation.py b/ssod/utils/hooks/submodules_evaluation.py new file mode 100644 index 0000000..13f5abe --- /dev/null +++ b/ssod/utils/hooks/submodules_evaluation.py @@ -0,0 +1,106 @@ +import os.path as osp + +import torch.distributed as dist +from mmcv.parallel import is_module_wrapper +from mmcv.runner.hooks import HOOKS +from mmdet.core import DistEvalHook +from torch.nn.modules.batchnorm import _BatchNorm + + +@HOOKS.register_module() +class SubModulesDistEvalHook(DistEvalHook): + def __init__(self, *args, evaluated_modules=None, **kwargs): + super().__init__(*args, **kwargs) + self.evaluated_modules = evaluated_modules + + def before_run(self, runner): + if is_module_wrapper(runner.model): + model = runner.model.module + else: + model = runner.model + assert hasattr(model, "submodules") + assert hasattr(model, "inference_on") + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + # Synchronization of BatchNorm's buffer (running_mean + # and running_var) is not supported in the DDP of pytorch, + # which may cause the inconsistent performance of models in + # different ranks, so we broadcast BatchNorm's buffers + # of rank 0 to other ranks to avoid this. + + if self.broadcast_bn_buffer: + model = runner.model + for name, module in model.named_modules(): + if isinstance(module, _BatchNorm) and module.track_running_stats: + dist.broadcast(module.running_var, 0) + dist.broadcast(module.running_mean, 0) + + if not self._should_evaluate(runner): + return + # TODO: add `runner.mode = "val"` + + tmpdir = self.tmpdir + if tmpdir is None: + tmpdir = osp.join(runner.work_dir, ".eval_hook") + + if is_module_wrapper(runner.model): + model_ref = runner.model.module + else: + model_ref = runner.model + if not self.evaluated_modules: + submodules = model_ref.submodules + else: + submodules = self.evaluated_modules + key_scores = [] + from mmdet.apis import multi_gpu_test + + for submodule in submodules: + # change inference on + model_ref.inference_on = submodule + results = multi_gpu_test( + runner.model, + self.dataloader, + tmpdir=tmpdir, + gpu_collect=self.gpu_collect, + ) + if runner.rank == 0: + key_score = self.evaluate(runner, results, prefix=submodule) + if key_score is not None: + key_scores.append(key_score) + + if runner.rank == 0: + runner.log_buffer.ready = True + if len(key_scores) == 0: + key_scores = [None] + best_score = key_scores[0] + for key_score in key_scores: + if hasattr(self, "compare_func") and self.compare_func( + key_score, best_score + ): + best_score = key_score + + print("\n") + runner.log_buffer.output["eval_iter_num"] = len(self.dataloader) + if self.save_best: + self._save_ckpt(runner, best_score) + + def evaluate(self, runner, results, prefix=""): + """Evaluate the results. + Args: + runner (:obj:`mmcv.Runner`): The underlined training runner. + results (list): Output results. + """ + eval_res = self.dataloader.dataset.evaluate( + results, logger=runner.logger, **self.eval_kwargs + ) + for name, val in eval_res.items(): + runner.log_buffer.output[(".").join([prefix, name])] = val + + if self.save_best is not None: + if self.key_indicator == "auto": + # infer from eval_results + self._init_rule(self.rule, list(eval_res.keys())[0]) + return eval_res[self.key_indicator] + + return None diff --git a/ssod/utils/hooks/weight_adjust.py b/ssod/utils/hooks/weight_adjust.py new file mode 100644 index 0000000..c874685 --- /dev/null +++ b/ssod/utils/hooks/weight_adjust.py @@ -0,0 +1,35 @@ +from mmcv.parallel import is_module_wrapper +from mmcv.runner.hooks import HOOKS, Hook +from bisect import bisect_right + + +@HOOKS.register_module() +class Weighter(Hook): + def __init__( + self, + steps=None, + vals=None, + name=None, + ): + self.steps = steps + self.vals = vals + self.name = name + if self.name is not None: + assert self.steps is not None + assert self.vals is not None + assert len(self.vals) == len(self.steps) + 1 + + def before_train_iter(self, runner): + curr_step = runner.iter + if self.name is None: + return + model = runner.model + if is_module_wrapper(model): + model = model.module + assert hasattr(model, self.name) + self.steps = [s if s > 0 else runner.max_iters - s for s in self.steps] + runner.log_buffer.output[self.name] = self.vals[ + bisect_right(self.steps, curr_step) + ] + + setattr(model, self.name, runner.log_buffer.output[self.name]) diff --git a/ssod/utils/hooks/weights_summary.py b/ssod/utils/hooks/weights_summary.py new file mode 100644 index 0000000..11dbd61 --- /dev/null +++ b/ssod/utils/hooks/weights_summary.py @@ -0,0 +1,101 @@ +import os.path as osp + +import torch.distributed as dist +from mmcv.parallel import is_module_wrapper +from mmcv.runner.hooks import HOOKS, Hook +from ..logger import get_root_logger +from prettytable import PrettyTable + + +def bool2str(input): + if input: + return "Y" + else: + return "N" + + +def unknown(): + return "-" + + +def shape_str(size): + size = [str(s) for s in size] + return "X".join(size) + + +def min_max_str(input): + return "Min:{:.3f} Max:{:.3f}".format(input.min(), input.max()) + + +def construct_params_dict(input): + assert isinstance(input, list) + param_dict = {} + for group in input: + if "name" in group: + param_dict[group["name"]] = group + return param_dict + + +def max_match_sub_str(strs, sub_str): + # find most related str for sub_str + matched = None + for child in strs: + if len(child) <= len(sub_str): + if child == sub_str: + return child + elif sub_str[: len(child)] == child: + if matched is None or len(matched) < len(child): + matched = child + return matched + + +def get_optim(optimizer, params_dict, name, key): + rel_name = max_match_sub_str(list(params_dict.keys()), name) + if rel_name is not None: + return params_dict[rel_name][key] + else: + if key in optimizer.defaults: + return optimizer.defaults[key] + + +@HOOKS.register_module() +class WeightSummary(Hook): + def before_run(self, runner): + if runner.rank != 0: + return + if is_module_wrapper(runner.model): + model = runner.model.module + else: + model = runner.model + weight_summaries = self.collect_model_info(model, optimizer=runner.optimizer) + logger = get_root_logger() + logger.info(weight_summaries) + + @staticmethod + def collect_model_info(model, optimizer=None, rich_text=False): + param_groups = None + if optimizer is not None: + param_groups = construct_params_dict(optimizer.param_groups) + + if not rich_text: + table = PrettyTable( + ["Name", "Optimized", "Shape", "Value Scale [Min,Max]", "Lr", "Wd"] + ) + for name, param in model.named_parameters(): + table.add_row( + [ + name, + bool2str(param.requires_grad), + shape_str(param.size()), + min_max_str(param), + unknown() + if param_groups is None + else get_optim(optimizer, param_groups, name, "lr"), + unknown() + if param_groups is None + else get_optim(optimizer, param_groups, name, "weight_decay"), + ] + ) + return "\n" + table.get_string(title="Model Information") + else: + pass diff --git a/ssod/utils/logger.py b/ssod/utils/logger.py new file mode 100644 index 0000000..c10e294 --- /dev/null +++ b/ssod/utils/logger.py @@ -0,0 +1,172 @@ +import logging +import os +import sys +from collections import Counter +from typing import Tuple + +import mmcv +import numpy as np +import torch +from mmcv.runner.dist_utils import get_dist_info +from mmcv.utils import get_logger +from mmdet.core.visualization import imshow_det_bboxes + +try: + import wandb +except: + wandb = None + +_log_counter = Counter() + + +def get_root_logger(log_file=None, log_level=logging.INFO): + """Get root logger. + + Args: + log_file (str, optional): File path of log. Defaults to None. + log_level (int, optional): The level of logger. + Defaults to logging.INFO. + + Returns: + :obj:`logging.Logger`: The obtained logger + """ + logger = get_logger(name="mmdet.ssod", log_file=log_file, log_level=log_level) + logger.propagate = False + return logger + + +def _find_caller(): + frame = sys._getframe(2) + while frame: + code = frame.f_code + if os.path.join("utils", "logger.") not in code.co_filename: + mod_name = frame.f_globals["__name__"] + if mod_name == "__main__": + mod_name = r"ssod" + return mod_name, (code.co_filename, frame.f_lineno, code.co_name) + frame = frame.f_back + + +def convert_box(tag, boxes, box_labels, class_labels, std, scores=None): + if isinstance(std, int): + std = [std, std] + if len(std) != 4: + std = std[::-1] * 2 + std = boxes.new_tensor(std).reshape(1, 4) + wandb_box = {} + boxes = boxes / std + boxes = boxes.detach().cpu().numpy().tolist() + box_labels = box_labels.detach().cpu().numpy().tolist() + class_labels = {k: class_labels[k] for k in range(len(class_labels))} + wandb_box["class_labels"] = class_labels + assert len(boxes) == len(box_labels) + if scores is not None: + scores = scores.detach().cpu().numpy().tolist() + box_data = [ + dict( + position=dict(minX=box[0], minY=box[1], maxX=box[2], maxY=box[3]), + class_id=label, + scores=dict(cls=scores[i]), + ) + for i, (box, label) in enumerate(zip(boxes, box_labels)) + ] + else: + box_data = [ + dict( + position=dict(minX=box[0], minY=box[1], maxX=box[2], maxY=box[3]), + class_id=label, + ) + for i, (box, label) in enumerate(zip(boxes, box_labels)) + ] + + wandb_box["box_data"] = box_data + return {tag: wandb.data_types.BoundingBoxes2D(wandb_box, tag)} + + +def color_transform(img_tensor, mean, std, to_rgb=False): + img_np = img_tensor.detach().cpu().numpy().transpose((1, 2, 0)).astype(np.float32) + return mmcv.imdenormalize(img_np, mean, std, to_bgr=not to_rgb) + + +def log_image_with_boxes( + tag: str, + image: torch.Tensor, + bboxes: torch.Tensor, + bbox_tag: str = None, + labels: torch.Tensor = None, + scores: torch.Tensor = None, + class_names: Tuple[str] = None, + filename: str = None, + img_norm_cfg: dict = None, + backend: str = "auto", + interval: int = 50, +): + rank, _ = get_dist_info() + if rank != 0: + return + _, key = _find_caller() + _log_counter[key] += 1 + if not (interval == 1 or _log_counter[key] % interval == 1): + return + if backend == "auto": + if wandb is None: + backend = "file" + else: + backend = "wandb" + + if backend == "wandb": + if wandb is None: + raise ImportError("wandb is not installed") + assert ( + wandb.run is not None + ), "wandb has not been initialized, call `wandb.init` first`" + + elif backend != "file": + raise TypeError("backend must be file or wandb") + + if filename is None: + filename = f"{_log_counter[key]}.jpg" + if bbox_tag is not None: + bbox_tag = "vis" + if img_norm_cfg is not None: + image = color_transform(image, **img_norm_cfg) + if labels is None: + labels = bboxes.new_zeros(bboxes.shape[0]).long() + class_names = ["foreground"] + if backend == "wandb": + im = {} + im["data_or_path"] = image + im["boxes"] = convert_box( + bbox_tag, bboxes, labels, class_names, scores=scores, std=image.shape[:2] + ) + wandb.log({tag: wandb.Image(**im)}, commit=False) + elif backend == "file": + root_dir = os.environ.get("WORK_DIR", ".") + + imshow_det_bboxes( + image, + bboxes.cpu().detach().numpy(), + labels.cpu().detach().numpy(), + class_names=class_names, + show=False, + out_file=os.path.join(root_dir, tag, bbox_tag, filename), + ) + else: + raise TypeError("backend must be file or wandb") + + +def log_every_n(msg: str, n: int = 50, level: int = logging.DEBUG, backend="auto"): + """ + Args: + msg (Any): + n (int): + level (int): + name (str): + """ + caller_module, key = _find_caller() + _log_counter[key] += 1 + if n == 1 or _log_counter[key] % n == 1: + if isinstance(msg, dict) and (wandb is not None) and (wandb.run is not None): + wandb.log(msg, commit=False) + else: + get_root_logger().log(level, msg) diff --git a/ssod/utils/patch.py b/ssod/utils/patch.py new file mode 100644 index 0000000..ddad8ab --- /dev/null +++ b/ssod/utils/patch.py @@ -0,0 +1,81 @@ +import glob +import os +import os.path as osp +import shutil +import types + +from mmcv.runner import BaseRunner, EpochBasedRunner, IterBasedRunner +from mmcv.utils import Config + +from .signature import parse_method_info +from .vars import resolve + + +def find_latest_checkpoint(path, ext="pth"): + if not osp.exists(path): + return None + if osp.exists(osp.join(path, f"latest.{ext}")): + return osp.join(path, f"latest.{ext}") + + checkpoints = glob.glob(osp.join(path, f"*.{ext}")) + if len(checkpoints) == 0: + return None + latest = -1 + latest_path = None + for checkpoint in checkpoints: + count = int(osp.basename(checkpoint).split("_")[-1].split(".")[0]) + if count > latest: + latest = count + latest_path = checkpoint + return latest_path + + +def patch_checkpoint(runner: BaseRunner): + # patch save_checkpoint + old_save_checkpoint = runner.save_checkpoint + params = parse_method_info(old_save_checkpoint) + default_tmpl = params["filename_tmpl"].default + + def save_checkpoint(self, out_dir, **kwargs): + create_symlink = kwargs.get("create_symlink", True) + filename_tmpl = kwargs.get("filename_tmpl", default_tmpl) + # create_symlink + kwargs.update(create_symlink=False) + old_save_checkpoint(out_dir, **kwargs) + if create_symlink: + dst_file = osp.join(out_dir, "latest.pth") + if isinstance(self, EpochBasedRunner): + filename = filename_tmpl.format(self.epoch + 1) + elif isinstance(self, IterBasedRunner): + filename = filename_tmpl.format(self.iter + 1) + else: + raise NotImplementedError() + filepath = osp.join(out_dir, filename) + shutil.copy(filepath, dst_file) + + runner.save_checkpoint = types.MethodType(save_checkpoint, runner) + return runner + + +def patch_runner(runner): + runner = patch_checkpoint(runner) + return runner + + +def setup_env(cfg): + os.environ["WORK_DIR"] = cfg.work_dir + + +def patch_config(cfg): + + cfg_dict = super(Config, cfg).__getattribute__("_cfg_dict").to_dict() + cfg_dict["cfg_name"] = osp.splitext(osp.basename(cfg.filename))[0] + cfg_dict = resolve(cfg_dict) + cfg = Config(cfg_dict, filename=cfg.filename) + # wrap for semi + if cfg.get("semi_wrapper", None) is not None: + cfg.model = cfg.semi_wrapper + cfg.pop("semi_wrapper") + # enable enviroment variables + setup_env(cfg) + return cfg diff --git a/ssod/utils/signature.py b/ssod/utils/signature.py new file mode 100644 index 0000000..abe83e0 --- /dev/null +++ b/ssod/utils/signature.py @@ -0,0 +1,7 @@ +import inspect + + +def parse_method_info(method): + sig = inspect.signature(method) + params = sig.parameters + return params diff --git a/ssod/utils/structure_utils.py b/ssod/utils/structure_utils.py new file mode 100644 index 0000000..fc48067 --- /dev/null +++ b/ssod/utils/structure_utils.py @@ -0,0 +1,153 @@ +import warnings +from collections import Counter, Mapping, Sequence +from numbers import Number +from typing import Dict, List + +import numpy as np +import torch +from mmdet.core.mask.structures import BitmapMasks +from torch.nn import functional as F + +_step_counter = Counter() + + +def list_concat(data_list: List[list]): + if isinstance(data_list[0], torch.Tensor): + return torch.cat(data_list) + else: + endpoint = [d for d in data_list[0]] + + for i in range(1, len(data_list)): + endpoint.extend(data_list[i]) + return endpoint + + +def sequence_concat(a, b): + if isinstance(a, Sequence) and isinstance(b, Sequence): + return a + b + else: + return None + + +def dict_concat(dicts: List[Dict[str, list]]): + return {k: list_concat([d[k] for d in dicts]) for k in dicts[0].keys()} + + +def dict_fuse(obj_list, reference_obj): + if isinstance(reference_obj, torch.Tensor): + return torch.stack(obj_list) + return obj_list + + +def dict_select(dict1: Dict[str, list], key: str, value: str): + flag = [v == value for v in dict1[key]] + return { + k: dict_fuse([vv for vv, ff in zip(v, flag) if ff], v) for k, v in dict1.items() + } + + +def dict_split(dict1, key): + group_names = list(set(dict1[key])) + dict_groups = {k: dict_select(dict1, key, k) for k in group_names} + + return dict_groups + + +def dict_sum(a, b): + if isinstance(a, dict): + assert isinstance(b, dict) + return {k: dict_sum(v, b[k]) for k, v in a.items()} + elif isinstance(a, list): + assert len(a) == len(b) + return [dict_sum(aa, bb) for aa, bb in zip(a, b)] + else: + return a + b + + +def zero_like(tensor_pack, prefix=""): + if isinstance(tensor_pack, Sequence): + return [zero_like(t) for t in tensor_pack] + elif isinstance(tensor_pack, Mapping): + return {prefix + k: zero_like(v) for k, v in tensor_pack.items()} + elif isinstance(tensor_pack, torch.Tensor): + return tensor_pack.new_zeros(tensor_pack.shape) + elif isinstance(tensor_pack, np.ndarray): + return np.zeros_like(tensor_pack) + else: + warnings.warn("Unexpected data type {}".format(type(tensor_pack))) + return 0 + + +def pad_stack(tensors, shape, pad_value=255): + tensors = torch.stack( + [ + F.pad( + tensor, + pad=[0, shape[1] - tensor.shape[1], 0, shape[0] - tensor.shape[0]], + value=pad_value, + ) + for tensor in tensors + ] + ) + return tensors + + +def result2bbox(result): + num_class = len(result) + + bbox = np.concatenate(result) + if bbox.shape[0] == 0: + label = np.zeros(0, dtype=np.uint8) + else: + label = np.concatenate( + [[i] * len(result[i]) for i in range(num_class) if len(result[i]) > 0] + ).reshape((-1,)) + return bbox, label + + +def result2mask(result): + num_class = len(result) + mask = [np.stack(result[i]) for i in range(num_class) if len(result[i]) > 0] + if len(mask) > 0: + mask = np.concatenate(mask) + else: + mask = np.zeros((0, 1, 1)) + return BitmapMasks(mask, mask.shape[1], mask.shape[2]), None + + +def sequence_mul(obj, multiplier): + if isinstance(obj, Sequence): + return [o * multiplier for o in obj] + else: + return obj * multiplier + + +def is_match(word, word_list): + for keyword in word_list: + if keyword in word: + return True + return False + + +def weighted_loss(loss: dict, weight, ignore_keys=[], warmup=0): + _step_counter["weight"] += 1 + lambda_weight = ( + lambda x: x * (_step_counter["weight"] - 1) / warmup + if _step_counter["weight"] <= warmup + else x + ) + if isinstance(weight, Mapping): + for k, v in weight.items(): + for name, loss_item in loss.items(): + if (k in name) and ("loss" in name): + loss[name] = sequence_mul(loss[name], lambda_weight(v)) + elif isinstance(weight, Number): + for name, loss_item in loss.items(): + if "loss" in name: + if not is_match(name, ignore_keys): + loss[name] = sequence_mul(loss[name], lambda_weight(weight)) + else: + loss[name] = sequence_mul(loss[name], 0.0) + else: + raise NotImplementedError() + return loss diff --git a/ssod/utils/vars.py b/ssod/utils/vars.py new file mode 100644 index 0000000..49da8fd --- /dev/null +++ b/ssod/utils/vars.py @@ -0,0 +1,35 @@ +import re +from typing import Union + +pattern = re.compile("\$\{[a-zA-Z\d_.]*\}") + + +def get_value(cfg: dict, chained_key: str): + keys = chained_key.split(".") + if len(keys) == 1: + return cfg[keys[0]] + else: + return get_value(cfg[keys[0]], ".".join(keys[1:])) + + +def resolve(cfg: Union[dict, list], base=None): + if base is None: + base = cfg + if isinstance(cfg, dict): + return {k: resolve(v, base) for k, v in cfg.items()} + elif isinstance(cfg, list): + return [resolve(v, base) for v in cfg] + elif isinstance(cfg, tuple): + return tuple([resolve(v, base) for v in cfg]) + elif isinstance(cfg, str): + # process + var_names = pattern.findall(cfg) + if len(var_names) == 1 and len(cfg) == len(var_names[0]): + return get_value(base, var_names[0][2:-1]) + else: + vars = [get_value(base, name[2:-1]) for name in var_names] + for name, var in zip(var_names, vars): + cfg = cfg.replace(name, str(var)) + return cfg + else: + return cfg diff --git a/ssod/version.py b/ssod/version.py new file mode 100644 index 0000000..8d885d2 --- /dev/null +++ b/ssod/version.py @@ -0,0 +1,3 @@ +__version__ = "0.0.1" + +__all__ = ["__version__"] diff --git a/tools/dataset/prepare_coco_data.sh b/tools/dataset/prepare_coco_data.sh new file mode 100644 index 0000000..b2cdb32 --- /dev/null +++ b/tools/dataset/prepare_coco_data.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + + +help() { + echo "Usage: $0 [option...] download|conduct|fulll" + echo "download download coco dataset" + echo "conduct conduct data split for semi supervised training and evaluation" + echo "option:" + echo " -r, --root [PATH] select the root path of dataset. The default dataset root is ssod/data" +} +download() { + mkdir -p coco + for split in train2017 val2017 unlabeled2017; + do + wget http://images.cocodataset.org/zips/${split}.zip; + unzip ${split}.zip + done + wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip + unzip annotations_trainval2017.zip +} +conduct() { + OFFSET=$RANDOM + for percent in 1 5 10; do + for fold in 1 2 3 4 5; do + python tools/dataset/semi_coco.py --percent ${percent} --seed ${fold} --data-dir "${data_root}"/coco --seed-offset ${OFFSET} + done + done +} + +data_root=data +ROOT=$(dirname "$0")/../.. + +cd "${ROOT}" + +case $1 in + -r | --root) + data_root=$2 + shift 2 + ;; +esac +mkdir -p ${data_root} +case $1 in + download) + cd ${data_root} + download + ;; + conduct) + conduct + ;; + full) + cd ${data_root} + download + cd .. + conduct + ;; + *) + help + exit 0 + ;; +esac diff --git a/tools/dataset/semi_coco.py b/tools/dataset/semi_coco.py new file mode 100644 index 0000000..3ef0b9b --- /dev/null +++ b/tools/dataset/semi_coco.py @@ -0,0 +1,121 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/bin/bash +"""Generate labeled and unlabeled dataset for coco train. + +Example: +python tools/coco_semi.py +""" + +import argparse +import numpy as np +import json +import os + + +def prepare_coco_data(seed=1, percent=10.0, version=2017, seed_offset=0): + """Prepare COCO dataset for Semi-supervised learning + Args: + seed: random seed for dataset split + percent: percentage of labeled dataset + version: COCO dataset version + """ + + def _save_anno(name, images, annotations): + """Save annotation.""" + print( + ">> Processing dataset {}.json saved ({} images {} annotations)".format( + name, len(images), len(annotations) + ) + ) + new_anno = {} + new_anno["images"] = images + new_anno["annotations"] = annotations + new_anno["licenses"] = anno["licenses"] + new_anno["categories"] = anno["categories"] + new_anno["info"] = anno["info"] + path = "{}/{}".format(COCOANNODIR, "semi_supervised") + if not os.path.exists(path): + os.mkdir(path) + + with open( + "{root}/{folder}/{save_name}.json".format( + save_name=name, root=COCOANNODIR, folder="semi_supervised" + ), + "w", + ) as f: + json.dump(new_anno, f) + print( + ">> Data {}.json saved ({} images {} annotations)".format( + name, len(images), len(annotations) + ) + ) + + np.random.seed(seed + seed_offset) + COCOANNODIR = os.path.join(DATA_DIR, "annotations") + + anno = json.load( + open(os.path.join(COCOANNODIR, "instances_train{}.json".format(version))) + ) + + image_list = anno["images"] + labeled_tot = int(percent / 100.0 * len(image_list)) + labeled_ind = np.random.choice( + range(len(image_list)), size=labeled_tot, replace=False + ) + labeled_id = [] + labeled_images = [] + unlabeled_images = [] + labeled_ind = set(labeled_ind) + for i in range(len(image_list)): + if i in labeled_ind: + labeled_images.append(image_list[i]) + labeled_id.append(image_list[i]["id"]) + else: + unlabeled_images.append(image_list[i]) + + # get all annotations of labeled images + labeled_id = set(labeled_id) + labeled_annotations = [] + unlabeled_annotations = [] + for an in anno["annotations"]: + if an["image_id"] in labeled_id: + labeled_annotations.append(an) + else: + unlabeled_annotations.append(an) + + # save labeled and unlabeled + save_name = "instances_train{version}.{seed}@{tot}".format( + version=version, seed=seed, tot=int(percent) + ) + _save_anno(save_name, labeled_images, labeled_annotations) + save_name = "instances_train{version}.{seed}@{tot}-unlabeled".format( + version=version, seed=seed, tot=int(percent) + ) + _save_anno(save_name, unlabeled_images, unlabeled_annotations) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--data-dir", type=str) + parser.add_argument("--percent", type=float, default=10) + parser.add_argument("--version", type=int, default=2017) + parser.add_argument("--seed", type=int, help="seed", default=1) + parser.add_argument("--seed-offset", type=int, default=0) + args = parser.parse_args() + print(args) + DATA_DIR = args.data_dir + prepare_coco_data(args.seed, args.percent, args.version, args.seed_offset) diff --git a/tools/dataset/semi_coco.sh b/tools/dataset/semi_coco.sh new file mode 100644 index 0000000..2d28e67 --- /dev/null +++ b/tools/dataset/semi_coco.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -x +OFFSET=$RANDOM +for percent in 1 5 10; do + for fold in 1 2 3 4 5; do + $(dirname "$0")/coco_semi.py --percent ${percent} --seed ${fold} --data-dir $1 --seed-offset ${OFFSET} + done +done diff --git a/tools/dist_test.sh b/tools/dist_test.sh new file mode 100755 index 0000000..3c74ec6 --- /dev/null +++ b/tools/dist_test.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 +PORT=${PORT:-29500} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} diff --git a/tools/dist_train.sh b/tools/dist_train.sh new file mode 100755 index 0000000..5b43fff --- /dev/null +++ b/tools/dist_train.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +CONFIG=$1 +GPUS=$2 +PORT=${PORT:-29500} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} diff --git a/tools/dist_train_partially.sh b/tools/dist_train_partially.sh new file mode 100644 index 0000000..465d3b3 --- /dev/null +++ b/tools/dist_train_partially.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +set -x + +TYPE=$1 +FOLD=$2 +PERCENT=$3 +GPUS=$4 +PORT=${PORT:-29500} + + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH + +if [[ ${TYPE} == 'baseline' ]]; then + python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/train.py configs/baseline/faster_rcnn_r50_caffe_fpn_coco_partial_180k.py --launcher pytorch \ + --cfg-options fold=${FOLD} percent=${PERCENT} ${@:5} +else + python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/train.py configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_180k.py --launcher pytorch \ + --cfg-options fold=${FOLD} percent=${PERCENT} ${@:5} +fi diff --git a/tools/misc/browse_dataset.py b/tools/misc/browse_dataset.py new file mode 100644 index 0000000..041f98c --- /dev/null +++ b/tools/misc/browse_dataset.py @@ -0,0 +1,173 @@ +import argparse +import os +from pathlib import Path + +import mmcv +import torch +from mmcv import Config, DictAction +from mmdet.core.utils import mask2ndarray +from mmdet.core.visualization import imshow_det_bboxes + +from ssod.datasets import build_dataset +from ssod.models.utils import Transform2D + + +def parse_args(): + parser = argparse.ArgumentParser(description="Browse a dataset") + parser.add_argument("config", help="train config file path") + parser.add_argument( + "--skip-type", + type=str, + nargs="+", + default=["DefaultFormatBundle", "Normalize", "Collect"], + help="skip some useless pipeline", + ) + parser.add_argument( + "--output-dir", + default=None, + type=str, + help="If there is no display interface, you can save it", + ) + parser.add_argument("--not-show", default=False, action="store_true") + parser.add_argument( + "--show-interval", type=float, default=2, help="the interval of show (s)" + ) + parser.add_argument( + "--cfg-options", + nargs="+", + action=DictAction, + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) + args = parser.parse_args() + return args + + +def remove_pipe(pipelines, skip_type): + if isinstance(pipelines, list): + new_pipelines = [] + for pipe in pipelines: + pipe = remove_pipe(pipe, skip_type) + if pipe is not None: + new_pipelines.append(pipe) + return new_pipelines + elif isinstance(pipelines, dict): + if pipelines["type"] in skip_type: + return None + elif pipelines["type"] == "MultiBranch": + new_pipelines = {} + for k, v in pipelines.items(): + if k != "type": + new_pipelines[k] = remove_pipe(v, skip_type) + else: + new_pipelines[k] = v + return new_pipelines + else: + return pipelines + else: + raise NotImplementedError() + + +def retrieve_data_cfg(config_path, skip_type, cfg_options): + cfg = Config.fromfile(config_path) + if cfg_options is not None: + cfg.merge_from_dict(cfg_options) + # import modules from string list. + if cfg.get("custom_imports", None): + from mmcv.utils import import_modules_from_strings + + import_modules_from_strings(**cfg["custom_imports"]) + train_data_cfg = cfg.data.train + while "dataset" in train_data_cfg: + train_data_cfg = train_data_cfg["dataset"] + train_data_cfg["pipeline"] = remove_pipe(train_data_cfg["pipeline"], skip_type) + return cfg + + +def main(): + args = parse_args() + cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options) + + dataset = build_dataset(cfg.data.train) + + progress_bar = mmcv.ProgressBar(len(dataset)) + + for item in dataset: + if not isinstance(item, list): + item = [item] + bboxes = [] + labels = [] + tran_mats = [] + out_shapes = [] + for it in item: + trans_matrix = it["transform_matrix"] + bbox = it["gt_bboxes"] + tran_mats.append(trans_matrix) + bboxes.append(bbox) + labels.append(it["gt_labels"]) + out_shapes.append(it["img_shape"]) + + filename = ( + os.path.join(args.output_dir, Path(it["filename"]).name) + if args.output_dir is not None + else None + ) + + gt_masks = it.get("gt_masks", None) + if gt_masks is not None: + gt_masks = mask2ndarray(gt_masks) + + imshow_det_bboxes( + it["img"], + it["gt_bboxes"], + it["gt_labels"], + gt_masks, + class_names=dataset.CLASSES, + show=not args.not_show, + wait_time=args.show_interval, + out_file=filename, + bbox_color=(255, 102, 61), + text_color=(255, 102, 61), + ) + + if len(tran_mats) == 2: + # check equality between different augmentation + transed_bboxes = Transform2D.transform_bboxes( + torch.from_numpy(bboxes[1]).float(), + torch.from_numpy(tran_mats[0]).float() + @ torch.from_numpy(tran_mats[1]).float().inverse(), + out_shapes[0], + ) + img = imshow_det_bboxes( + item[0]["img"], + item[0]["gt_bboxes"], + item[0]["gt_labels"], + class_names=dataset.CLASSES, + show=False, + wait_time=args.show_interval, + out_file=None, + bbox_color=(255, 102, 61), + text_color=(255, 102, 61), + ) + imshow_det_bboxes( + img, + transed_bboxes.numpy(), + labels[1], + class_names=dataset.CLASSES, + show=True, + wait_time=args.show_interval, + out_file=None, + bbox_color=(0, 0, 255), + text_color=(0, 0, 255), + thickness=5, + ) + + progress_bar.update() + + +if __name__ == "__main__": + main() diff --git a/tools/test.py b/tools/test.py new file mode 100644 index 0000000..f09e298 --- /dev/null +++ b/tools/test.py @@ -0,0 +1,261 @@ +import argparse +import os +import os.path as osp +import time +import warnings + +import mmcv +import torch +from mmcv import Config, DictAction +from mmcv.cnn import fuse_conv_bn +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import get_dist_info, init_dist, load_checkpoint, wrap_fp16_model +from mmdet.apis import multi_gpu_test, single_gpu_test +from mmdet.datasets import build_dataloader, build_dataset, replace_ImageToTensor +from mmdet.models import build_detector + +from ssod.utils import patch_config + + +def parse_args(): + parser = argparse.ArgumentParser(description="MMDet test (and eval) a model") + parser.add_argument("config", help="test config file path") + parser.add_argument("checkpoint", help="checkpoint file") + parser.add_argument( + "--work-dir", + help="the directory to save the file containing evaluation metrics", + ) + parser.add_argument("--out", help="output result file in pickle format") + parser.add_argument( + "--fuse-conv-bn", + action="store_true", + help="Whether to fuse conv and bn, this will slightly increase" + "the inference speed", + ) + parser.add_argument( + "--format-only", + action="store_true", + help="Format the output results without perform evaluation. It is" + "useful when you want to format the result to a specific format and " + "submit it to the test server", + ) + parser.add_argument( + "--eval", + type=str, + nargs="+", + help='evaluation metrics, which depends on the dataset, e.g., "bbox",' + ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC', + ) + parser.add_argument("--show", action="store_true", help="show results") + parser.add_argument( + "--show-dir", help="directory where painted images will be saved" + ) + parser.add_argument( + "--show-score-thr", + type=float, + default=0.3, + help="score threshold (default: 0.3)", + ) + parser.add_argument( + "--gpu-collect", + action="store_true", + help="whether to use gpu to collect results.", + ) + parser.add_argument( + "--tmpdir", + help="tmp directory used for collecting results from multiple " + "workers, available when gpu-collect is not specified", + ) + parser.add_argument( + "--cfg-options", + nargs="+", + action=DictAction, + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) + parser.add_argument( + "--options", + nargs="+", + action=DictAction, + help="custom options for evaluation, the key-value pair in xxx=yyy " + "format will be kwargs for dataset.evaluate() function (deprecate), " + "change to --eval-options instead.", + ) + parser.add_argument( + "--eval-options", + nargs="+", + action=DictAction, + help="custom options for evaluation, the key-value pair in xxx=yyy " + "format will be kwargs for dataset.evaluate() function", + ) + parser.add_argument( + "--launcher", + choices=["none", "pytorch", "slurm", "mpi"], + default="none", + help="job launcher", + ) + parser.add_argument("--local_rank", type=int, default=0) + args = parser.parse_args() + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(args.local_rank) + + if args.options and args.eval_options: + raise ValueError( + "--options and --eval-options cannot be both " + "specified, --options is deprecated in favor of --eval-options" + ) + if args.options: + warnings.warn("--options is deprecated in favor of --eval-options") + args.eval_options = args.options + return args + + +def main(): + args = parse_args() + + assert args.out or args.eval or args.format_only or args.show or args.show_dir, ( + "Please specify at least one operation (save/eval/format/show the " + 'results / save the results) with the argument "--out", "--eval"' + ', "--format-only", "--show" or "--show-dir"' + ) + + if args.eval and args.format_only: + raise ValueError("--eval and --format_only cannot be both specified") + + if args.out is not None and not args.out.endswith((".pkl", ".pickle")): + raise ValueError("The output file must be a pkl file.") + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # import modules from string list. + if cfg.get("custom_imports", None): + from mmcv.utils import import_modules_from_strings + + import_modules_from_strings(**cfg["custom_imports"]) + # set cudnn_benchmark + if cfg.get("cudnn_benchmark", False): + torch.backends.cudnn.benchmark = True + + cfg.model.pretrained = None + if cfg.model.get("neck"): + if isinstance(cfg.model.neck, list): + for neck_cfg in cfg.model.neck: + if neck_cfg.get("rfp_backbone"): + if neck_cfg.rfp_backbone.get("pretrained"): + neck_cfg.rfp_backbone.pretrained = None + elif cfg.model.neck.get("rfp_backbone"): + if cfg.model.neck.rfp_backbone.get("pretrained"): + cfg.model.neck.rfp_backbone.pretrained = None + + # in case the test dataset is concatenated + samples_per_gpu = 1 + if isinstance(cfg.data.test, dict): + cfg.data.test.test_mode = True + samples_per_gpu = cfg.data.test.pop("samples_per_gpu", 1) + if samples_per_gpu > 1: + # Replace 'ImageToTensor' to 'DefaultFormatBundle' + cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) + elif isinstance(cfg.data.test, list): + for ds_cfg in cfg.data.test: + ds_cfg.test_mode = True + samples_per_gpu = max( + [ds_cfg.pop("samples_per_gpu", 1) for ds_cfg in cfg.data.test] + ) + if samples_per_gpu > 1: + for ds_cfg in cfg.data.test: + ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) + + # init distributed env first, since logger depends on the dist info. + if args.launcher == "none": + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + + rank, _ = get_dist_info() + # allows not to create + if args.work_dir is not None and rank == 0: + cfg.work_dir = args.work_dir + mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) + timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) + json_file = osp.join(args.work_dir, f"eval_{timestamp}.json") + elif cfg.get("work_dir", None) is None: + cfg.work_dir = osp.join( + "./work_dirs", osp.splitext(osp.basename(args.config))[0] + ) + cfg = patch_config(cfg) + # build the dataloader + dataset = build_dataset(cfg.data.test) + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False, + ) + + # build the model and load checkpoint + cfg.model.train_cfg = None + model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg")) + fp16_cfg = cfg.get("fp16", None) + if fp16_cfg is not None: + wrap_fp16_model(model) + checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu") + if args.fuse_conv_bn: + model = fuse_conv_bn(model) + # old versions did not save class info in checkpoints, this walkaround is + # for backward compatibility + if "CLASSES" in checkpoint.get("meta", {}): + model.CLASSES = checkpoint["meta"]["CLASSES"] + else: + model.CLASSES = dataset.CLASSES + + if not distributed: + model = MMDataParallel(model, device_ids=[0]) + outputs = single_gpu_test( + model, data_loader, args.show, args.show_dir, args.show_score_thr + ) + else: + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + ) + outputs = multi_gpu_test(model, data_loader, args.tmpdir, args.gpu_collect) + + rank, _ = get_dist_info() + if rank == 0: + if args.out: + print(f"\nwriting results to {args.out}") + mmcv.dump(outputs, args.out) + kwargs = {} if args.eval_options is None else args.eval_options + if args.format_only: + dataset.format_results(outputs, **kwargs) + if args.eval: + eval_kwargs = cfg.get("evaluation", {}).copy() + # hard-code way to remove EvalHook args + for key in [ + "type", + "interval", + "tmpdir", + "start", + "gpu_collect", + "save_best", + "rule", + ]: + eval_kwargs.pop(key, None) + eval_kwargs.update(dict(metric=args.eval, **kwargs)) + metric = dataset.evaluate(outputs, **eval_kwargs) + print(metric) + metric_dict = dict(config=args.config, metric=metric) + if args.work_dir is not None and rank == 0: + mmcv.dump(metric_dict, json_file) + + +if __name__ == "__main__": + main() diff --git a/tools/train.py b/tools/train.py new file mode 100644 index 0000000..2e741cf --- /dev/null +++ b/tools/train.py @@ -0,0 +1,198 @@ +import argparse +import copy +import os +import os.path as osp +import time +import warnings +from logging import log + +import mmcv +import torch +from mmcv import Config, DictAction +from mmcv.runner import get_dist_info, init_dist +from mmcv.utils import get_git_hash +from mmdet import __version__ +from mmdet.models import build_detector +from mmdet.utils import collect_env + +from ssod.apis import get_root_logger, set_random_seed, train_detector +from ssod.datasets import build_dataset +from ssod.utils import patch_config + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train a detector") + parser.add_argument("config", help="train config file path") + parser.add_argument("--work-dir", help="the dir to save logs and models") + parser.add_argument("--resume-from", help="the checkpoint file to resume from") + parser.add_argument( + "--no-validate", + action="store_true", + help="whether not to evaluate the checkpoint during training", + ) + group_gpus = parser.add_mutually_exclusive_group() + group_gpus.add_argument( + "--gpus", + type=int, + help="number of gpus to use " "(only applicable to non-distributed training)", + ) + group_gpus.add_argument( + "--gpu-ids", + type=int, + nargs="+", + help="ids of gpus to use " "(only applicable to non-distributed training)", + ) + parser.add_argument("--seed", type=int, default=None, help="random seed") + parser.add_argument( + "--deterministic", + action="store_true", + help="whether to set deterministic options for CUDNN backend.", + ) + parser.add_argument( + "--options", + nargs="+", + action=DictAction, + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + parser.add_argument( + "--cfg-options", + nargs="+", + action=DictAction, + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) + parser.add_argument( + "--launcher", + choices=["none", "pytorch", "slurm", "mpi"], + default="none", + help="job launcher", + ) + parser.add_argument("--local_rank", type=int, default=0) + args = parser.parse_args() + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(args.local_rank) + + if args.options and args.cfg_options: + raise ValueError( + "--options and --cfg-options cannot be both " + "specified, --options is deprecated in favor of --cfg-options" + ) + if args.options: + warnings.warn("--options is deprecated in favor of --cfg-options") + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # import modules from string list. + if cfg.get("custom_imports", None): + from mmcv.utils import import_modules_from_strings + + import_modules_from_strings(**cfg["custom_imports"]) + # set cudnn_benchmark + if cfg.get("cudnn_benchmark", False): + torch.backends.cudnn.benchmark = True + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get("work_dir", None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join( + "./work_dirs", osp.splitext(osp.basename(args.config))[0] + ) + cfg = patch_config(cfg) + if args.resume_from is not None: + cfg.resume_from = args.resume_from + if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids + else: + cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) + + # init distributed env first, since logger depends on the dist info. + if args.launcher == "none": + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + # re-set gpu_ids with distributed training mode + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + # create work_dir + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + # dump config + cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) + # init the logger before other steps + timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) + log_file = osp.join(cfg.work_dir, f"{timestamp}.log") + logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) + + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = "\n".join([(f"{k}: {v}") for k, v in env_info_dict.items()]) + dash_line = "-" * 60 + "\n" + logger.info(logger.handlers) + logger.info("Environment info:\n" + dash_line + env_info + "\n" + dash_line) + meta["env_info"] = env_info + meta["config"] = cfg.pretty_text + # log some basic info + logger.info(f"Distributed training: {distributed}") + logger.info(f"Config:\n{cfg.pretty_text}") + + # set random seeds + if args.seed is not None: + logger.info( + f"Set random seed to {args.seed}, " f"deterministic: {args.deterministic}" + ) + set_random_seed(args.seed, deterministic=args.deterministic) + cfg.seed = args.seed + meta["seed"] = args.seed + meta["exp_name"] = osp.basename(args.config) + + model = build_detector( + cfg.model, train_cfg=cfg.get("train_cfg"), test_cfg=cfg.get("test_cfg") + ) + model.init_weights() + + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + val_dataset.pipeline = cfg.data.train.pipeline + datasets.append(build_dataset(val_dataset)) + if cfg.checkpoint_config is not None: + # save mmdet version, config file content and class names in + # checkpoints as meta data + cfg.checkpoint_config.meta = dict( + mmdet_version=__version__ + get_git_hash()[:7], CLASSES=datasets[0].CLASSES + ) + # add an attribute for visualization convenience + model.CLASSES = datasets[0].CLASSES + train_detector( + model, + datasets, + cfg, + distributed=distributed, + validate=(not args.no_validate), + timestamp=timestamp, + meta=meta, + ) + + +if __name__ == "__main__": + main()