speechbrain.dataio.dataloader module¶
PyTorch compatible DataLoaders
Essentially we extend PyTorch DataLoader by adding the ability to save the data loading state, so that a checkpoint may be saved in the middle of an epoch.
Example
>>> import torch
>>> from speechbrain.utils.checkpoints import Checkpointer
>>> # An example "dataset" and its loader
>>> dataset = torch.randn(10, 1)
>>> dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> # Setup the checkpointer:
>>> tmpdir = getfixture('tmpdir')
>>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader})
>>> # Iterate:
>>> for i, data_point in enumerate(dataloader):
... # Here you would process the data:
... rainfall_amount_prediction = data_point * 4.
... # Now, imagine the experiment gets killed on the fifth batch:
... if i == 4:
... break
... # Luckily, you had just saved a checkpoint:
... if i == 3:
... _ = checkpointer.save_checkpoint(end_of_epoch = False)
>>> # So when you restart the experiment:
>>> new_dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader})
>>> _ = new_checkpointer.recover_if_possible()
>>> # The dataloader fast-forwards to the position where we left off:
>>> assert next(iter(new_dataloader)) == dataset[4]
- Authors:
Aku Rouhe 2020
Summary¶
Classes:
A saveable version of the PyTorch DataLoader. |
Functions:
Makes a basic DataLoader with SpeechBrain defaults. |
Reference¶
-
speechbrain.dataio.dataloader.
make_dataloader
(dataset, **loader_kwargs)[source]¶ Makes a basic DataLoader with SpeechBrain defaults.
For DynamicItemDatasets (which return dicts), use PaddedBatch as the default collate_fn.
Shuffling gets implemented by ReproducibleRandomSampler.
If the Dataset is not an IterableDataset, the DataLoader is a SaveableDataLoader.
- Parameters
dataset (Dataset) – The dataset to make a DataLoader for.
**loader_kwargs (dict) – Keyword args to DataLoader, see PyTorch DataLoader for options.
- Returns
- Return type
DataLoader
-
class
speechbrain.dataio.dataloader.
SaveableDataLoader
(*args, **kwargs)[source]¶ Bases:
Generic
[torch.utils.data.dataloader.T_co
]A saveable version of the PyTorch DataLoader.
See torch.utils.data.DataLoader for usage. This class should work exactly like the PyTorch basic DataLoader, but this can be checkpointed with SpeechBrain’s Checkpointer.
Note
1. The saveability is implemented via some unfortunately slightly magical means. 2. The data loader cannot recover after entering __iter__. Normally this is not a problem, as recovery should happen before training begins. However, just before evaluation, it is also typical to recover the checkpoint at which performance was the best. Thus, if a checkpoint is loaded after entering __iter__, we just assume it is for this reason. A warning is logged, but that is all.
-
dataset
: torch.utils.data.dataset.Dataset[T_co]¶
-
sampler
: torch.utils.data.sampler.Sampler¶
-