speechbrain.dataio.dataloader module

PyTorch compatible DataLoaders

Essentially we extend PyTorch DataLoader by adding the ability to save the data loading state, so that a checkpoint may be saved in the middle of an epoch.


>>> import torch
>>> from speechbrain.utils.checkpoints import Checkpointer
>>> # An example "dataset" and its loader
>>> dataset = torch.randn(10, 1)
>>> dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> # Setup the checkpointer:
>>> tmpdir = getfixture('tmpdir')
>>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader})
>>> # Iterate:
>>> for i, data_point in enumerate(dataloader):
...     # Here you would process the data:
...     rainfall_amount_prediction = data_point * 4.
...     # Now, imagine the experiment gets killed on the fifth batch:
...     if i == 4:
...         break
...     # Luckily, you had just saved a checkpoint:
...     if i == 3:
...         _ = checkpointer.save_checkpoint(end_of_epoch = False)
>>> # So when you restart the experiment:
>>> new_dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader})
>>> _ = new_checkpointer.recover_if_possible()
>>> # The dataloader fast-forwards to the position where we left off:
>>> assert next(iter(new_dataloader)) == dataset[4]
  • Aku Rouhe 2020




Loops an underlying iterable indefinitely, with nominal epoch lengths


A saveable version of the PyTorch DataLoader.



Makes a basic DataLoader with SpeechBrain defaults.


speechbrain.dataio.dataloader.make_dataloader(dataset, looped_nominal_epoch=None, **loader_kwargs)[source]

Makes a basic DataLoader with SpeechBrain defaults.

For DynamicItemDatasets (which return dicts), use PaddedBatch as the default collate_fn.

Shuffling gets implemented by ReproducibleRandomSampler.

If the Dataset is not an IterableDataset, the DataLoader is a SaveableDataLoader.

If the Dataset is a webdataset.dataset.Composable, set default batch_size = None.

Can also loop over the underlying dataloader continuously, and stop iterations at nominal epoch lengths.

  • dataset (Dataset) – The dataset to make a DataLoader for.

  • looped_nominal_epoch (None, int) – If an integer is given, loop the underlying DataLoader infinitely and set a nominal epoch length in batches (or whatever the DataLoader yields).

  • **loader_kwargs (dict) – Keyword args to DataLoader, see PyTorch DataLoader for options.


  • DataLoader – If looped_nominal_epoch is None

  • LoopedLoader – If looped_nominal_epoch is not None

class speechbrain.dataio.dataloader.SaveableDataLoader(*args, **kwargs)[source]

Bases: DataLoader

A saveable version of the PyTorch DataLoader.

See torch.utils.data.DataLoader for usage. This class should work exactly like the PyTorch basic DataLoader, but this can be checkpointed with SpeechBrain’s Checkpointer.


1. The saveability is implemented via some unfortunately slightly magical means. 2. The data loader cannot recover after entering __iter__. Normally this is not a problem, as recovery should happen before training begins. However, just before evaluation, it is also typical to recover the checkpoint at which performance was the best. Thus, if a checkpoint is loaded after entering __iter__, we just assume it is for this reason. A warning is logged, but that is all.

dataset: Dataset[T_co]
batch_size: Optional[int]
num_workers: int
pin_memory: bool
drop_last: bool
timeout: float
sampler: Union[Sampler, Iterable]
pin_memory_device: str
prefetch_factor: int
class speechbrain.dataio.dataloader.LoopedLoader(loader, epoch_length, batchsize_fn=None)[source]

Bases: object

Loops an underlying iterable indefinitely, with nominal epoch lengths

This is useful for working with IterableDatasets, and particularly webdataset-style loading. We recommend using .repeat() on the webdataset IterableDataset instance, so that the underlying dataloader naturally continues for ever.

  • loader (iterable) – A DataLoader or other iterable that is looped repeatedly.

  • epoch_length (int) – The length of the nominal epoch. After this many steps, raises StopIteration


Saves the needed information.

load(path, end_of_epoch=True, device=None)[source]

Loads the needed information.