diff --git a/README.md b/README.md
index cd64703..01d0439 100644
--- a/README.md
+++ b/README.md
@@ -37,7 +37,7 @@ This is an unofficial **PyTorch** implementation of **DeepLab v2** [[1](##refere
This repo |
- Download |
+ Download |
|
65.8 |
45.7 |
@@ -57,7 +57,7 @@ This is an unofficial **PyTorch** implementation of **DeepLab v2** [[1](##refere
164k val |
This repo |
- Download ‡ |
+ Download ‡ |
|
66.8 |
51.2 |
@@ -112,7 +112,7 @@ This is an unofficial **PyTorch** implementation of **DeepLab v2** [[1](##refere
This repo |
- Download |
+ Download |
|
94.64 |
86.50 |
@@ -240,7 +240,7 @@ python demo.py single \
To run on a webcam:
-```console
+```bash
python demo.py live \
--config-path configs/voc12.yaml \
--model-path deeplabv2_resnet101_msc-vocaug-20000.pth
@@ -252,12 +252,11 @@ To run a CRF post-processing, add `--crf`. To run on a CPU, add `--cpu`.
### torch.hub
-Model setup with 3 lines
+Model setup with two lines
```python
import torch.hub
-model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=182)
-model.load_state_dict(torch.load("deeplabv2_resnet101_msc-cocostuff164k-100000.pth"))
+model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", pretrained='cocostuff164k', n_classes=182)
```
### Difference with Caffe version
diff --git a/hubconf.py b/hubconf.py
index f2431ab..521c5bf 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -7,36 +7,41 @@
from __future__ import print_function
+from torch.hub import load_state_dict_from_url
-def deeplabv2_resnet101(pretrained=False, **kwargs):
- """
- DeepLab v2 model with ResNet-101 backbone
- n_classes (int): the number of classes
- """
+model_url_root = "https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/"
+model_dict = {
+ "cocostuff10k": ("deeplabv2_resnet101_msc-cocostuff10k-20000.pth", 182),
+ "cocostuff164k": ("deeplabv2_resnet101_msc-cocostuff164k-100000.pth", 182),
+ "voc12": ("deeplabv2_resnet101_msc-vocaug-20000.pth", 21),
+}
- if pretrained:
- raise NotImplementedError(
- "Please download from "
- "https://github.com/kazuto1011/deeplab-pytorch/tree/master#performance"
- )
+
+def deeplabv2_resnet101(pretrained=None, n_classes=182, scales=None):
from libs.models.deeplabv2 import DeepLabV2
from libs.models.msc import MSC
- base = DeepLabV2(n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18, 24], **kwargs)
- model = MSC(base=base, scales=[0.5, 0.75])
+ # Model parameters
+ n_blocks = [3, 4, 23, 3]
+ atrous_rates = [6, 12, 18, 24]
+ if scales is None:
+ scales = [0.5, 0.75]
- return model
+ base = DeepLabV2(n_classes=n_classes, n_blocks=n_blocks, atrous_rates=atrous_rates)
+ model = MSC(base=base, scales=scales)
+ # Load pretrained models
+ if isinstance(pretrained, str):
-if __name__ == "__main__":
- import torch.hub
+ assert pretrained in model_dict, list(model_dict.keys())
+ expected = model_dict[pretrained][1]
+ error_message = "Expected: n_classes={}".format(expected)
+ assert n_classes == expected, error_message
- model = torch.hub.load(
- "kazuto1011/deeplab-pytorch",
- "deeplabv2_resnet101",
- n_classes=182,
- force_reload=True,
- )
+ model_url = model_url_root + model_dict[pretrained][0]
+ state_dict = load_state_dict_from_url(model_url)
+ model.load_state_dict(state_dict)
+
+ return model
- print(model)