diff --git a/README.md b/README.md index cd64703..01d0439 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ This is an unofficial **PyTorch** implementation of **DeepLab v2** [[1](##refere This repo - Download + Download 65.8 45.7 @@ -57,7 +57,7 @@ This is an unofficial **PyTorch** implementation of **DeepLab v2** [[1](##refere 164k val This repo - Download ‡ + Download ‡ 66.8 51.2 @@ -112,7 +112,7 @@ This is an unofficial **PyTorch** implementation of **DeepLab v2** [[1](##refere This repo - Download + Download 94.64 86.50 @@ -240,7 +240,7 @@ python demo.py single \ To run on a webcam: -```console +```bash python demo.py live \ --config-path configs/voc12.yaml \ --model-path deeplabv2_resnet101_msc-vocaug-20000.pth @@ -252,12 +252,11 @@ To run a CRF post-processing, add `--crf`. To run on a CPU, add `--cpu`. ### torch.hub -Model setup with 3 lines +Model setup with two lines ```python import torch.hub -model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=182) -model.load_state_dict(torch.load("deeplabv2_resnet101_msc-cocostuff164k-100000.pth")) +model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", pretrained='cocostuff164k', n_classes=182) ``` ### Difference with Caffe version diff --git a/hubconf.py b/hubconf.py index f2431ab..521c5bf 100644 --- a/hubconf.py +++ b/hubconf.py @@ -7,36 +7,41 @@ from __future__ import print_function +from torch.hub import load_state_dict_from_url -def deeplabv2_resnet101(pretrained=False, **kwargs): - """ - DeepLab v2 model with ResNet-101 backbone - n_classes (int): the number of classes - """ +model_url_root = "https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/" +model_dict = { + "cocostuff10k": ("deeplabv2_resnet101_msc-cocostuff10k-20000.pth", 182), + "cocostuff164k": ("deeplabv2_resnet101_msc-cocostuff164k-100000.pth", 182), + "voc12": ("deeplabv2_resnet101_msc-vocaug-20000.pth", 21), +} - if pretrained: - raise NotImplementedError( - "Please download from " - "https://github.com/kazuto1011/deeplab-pytorch/tree/master#performance" - ) + +def deeplabv2_resnet101(pretrained=None, n_classes=182, scales=None): from libs.models.deeplabv2 import DeepLabV2 from libs.models.msc import MSC - base = DeepLabV2(n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24], **kwargs) - model = MSC(base=base, scales=[0.5, 0.75]) + # Model parameters + n_blocks = [3, 4, 23, 3] + atrous_rates = [6, 12, 18, 24] + if scales is None: + scales = [0.5, 0.75] - return model + base = DeepLabV2(n_classes=n_classes, n_blocks=n_blocks, atrous_rates=atrous_rates) + model = MSC(base=base, scales=scales) + # Load pretrained models + if isinstance(pretrained, str): -if __name__ == "__main__": - import torch.hub + assert pretrained in model_dict, list(model_dict.keys()) + expected = model_dict[pretrained][1] + error_message = "Expected: n_classes={}".format(expected) + assert n_classes == expected, error_message - model = torch.hub.load( - "kazuto1011/deeplab-pytorch", - "deeplabv2_resnet101", - n_classes=182, - force_reload=True, - ) + model_url = model_url_root + model_dict[pretrained][0] + state_dict = load_state_dict_from_url(model_url) + model.load_state_dict(state_dict) + + return model - print(model)