Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot Infer Shapes from pretrained Models. #15

Open
mycpuorg opened this issue Apr 3, 2022 · 0 comments
Open

Cannot Infer Shapes from pretrained Models. #15

mycpuorg opened this issue Apr 3, 2022 · 0 comments

Comments

@mycpuorg
Copy link

mycpuorg commented Apr 3, 2022

Hi,
I have a relatively straightforward situation where I need to validate my input shapes while jit.loading a saved model and I can't seem to find the solution for.

save a traced model and verify inputs shapes:
test_model = torchvision.models.resnet.resnet18(pretrained=True).eval()
with torch.no_grad():
    inp_224 = torch.rand(1, 3, 224, 224, dtype=torch.float)
    script_module_224 = torch.jit.trace(test_model, inp_224)
    graph_inputs = list(script_module_224.graph.inputs())
    graph_inputs = graph_inputs[1:] # ignore self
    print(graph_inputs[0].type().sizes()) # [1, 3, 224, 224]
    script_module_224.save("saved_model_224.pt")

load the same model cannot validate the traced input shapes:

# ... later elsewhere load my saved model
loaded_224 = torch.jit.load("saved_model_224.pt")
# there's nothing preventing me from sending incorrect input shapes
# i.e., traced with 224 but called with 500
inp_500 = torch.rand(1, 3, 500, 500, dtype=torch.float)
loaded_224(inp_500)

I'd like to prevent feeding incorrect shaped inputs after loading models.
In particular, I would like to be able to do something like:

traced_input_shape = loaded_224.get_input_shape()
if (inp_500.shape() != traced_input_shape()):
    print("Error: Trying to run Inference with Incorrect Shaped Inputs!")
    # die

I tried using torchlayers to help with this situation by:

import torchlayers as tl
t.build(loaded_224, inp_224)

This failed (reasonably) with:

PickleError: ScriptModules cannot be deepcopied using copy.deepcopy or saved using torch.save. Mixed serialization of script and non-script modules is not supported. For purely script modules use my_script_module.save(<filename>) instead.

Any recommendations?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant