-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Design discussion] Batches and resampling #97
Comments
I'm not all too familiar with the way MLJFlux.jl currently handles this, but I'll try to give some comments and what our current approach on the deep learning side is. In FastAI.jl, batching is mostly an implementation detail that is independent of the semantic transformations applied to observations. Cross-validation is rarely done since training time is usually fairly long already, but data containers are split into train/test sets and the training set is reshuffled every epoch. Both are done as lazy operations on the containers so they don't incur any performance problems. I'm not sure if that's what you mean, but I guess this does "break the batches", but many datasets (e.g. in computer vision) are larger than memory anyway and so each observation is loaded and batched every single epoch. Since reloading batches is unavoidable, doing it as fast as possible is handled by DataLoaders.jl with some of the important performance aspects explained here. Regarding the interface used in MLJFlux.jl: I am curious if it might make sense to separate an MLJFlux model into a) the Flux model and b) the training hyperparameters. This would allow factoring out common configuration parameters. I'm not sure if it makes sense to make |
Without speaking directly to MLJFlux, this is how
As a consequence, applying a resampling iterator on top of a batch iterator will resample by batches. In deep learning, the batching is usually the last step before augmentation. You do resampling, shuffling, splitting on the entire dataset, then batch each split, then augment each batch. k-folds style splitting is not generally done, but if you were to do it, it would make more sense to me to do that on the dataset than the batch. |
I would also add that the type of data in a batch can be quite heterogeneous. Think nested dicts, strings and whatnot. If MLJ(Flux) wants to handle that level of complexity, then it's worth making a distinction between batches pre- and post-collation. The former is a SOA collection of observations with the same structure, while the latter is an AOS structure which is only required because models expect to work with contiguous memory regions. |
If I'm understanding this correctly, this means caching batches? That's something you can do (MLDataPattern.jl does), but I think as the dataset scales to DL sizes, re-batching is quite cheap (especially when the memory could be used for something else). |
@darsnack @ToucheSir @lorenzoh Thank you for your comments. These are very helpful. I will take a closer look at DataLoaders.jl, which would be nice to integrate if possible. I appear to be mistaken that re-batching is a possible performance bottleneck, partly in view of the fact that re-sampling beyond a holdout set is not so common in DL. This is good to know. I may misunderstand, but it seems there is not complete consenus here on where the responsibility for batching should lie. On the one hand, as @lorenzoh suggests, batching should be something handled on the training side, not in data preparation:
On the other hand, there is the practice pointed out by @darsnack that augmentation is performed after batching (so each batching gets the same amount of augmentation?). This suggests batching is a pre-training data preparation step, no?
In the current approach, MLJFlux considers @lorenzoh Regarding this comment:
I think this is essentially the case. An MLJ "model" is a struct with a bunch of "training parameters" as fields (regularisation, parameters, batch size, optimiser choice, and so forth) plus a single field called a builder that furnishes instructions on how to build the Flux "model" (new meaning) once the data has been inspected. The Flux model is something that can be accessed using the generic |
Augmentation can be done on or offline. Usually you would do offline augmentation as either a performance measure or as a way to increase the effective size of a dataset. The latter is possible to do in an online fashion as well, but it requires quite a bit more thought about how attributes like dataset length change as a result. In any case, MLJ(Flux) probably can just assume that the user will have done whatever offline augmentation is necessary already? WRT per-sample vs per-batch, I think the vast majority of augmentations are applied per-sample and prior to collation. That is, even if the samples in each batch are determined ahead of time, the augmentation function will only see one sample at a time. This is why I noted the distinction between (1) a collection of samples and (2) a collated set of tensors, both of which unfortunately have been given the moniker "batch". Augmentation usually happens before (1) or between (1) and (2). |
Yes, to be clear, I didn't mean that augmentation happens on the full batch. It is still a per-sample operation. There are multiple ways to apply an augmentation. It can either be distinct from training where an augmentation Still, in both cases, batching is part of training (so I agree with @lorenzoh). You'll notice that if the augmentations are per-sample, then it doesn't really matter semantically whether this is done when you fetch sample |
There is a fundamental problem with the way we handle batches, at least as far applications where the extra GPU speed gained with batching is important. The issue is the incompatibility of batching with observation resampling, as conventionally understood.
So, for example, if we are wrapping a model for cross-validation, then the observations get split up multiple times into different test/train sets. At present, a "batch" is understood to consist of multiple "observations", which means that resampling a MLJFlux model breaks the batches, an expensive operation for large data.
I'm guessing this is a very familiar problem to people in deep learning and so am copying some of them in for comment and will post a link on slack. The solution I am considering for MLJ is to regard a "batch" of images as an unbreakable object that we consequently view as an observation, by definition. It would be natural do introduce a new parametric scientific type
Batch{SomeAtomicScitype}
to articulate a model's participation in this convention.Thoughts anyone?
Some consequences of this breaking change would be:
batch_size
disappears as a hyper-parameter of MLJFlux models, at least forImageClassifier
, but probably for all the models, for simplicity. So changing the batch size becomes the responsibility of a pre-processing transformer external to the model. I need to give some thought to transformers that reduce the number of observations, when inserted into MLJ pipelines (and learning networks, more generally). If that works, "smart" training of MLJ pipelines would mean no "re-batching" when retraining the composite model, unless the batch size changes, which is good.with this change one could implement the
reformat
andselectrows
(now same as "select batches") functions that constitute buy-in for MLJ's new data front-end.The text was updated successfully, but these errors were encountered: