Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to save class definitions from imported modules? #670

Open
j-adamczyk opened this issue Jul 22, 2024 · 0 comments
Open

How to save class definitions from imported modules? #670

j-adamczyk opened this issue Jul 22, 2024 · 0 comments

Comments

@j-adamczyk
Copy link

I have a pretty typical machine learning use case - train the model, ship it to the server, perform inference there. So I found dill for serialization. I have a lot of custom preprocessing classes, so training code runs from train.py and imports a few additional classes from other modules. They are all scikit-learn compatible transformers. So I built the Pipeline object in scikit-learn, and saved it will dill.

After training, I save my pipeline in train.py like:

with open(model_path, "wb") as file:
    dill.dump(pipeline, file, recurse=True)

I load it like:

with open(model_path, "rb") as file:
    pipeline = model.load(file)

I get errors like:

Traceback (most recent call last):
  File "/home/jakub/PycharmProjects/project-name/tmp.py", line 26, in <module>
    model = dill.load(file)
            ^^^^^^^^^^^^^^^
  File "/home/jakub/.cache/pypoetry/virtualenvs/project-name-xf0sj8ML-py3.11/lib/python3.11/site-packages/dill/_dill.py", line 289, in load
    return Unpickler(file, ignore=ignore, **kwds).load()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jakub/.cache/pypoetry/virtualenvs/project-name-xf0sj8ML-py3.11/lib/python3.11/site-packages/dill/_dill.py", line 444, in load
    obj = StockUnpickler.load(self)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jakub/.cache/pypoetry/virtualenvs/project-name-xf0sj8ML-py3.11/lib/python3.11/site-packages/dill/_dill.py", line 434, in find_class
    return StockUnpickler.find_class(self, module, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'ClientHistoricalMedianTransformer' on <module 'transformers' from '/home/jakub/PycharmProjects/project-name/src/create_datasets/transformers.py'>

How can I make this work? Saving those additional classes is crucial for me. There are also too many additional classes to save them separately, if that matters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant