Skip to content

Commit

Permalink
add image demo script
Browse files Browse the repository at this point in the history
  • Loading branch information
MendelXu committed Sep 16, 2021
1 parent 99630f9 commit dd09b38
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 36 deletions.
29 changes: 25 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ make install
# train2017/
# val2017/
# unlabeled2017/
# annotations/
# annotations/
ln -s ${YOUR_DATA} data
bash tools/dataset/prepare_coco_data.sh conduct

Expand Down Expand Up @@ -141,13 +141,34 @@ bash tools/dist_train.sh configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe



### Inference
### Evaluation
```
bash tools/dist_test.sh <CONFIG_FILE_PATH> <CHECKPOINT_PATH> <NUM_GPUS> --eval bbox --cfg-options model.test_cfg.rcnn.score_thr=<THR>
```
### Inference
To inference with trained model and visualize the detection results:

[1] [A Simple Semi-Supervised Learning Framework for Object Detection](https://arxiv.org/pdf/2005.04757.pdf)
```shell script
# [IMAGE_FILE_PATH]: the path of your image file in local file system
# [CONFIG_FILE]: the path of a confile file
# [CHECKPOINT_PATH]: the path of a trained model related to provided confilg file.
# [OUTPUT_PATH]: the directory to save detection result
python demo/image_demo.py [IMAGE_FILE_PATH] [CONFIG_FILE] [CHECKPOINT_PATH] --output [OUTPUT_PATH]
```
For example:
- Inference on single image with provided `R50` model:
```shell script
python demo/image_demo.py /tmp/tmp.png configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py work_dirs/downloaded.model --output work_dirs/
```

After the program completes, a image with the same name as input will be saved to `work_dirs`

[2] [Instant-Teaching: An End-to-End Semi-SupervisedObject Detection Framework](https://arxiv.org/pdf/2103.11402.pdf)
- Inference on many images with provided `R50` model:
```shell script
python demo/image_demo.py '/tmp/*.jpg' configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py work_dirs/downloaded.model --output work_dirs/
```

[1] [A Simple Semi-Supervised Learning Framework for Object Detection](https://arxiv.org/pdf/2005.04757.pdf)


[2] [Instant-Teaching: An End-to-End Semi-SupervisedObject Detection Framework](https://arxiv.org/pdf/2103.11402.pdf)
89 changes: 89 additions & 0 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from thirdparty/mmdetection/demo/image_demo.py
import asyncio
import glob
import os
from argparse import ArgumentParser

from mmcv import Config
from mmdet.apis import async_inference_detector, inference_detector, show_result_pyplot

from ssod.apis.inference import init_detector, save_result
from ssod.utils import patch_config


def parse_args():
parser = ArgumentParser()
parser.add_argument("img", help="Image file")
parser.add_argument("config", help="Config file")
parser.add_argument("checkpoint", help="Checkpoint file")
parser.add_argument("--device", default="cuda:0", help="Device used for inference")
parser.add_argument(
"--score-thr", type=float, default=0.3, help="bbox score threshold"
)
parser.add_argument(
"--async-test",
action="store_true",
help="whether to set async options for async inference.",
)
parser.add_argument(
"--output",
type=str,
default=None,
help="specify the directory to save visualization results.",
)
args = parser.parse_args()
return args


def main(args):
cfg = Config.fromfile(args.config)
# Not affect anything, just avoid index error
cfg.work_dir = "./work_dirs"
cfg = patch_config(cfg)
# build the model from a config file and a checkpoint file
model = init_detector(cfg, args.checkpoint, device=args.device)
imgs = glob.glob(args.img)
for img in imgs:
# test a single image
result = inference_detector(model, img)
# show the results
if args.output is None:
show_result_pyplot(model, img, result, score_thr=args.score_thr)
else:
out_file_path = os.path.join(args.output, os.path.basename(img))
print(f"Save results to {out_file_path}")
save_result(
model, img, result, score_thr=args.score_thr, out_file=out_file_path
)


async def async_main(args):
cfg = Config.fromfile(args.config)
# Not affect anything, just avoid index error
cfg.work_dir = "./work_dirs"
cfg = patch_config(cfg)
# build the model from a config file and a checkpoint file
model = init_detector(cfg, args.checkpoint, device=args.device)
# test a single image
args.img = glob.glob(args.img)
tasks = asyncio.create_task(async_inference_detector(model, args.img))
result = await asyncio.gather(tasks)
# show the results
for img, pred in zip(args.img, result):
if args.output is None:
show_result_pyplot(model, img, pred, score_thr=args.score_thr)
else:
out_file_path = os.path.join(args.output, os.path.basename(img))
print(f"Save results to {out_file_path}")
save_result(
model, img, pred, score_thr=args.score_thr, out_file=out_file_path
)


if __name__ == "__main__":
args = parse_args()
if args.async_test:
asyncio.run(async_main(args))
else:
main(args)
81 changes: 81 additions & 0 deletions ssod/apis/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import mmcv
from mmcv.runner import load_checkpoint

from mmdet.core import get_classes
from mmdet.models import build_detector


def init_detector(config, checkpoint=None, device="cuda:0", cfg_options=None):
"""Initialize a detector from config file.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
checkpoint (str, optional): Checkpoint path. If left as None, the model
will not load any weights.
cfg_options (dict): Options to override some settings in the used
config.
Returns:
nn.Module: The constructed detector.
"""
if isinstance(config, str):
config = mmcv.Config.fromfile(config)
elif not isinstance(config, mmcv.Config):
raise TypeError(
"config must be a filename or Config object, " f"but got {type(config)}"
)
if cfg_options is not None:
config.merge_from_dict(cfg_options)
config.model.train_cfg = None

if hasattr(config.model, "model"):
config.model.model.pretrained = None
config.model.model.train_cfg = None
else:
config.model.pretrained = None

model = build_detector(config.model, test_cfg=config.get("test_cfg"))
if checkpoint is not None:
map_loc = "cpu" if device == "cpu" else None
checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
if "CLASSES" in checkpoint.get("meta", {}):
model.CLASSES = checkpoint["meta"]["CLASSES"]
else:
warnings.simplefilter("once")
warnings.warn(
"Class names are not saved in the checkpoint's "
"meta data, use COCO classes by default."
)
model.CLASSES = get_classes("coco")
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
return model


def save_result(model, img, result, score_thr=0.3, out_file="res.png"):
"""Save the detection results on the image.
Args:
model (nn.Module): The loaded detector.
img (str or np.ndarray): Image filename or loaded image.
result (tuple[list] or list): The detection result, can be either
(bbox, segm) or just bbox.
score_thr (float): The threshold to visualize the bboxes and masks.
out_file (str): Specifies where to save the visualization result
"""
if hasattr(model, "module"):
model = model.module
model.show_result(
img,
result,
score_thr=score_thr,
show=False,
out_file=out_file,
bbox_color=(72, 101, 241),
text_color=(72, 101, 241),
)
39 changes: 7 additions & 32 deletions ssod/models/multi_stream_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ def model(self, **kwargs) -> TwoStageDetector:
else:
model: TwoStageDetector = getattr(self, self.inference_on)
return model
def freeze(self,model_ref:str):

def freeze(self, model_ref: str):
assert model_ref in self.submodules
model = getattr(self, model_ref)
model.eval()
for param in model.parameters():
param.requires_grad=False
param.requires_grad = False

def forward_test(self, imgs, img_metas, **kwargs):

return self.model(**kwargs).forward_test(imgs, img_metas, **kwargs)
Expand All @@ -52,33 +54,6 @@ def simple_test(self, img, img_metas, **kwargs):
async def async_simple_test(self, img, img_metas, **kwargs):
return self.model(**kwargs).async_simple_test(img, img_metas, **kwargs)

def show_result(
self,
img,
result,
score_thr=0.3,
bbox_color=(72, 101, 241),
text_color=(72, 101, 241),
mask_color=None,
thickness=2,
font_size=13,
win_name="",
show=False,
wait_time=0,
out_file=None,
):
return self.model().show_result(
self,
img,
result,
score_thr,
bbox_color,
text_color,
mask_color,
thickness,
font_size,
win_name,
show,
wait_time,
out_file,
)
def show_result(self, *args, **kwargs):
self.model().CLASSES = self.CLASSES
return self.model().show_result(*args, **kwargs)

0 comments on commit dd09b38

Please sign in to comment.