Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HF Trainer TrainingArguments can't be used with default_factory #275

Open
levmckinney opened this issue Jul 14, 2023 · 1 comment
Open

Comments

@levmckinney
Copy link

Describe the bug
I've really been loving using simple-parsing in my projects. It looks like you are trying to maintain compatibility with hugging faces dataclass #172. One use case I've been trying to get to work that involves this is to expose the TrainingArguments dataclass on the command line using simple-parsing so that I don't have to manual pass all the different configuration options through. This was working great until I tried to add default arguments, at which point I started running into errors of the form:

ValueError: IntervalStrategy.STEPS is not a valid IntervalStrategy, please select one of ['no', 'steps', 'epoch']

I believe this is because at some point simple-parsing converts IntervalStrategy.STEP into the string literal

'IntervalStrategy.STEP'

To Reproduce

# issue.py
from dataclasses import dataclass

from transformers import TrainingArguments
from simple_parsing import field, parse

@dataclass
class HParams:
    """You can use Enums"""

    sub_component: TrainingArguments = field(
        default_factory=lambda : TrainingArguments(evaluation_strategy="steps")
    )

if __name__ == "__main__":
    my_preferences: HParams = parse(HParams)
    print(my_preferences)

So you don't have to dig through hugging faces code, here is a minimal replication of what's happening.

See huggingface/transformers#17933 for why it inherits from string

# simplified.py
# ======================= Their Code =======================
from typing import Union
import enum
from dataclasses import dataclass

from simple_parsing import parse, field

class Color(str, enum.Enum):
    RED = "red"
    ORANGE = "orange"
    BLUE = "blue"

@dataclass
class SubComponent:
    color: Union[str, Color] = Color.BLUE
    
    def __post_init__(self):
        self.color = Color(self.color)

# ======================= My Code =======================
@dataclass
class HParams:
    """You can use Enums"""

    sub_component: SubComponent = field(
        default_factory=lambda : SubComponent(color="red")
    )

if __name__ == "__main__":
    hparams: HParams = parse(HParams)
    print(hparams)

Expected behavior
A clear and concise description of what you expected to happen.

$ python issue.py
HParams(TrainingArguments(...))
$ python simplified.py
HParams(sub_component=SubComponent(color=<Color.Red: 'red'>))

Actual behavior
I get errors of the form:

$ python issue.py
Traceback (most recent call last):
  File "/home/lev/Projects/robust-llm/test_enum_parsing.py", line 15, in <module>
    hparams: HParams = parse(HParams)
                       ^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 1021, in parse
    parsed_args = parser.parse_args(args)
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/argparse.py", line 1869, in parse_args
    args, argv = self.parse_known_args(args, namespace)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 349, in parse_known_args
    parsed_args = self._postprocessing(parsed_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 581, in _postprocessing
    parsed_args = self._instantiate_dataclasses(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 849, in _instantiate_dataclasses
    value_for_dataclass_field = _create_dataclass_instance(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 1137, in _create_dataclass_instance
    return constructor(**constructor_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 111, in __init__
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/transformers/training_args.py", line 1199, in __post_init__
    self.evaluation_strategy = IntervalStrategy(self.evaluation_strategy)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/enum.py", line 714, in __call__
    return cls.__new__(cls, value)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/enum.py", line 1138, in __new__
    raise exc
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/enum.py", line 1115, in __new__
    result = cls._missing_(value)
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/transformers/utils/generic.py", line 348, in _missing_
    raise ValueError(
ValueError: IntervalStrategy.STEPS is not a valid IntervalStrategy, please select one of ['no', 'steps', 'epoch']

Here is the simplified example that replicates the basic issue without the HF stuff.

$ python simplified.py
Traceback (most recent call last):
  File "/home/lev/Projects/robust-llm/test_enum_parsing.py", line 30, in <module>
    hparams: HParams = parse(HParams)
                       ^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 1021, in parse
    parsed_args = parser.parse_args(args)
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/argparse.py", line 1869, in parse_args
    args, argv = self.parse_known_args(args, namespace)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 349, in parse_known_args
    parsed_args = self._postprocessing(parsed_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 581, in _postprocessing
    parsed_args = self._instantiate_dataclasses(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 849, in _instantiate_dataclasses
    value_for_dataclass_field = _create_dataclass_instance(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/site-packages/simple_parsing/parsing.py", line 1137, in _create_dataclass_instance
    return constructor(**constructor_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 4, in __init__
  File "/home/lev/Projects/robust-llm/test_enum_parsing.py", line 18, in __post_init__
    self.color = Color(self.color)
                 ^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/enum.py", line 714, in __call__
    return cls.__new__(cls, value)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lev/miniconda3/envs/robust-llm/lib/python3.11/enum.py", line 1130, in __new__
    raise ve_exc
ValueError: 'Color.RED' is not a valid Color

Desktop (please complete the following information):

  • Package versions: simple-parsing==0.1.3, transformers==4.30.2
  • Python version: python==3.11

Additional context
My current understanding is that our Enum class Color or IntervalStrategy inheriting from str is causing the problem. This seems to be a hack on Hugging Faces side to help with serialization see huggingface/transformers#17933.

@lebrice
Copy link
Owner

lebrice commented Jul 14, 2023

Hey @levmckinney , thanks for posting!

I'm familiar with this issue, let me try to recall what's going on.
I believe what's happening is that SimpleParsing is parsing the value from str into an Enum, so in the __post_init__, you're calling the Color constructor with a Color instance, rather than a string.

I'll try to whip up a solution on Monday, but for now, I think you could fix it with something like:

@dataclass
class SubComponent:
    color: Union[str, Color] = Color.BLUE
    
    def __post_init__(self):
        if isinstance(self.color, str):
            self.color = Color(self.color)

I'm surprised though, I thought I had this issue already nailed down with my HuggingFace example / test. I guess one other approach would be to leave those HF classes as-is, but to add a custom handler for them..

I have to think about this, I'll get back to you, thanks again for posting!

lebrice added a commit that referenced this issue Jul 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants