Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Other models from torchvision #6

Open
madiltalay opened this issue Apr 15, 2019 · 2 comments
Open

Other models from torchvision #6

madiltalay opened this issue Apr 15, 2019 · 2 comments

Comments

@madiltalay
Copy link

Hi, I wanted to use light-weight base models from torchvision like mobilenet, squeezenet and densenet. Any tips on how to edit the code?

@potterhsu
Copy link
Owner

To add a new backbone, you should do the following steps:

  1. Create an inherited class from backbone.base.Base, then override the method features

    class DenseNet161(backbone.base.Base):
        def __init__(self, pretrained: bool):
            super().__init__(pretrained)
    
        def features(self) -> Tuple[nn.Module, nn.Module, int, int]:
            densenet161 = torchvision.models.densenet161(pretrained=self._pretrained)
            ...
  2. In backbone/base.py, add a new branch for your new backbone, and don't forget to extend the options

  3. Specific new backbone to script

    $ python train.py -s=xxx -b=densenet161
    

Hope this helps.

@madiltalay
Copy link
Author

Thanks for the prompt reply, I'm now working to get the right dimensions for

  1. input
  2. features
  3. hidden

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants