diff --git a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py index dbfeb5c6e8..f2ff296e18 100644 --- a/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py +++ b/tests/test_apis/test_inferencers/test_diffusers_pipeline_inferencer.py @@ -26,15 +26,10 @@ def test_diffusers_pipeline_inferencer(): inferencer_instance = DiffusersPipelineInferencer(cfg, None) - def mock_encode_prompt(prompt, do_classifier_free_guidance, - num_images_per_prompt, *args, **kwargs): - batch_size = len(prompt) if isinstance(prompt, list) else 1 - batch_size *= num_images_per_prompt - if do_classifier_free_guidance: - batch_size *= 2 - return torch.randn(batch_size, 5, 16) # 2 for cfg - - inferencer_instance.model._encode_prompt = mock_encode_prompt + def mock_infer(*args, **kwargs): + return dict(samples=torch.randn(1, 3, 64, 64)) + + inferencer_instance.model.infer = mock_infer text_prompts = 'Japanese anime style, girl' negative_prompt = 'bad face, bad hands'