speechbrain.dataio.dataloader module
PyTorch compatible DataLoaders
Essentially we extend PyTorch DataLoader by adding the ability to save the data loading state, so that a checkpoint may be saved in the middle of an epoch.
Example
>>> import torch
>>> from speechbrain.utils.checkpoints import Checkpointer
>>> # An example "dataset" and its loader
>>> dataset = torch.randn(10, 1)
>>> dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> # Setup the checkpointer:
>>> tmpdir = getfixture('tmpdir')
>>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader})
>>> # Iterate:
>>> for i, data_point in enumerate(dataloader):
... # Here you would process the data:
... rainfall_amount_prediction = data_point * 4.
... # Now, imagine the experiment gets killed on the fifth batch:
... if i == 4:
... break
... # Luckily, you had just saved a checkpoint:
... if i == 3:
... _ = checkpointer.save_checkpoint(end_of_epoch = False)
>>> # So when you restart the experiment:
>>> new_dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader})
>>> _ = new_checkpointer.recover_if_possible()
>>> # The dataloader fast-forwards to the position where we left off:
>>> assert next(iter(new_dataloader)) == dataset[4]
- Authors:
Aku Rouhe 2020
Summary
Classes:
Loops an underlying iterable indefinitely, with nominal epoch lengths |
|
A saveable version of the PyTorch DataLoader. |
Functions:
Makes a basic DataLoader with SpeechBrain defaults. |
Reference
- speechbrain.dataio.dataloader.make_dataloader(dataset, looped_nominal_epoch=None, **loader_kwargs)[source]
Makes a basic DataLoader with SpeechBrain defaults.
For DynamicItemDatasets (which return dicts), use PaddedBatch as the default collate_fn.
Shuffling gets implemented by ReproducibleRandomSampler.
If the Dataset is not an IterableDataset, the DataLoader is a SaveableDataLoader.
If the Dataset is a webdataset.dataset.Composable, set default batch_size = None.
Can also loop over the underlying dataloader continuously, and stop iterations at nominal epoch lengths.
- Parameters
dataset (Dataset) – The dataset to make a DataLoader for.
looped_nominal_epoch (None, int) – If an integer is given, loop the underlying DataLoader infinitely and set a nominal epoch length in batches (or whatever the DataLoader yields).
**loader_kwargs (dict) – Keyword args to DataLoader, see PyTorch DataLoader for options.
- Returns
DataLoader – If looped_nominal_epoch is None
LoopedLoader – If looped_nominal_epoch is not None
- class speechbrain.dataio.dataloader.SaveableDataLoader(*args, **kwargs)[source]
Bases:
DataLoader
A saveable version of the PyTorch DataLoader.
See torch.utils.data.DataLoader for usage. This class should work exactly like the PyTorch basic DataLoader, but this can be checkpointed with SpeechBrain’s Checkpointer.
Note
1. The saveability is implemented via some unfortunately slightly magical means. 2. The data loader cannot recover after entering __iter__. Normally this is not a problem, as recovery should happen before training begins. However, just before evaluation, it is also typical to recover the checkpoint at which performance was the best. Thus, if a checkpoint is loaded after entering __iter__, we just assume it is for this reason. A warning is logged, but that is all.
- class speechbrain.dataio.dataloader.LoopedLoader(loader, epoch_length, batchsize_fn=None)[source]
Bases:
object
Loops an underlying iterable indefinitely, with nominal epoch lengths
This is useful for working with IterableDatasets, and particularly webdataset-style loading. We recommend using
.repeat()
on the webdataset IterableDataset instance, so that the underlying dataloader naturally continues for ever.- Parameters
loader (iterable) – A DataLoader or other iterable that is looped repeatedly.
epoch_length (int) – The length of the nominal epoch. After this many steps, raises StopIteration