Skip to content

Commit

Permalink
[Fix] Fix verify dataset tool in 1.x (#1062)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ezra-Yu authored Sep 30, 2022
1 parent 8c5d86a commit 080eb79
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions tools/misc/verify_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from mmengine import (Config, DictAction, track_parallel_progress,
track_progress)

from mmcls.datasets import PIPELINES, build_dataset
from mmcls.datasets import build_dataset
from mmcls.registry import TRANSFORMS


def parse_args():
Expand Down Expand Up @@ -46,15 +47,14 @@ def parse_args():
class DatasetValidator():
"""the dataset tool class to check if all file are broken."""

def __init__(self, dataset_cfg, log_file_path, phase):
def __init__(self, dataset_cfg, log_file_path):
super(DatasetValidator, self).__init__()
# keep only LoadImageFromFile pipeline
assert dataset_cfg.data[phase].pipeline[0][
'type'] == 'LoadImageFromFile', 'This tool is only for dataset ' \
'that needs to load image from files.'
self.pipeline = PIPELINES.build(dataset_cfg.data[phase].pipeline[0])
dataset_cfg.data[phase].pipeline = []
dataset = build_dataset(dataset_cfg.data[phase])
assert dataset_cfg.pipeline[0]['type'] == 'LoadImageFromFile', (
'This tool is only for datasets needs to load image from files.')
self.pipeline = TRANSFORMS.build(dataset_cfg.pipeline[0])
dataset_cfg.pipeline = []
dataset = build_dataset(dataset_cfg)

self.dataset = dataset
self.log_file_path = log_file_path
Expand Down Expand Up @@ -102,13 +102,22 @@ def main():
# touch output file to save broken files list.
output_path = Path(args.out_path)
if not output_path.parent.exists():
raise Exception('log_file parent directory not found.')
raise Exception("Path '--out-path' parent directory not found.")
if output_path.exists():
os.remove(output_path)
output_path.touch()

# do valid
validator = DatasetValidator(cfg, output_path, args.phase)
if args.phase == 'train':
dataset_cfg = cfg.train_dataloader.dataset
elif args.phase == 'val':
dataset_cfg = cfg.val_dataloader.dataset
elif args.phase == 'test':
dataset_cfg = cfg.test_dataloader.dataset
else:
raise ValueError("'--phase' only support 'train', 'val' and 'test'.")

# do validate
validator = DatasetValidator(dataset_cfg, output_path)

if args.num_process > 1:
# The default chunksize calcuation method of Pool.map
Expand Down

0 comments on commit 080eb79

Please sign in to comment.