diff --git a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py index 98b49f6e7c..dbfeb5c6e8 100644 --- a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py +++ b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py @@ -2,6 +2,7 @@ import platform import pytest +import torch from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -24,6 +25,17 @@ def test_diffusers_pipeline_inferencer(): from_pretrained='runwayml/stable-diffusion-v1-5')) inferencer_instance = DiffusersPipelineInferencer(cfg, None) + + def mock_encode_prompt(prompt, do_classifier_free_guidance, + num_images_per_prompt, *args, **kwargs): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + batch_size *= num_images_per_prompt + if do_classifier_free_guidance: + batch_size *= 2 + return torch.randn(batch_size, 5, 16) # 2 for cfg + + inferencer_instance.model._encode_prompt = mock_encode_prompt + text_prompts = 'Japanese anime style, girl' negative_prompt = 'bad face, bad hands' result = inferencer_instance(