speechbrain.dataio.dataloader module

PyTorch compatible DataLoaders

Essentially we extend PyTorch DataLoader by adding the ability to save the data loading state, so that a checkpoint may be saved in the middle of an epoch.

Example

>>> import torch
>>> from speechbrain.utils.checkpoints import Checkpointer
>>> # An example "dataset" and its loader
>>> dataset = torch.randn(10, 1)
>>> dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> # Setup the checkpointer:
>>> tmpdir = getfixture('tmpdir')
>>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader})
>>> # Iterate:
>>> for i, data_point in enumerate(dataloader):
...     # Here you would process the data:
...     rainfall_amount_prediction = data_point * 4.
...     # Now, imagine the experiment gets killed on the fifth batch:
...     if i == 4:
...         break
...     # Luckily, you had just saved a checkpoint:
...     if i == 3:
...         _ = checkpointer.save_checkpoint(end_of_epoch = False)
>>> # So when you restart the experiment:
>>> new_dataloader = SaveableDataLoader(dataset, num_workers = 3)
>>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader})
>>> _ = new_checkpointer.recover_if_possible()
>>> # The dataloader fast-forwards to the position where we left off:
>>> assert next(iter(new_dataloader)) == dataset[4]
Authors:
  • Aku Rouhe 2020

Summary

Classes:

SaveableDataLoader

A saveable version of the PyTorch DataLoader.

Functions:

make_dataloader

Makes a basic DataLoader with SpeechBrain defaults.

Reference

speechbrain.dataio.dataloader.make_dataloader(dataset, **loader_kwargs)[source]

Makes a basic DataLoader with SpeechBrain defaults.

For DynamicItemDatasets (which return dicts), use PaddedBatch as the default collate_fn.

Shuffling gets implemented by ReproducibleRandomSampler.

If the Dataset is not an IterableDataset, the DataLoader is a SaveableDataLoader.

Parameters
  • dataset (Dataset) – The dataset to make a DataLoader for.

  • **loader_kwargs (dict) – Keyword args to DataLoader, see PyTorch DataLoader for options.

Returns

Return type

DataLoader

class speechbrain.dataio.dataloader.SaveableDataLoader(*args, **kwargs)[source]

Bases: Generic[torch.utils.data.dataloader.T_co]

A saveable version of the PyTorch DataLoader.

See torch.utils.data.DataLoader for usage. This class should work exactly like the PyTorch basic DataLoader, but this can be checkpointed with SpeechBrain’s Checkpointer.

Note

1. The saveability is implemented via some unfortunately slightly magical means. 2. The data loader cannot recover after entering __iter__. Normally this is not a problem, as recovery should happen before training begins. However, just before evaluation, it is also typical to recover the checkpoint at which performance was the best. Thus, if a checkpoint is loaded after entering __iter__, we just assume it is for this reason. A warning is logged, but that is all.

dataset: torch.utils.data.dataset.Dataset[torch.utils.data.dataloader.T_co]
batch_size: Optional[int]
num_workers: int
pin_memory: bool
drop_last: bool
timeout: float
sampler: torch.utils.data.sampler.Sampler
prefetch_factor: int