-
Notifications
You must be signed in to change notification settings - Fork 1
/
prepare_pcontext.py
104 lines (83 loc) · 3.07 KB
/
prepare_pcontext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Prepare PASCAL Context dataset"""
import click
import shutil
import tarfile
import torch
from tqdm import tqdm
from pathlib import Path
from segm.utils.download import download
def download_pcontext(path, overwrite=False):
_AUG_DOWNLOAD_URLS = [
(
"https://www.dropbox.com/s/wtdibo9lb2fur70/VOCtrainval_03-May-2010.tar?dl=1",
"VOCtrainval_03-May-2010.tar",
"bf9985e9f2b064752bf6bd654d89f017c76c395a",
),
(
"https://codalabuser.blob.core.windows.net/public/trainval_merged.json",
"",
"169325d9f7e9047537fedca7b04de4dddf10b881",
),
(
"https://hangzh.s3.amazonaws.com/encoding/data/pcontext/train.pth",
"",
"4bfb49e8c1cefe352df876c9b5434e655c9c1d07",
),
(
"https://hangzh.s3.amazonaws.com/encoding/data/pcontext/val.pth",
"",
"ebedc94247ec616c57b9a2df15091784826a7b0c",
),
]
download_dir = path / "downloads"
download_dir.mkdir(parents=True, exist_ok=True)
for url, filename, checksum in _AUG_DOWNLOAD_URLS:
filename = download(
url,
path=str(download_dir / filename),
overwrite=overwrite,
sha1_hash=checksum,
)
# extract
if Path(filename).suffix == ".tar":
with tarfile.open(filename) as tar:
tar.extractall(path=str(path))
else:
shutil.move(
filename,
str(path / "VOCdevkit" / "VOC2010" / Path(filename).name),
)
@click.command(help="Initialize PASCAL Context dataset.")
@click.argument("download_dir", type=str)
def main(download_dir):
dataset_dir = Path(download_dir) / "pcontext"
download_pcontext(dataset_dir, overwrite=False)
devkit_path = dataset_dir / "VOCdevkit"
out_dir = devkit_path / "VOC2010" / "SegmentationClassContext"
imageset_dir = devkit_path / "VOC2010" / "ImageSets" / "SegmentationContext"
out_dir.mkdir(parents=True, exist_ok=True)
imageset_dir.mkdir(parents=True, exist_ok=True)
train_torch_path = devkit_path / "VOC2010" / "train.pth"
val_torch_path = devkit_path / "VOC2010" / "val.pth"
train_dict = torch.load(str(train_torch_path))
train_list = []
for idx, label in tqdm(train_dict.items()):
idx = str(idx)
new_idx = idx[:4] + "_" + idx[4:]
train_list.append(new_idx)
label_path = out_dir / f"{new_idx}.png"
label.save(str(label_path))
with open(str(imageset_dir / "train.txt"), "w") as f:
f.writelines(line + "\n" for line in sorted(train_list))
val_dict = torch.load(str(val_torch_path))
val_list = []
for idx, label in tqdm(val_dict.items()):
idx = str(idx)
new_idx = idx[:4] + "_" + idx[4:]
val_list.append(new_idx)
label_path = out_dir / f"{new_idx}.png"
label.save(str(label_path))
with open(str(imageset_dir / "val.txt"), "w") as f:
f.writelines(line + "\n" for line in sorted(val_list))
if __name__ == "__main__":
main()