From a08fd8c2ddc30c5840dc2724aa657b27693e59ff Mon Sep 17 00:00:00 2001 From: liuwenran <448073814@qq.com> Date: Fri, 8 Sep 2023 16:26:40 +0800 Subject: [PATCH] add ut and fix base inferencer --- .../inferencers/base_mmagic_inferencer.py | 8 ++-- .../diffusers_pipeline_inferencer.py | 18 ++++++--- mmagic/models/archs/wrapper.py | 12 ++++++ .../test_diffusers_pipeline_inferencer.py | 40 +++++++++++++++++++ 4 files changed, 69 insertions(+), 9 deletions(-) create mode 100644 tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py diff --git a/mmagic/apis/inferencers/base_mmagic_inferencer.py b/mmagic/apis/inferencers/base_mmagic_inferencer.py index 99bfca2680..8fb93401e1 100644 --- a/mmagic/apis/inferencers/base_mmagic_inferencer.py +++ b/mmagic/apis/inferencers/base_mmagic_inferencer.py @@ -130,10 +130,10 @@ def __call__(self, **kwargs) -> Union[Dict, List[Dict]]: Returns: Union[Dict, List[Dict]]: Results of inference pipeline. """ - if 'extra_parameters' in kwargs.keys(): - if 'infer_with_grad' in kwargs['extra_parameters'].keys(): - if kwargs['extra_parameters']['infer_with_grad']: - results = self.base_call(**kwargs) + if ('extra_parameters' in kwargs.keys() + and 'infer_with_grad' in kwargs['extra_parameters'].keys() + and kwargs['extra_parameters']['infer_with_grad']): + results = self.base_call(**kwargs) else: with torch.no_grad(): results = self.base_call(**kwargs) diff --git a/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py index 5af47ebf13..2143cd64db 100644 --- a/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py +++ b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py @@ -14,16 +14,19 @@ class DiffusersPipelineInferencer(BaseMMagicInferencer): """inferencer that predicts with text2image models.""" func_kwargs = dict( - preprocess=['text', 'negative_prompt'], + preprocess=[ + 'text', 'negative_prompt', 'num_inference_steps', 'height', 'width' + ], forward=[], visualize=['result_out_dir'], postprocess=[]) - extra_parameters = dict(height=None, width=None) - def preprocess(self, text: InputsType, - negative_prompt: InputsType = None) -> Dict: + negative_prompt: InputsType = None, + num_inference_steps: int = 20, + height=None, + width=None) -> Dict: """Process the inputs into a model-feedable format. Args: @@ -35,9 +38,14 @@ def preprocess(self, """ result = self.extra_parameters result['prompt'] = text - if negative_prompt: result['negative_prompt'] = negative_prompt + if num_inference_steps: + result['num_inference_steps'] = num_inference_steps + if height: + result['height'] = height + if width: + result['width'] = width return result diff --git a/mmagic/models/archs/wrapper.py b/mmagic/models/archs/wrapper.py index b985ac231d..ebf141f7c3 100644 --- a/mmagic/models/archs/wrapper.py +++ b/mmagic/models/archs/wrapper.py @@ -183,6 +183,18 @@ def to( torch_device: Optional[Union[str, torch.device]] = None, torch_dtype: Optional[torch.dtype] = None, ): + """Put wrapped module to device or convert it to torch_dtype. There are + two to() function. One is nn.module.to() and the other is + diffusers.pipeline.to(), if both args are passed, + diffusers.pipeline.to() is called. + + Args: + torch_device: The device to put to. + torch_dtype: The type to convert to. + + Returns: + self: the wrapped module itself. + """ if torch_dtype is None: self.model.to(torch_device) else: diff --git a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py new file mode 100644 index 0000000000..a16fcd9c74 --- /dev/null +++ b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import platform +from mmengine.utils import digit_version +from mmengine.utils.dl_utils import TORCH_VERSION + +from mmagic.apis.inferencers.diffusers_pipeline_inferencer import \ + DiffusersPipelineInferencer +from mmagic.utils import register_all_modules + +register_all_modules() + + +@pytest.mark.skipif( + 'win' in platform.system().lower() + or digit_version(TORCH_VERSION) <= digit_version('1.8.1'), + reason='skip on windows due to limited RAM' + 'and get_submodule requires torch >= 1.9.0') +def test_diffusers_pipeline_inferencer(): + cfg = dict( + model=dict( + type='DiffusionPipeline', + from_pretrained='runwayml/stable-diffusion-v1-5')) + + inferencer_instance = DiffusersPipelineInferencer(cfg, None) + text_prompts = 'Japanese anime style, girl' + negative_prompt = 'bad face, bad hands' + result = inferencer_instance( + text=text_prompts, + negative_prompt=negative_prompt, + height=128, + width=128) + assert result[1][0].size == (128, 128) + + +def teardown_module(): + import gc + gc.collect() + globals().clear() + locals().clear()