Skip to content

Commit

Permalink
add ut and fix base inferencer
Browse files Browse the repository at this point in the history
  • Loading branch information
liuwenran committed Sep 8, 2023
1 parent 83a30e8 commit a08fd8c
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 9 deletions.
8 changes: 4 additions & 4 deletions mmagic/apis/inferencers/base_mmagic_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def __call__(self, **kwargs) -> Union[Dict, List[Dict]]:
Returns:
Union[Dict, List[Dict]]: Results of inference pipeline.
"""
if 'extra_parameters' in kwargs.keys():
if 'infer_with_grad' in kwargs['extra_parameters'].keys():
if kwargs['extra_parameters']['infer_with_grad']:
results = self.base_call(**kwargs)
if ('extra_parameters' in kwargs.keys()
and 'infer_with_grad' in kwargs['extra_parameters'].keys()
and kwargs['extra_parameters']['infer_with_grad']):
results = self.base_call(**kwargs)
else:
with torch.no_grad():
results = self.base_call(**kwargs)
Expand Down
18 changes: 13 additions & 5 deletions mmagic/apis/inferencers/diffusers_pipeline_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@ class DiffusersPipelineInferencer(BaseMMagicInferencer):
"""inferencer that predicts with text2image models."""

func_kwargs = dict(
preprocess=['text', 'negative_prompt'],
preprocess=[
'text', 'negative_prompt', 'num_inference_steps', 'height', 'width'
],
forward=[],
visualize=['result_out_dir'],
postprocess=[])

extra_parameters = dict(height=None, width=None)

def preprocess(self,
text: InputsType,
negative_prompt: InputsType = None) -> Dict:
negative_prompt: InputsType = None,
num_inference_steps: int = 20,
height=None,
width=None) -> Dict:
"""Process the inputs into a model-feedable format.
Args:
Expand All @@ -35,9 +38,14 @@ def preprocess(self,
"""
result = self.extra_parameters
result['prompt'] = text

if negative_prompt:
result['negative_prompt'] = negative_prompt
if num_inference_steps:
result['num_inference_steps'] = num_inference_steps
if height:
result['height'] = height
if width:
result['width'] = width

return result

Expand Down
12 changes: 12 additions & 0 deletions mmagic/models/archs/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,18 @@ def to(
torch_device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
):
"""Put wrapped module to device or convert it to torch_dtype. There are
two to() function. One is nn.module.to() and the other is
diffusers.pipeline.to(), if both args are passed,
diffusers.pipeline.to() is called.
Args:
torch_device: The device to put to.
torch_dtype: The type to convert to.
Returns:
self: the wrapped module itself.
"""
if torch_dtype is None:
self.model.to(torch_device)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import platform
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

from mmagic.apis.inferencers.diffusers_pipeline_inferencer import \
DiffusersPipelineInferencer
from mmagic.utils import register_all_modules

register_all_modules()


@pytest.mark.skipif(
'win' in platform.system().lower()
or digit_version(TORCH_VERSION) <= digit_version('1.8.1'),
reason='skip on windows due to limited RAM'
'and get_submodule requires torch >= 1.9.0')
def test_diffusers_pipeline_inferencer():
cfg = dict(
model=dict(
type='DiffusionPipeline',
from_pretrained='runwayml/stable-diffusion-v1-5'))

inferencer_instance = DiffusersPipelineInferencer(cfg, None)
text_prompts = 'Japanese anime style, girl'
negative_prompt = 'bad face, bad hands'
result = inferencer_instance(
text=text_prompts,
negative_prompt=negative_prompt,
height=128,
width=128)
assert result[1][0].size == (128, 128)


def teardown_module():
import gc
gc.collect()
globals().clear()
locals().clear()

0 comments on commit a08fd8c

Please sign in to comment.