From bced86e43fe2e13032a43d359f22d7bb654eeca7 Mon Sep 17 00:00:00 2001 From: rangoliu Date: Thu, 7 Sep 2023 13:30:30 +0800 Subject: [PATCH] [Enhance] add negative prompt for sd inferencer (#2021) add negative prompt for sd --- mmagic/apis/inferencers/text2image_inferencer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/mmagic/apis/inferencers/text2image_inferencer.py b/mmagic/apis/inferencers/text2image_inferencer.py index e232859a18..61411f5877 100644 --- a/mmagic/apis/inferencers/text2image_inferencer.py +++ b/mmagic/apis/inferencers/text2image_inferencer.py @@ -15,18 +15,23 @@ class Text2ImageInferencer(BaseMMagicInferencer): """inferencer that predicts with text2image models.""" func_kwargs = dict( - preprocess=['text', 'control'], + preprocess=['text', 'control', 'negative_prompt'], forward=[], visualize=['result_out_dir'], postprocess=[]) extra_parameters = dict(height=None, width=None, seed=1) - def preprocess(self, text: InputsType, control: str = None) -> Dict: + def preprocess(self, + text: InputsType, + control: str = None, + negative_prompt: InputsType = None) -> Dict: """Process the inputs into a model-feedable format. Args: text(InputsType): text input for text-to-image model. + control(str): control img dir for controlnet. + negative_prompt(InputsType): negative prompt. Returns: result(Dict): Results of preprocess. @@ -43,6 +48,9 @@ def preprocess(self, text: InputsType, control: str = None) -> Dict: result['control'] = control_img result.pop('seed', None) + if negative_prompt: + result['negative_prompt'] = negative_prompt + return result def forward(self, inputs: InputsType) -> PredType: