speechbrain.dataio.sampler module

PyTorch compatible samplers.

These determine the order of iteration through a dataset.

Authors:
  • Aku Rouhe 2020

  • Samuele Cornell 2020

  • Ralf Leibold 2020

  • Artem Ploujnikov 2021

  • Andreas Nautsch 2021

Summary

Classes:

BalancingDataSampler

A data sampler that takes a single key from the dataset and ensures an approximately equal distribution by that key

ConcatDatasetBatchSampler

This sampler is built to work with a standard Pytorch ConcatDataset.

DistributedSamplerWrapper

This wrapper allows using any sampler (for example batch) with Distributed Data Parallel (DDP) correctly.

DynamicBatchSampler

This BatchSampler batches examples together by grouping them by their length.

ReproducibleRandomSampler

A modification of RandomSampler which always returns the same values.

ReproducibleWeightedRandomSampler

A reproducible modification of WeightedRandomSampler.

Reference

class speechbrain.dataio.sampler.ReproducibleRandomSampler(data_source, seed=563375142, epoch=0, **kwargs)[source]

Bases: RandomSampler

A modification of RandomSampler which always returns the same values.

Also look at torch.utils.data.RandomSampler. This has mostly the same behaviour and arguments, except for adding ‘seed’ and ‘epoch’ and not supporting ‘generator’.

Note

Call set_epoch before every epoch. Otherwise, the sampler will produce the same sequence of indices every epoch.

Parameters:
  • data_source (Dataset) – The data source to sample indices for.

  • seed (int) – The base seed to use for the random number generator. It is recommended to use a value which has a good mix of 0 and 1 bits.

  • epoch (int) – The epoch to start at.

Example

>>> import torch
>>> from speechbrain.utils.checkpoints import Checkpointer
>>> from speechbrain.dataio.dataloader import SaveableDataLoader
>>> # An example "dataset"
>>> dataset = torch.arange(10).unsqueeze(1)
>>> # Create the random sampler:
>>> sampler = ReproducibleRandomSampler(dataset)
>>> dataloader = SaveableDataLoader(dataset, sampler = sampler,
...     num_workers = 3)
>>> # Setup the checkpointer.
>>> # Note that the sampler doesn't need to be saved itself.
>>> tmpdir = getfixture('tmpdir')
>>> checkpointer = Checkpointer(tmpdir, {"dataloader": dataloader})
>>> # Iterate:
>>> subset = []
>>> for i, data_point in enumerate(dataloader):
...     # Say you save a checkpoint on the fourth batch:
...     if i == 3:
...         _ = checkpointer.save_checkpoint(end_of_epoch = False)
...     # So let's save the numbers you would get if you continue
...     if i >= 4:
...         subset.append(data_point.item())
>>> # What if instead you had to restart the experiment?
>>> new_sampler = ReproducibleRandomSampler(dataset)
>>> new_dataloader = SaveableDataLoader(dataset, sampler = new_sampler,
...        num_workers = 3)
>>> new_checkpointer = Checkpointer(tmpdir, {"dataloader": new_dataloader})
>>> _ = new_checkpointer.recover_if_possible()
>>> # You'll get the same random order again:
>>> new_subset = [data_point.item() for data_point in new_dataloader]
>>> assert subset == new_subset
set_epoch(epoch)[source]

You can also just access self.epoch, but we maintain this interface to mirror torch.utils.data.distributed.DistributedSampler

data_source: Sized
replacement: bool
class speechbrain.dataio.sampler.ReproducibleWeightedRandomSampler(weights, num_samples, replacement, seed=129491412, epoch=0, **kwargs)[source]

Bases: WeightedRandomSampler

A reproducible modification of WeightedRandomSampler.

Also look at torch.utils.data.WeightedRandomSampler. This has the the same behaviour and arguments, except for adding ‘seed’ and ‘epoch’ and not supporting ‘generator’.

Note

Call set_epoch before every epoch. Otherwise, the sampler will produce the same sequence of indices every epoch.

Parameters:
  • weights (sequence of float) – Weights for each index. Doesn’t need to sum to one.

  • num_samples (int) – Number of samples to draw

  • replacement (bool) – To draw with replacement or not (within an epoch of num_samples).

  • seed (int) – The base seed to use for the random number generator. It is recommended to use a value which has a good mix of 0 and 1 bits.

  • epoch (int) – The epoch to start at.

Example

>>> a = ReproducibleWeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)
>>> b = ReproducibleWeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)
>>> list(a)
[3, 1, 4, 4, 4]
>>> list(b)
[3, 1, 4, 4, 4]
>>> a.set_epoch(1)
>>> list(a)
[4, 5, 4, 4, 3]
>>> b.set_epoch(1)
>>> list(b)
[4, 5, 4, 4, 3]
set_epoch(epoch)[source]

You can also just access self.epoch, but we maintain this interface to mirror torch.utils.data.distributed.DistributedSampler

weights: Tensor
num_samples: int
replacement: bool
class speechbrain.dataio.sampler.ConcatDatasetBatchSampler(samplers, batch_sizes: (<class 'tuple'>, <class 'list'>), epoch=0)[source]

Bases: Sampler

This sampler is built to work with a standard Pytorch ConcatDataset.

It is used to retrieve elements from the different concatenated datasets placing them in the same batch with proportion specified by batch_sizes, e.g 8, 16 means each batch will be of 24 elements with the first 8 belonging to the first dataset in ConcatDataset object and the last 16 to the second. More than two datasets are supported, in that case you need to provide 3 batch sizes.

Note

Batched are drawn from the datasets till the one with smallest length is exhausted. Thus number of examples in your training epoch is dictated by the dataset whose length is the smallest.

Parameters:
  • samplers (int) – The base seed to use for the random number generator. It is recommended to use a value which has a good mix of 0 and 1 bits.

  • batch_sizes (list) – Batch sizes.

  • epoch (int) – The epoch to start at.

Example

>>> import torch
>>> from speechbrain.dataio.sampler import ConcatDatasetBatchSampler, ReproducibleRandomSampler
>>> from speechbrain.dataio.sampler import ReproducibleRandomSampler
>>> from speechbrain.dataio.dataloader import SaveableDataLoader
>>> # example "datasets"
>>> dataset1 = torch.arange(0, 10).unsqueeze(1)
>>> dataset2 = torch.arange(20, 40).unsqueeze(1)
>>> tot_dataset = torch.utils.data.ConcatDataset([dataset1, dataset2])
>>> sampler1 = ReproducibleRandomSampler(dataset1)
>>> sampler2 = ReproducibleRandomSampler(dataset2)
>>> tot_sampler = ConcatDatasetBatchSampler([sampler1, sampler2], [2, 4])
>>> dataloader = SaveableDataLoader(tot_dataset, batch_sampler = tot_sampler,
...     num_workers = 3)
>>> for data_point in dataloader:
...      assert len(data_point) == 6
...      for i in range(2):
...         assert data_point[i] in [x for x in range(0, 10)]
...      for i in range(2, 4):
...         assert data_point[i] in [x for x in range(10, 40)]
set_epoch(epoch)[source]

You can also just access self.epoch, but we maintain this interface to mirror torch.utils.data.distributed.DistributedSampler.

class speechbrain.dataio.sampler.DynamicBatchSampler(dataset, max_batch_length: int, num_buckets: int | None = None, length_func=<function DynamicBatchSampler.<lambda>>, shuffle: bool = True, batch_ordering: str = 'random', max_batch_ex: int | None = None, bucket_boundaries: ~typing.List[int] = [], lengths_list: ~typing.List[int] | None = None, seed: int = 42, epoch: int = 0, drop_last: bool = False, verbose: bool = False)[source]

Bases: Sampler

This BatchSampler batches examples together by grouping them by their length.

Every example in the batch have approximately the same length and thus padding is minimized. This enables faster training on datasets where length of examples can vary significantly (e.g Librispeech). Inspired by: https://www.tensorflow.org/api_docs/python/tf/data/experimental/bucket_by_sequence_length

Dynamic batching is performed by specifying a max_batch_length which is the upper limit for the sum of the length of examples in a batch: e.g., if ex1 has length 4, ex2 length 5 and if max_batch_length is set to 6 ex1 and ex2 will be placed, alone, in two distinct batches.

Length for each example can be obtained in two manners. If the input dataset is a DynamicItemDataset it can be obtained by specifying a length_func. Default assumes a “duration” entry is in the annotation. Length for each example can also be passed to this class upon instantiation by specifying a list containing the length for each example and passing it to lengths_list.

Examples are grouped together by defining a set of possible discrete intervals (buckets). Examples whose length fall into these intervals can be batched together.

The number of buckets can be specified by using the arg num_buckets. There is usually an optimal range for the value of this argument.

If num_buckets == 1, all examples can be batched together. You have maximum randomization but your training speed will be slower due to the fact that a large amount of the values will be padding as long and short examples can be batched together. As the number of buckets grows only examples with similar length can be grouped together. This trades-off speed with randomization. TLDR: Low number -> better randomization, High number -> faster training. NOTE THAT: if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size will be small impacting training speed and possibly performance.

The buckets can also be specified by passing a list to the bucket_boundaries argument instead of specifying a left_bucket_length and a bucket_length_multiplier.

Example

>>> import torch
>>> import speechbrain as sb
>>> from speechbrain.dataio.sampler import DynamicBatchSampler
>>> from speechbrain.dataio.dataset import DynamicItemDataset
>>> from speechbrain.dataio.dataloader import SaveableDataLoader
>>> from speechbrain.dataio.batch import PaddedBatch
>>> import numpy as np
>>> item_lengths = sorted([np.random.randint(10, 100) for x in range(20)])
>>> dataset = {"ex_{}".format(x) : {"wav" :torch.randn(x)} for x in item_lengths}
>>> dataset = DynamicItemDataset(dataset)
>>> dataset.set_output_keys(["wav"])
>>> length_func = lambda x : len(x) # trivial in this example
>>> bsampler = DynamicBatchSampler(dataset, 20, 4, length_func, shuffle=False, batch_ordering='descending')
>>> dataloader = SaveableDataLoader(dataset, batch_sampler=bsampler, collate_fn=PaddedBatch)
>>> for i, b in enumerate(dataloader):
...     data, length = b["wav"]
>>> assert data.shape[-1] == max(item_lengths)
Parameters:
  • dataset (torch.utils.data.Dataset) – Pytorch Dataset from which elements will be sampled.

  • max_batch_length (int) – Upper limit for the sum of the length of examples in a batch. Should be chosen based on your GPU memory.

  • num_buckets (int) – Number of discrete buckets used to group examples together. If num_buckets == 1, all examples can be batched together. As the number of buckets grows only examples with similar length can be grouped together. This trades-off speed with randomization. Low number -> better randomization, High number -> faster training. However if set too high the training speed will decrease. If num_buckets -> number of examples in the dataset the batch size will be small impacting training speed and possibly performance. NOTE: you have either to specify manually the bucket_boundaries or the number of buckets.

  • length_func (callable) – Function used to get length of each example from the dataset. This argument can be used only when the dataset is a Speechbrain DynamicItemDataset object. Can be anything: e.g. lambda x: x[“duration”]*16000 returns number of samples if duration key in the annotation is in seconds and the file has 16kHz sampling freq.

  • shuffle (bool) – Whether or not shuffle examples between each epoch.

  • batch_ordering (string) – If random, batches are randomly permuted; otherwise ascending or descending sorted by length.

  • max_batch_ex (int) – If set, it limits the maximum number of examples that can be in a batch superseeding max_batch_length in instances where the amount of examples will exceeed the value specified here. E.g. you have a lot of short examples and the batch size for those will be too high, you can use this argument to limit the batch size for these short examples.

  • bucket_boundaries (list) – Overrides bucket_length_multiplier and left_bucket_length by specifying manually the buckets right boundaries.

  • lengths_list (list) – Overrides length_func by passing a list containing the length of each example in the dataset. This argument must be set when the dataset is a plain Pytorch Dataset object and not a DynamicItemDataset object as length_func cannot be used on Pytorch Datasets.

  • epoch (int) – The epoch to start at.

  • drop_last (bool) – If True, the sampler will drop the last examples which have not been grouped.

  • verbose (bool) – If True, log also the stats for each batch at the first epoch.

get_durations(batch)[source]

Gets durations of the elements in the batch.

set_epoch(epoch)[source]

You can also just access self.epoch, but we maintain this interface to mirror torch.utils.data.distributed.DistributedSampler

class speechbrain.dataio.sampler.DistributedSamplerWrapper(sampler, *args, **kwargs)[source]

Bases: DistributedSampler

This wrapper allows using any sampler (for example batch) with Distributed Data Parallel (DDP) correctly.

Passing blindly the sampler to each DDP process will cause to have access within each process to all the data in the dataset instead of only a subset of it which is unique to each process. This wrapper prevents this and allows to use only a subset of the original data for each process.

Note

This is is automatically applied to any sampler in the Brain class when DDP training is used.

set_epoch(epoch)[source]

Pass set_epoch() through to DistributedSampler and the wrapper one

class speechbrain.dataio.sampler.BalancingDataSampler(dataset, key, num_samples=None, replacement=True, seed=563375142, epoch=0, **kwargs)[source]

Bases: ReproducibleWeightedRandomSampler

A data sampler that takes a single key from the dataset and ensures an approximately equal distribution by that key

Parameters:
  • dataset (DynamicItemDataset) – the dataset form which samples will be drawn

  • key (str) – the key from which samples will be taken

  • num_samples (int) – Number of samples to draw

  • replacement (bool) – To draw with replacement or not (within an epoch of num_samples).

  • seed (int) – The base seed to use for the random number generator. It is recommended to use a value which has a good mix of 0 and 1 bits.

  • epoch (int) – The epoch to start at.

Example

>>> from speechbrain.dataio.sampler import BalancingDataSampler
>>> from speechbrain.dataio.dataset import DynamicItemDataset
>>> sample_data = {
...   1: {"category": "A",
...       "text": "This is a test"},
...   2: {"category": "A",
...       "text": "This is a second test"},
...   3: {"category": "B",
...       "text": "This is a third test"}
...  }
>>> dataset = DynamicItemDataset(data=sample_data)
>>> sampler = BalancingDataSampler(
...     dataset=dataset,
...     key="category",
...     num_samples=10
... )
>>> sampler.weights
tensor([0.5000, 0.5000, 1.0000], dtype=torch.float64)
>>> it = iter(sampler)
>>> [next(it) for _ in range(10)]
[2, 2, 1, 2, 2, 0, 1, 1, 1, 2]
weights: Tensor
num_samples: int
replacement: bool