speechbrain.dataio.batch module

Batch collation

Authors
  • Aku Rouhe 2020

Summary

Classes:

BatchsizeGuesser

Try to figure out the batchsize, but never error out

PaddedBatch

Collate_fn when examples are dicts and have variable-length sequences.

PaddedData

Reference

class speechbrain.dataio.batch.PaddedData(data, lengths)

Bases: tuple

data

Alias for field number 0

lengths

Alias for field number 1

class speechbrain.dataio.batch.PaddedBatch(examples, padded_keys=None, device_prep_keys=None, padding_func=<function batch_pad_right>, padding_kwargs={}, apply_default_convert=True, nonpadded_stack=True)[source]

Bases: object

Collate_fn when examples are dicts and have variable-length sequences.

Different elements in the examples get matched by key. All numpy tensors get converted to Torch (PyTorch default_convert) Then, by default, all torch.Tensor valued elements get padded and support collective pin_memory() and to() calls. Regular Python data types are just collected in a list.

Parameters:
  • examples (list) – List of example dicts, as produced by Dataloader.

  • padded_keys (list, None) – (Optional) List of keys to pad on. If None, pad all torch.Tensors

  • device_prep_keys (list, None) – (Optional) Only these keys participate in collective memory pinning and moving with to(). If None, defaults to all items with torch.Tensor values.

  • padding_func (callable, optional) – Called with a list of tensors to be padded together. Needs to return two tensors: the padded data, and another tensor for the data lengths.

  • padding_kwargs (dict) – (Optional) Extra kwargs to pass to padding_func. E.G. mode, value

  • apply_default_convert (bool) – Whether to apply PyTorch default_convert (numpy to torch recursively, etc.) on all data. Default:True, usually does the right thing.

  • nonpadded_stack (bool) – Whether to apply PyTorch-default_collate-like stacking on values that didn’t get padded. This stacks if it can, but doesn’t error out if it cannot. Default:True, usually does the right thing.

Example

>>> batch = PaddedBatch([
...     {"id": "ex1", "foo": torch.Tensor([1.])},
...     {"id": "ex2", "foo": torch.Tensor([2., 1.])}])
>>> # Attribute or key-based access:
>>> batch.id
['ex1', 'ex2']
>>> batch["id"]
['ex1', 'ex2']
>>> # torch.Tensors get padded
>>> type(batch.foo)
<class 'speechbrain.dataio.batch.PaddedData'>
>>> batch.foo.data
tensor([[1., 0.],
        [2., 1.]])
>>> batch.foo.lengths
tensor([0.5000, 1.0000])
>>> # Batch supports collective operations:
>>> _ = batch.to(dtype=torch.half)
>>> batch.foo.data
tensor([[1., 0.],
        [2., 1.]], dtype=torch.float16)
>>> batch.foo.lengths
tensor([0.5000, 1.0000], dtype=torch.float16)
>>> # Numpy tensors get converted to torch and padded as well:
>>> import numpy as np
>>> batch = PaddedBatch([
...     {"wav": np.asarray([1,2,3,4])},
...     {"wav": np.asarray([1,2,3])}])
>>> batch.wav  # +ELLIPSIS
PaddedData(data=tensor([[1, 2,...
>>> # Basic stacking collation deals with non padded data:
>>> batch = PaddedBatch([
...     {"spk_id": torch.tensor([1]), "wav": torch.tensor([.1,.0,.3])},
...     {"spk_id": torch.tensor([2]), "wav": torch.tensor([.2,.3,-.1])}],
...     padded_keys=["wav"])
>>> batch.spk_id
tensor([[1],
        [2]])
>>> # And some data is left alone:
>>> batch = PaddedBatch([
...     {"text": ["Hello"]},
...     {"text": ["How", "are", "you?"]}])
>>> batch.text
[['Hello'], ['How', 'are', 'you?']]
__iter__()[source]

Iterates over the different elements of the batch.

Return type:

Iterator over the batch.

Example

>>> batch = PaddedBatch([
...     {"id": "ex1", "val": torch.Tensor([1.])},
...     {"id": "ex2", "val": torch.Tensor([2., 1.])}])
>>> ids, vals = batch
>>> ids
['ex1', 'ex2']
pin_memory()[source]

In-place, moves relevant elements to pinned memory.

to(*args, **kwargs)[source]

In-place move/cast relevant elements.

Passes all arguments to torch.Tensor.to, see its documentation.

at_position(pos)[source]

Gets the position.

property batchsize

Returns the bach size

class speechbrain.dataio.batch.BatchsizeGuesser[source]

Bases: object

Try to figure out the batchsize, but never error out

If this cannot figure out anything else, will fallback to guessing 1

Example

>>> guesser = BatchsizeGuesser()
>>> # Works with simple tensors:
>>> guesser(torch.randn((2,3)))
2
>>> # Works with sequences of tensors:
>>> guesser((torch.randn((2,3)), torch.randint(high=5, size=(2,))))
2
>>> # Works with PaddedBatch:
>>> guesser(PaddedBatch([{"wav": [1.,2.,3.]}, {"wav": [4.,5.,6.]}]))
2
>>> guesser("Even weird non-batches have a fallback")
1
find_suitable_method(batch)[source]

Try the different methods and note which worked

attr_based(batch)[source]

Implementation of attr_based.

torch_tensor_bs(batch)[source]

Implementation of torch_tensor_bs.

len_of_first(batch)[source]

Implementation of len_of_first.

len_of_iter_first(batch)[source]

Implementation of len_of_iter_first.

fallback(batch)[source]

Implementation of fallback.