From 79d091e9bff4642b511c80a2dd01c479eb93aaa0 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sat, 14 Oct 2023 18:15:03 +0200 Subject: [PATCH] fix overwritten zeroshot templates issue #109. Thanks to @djghosh13. --- clip_benchmark/datasets/builder.py | 135 +++++++++++++++++------------ 1 file changed, 80 insertions(+), 55 deletions(-) diff --git a/clip_benchmark/datasets/builder.py b/clip_benchmark/datasets/builder.py index 0e659d9..635d967 100644 --- a/clip_benchmark/datasets/builder.py +++ b/clip_benchmark/datasets/builder.py @@ -49,47 +49,38 @@ def build_dataset(dataset_name, root="root", transform=None, split="test", downl if task in ('zeroshot_classification', 'linear_probe'): # Only load templates and classnames if we have to current_folder = os.path.dirname(__file__) - if dataset_name == "babel_imagenet": - classnames = json.load(open(os.path.join(current_folder, "babel_imagenet.json"))) - assert language.upper() in classnames, f"Language '{language}' not supported for Babel-ImageNet" - classnames = classnames[language.upper()] - templates = json.load(open(os.path.join(current_folder, "nllb_dist13b_prompts.json"))) - templates = templates[language.upper()] - templates = [t.replace('{}', '{c}') for t in templates] - else: - if custom_classname_file and not os.path.exists(custom_classname_file): - # look at current_folder - custom_classname_file_attempt = os.path.join(current_folder, custom_classname_file) - assert os.path.exists(custom_classname_file_attempt), f"Custom classname file '{custom_classname_file}' does not exist" - custom_classname_file = custom_classname_file_attempt - else: - custom_classname_file = os.path.join(current_folder, language + "_classnames.json") - - if custom_template_file and not os.path.exists(custom_template_file): - # look at current_folder - custom_template_file_attempt = os.path.join(current_folder, custom_template_file) - assert os.path.exists(custom_template_file_attempt), f"Custom template file '{custom_template_file}' does not exist" - custom_template_file = custom_template_file_attempt - else: - custom_template_file = os.path.join(current_folder, language + "_zeroshot_classification_templates.json") - + # Load _classnames.json (default) + default_classname_file = os.path.join(current_folder, language + "_classnames.json") + with open(default_classname_file, "r") as f: + default_classnames = json.load(f) + # Load _zeroshot_classification_templates.json + default_template_file = os.path.join(current_folder, language + "_zeroshot_classification_templates.json") + with open(default_template_file, "r") as f: + default_templates = json.load(f) + + # Load custom classnames file if --custom_classname_file is specified + if custom_classname_file: + if not os.path.exists(custom_classname_file): + custom_classname_file = os.path.join(current_folder, custom_classname_file) + assert os.path.exists(custom_classname_file), f"Custom classname file '{custom_classname_file}' does not exist" with open(custom_classname_file, "r") as f: - classnames = json.load(f) + custom_classnames = json.load(f) + else: + custom_classnames = None + # Load custom template file if --custom_template_file is specified + if custom_template_file: + if not os.path.exists(custom_template_file): + # look at current_folder + custom_template_file = os.path.join(current_folder, custom_template_file) + assert os.path.exists(custom_template_file), f"Custom template file '{custom_template_file}' does not exist" with open(custom_template_file, "r") as f: - templates = json.load(f) - - default_template = templates["imagenet1k"] if "imagenet1k" in templates else None + custom_templates = json.load(f) + else: + custom_templates = None + + default_or_custom_classnames = custom_classnames if custom_classnames else default_classnames - if dataset_name.startswith("tfds/") or dataset_name.startswith("vtab/") or dataset_name.startswith("wds/"): - name = dataset_name.split("/")[-1] - else: - name = dataset_name - templates = templates.get(name, default_template) - assert templates is not None, f"Templates for dataset '{dataset_name}' not found in '{custom_template_file}'" - else: - classnames, templates = None, None - def download_imagenet(r): os.makedirs(r, exist_ok=True) call(f"wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz --output-document={r}/ILSVRC2012_devkit_t12.tar.gz", shell=True) @@ -107,7 +98,7 @@ def download_imagenet(r): if not os.path.exists(root): download_imagenet(root) ds = ImageNet(root=root, split="train" if train else "val", transform=transform, **kwargs) - ds.classes = classnames["imagenet1k"] + ds.classes = default_or_custom_classnames["imagenet1k"] elif dataset_name == "imagenet-w": assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" from imagenet_w import AddWatermark @@ -125,26 +116,33 @@ def download_imagenet(r): assert index_normalize is not None, "Normalize not found in transform" transform.transforms.insert(index_normalize, AddWatermark(crop_size)) ds = ImageNet(root=root, split="train" if train else "val", transform=transform, **kwargs) - ds.classes = classnames["imagenet1k"] + ds.classes = default_or_custom_classnames["imagenet1k"] elif dataset_name == "babel_imagenet": assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" # babel ImageNet from https://github.com/gregor-ge/Babel-ImageNet if not os.path.exists(root): download_imagenet(root) + classnames = json.load(open(os.path.join(current_folder, "babel_imagenet.json"))) + assert language.upper() in classnames, f"Language '{language}' not supported for Babel-ImageNet" + classnames = classnames[language.upper()] + templates = json.load(open(os.path.join(current_folder, "nllb_dist13b_prompts.json"))) + templates = templates[language.upper()] + templates = [t.replace('{}', '{c}') for t in templates] idxs, classnames = classnames ds = babel_imagenet.BabelImageNet(root=root, idxs=idxs, split="train" if train else "val", transform=transform, **kwargs) ds.classes = classnames + ds.templates = templates elif dataset_name == "imagenet1k-unverified": assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" split = "train" if train else "val" ds = ImageFolder(root=os.path.join(root, split), transform=transform, **kwargs) # use classnames from OpenAI - ds.classes = classnames["imagenet1k"] + ds.classes = default_or_custom_classnames["imagenet1k"] elif dataset_name == "imagenetv2": assert split == "test", f"Only `test` split available for {dataset_name}" os.makedirs(root, exist_ok=True) ds = imagenetv2.ImageNetV2Dataset(variant="matched-frequency", transform=transform, location=root) - ds.classes = classnames["imagenet1k"] + ds.classes = default_or_custom_classnames["imagenet1k"] elif dataset_name == "imagenet_sketch": assert split == "test", f"Only `test` split available for {dataset_name}" # Downloadable from https://drive.google.com/open?id=1Mj0i5HBthqH1p_yeXzsg22gZduvgoNeA @@ -161,7 +159,7 @@ def download_imagenet(r): call("unzip ImageNet-Sketch.zip", shell=True) call(f"mv sketch {root}", shell=True) ds = ImageFolder(root=root, transform=transform, **kwargs) - ds.classes = classnames["imagenet1k"] + ds.classes = default_or_custom_classnames["imagenet1k"] elif dataset_name == "imagenet-a": assert split == "test", f"Only `test` split available for {dataset_name}" # Downloadable from https://people.eecs.berkeley.edu/~hendrycks/imagenet-a.tar @@ -172,7 +170,7 @@ def download_imagenet(r): call("tar xvf imagenet-a.tar", shell=True) call(f"mv imagenet-a {root}", shell=True) ds = ImageFolder(root=root, transform=transform, **kwargs) - ds.classes = classnames["imagenet1k"] + ds.classes = default_or_custom_classnames["imagenet1k"] imagenet_a_wnids = ['n01498041', 'n01531178', 'n01534433', 'n01558993', 'n01580077', 'n01614925', 'n01616318', 'n01631663', 'n01641577', 'n01669191', 'n01677366', 'n01687978', 'n01694178', 'n01698640', 'n01735189', 'n01770081', 'n01770393', 'n01774750', 'n01784675', 'n01819313', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01882714', 'n01910747', 'n01914609', 'n01924916', 'n01944390', 'n01985128', 'n01986214', 'n02007558', 'n02009912', 'n02037110', 'n02051845', 'n02077923', 'n02085620', 'n02099601', 'n02106550', 'n02106662', 'n02110958', 'n02119022', 'n02123394', 'n02127052', 'n02129165', 'n02133161', 'n02137549', 'n02165456', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02231487', 'n02233338', 'n02236044', 'n02259212', 'n02268443', 'n02279972', 'n02280649', 'n02281787', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02361337', 'n02410509', 'n02445715', 'n02454379', 'n02486410', 'n02492035', 'n02504458', 'n02655020', 'n02669723', 'n02672831', 'n02676566', 'n02690373', 'n02701002', 'n02730930', 'n02777292', 'n02782093', 'n02787622', 'n02793495', 'n02797295', 'n02802426', 'n02814860', 'n02815834', 'n02837789', 'n02879718', 'n02883205', 'n02895154', 'n02906734', 'n02948072', 'n02951358', 'n02980441', 'n02992211', 'n02999410', 'n03014705', 'n03026506', 'n03124043', 'n03125729', 'n03187595', 'n03196217', 'n03223299', 'n03250847', 'n03255030', 'n03291819', 'n03325584', 'n03355925', 'n03384352', 'n03388043', 'n03417042', 'n03443371', 'n03444034', 'n03445924', 'n03452741', 'n03483316', 'n03584829', 'n03590841', 'n03594945', 'n03617480', 'n03666591', 'n03670208', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03775071', 'n03788195', 'n03804744', 'n03837869', 'n03840681', 'n03854065', 'n03888257', 'n03891332', 'n03935335', 'n03982430', 'n04019541', 'n04033901', 'n04039381', 'n04067472', 'n04086273', 'n04099969', 'n04118538', 'n04131690', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04179913', 'n04208210', 'n04235860', 'n04252077', 'n04252225', 'n04254120', 'n04270147', 'n04275548', 'n04310018', 'n04317175', 'n04344873', 'n04347754', 'n04355338', 'n04366367', 'n04376876', 'n04389033', 'n04399382', 'n04442312', 'n04456115', 'n04482393', 'n04507155', 'n04509417', 'n04532670', 'n04540053', 'n04554684', 'n04562935', 'n04591713', 'n04606251', 'n07583066', 'n07695742', 'n07697313', 'n07697537', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07749582', 'n07753592', 'n07760859', 'n07768694', 'n07831146', 'n09229709', 'n09246464', 'n09472597', 'n09835506', 'n11879895', 'n12057211', 'n12144580', 'n12267677'] imagenet_a_mask = [wnid in set(imagenet_a_wnids) for wnid in all_imagenet_wordnet_ids] ds.classes = [cl for cl, mask in zip(ds.classes, imagenet_a_mask) if mask] @@ -188,7 +186,7 @@ def download_imagenet(r): imagenet_r_wnids = {'n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', 'n01986214', 'n02007558', 'n02009912', 'n02051845', 'n02056570', 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', 'n02391049', 'n02395406', 'n02398521', 'n02410509', 'n02423022', 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677'} imagenet_r_mask = [wnid in imagenet_r_wnids for wnid in all_imagenet_wordnet_ids] ds = ImageFolder(root=root, transform=transform, **kwargs) - ds.classes = classnames["imagenet1k"] + ds.classes = default_or_custom_classnames["imagenet1k"] ds.classes = [cl for cl, mask in zip(ds.classes, imagenet_r_mask) if mask] elif dataset_name == "imagenet-o": assert split == "test", f"Only `test` split available for {dataset_name}" @@ -200,7 +198,7 @@ def download_imagenet(r): call("tar xvf imagenet-o.tar", shell=True) call(f"mv imagenet-o {root}", shell=True) ds = ImageFolder(root=root, transform=transform, **kwargs) - ds.classes = classnames["imagenet1k"] + ds.classes = default_or_custom_classnames["imagenet1k"] imagenet_o_wnids = ['n01443537', 'n01704323', 'n01770081', 'n01784675', 'n01819313', 'n01820546', 'n01910747', 'n01917289', 'n01968897', 'n02074367', 'n02317335', 'n02319095', 'n02395406', 'n02454379', 'n02606052', 'n02655020', 'n02666196', 'n02672831', 'n02730930', 'n02777292', 'n02783161', 'n02786058', 'n02787622', 'n02791270', 'n02808304', 'n02817516', 'n02841315', 'n02865351', 'n02877765', 'n02892767', 'n02906734', 'n02910353', 'n02916936', 'n02948072', 'n02965783', 'n03000134', 'n03000684', 'n03017168', 'n03026506', 'n03032252', 'n03075370', 'n03109150', 'n03126707', 'n03134739', 'n03160309', 'n03196217', 'n03207743', 'n03218198', 'n03223299', 'n03240683', 'n03271574', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03344393', 'n03347037', 'n03372029', 'n03376595', 'n03388043', 'n03388183', 'n03400231', 'n03445777', 'n03457902', 'n03467068', 'n03482405', 'n03483316', 'n03494278', 'n03530642', 'n03544143', 'n03584829', 'n03590841', 'n03598930', 'n03602883', 'n03649909', 'n03661043', 'n03666591', 'n03676483', 'n03692522', 'n03706229', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03742115', 'n03786901', 'n03788365', 'n03794056', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03840681', 'n03843555', 'n03854065', 'n03857828', 'n03868863', 'n03874293', 'n03884397', 'n03891251', 'n03908714', 'n03920288', 'n03929660', 'n03930313', 'n03937543', 'n03942813', 'n03944341', 'n03961711', 'n03970156', 'n03982430', 'n03991062', 'n03995372', 'n03998194', 'n04005630', 'n04023962', 'n04033901', 'n04040759', 'n04067472', 'n04074963', 'n04116512', 'n04118776', 'n04125021', 'n04127249', 'n04131690', 'n04141975', 'n04153751', 'n04154565', 'n04201297', 'n04204347', 'n04209133', 'n04209239', 'n04228054', 'n04235860', 'n04243546', 'n04252077', 'n04254120', 'n04258138', 'n04265275', 'n04270147', 'n04275548', 'n04330267', 'n04332243', 'n04336792', 'n04347754', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04429376', 'n04435653', 'n04442312', 'n04482393', 'n04501370', 'n04507155', 'n04525305', 'n04542943', 'n04554684', 'n04557648', 'n04562935', 'n04579432', 'n04591157', 'n04597913', 'n04599235', 'n06785654', 'n06874185', 'n07615774', 'n07693725', 'n07695742', 'n07697537', 'n07711569', 'n07714990', 'n07715103', 'n07716358', 'n07717410', 'n07718472', 'n07720875', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753275', 'n07753592', 'n07754684', 'n07768694', 'n07836838', 'n07871810', 'n07873807', 'n07880968', 'n09229709', 'n09472597', 'n12144580', 'n12267677', 'n13052670'] imagenet_o_mask = [wnid in set(imagenet_o_wnids) for wnid in all_imagenet_wordnet_ids] ds.classes = [cl for cl, mask in zip(ds.classes, imagenet_o_mask) if mask] @@ -380,7 +378,7 @@ def download_mscoco_split(target_split): # also available in "vtab/caltech101" using VTAB splits, we advice to use VTAB version rather than this one # since in this one (torchvision) there are no pre-defined test splits ds = caltech101.Caltech101(root=root, target_type="category", transform=transform, download=download, **kwargs) - ds.classes = classnames["caltech101"] + ds.classes = default_or_custom_classnames["caltech101"] elif dataset_name == "flowers": assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}" ds = Flowers102(root=root, split=split, transform=transform, download=download, **kwargs) @@ -389,32 +387,32 @@ def download_mscoco_split(target_split): # TODO figure out minimal torchvision version needed instead of decrementing if ds[0][1] == 1: ds.target_transform = lambda y:y-1 - ds.classes = classnames["flowers"] + ds.classes = default_or_custom_classnames["flowers"] elif dataset_name == "mnist": assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" ds = MNIST(root=root, train=train, transform=transform, download=download, **kwargs) - ds.classes = classnames["mnist"] + ds.classes = default_or_custom_classnames["mnist"] elif dataset_name == "stl10": assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" ds = STL10(root=root, split=split, transform=transform, download=download, **kwargs) elif dataset_name == "eurosat": warnings.warn(f"split argument ignored for `{dataset_name}`, there are no pre-defined train/test splits for this dataset") ds = EuroSAT(root=root, transform=transform, download=download, **kwargs) - ds.classes = classnames["eurosat"] + ds.classes = default_or_custom_classnames["eurosat"] elif dataset_name == "gtsrb": assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}" ds = GTSRB(root=root, split=split, transform=transform, download=download, **kwargs) - ds.classes = classnames["gtsrb"] + ds.classes = default_or_custom_classnames["gtsrb"] elif dataset_name == "country211": assert split in ("train", "valid", "test"), f"Only `train` and `valid` and `test` split available for {dataset_name}" ds = Country211(root=root, split=split, transform=transform, download=download, **kwargs) - ds.classes = classnames["country211"] + ds.classes = default_or_custom_classnames["country211"] elif dataset_name == "pcam": assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}" # Dead link. Fixed by this PR on torchvision https://github.com/pytorch/vision/pull/5645 # TODO figure out minimal torchvision version needed ds = PCAM(root=root, split=split, transform=transform, download=download, **kwargs) - ds.classes = classnames["pcam"] + ds.classes = default_or_custom_classnames["pcam"] elif dataset_name == "renderedsst2": assert split in ("train", "val", "test"), f"Only `train` and `val` and `test` split available for {dataset_name}" ds = RenderedSST2(root=root, split=split, transform=transform, download=download, **kwargs) @@ -432,7 +430,7 @@ def download_mscoco_split(target_split): call(f"unzip fer2013.zip -d {root}", shell=True) root = os.path.join(root, "train" if train else "test") ds = ImageFolder(root=root, transform=transform) - ds.classes = classnames["fer2013"] + ds.classes = default_or_custom_classnames["fer2013"] elif dataset_name.startswith("tfds/"): # TFDS datasets support using `timm` and `tensorflow_datasets` prefix, *name_list = dataset_name.split("/") @@ -442,18 +440,45 @@ def download_mscoco_split(target_split): # VTAB datasets support using `tensorflow_datasets` and `task_adaptation` prefix, *name_list = dataset_name.split("/") name = "/".join(name_list) - ds = build_vtab_dataset(name, download=download, split=split, data_dir=root, transform=transform, classnames=classnames) + ds = build_vtab_dataset(name, download=download, split=split, data_dir=root, transform=transform, classnames=default_or_custom_classnames) elif dataset_name.startswith("wds/"): # WebDataset support using `webdataset` library name = dataset_name.split("/", 1)[1] ds = build_wds_dataset(name, transform=transform, split=split, data_dir=root, cache_dir=wds_cache_dir) + # WDS specify classnames and templates on this own. elif dataset_name == "dummy": ds = Dummy() else: raise ValueError(f"Unsupported dataset: {dataset_name}.") - ds.templates = templates + + default_dataset_for_templates = "imagenet1k" + if dataset_name.startswith("tfds/") or dataset_name.startswith("vtab/") or dataset_name.startswith("wds/"): + prefix, *rest = dataset_name.split("/") + short_name = "/".join(rest) + keys_to_lookup = [dataset_name, short_name, default_dataset_for_templates] + else: + keys_to_lookup = [dataset_name, default_dataset_for_templates] + + # Specify templates for the dataset (if needed) + if custom_templates: + # We override with custom templates ONLY if they are provided. + ds.templates = value_from_first_key_found(custom_templates, keys=keys_to_lookup) + assert ds.templates is not None, f"Templates not specified for {dataset_name}" + elif not hasattr(ds, "templates"): + # No templates specified by the dataset itself, so we use templates are packaged with CLIP benchmark (loaded from _zeroshot_classification_templates.json). + ds.templates = value_from_first_key_found(default_templates, keys=keys_to_lookup) + assert ds.templates is not None, f"Templates not specified for {dataset_name}" + else: + # dataset has templates already (e.g., WDS case), so we keep it as is. + pass return ds +def value_from_first_key_found(dic, keys): + for k in keys: + if k in dic: + return dic[k] + + class Dummy(): def __init__(self):