-
Notifications
You must be signed in to change notification settings - Fork 212
[Feature request] Compatibility with iterable-style datasets #1237
Comments
For a little more context, I'll paste below example code for a custom import os
import glob
import pickle
import numpy as np
import cv2
import torch
import torchvision.transforms as T
import warnings
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import pytorch_lightning as pl
from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin
from nvidia import dali
from nvidia.dali import pipeline_def, types, fn
from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy
# Read label map (dict, like 1: person, 2: car, etc.)
with open('coco_idx2label', 'rb') as f:
idx2label = pickle.load(f)
# Get urls (.tar file paths)
train_dali_urls = sorted(glob.glob(os.path.join(os.getcwd(), 'coco_shards_dali', 'train*')))
val_dali_urls = sorted(glob.glob(os.path.join(os.getcwd(), 'coco_shards_dali', 'val*')))
# For example:
# ['/home/ubuntu/data/coco_shards_dali/train-000000.tar',
# '/home/ubuntu/data/coco_shards_dali/train-000001.tar',
# ...
# '/home/ubuntu/data/coco_shards_dali/train-000031.tar']
class DataModuleClass(pl.LightningDataModule):
def __init__(self,
idx2label,
train_urls,
val_urls=None,
batch_size=16,
num_workers=os.cpu_count() // torch.cuda.device_count(),
mean=[103.530, 116.280, 123.675],
std=[57.375, 57.120, 58.395],
seed=42):
#Define required parameters here
self.idx2label = idx2label
self.train_urls = train_urls
self.val_urls = val_urls
self.batch_size = batch_size
self.num_workers = num_workers
self.mean = mean
self.std = std
self.seed = seed
self.prepare_data_per_node = False
self._log_hyperparams = False
def prepare_data(self):
# Define steps that should be done
# on only one GPU, like getting data.
pass
def setup(self, stage=None):
# Define steps that should be done on
# every GPU, like splitting data, applying
# transform etc.
# Create train and val dataloaders
if hasattr(self.trainer, 'local_rank'):
device_id = self.trainer.local_rank
shard_id = self.trainer.global_rank
num_shards = self.trainer.world_size
else:
warnings.warn('DataModule setup called before trainer init, using default device_id, shard_id, num_shards')
device_id = 0
shard_id = 0
num_shards = 1
train_pipe = self._wds_pipeline(urls=self.train_urls,
batch_size=self.batch_size,
num_threads=self.num_workers,
device='gpu',
device_id=device_id,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=True,
seed=self.seed,
train=True)
class LightningWrapper(DALIGenericIterator):
def __init__(self, *kargs, **kvargs):
super().__init__(*kargs, **kvargs)
def __next__(self):
item = super().__next__()
images = item[0]['images']
bboxes = item[0]['bboxes']
labels = item[0]['labels']
return {'images': images, 'bboxes': bboxes, 'labels': labels}
self.train_loader = LightningWrapper(
train_pipe,
['images', 'bboxes', 'labels'],
reader_name='Reader',
last_batch_policy=LastBatchPolicy.PARTIAL,
auto_reset=True)
if self.val_urls:
val_pipe = self._wds_pipeline(urls=self.val_urls,
batch_size=self.batch_size,
num_threads=self.num_workers,
device='gpu',
device_id=device_id,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=False,
seed=self.seed,
train=False)
self.val_loader = LightningWrapper(
val_pipe,
['images', 'bboxes', 'labels'],
reader_name='Reader',
last_batch_policy=LastBatchPolicy.PARTIAL,
auto_reset=True)
def train_dataloader(self):
# Return DataLoader for Training Data here
return self.train_loader
def val_dataloader(self):
# Return DataLoader for Validation Data here
if self.val_urls is not None:
return self.val_loader
def _decode_augment(self, images, bboxes, labels, device, seed=0, fp16=True, train=True):
bboxes = fn.reshape(bboxes, shape=[64,4])
# Adjust boxes due to rounding issues with xyWH format
bboxes = dali.math.clamp(bboxes, lo=0.0, hi=1.0)
xy = bboxes[:,0:2]
wh = bboxes[:,2:4]
wh -= dali.math.max(0.0, (xy+wh) - 1.0)
bboxes = fn.cat(xy,wh, axis=1)
if train:
aspect_ratio = [0.5, 2.0]
thresholds=[0, 0.1, 0.3, 0.5, 0.7, 0.9]
scaling=[0.3, 1.0]
else:
aspect_ratio = [1.0, 1.0]
thresholds= [0.9]
scaling = [1.0, 1.0]
#input_shape = fn.slice(fn.cast(fn.peek_image_shape(images), dtype=types.INT32), 0, 2, axes=[0])
crop_begin, crop_size, bboxes, labels = fn.random_bbox_crop(bboxes, labels,
device='cpu',
aspect_ratio=aspect_ratio,
thresholds=thresholds,
scaling=scaling,
bbox_layout='xyWH',
allow_no_crop=True,
num_attempts=50)
#images = fn.decoders.image(images, device='mixed', output_type=types.RGB)
images = fn.decoders.image_slice(images, crop_begin, crop_size,
device='mixed' if device == 'gpu' else 'cpu',
output_type=types.RGB)
if train:
flip_coin = fn.random.coin_flip(probability=0.5)
else:
flip_coin = fn.random.coin_flip(probability=0.0)
images = fn.resize(images, resize_x=416, resize_y=416,
min_filter=types.DALIInterpType.INTERP_TRIANGULAR)
if train:
saturation = fn.random.uniform(range=[0.5, 1.5])
contrast = fn.random.uniform(range=[0.5, 1.5])
brightness = fn.random.uniform(range=[0.875, 1.125])
hue = fn.random.uniform(range=[-0.5, 0.5])
images = fn.hsv(images, dtype=types.FLOAT, hue=hue, saturation=saturation) # use float to avoid clipping and
# quantizing the intermediate result
images = fn.brightness_contrast(images,
contrast_center = 128, # input is in float, but in 0..255 range
dtype = types.UINT8,
brightness = brightness,
contrast = contrast)
dtype = types.FLOAT16 if fp16 else types.FLOAT
bboxes = fn.bb_flip(bboxes, ltrb=False, horizontal=flip_coin)
images = fn.crop_mirror_normalize(images,
crop=(416, 416),
mean=self.mean,
std=self.std,
mirror=flip_coin,
dtype=dtype,
output_layout='CHW',
pad_output=False)
# Un-normalize
bboxes *= 416
# Pad
bboxes = fn.pad(bboxes, fill_value=0.0, axes=(0,), shape=(64,))
labels = fn.pad(labels, fill_value=0.0, axes=(0,), shape=(64,))
if device == 'gpu':
labels = labels.gpu()
bboxes = bboxes.gpu()
# Cast to int
bboxes = fn.cast(bboxes, dtype=types.INT64)
labels = fn.cast(labels, dtype=types.INT64)
return images, bboxes, labels
@pipeline_def
def _wds_pipeline(self,
urls,
device,
shard_id=0,
num_shards=1,
random_shuffle=True,
train=True):
images, bboxes, labels = fn.readers.webdataset(
paths=urls,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=random_shuffle,
#device='mixed' if device == 'gpu' else 'cpu',
ext=['jpg', 'bboxes', 'labels'],
missing_component_behavior='error',
dtypes=[types.UINT8, types.FLOAT, types.INT32],
seed=self.seed,
name='Reader')
return self._decode_augment(images, bboxes=bboxes, labels=labels, device=device, seed=self.seed, train=train)
# intantiate the datamodule
datamodule = DataModuleClass(
idx2label,
train_urls=train_dali_urls,
val_urls=val_dali_urls,
batch_size=16,
)
# If you need information from the dataset to build your model, then run prepare_data() and setup() manually (Lightning ensures the method runs on the correct devices).
datamodule.prepare_data()
datamodule.setup(stage='fit') |
Hi @austinmw Thanks for your request! This is a current limitation of certain tasks in Flash where they cannot be directly used with your own datamodule because the model needs to provide the collate function for the data. IceVision models are slightly more complex again in that they need to provide the dataloader in full. I think it should be possible for us to find a workaround there as this would be a great use-case to support 😃 |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
Not stale |
Hey @austinmw just to give you an update. We have resolved in the framework most of the issues that are needed to support your use-case and now just need to document it properly and ship it in our upcoming 0.8 release. Can't give an exact timeline, but aiming for weeks rather than months. I'll come back here when I can give an updated code snippet to make this work 😃 |
Awesome news, can't wait to see, thanks! |
🚀 Feature
I'd like to be able to train iterable-style datasets instead of just map-style datasets.
(a map-style dataset in PyTorch has
__getitem__
and__len__
, whereas iterable-style datasets only have__iter__
)Motivation
Many image datasets in commercial use cases are very large, and therefore require iterable-style rather than map-style.
(Users may create custom iterable datasets, or use torchdata, webdataset, DALI, etc.)
Pitch
Vision tasks seem to require iterating over the entire dataset and building records prior to training (e.g.
ObjectDetectionData
). This does not make sense as a required step for large datasets. Say for example you want to compare models on a dataset of 10M images. Requiring iterating over this dataset for potentially several hours before training starts seems like an unnecessary and costly step. Users should be able to begin training online and have each sample from an iterable dataset provide the necessary information.Lack of this capability in my opinion prevents adoption of vision tasks in this library on large scale image training in commercial settings.
Additional context
lightning-bolts object detectors seem to support this style of dataset already.
Links:
https://pytorch.org/blog/efficient-pytorch-io-library-for-large-datasets-many-files-many-gpus/
https://github.com/pytorch/data
The text was updated successfully, but these errors were encountered: