diff --git a/mmagic/apis/inferencers/text2image_inferencer.py b/mmagic/apis/inferencers/text2image_inferencer.py index e232859a18..61411f5877 100644 --- a/mmagic/apis/inferencers/text2image_inferencer.py +++ b/mmagic/apis/inferencers/text2image_inferencer.py @@ -15,18 +15,23 @@ class Text2ImageInferencer(BaseMMagicInferencer): """inferencer that predicts with text2image models.""" func_kwargs = dict( - preprocess=['text', 'control'], + preprocess=['text', 'control', 'negative_prompt'], forward=[], visualize=['result_out_dir'], postprocess=[]) extra_parameters = dict(height=None, width=None, seed=1) - def preprocess(self, text: InputsType, control: str = None) -> Dict: + def preprocess(self, + text: InputsType, + control: str = None, + negative_prompt: InputsType = None) -> Dict: """Process the inputs into a model-feedable format. Args: text(InputsType): text input for text-to-image model. + control(str): control img dir for controlnet. + negative_prompt(InputsType): negative prompt. Returns: result(Dict): Results of preprocess. @@ -43,6 +48,9 @@ def preprocess(self, text: InputsType, control: str = None) -> Dict: result['control'] = control_img result.pop('seed', None) + if negative_prompt: + result['negative_prompt'] = negative_prompt + return result def forward(self, inputs: InputsType) -> PredType: