-
Notifications
You must be signed in to change notification settings - Fork 1
/
synthclip_loader.py
117 lines (97 loc) · 3.14 KB
/
synthclip_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Minimal code for loading SynthCI
"""
from PIL import Image
import pandas as pd
import logging
import zipfile
import os
from torch.utils.data import Dataset
from huggingface_hub import snapshot_download
from huggingface_hub import hf_hub_download
from torchvision import transforms
class CsvDataset(Dataset):
def __init__(
self,
input_filename,
transforms,
img_key,
caption_key,
prefix_path,
sep="\t",
tokenizer=None,
):
logging.debug(f"Loading csv data from {input_filename}.")
df = pd.read_csv(input_filename, sep=sep)
self.images = df[img_key].tolist()
self.captions = df[caption_key].tolist()
self.transforms = transforms
self.prefix_path = prefix_path
logging.debug("Done loading data.")
self.tokenize = tokenizer
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
images = self.transforms(
Image.open(self.prefix_path + str(self.images[idx])).convert("RGB")
)
# With tokenization
# texts = self.tokenize([str(self.captions[idx])])[0]
# Without tokenization
texts = str(self.captions[idx])
return images, texts
if __name__ == "__main__":
# Download the dataset
REPO_ID = "hammh0a/SynthCLIP"
# Uncomment for full dataset download
# snapshot_download(repo_id=REPO_ID, repo_type="dataset", cache_dir="./cache/", local_dir_use_symlinks=False, local_dir="./synthclip_data/")
# Download only ./synthclip_data/data/0.zip and ./synthclip_data/combined_images_and_captions.csv
hf_hub_download(
repo_id=REPO_ID,
repo_type="dataset",
cache_dir="./cache/",
local_dir_use_symlinks=False,
local_dir="./synthclip_data/",
filename="./SynthCI-30/data/0.zip",
)
hf_hub_download(
repo_id=REPO_ID,
repo_type="dataset",
cache_dir="./cache/",
local_dir_use_symlinks=False,
local_dir="./synthclip_data/",
filename="./SynthCI-30/combined_images_and_captions.csv",
)
prefix = "./synthclip_data/SynthCI-30/data/"
# Inside ./synthclip_data/data there will be multiple zip files unzip all
# Unzip the files
for file in os.listdir(prefix):
if file.endswith(".zip"):
with zipfile.ZipFile(prefix + file, "r") as zip_ref:
zip_ref.extractall(prefix)
# Remove the zip files
for file in os.listdir(prefix):
if file.endswith(".zip"):
os.remove(prefix + file)
# transforms
transform = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
]
)
# Load the dataset
dataset = CsvDataset(
input_filename="./synthclip_data/SynthCI-30/combined_images_and_captions.csv",
transforms=transform,
img_key="image_path",
caption_key="caption",
prefix_path=prefix,
)
img, caption = dataset[0]
# visualize the image
import matplotlib.pyplot as plt
plt.imshow(img.permute(1, 2, 0))
plt.title(caption)
plt.savefig("sample.png")
plt.show()