From dd09b38068942981547328fcf60c980a650dc57c Mon Sep 17 00:00:00 2001 From: MendelXu Date: Thu, 16 Sep 2021 11:54:56 +0800 Subject: [PATCH] add image demo script --- README.md | 29 +++++++-- demo/image_demo.py | 89 ++++++++++++++++++++++++++++ ssod/apis/inference.py | 81 +++++++++++++++++++++++++ ssod/models/multi_stream_detector.py | 39 +++--------- 4 files changed, 202 insertions(+), 36 deletions(-) create mode 100644 demo/image_demo.py create mode 100644 ssod/apis/inference.py diff --git a/README.md b/README.md index 1a8cfad..27ccea3 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ make install # train2017/ # val2017/ # unlabeled2017/ -# annotations/ +# annotations/ ln -s ${YOUR_DATA} data bash tools/dataset/prepare_coco_data.sh conduct @@ -141,13 +141,34 @@ bash tools/dist_train.sh configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe -### Inference +### Evaluation ``` bash tools/dist_test.sh --eval bbox --cfg-options model.test_cfg.rcnn.score_thr= ``` +### Inference + To inference with trained model and visualize the detection results: -[1] [A Simple Semi-Supervised Learning Framework for Object Detection](https://arxiv.org/pdf/2005.04757.pdf) + ```shell script + # [IMAGE_FILE_PATH]: the path of your image file in local file system + # [CONFIG_FILE]: the path of a confile file + # [CHECKPOINT_PATH]: the path of a trained model related to provided confilg file. + # [OUTPUT_PATH]: the directory to save detection result + python demo/image_demo.py [IMAGE_FILE_PATH] [CONFIG_FILE] [CHECKPOINT_PATH] --output [OUTPUT_PATH] + ``` + For example: + - Inference on single image with provided `R50` model: + ```shell script + python demo/image_demo.py /tmp/tmp.png configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py work_dirs/downloaded.model --output work_dirs/ + ``` + After the program completes, a image with the same name as input will be saved to `work_dirs` -[2] [Instant-Teaching: An End-to-End Semi-SupervisedObject Detection Framework](https://arxiv.org/pdf/2103.11402.pdf) + - Inference on many images with provided `R50` model: + ```shell script + python demo/image_demo.py '/tmp/*.jpg' configs/soft_teacher/soft_teacher_faster_rcnn_r50_caffe_fpn_coco_full_720k.py work_dirs/downloaded.model --output work_dirs/ + ``` +[1] [A Simple Semi-Supervised Learning Framework for Object Detection](https://arxiv.org/pdf/2005.04757.pdf) + + +[2] [Instant-Teaching: An End-to-End Semi-SupervisedObject Detection Framework](https://arxiv.org/pdf/2103.11402.pdf) diff --git a/demo/image_demo.py b/demo/image_demo.py new file mode 100644 index 0000000..a04971b --- /dev/null +++ b/demo/image_demo.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Modified from thirdparty/mmdetection/demo/image_demo.py +import asyncio +import glob +import os +from argparse import ArgumentParser + +from mmcv import Config +from mmdet.apis import async_inference_detector, inference_detector, show_result_pyplot + +from ssod.apis.inference import init_detector, save_result +from ssod.utils import patch_config + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument("img", help="Image file") + parser.add_argument("config", help="Config file") + parser.add_argument("checkpoint", help="Checkpoint file") + parser.add_argument("--device", default="cuda:0", help="Device used for inference") + parser.add_argument( + "--score-thr", type=float, default=0.3, help="bbox score threshold" + ) + parser.add_argument( + "--async-test", + action="store_true", + help="whether to set async options for async inference.", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="specify the directory to save visualization results.", + ) + args = parser.parse_args() + return args + + +def main(args): + cfg = Config.fromfile(args.config) + # Not affect anything, just avoid index error + cfg.work_dir = "./work_dirs" + cfg = patch_config(cfg) + # build the model from a config file and a checkpoint file + model = init_detector(cfg, args.checkpoint, device=args.device) + imgs = glob.glob(args.img) + for img in imgs: + # test a single image + result = inference_detector(model, img) + # show the results + if args.output is None: + show_result_pyplot(model, img, result, score_thr=args.score_thr) + else: + out_file_path = os.path.join(args.output, os.path.basename(img)) + print(f"Save results to {out_file_path}") + save_result( + model, img, result, score_thr=args.score_thr, out_file=out_file_path + ) + + +async def async_main(args): + cfg = Config.fromfile(args.config) + # Not affect anything, just avoid index error + cfg.work_dir = "./work_dirs" + cfg = patch_config(cfg) + # build the model from a config file and a checkpoint file + model = init_detector(cfg, args.checkpoint, device=args.device) + # test a single image + args.img = glob.glob(args.img) + tasks = asyncio.create_task(async_inference_detector(model, args.img)) + result = await asyncio.gather(tasks) + # show the results + for img, pred in zip(args.img, result): + if args.output is None: + show_result_pyplot(model, img, pred, score_thr=args.score_thr) + else: + out_file_path = os.path.join(args.output, os.path.basename(img)) + print(f"Save results to {out_file_path}") + save_result( + model, img, pred, score_thr=args.score_thr, out_file=out_file_path + ) + + +if __name__ == "__main__": + args = parse_args() + if args.async_test: + asyncio.run(async_main(args)) + else: + main(args) diff --git a/ssod/apis/inference.py b/ssod/apis/inference.py new file mode 100644 index 0000000..6761dd0 --- /dev/null +++ b/ssod/apis/inference.py @@ -0,0 +1,81 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import mmcv +from mmcv.runner import load_checkpoint + +from mmdet.core import get_classes +from mmdet.models import build_detector + + +def init_detector(config, checkpoint=None, device="cuda:0", cfg_options=None): + """Initialize a detector from config file. + + Args: + config (str or :obj:`mmcv.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + cfg_options (dict): Options to override some settings in the used + config. + + Returns: + nn.Module: The constructed detector. + """ + if isinstance(config, str): + config = mmcv.Config.fromfile(config) + elif not isinstance(config, mmcv.Config): + raise TypeError( + "config must be a filename or Config object, " f"but got {type(config)}" + ) + if cfg_options is not None: + config.merge_from_dict(cfg_options) + config.model.train_cfg = None + + if hasattr(config.model, "model"): + config.model.model.pretrained = None + config.model.model.train_cfg = None + else: + config.model.pretrained = None + + model = build_detector(config.model, test_cfg=config.get("test_cfg")) + if checkpoint is not None: + map_loc = "cpu" if device == "cpu" else None + checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc) + if "CLASSES" in checkpoint.get("meta", {}): + model.CLASSES = checkpoint["meta"]["CLASSES"] + else: + warnings.simplefilter("once") + warnings.warn( + "Class names are not saved in the checkpoint's " + "meta data, use COCO classes by default." + ) + model.CLASSES = get_classes("coco") + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +def save_result(model, img, result, score_thr=0.3, out_file="res.png"): + """Save the detection results on the image. + + Args: + model (nn.Module): The loaded detector. + img (str or np.ndarray): Image filename or loaded image. + result (tuple[list] or list): The detection result, can be either + (bbox, segm) or just bbox. + score_thr (float): The threshold to visualize the bboxes and masks. + out_file (str): Specifies where to save the visualization result + """ + if hasattr(model, "module"): + model = model.module + model.show_result( + img, + result, + score_thr=score_thr, + show=False, + out_file=out_file, + bbox_color=(72, 101, 241), + text_color=(72, 101, 241), + ) diff --git a/ssod/models/multi_stream_detector.py b/ssod/models/multi_stream_detector.py index ae153b3..38e8b29 100644 --- a/ssod/models/multi_stream_detector.py +++ b/ssod/models/multi_stream_detector.py @@ -24,12 +24,14 @@ def model(self, **kwargs) -> TwoStageDetector: else: model: TwoStageDetector = getattr(self, self.inference_on) return model - def freeze(self,model_ref:str): + + def freeze(self, model_ref: str): assert model_ref in self.submodules model = getattr(self, model_ref) model.eval() for param in model.parameters(): - param.requires_grad=False + param.requires_grad = False + def forward_test(self, imgs, img_metas, **kwargs): return self.model(**kwargs).forward_test(imgs, img_metas, **kwargs) @@ -52,33 +54,6 @@ def simple_test(self, img, img_metas, **kwargs): async def async_simple_test(self, img, img_metas, **kwargs): return self.model(**kwargs).async_simple_test(img, img_metas, **kwargs) - def show_result( - self, - img, - result, - score_thr=0.3, - bbox_color=(72, 101, 241), - text_color=(72, 101, 241), - mask_color=None, - thickness=2, - font_size=13, - win_name="", - show=False, - wait_time=0, - out_file=None, - ): - return self.model().show_result( - self, - img, - result, - score_thr, - bbox_color, - text_color, - mask_color, - thickness, - font_size, - win_name, - show, - wait_time, - out_file, - ) + def show_result(self, *args, **kwargs): + self.model().CLASSES = self.CLASSES + return self.model().show_result(*args, **kwargs)