-
Notifications
You must be signed in to change notification settings - Fork 186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Missing torch.load #403
Comments
Python pickling is not going to be a solution here, it's too intimately tied to Python's object model. If there is a bug with saving state_dict() in Python, and loading it in .NET, I want to make sure to fix it. At a minimum, provide a more actionable error message. Like the article on saving and loading demonstrates, you have to use the special format for exporting model weights. Could you please provide more details on what you are saving and how your are loading it? If you can provide some source files for Python and .NET (narrowed down, preferably), that would be great. |
I am attempting to hijack the populated model decribed in main.py and save the State_Dict as decribied using exportsd.py At this empty line, these codes were inserted to export State_Dict f = open("gpt2-pytorch_model.ts", "wb")
exportsd.save_state_dict(model.to("cpu").state_dict(), f)
f.close() When I attempted to load Since in PyTorch, it is possible to simply load the state_dict without first defining the model, I attempting to check if that is possible in TorchSharp Code As the PyTorch load function involves more parameters, I wonder if TorchSharp need that too. I hope I understand correctly this is how to use the exportsd.py function |
Thanks for that information. In terms of the how TorchSharp model serialization works, it loads and saves model parameters (weights), not models. That means that in order to load weights, you have to have an exact copy of the original model defined in .NET, and an instance of the model created (presumably with random or empty weights). Parameters should be represented as fields in the model, as should buffers (tensors that are used by the model, but not affected by training), and they must have exactly the same name as in the original model. I'll construct some negative unit tests for this and see if I can improve the error messages to be more informative. |
Create a unit test to test the feasibility [Fact]
public void LoadModelTest()
{
string fileName = "gpt2-pytorch_model.ts";
if (File.Exists(fileName)) {
var stateDict = Module.Load(fileName);
Assert.NotNull(stateDict);
}
} |
Do you think this is a valid use case? Enhance the exportsd.py script so that TorchSharp compatible parameters and weights can be exported and then loaded back into TorchSharp and populate a TorchSharp custom model as described in PyTorch main.py This will provide the .NET community additional ways to evaluate how close compatibility is TorchSharp with the PyTorch counterpart. |
Yes, it is exactly what the intent for Module.load is -- to deserialize parameters. However, the model definition needs to exactly match the origin, since (unlike Python) .NET cannot create a class definition and then instantiate it. Like I mentioned, Module.load() loads weights, not modules. |
This is the method to load weights, not modules as described in Utils.py Is that conceivable possible with TorchSharp? Export TorchSharp compatible weights and load them back to a compatible model using TorchSharp compatible codes as described below? def load_weight(model, state_dict):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if key.endswith(".g"):
new_key = key[:-2] + ".weight"
elif key.endswith(".b"):
new_key = key[:-2] + ".bias"
elif key.endswith(".w"):
new_key = key[:-2] + ".weight"
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata |
Anything that looks like a <String,Tensor> dictionary and was saved using the format that exportsd.py also uses should be possible to load, but when loading, the keys come from the model instance (either a custom module or Sequential) that the weights are being loaded into. On the saving side, the keys likewise come from the original model. Thus, the two have to exactly match -- that's the key here. Without seeing the model definition on both sides, it's hard to help debug it. The best I can do, and I will try to get that into the next release, is to improve the error messages so that they are more informative. |
@GeorgeS2019 -- I suggest adding a print statement (on your machine) to the exportsd.py, something like: for entry in sd:
print(entry)
stream.write(leb128.u.encode(len(entry)))
stream.write(bytes(entry, 'utf-8'))
_write_tensor(sd[entry], stream) and see what the names of all the state_dict entries are, then compare that to your .NET module that you are loading the weights into. |
That discussion was back in 2019. I wonder how moving target is PyTorch format 2021 and will that affect future TorchSharp import PyTorch model? |
In order to train with TorchSharp, you will always need a representation of the model in code. A longer-term solution than the one we have in place now (loading weights into a model instance) will be to generate the TorchSharp code from an ONNX graph, so that you have full fidelity, then load the weights into that. Implementing ONNX import will take some time, so if you need to transfer from PyTorch to TorchSharp before then, the existing mechanism is your only option. If the article at: https://github.com/dotnet/TorchSharp/blob/main/docfx/articles/saveload.md doesn't describe the mechanics of the current approach in sufficient detail, please file a documentation issue with feedback on where the article is lacking in detail or clarity. |
@NiklasGustafsson I see this eventually related to ML.NET TorchSharp integration |
@GeorgeS2019 I'm curious -- where did the diagram above come from? |
@NiklasGustafsson I took an existing diagram from a Microsoft documentation on tensorflow integration to ML.NET and simply clone it and add in the PyTorch part. |
One of the differences between TF and PyTroch is that TF saves the graph using protobuf, which is language-independent. In PyTorch, there really isn't a graph, and the weights are saved using Python pickling, which is a serialization format that is specific to Python. That's why you need to use something like ONNX to get not only the weights, but also the graph. This is why you need the model code in TorchSharp before you can load the weights. See, for example, the PyTorch documentation for 'load_state_dict()' (which is analogous to what we're doing in TorchSharp):
Once the model has been exported to ONNX, we can (theoretically) use that to recreate the model code in C# (or F#), which can then be used to load the model weights. That will take a lot of work, so it's not something that is coming soon. In the meantime, you will have to recreate the model code in .NET if you want to load parameters from Python. |
I close this issue now after an unit test provided to show how to save dict_states in pytorch and load them in torchsharp. |
Similar with the discussion of the missing torch.save
pytorch torch.load refers to torch\serialization.py
which provide parameter instructions for loading
Example
Currently LibTorchSharp implements one of the possible loading options listed above
Since pickling is an overkill as discussed for .NET
As I am still learning ... is there a need to provide more loading options provided through torch.load instead of Module.load in TorchSharp?
I am raising this issue, as I fail to load a saved State_Dict created through exportsd.py back to TorchSharp using Module.Load
I did not get any error message, as the process crashes.
suggestions: is there a need for error messages when loading fail to assist in a more reliable loading state_dict.
The text was updated successfully, but these errors were encountered: