Skip to content

Commit

Permalink
update ut
Browse files Browse the repository at this point in the history
  • Loading branch information
liuwenran committed Sep 11, 2023
1 parent 5333cc5 commit 0126598
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 19 deletions.
5 changes: 3 additions & 2 deletions mmagic/apis/inferencers/diffusers_pipeline_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DiffusersPipelineInferencer(BaseMMagicInferencer):
postprocess=[])

def preprocess(self,
text: InputsType,
text: InputsType = None,
negative_prompt: InputsType = None,
num_inference_steps: int = 20,
height=None,
Expand All @@ -37,7 +37,8 @@ def preprocess(self,
result(Dict): Results of preprocess.
"""
result = self.extra_parameters
result['prompt'] = text
if text:
result['prompt'] = text
if negative_prompt:
result['negative_prompt'] = negative_prompt
if num_inference_steps:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import platform

import pytest
import torch
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

Expand All @@ -21,24 +20,11 @@
def test_diffusers_pipeline_inferencer():
cfg = dict(
model=dict(
type='DiffusionPipeline',
from_pretrained='runwayml/stable-diffusion-v1-5'))
type='DiffusionPipeline', from_pretrained='google/ddpm-cat-256'))

inferencer_instance = DiffusersPipelineInferencer(cfg, None)

def mock_infer(*args, **kwargs):
return dict(samples=torch.randn(1, 3, 64, 64))

inferencer_instance.model.infer = mock_infer

text_prompts = 'Japanese anime style, girl'
negative_prompt = 'bad face, bad hands'
result = inferencer_instance(
text=text_prompts,
negative_prompt=negative_prompt,
height=64,
width=64)
assert result[1][0].size == (64, 64)
result = inferencer_instance()
assert result[1][0].size == (256, 256)


def teardown_module():
Expand Down

0 comments on commit 0126598

Please sign in to comment.