diff --git a/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py index 2143cd64db..573c0750f6 100644 --- a/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py +++ b/mmagic/apis/inferencers/diffusers_pipeline_inferencer.py @@ -22,7 +22,7 @@ class DiffusersPipelineInferencer(BaseMMagicInferencer): postprocess=[]) def preprocess(self, - text: InputsType, + text: InputsType = None, negative_prompt: InputsType = None, num_inference_steps: int = 20, height=None, @@ -37,7 +37,8 @@ def preprocess(self, result(Dict): Results of preprocess. """ result = self.extra_parameters - result['prompt'] = text + if text: + result['prompt'] = text if negative_prompt: result['negative_prompt'] = negative_prompt if num_inference_steps: diff --git a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py index f2ff296e18..2e9285a7a0 100644 --- a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py +++ b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py @@ -2,7 +2,6 @@ import platform import pytest -import torch from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -21,24 +20,11 @@ def test_diffusers_pipeline_inferencer(): cfg = dict( model=dict( - type='DiffusionPipeline', - from_pretrained='runwayml/stable-diffusion-v1-5')) + type='DiffusionPipeline', from_pretrained='google/ddpm-cat-256')) inferencer_instance = DiffusersPipelineInferencer(cfg, None) - - def mock_infer(*args, **kwargs): - return dict(samples=torch.randn(1, 3, 64, 64)) - - inferencer_instance.model.infer = mock_infer - - text_prompts = 'Japanese anime style, girl' - negative_prompt = 'bad face, bad hands' - result = inferencer_instance( - text=text_prompts, - negative_prompt=negative_prompt, - height=64, - width=64) - assert result[1][0].size == (64, 64) + result = inferencer_instance() + assert result[1][0].size == (256, 256) def teardown_module():