diff --git a/mmagic/structures/data_sample.py b/mmagic/structures/data_sample.py index dcd4e20ce2..248e0cc994 100644 --- a/mmagic/structures/data_sample.py +++ b/mmagic/structures/data_sample.py @@ -244,6 +244,11 @@ def set_gt_label( self.gt_label = label return self + def set_gt_prompt(self, value: Union[str, Sequence[str]]) -> 'DataSample': + """Set label of ``gt_label``.""" + self.prompt = value + return self + @property def gt_label(self): """This the function to fetch gt label.