Skip to content

Commit

Permalink
mock text encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
liuwenran committed Sep 8, 2023
1 parent 4be2dd8 commit 6a5c503
Showing 1 changed file with 12 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import platform

import pytest
import torch
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

Expand All @@ -24,6 +25,17 @@ def test_diffusers_pipeline_inferencer():
from_pretrained='runwayml/stable-diffusion-v1-5'))

inferencer_instance = DiffusersPipelineInferencer(cfg, None)

def mock_encode_prompt(prompt, do_classifier_free_guidance,
num_images_per_prompt, *args, **kwargs):
batch_size = len(prompt) if isinstance(prompt, list) else 1
batch_size *= num_images_per_prompt
if do_classifier_free_guidance:
batch_size *= 2
return torch.randn(batch_size, 5, 16) # 2 for cfg

inferencer_instance.model._encode_prompt = mock_encode_prompt

text_prompts = 'Japanese anime style, girl'
negative_prompt = 'bad face, bad hands'
result = inferencer_instance(
Expand Down

0 comments on commit 6a5c503

Please sign in to comment.