Source code for speechbrain.dataio.batch

"""Batch collation

Authors
  * Aku Rouhe 2020
"""
import collections
import torch
from speechbrain.utils.data_utils import mod_default_collate
from speechbrain.utils.data_utils import recursive_to
from speechbrain.utils.data_utils import batch_pad_right
from torch.utils.data._utils.collate import default_convert
from torch.utils.data._utils.pin_memory import (
    pin_memory as recursive_pin_memory,
)


PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"])


[docs] class PaddedBatch: """Collate_fn when examples are dicts and have variable-length sequences. Different elements in the examples get matched by key. All numpy tensors get converted to Torch (PyTorch default_convert) Then, by default, all torch.Tensor valued elements get padded and support collective pin_memory() and to() calls. Regular Python data types are just collected in a list. Arguments --------- examples : list List of example dicts, as produced by Dataloader. padded_keys : list, None (Optional) List of keys to pad on. If None, pad all torch.Tensors device_prep_keys : list, None (Optional) Only these keys participate in collective memory pinning and moving with to(). If None, defaults to all items with torch.Tensor values. padding_func : callable, optional Called with a list of tensors to be padded together. Needs to return two tensors: the padded data, and another tensor for the data lengths. padding_kwargs : dict (Optional) Extra kwargs to pass to padding_func. E.G. mode, value apply_default_convert : bool Whether to apply PyTorch default_convert (numpy to torch recursively, etc.) on all data. Default:True, usually does the right thing. nonpadded_stack : bool Whether to apply PyTorch-default_collate-like stacking on values that didn't get padded. This stacks if it can, but doesn't error out if it cannot. Default:True, usually does the right thing. Example ------- >>> batch = PaddedBatch([ ... {"id": "ex1", "foo": torch.Tensor([1.])}, ... {"id": "ex2", "foo": torch.Tensor([2., 1.])}]) >>> # Attribute or key-based access: >>> batch.id ['ex1', 'ex2'] >>> batch["id"] ['ex1', 'ex2'] >>> # torch.Tensors get padded >>> type(batch.foo) <class 'speechbrain.dataio.batch.PaddedData'> >>> batch.foo.data tensor([[1., 0.], [2., 1.]]) >>> batch.foo.lengths tensor([0.5000, 1.0000]) >>> # Batch supports collective operations: >>> _ = batch.to(dtype=torch.half) >>> batch.foo.data tensor([[1., 0.], [2., 1.]], dtype=torch.float16) >>> batch.foo.lengths tensor([0.5000, 1.0000], dtype=torch.float16) >>> # Numpy tensors get converted to torch and padded as well: >>> import numpy as np >>> batch = PaddedBatch([ ... {"wav": np.asarray([1,2,3,4])}, ... {"wav": np.asarray([1,2,3])}]) >>> batch.wav # +ELLIPSIS PaddedData(data=tensor([[1, 2,... >>> # Basic stacking collation deals with non padded data: >>> batch = PaddedBatch([ ... {"spk_id": torch.tensor([1]), "wav": torch.tensor([.1,.0,.3])}, ... {"spk_id": torch.tensor([2]), "wav": torch.tensor([.2,.3,-.1])}], ... padded_keys=["wav"]) >>> batch.spk_id tensor([[1], [2]]) >>> # And some data is left alone: >>> batch = PaddedBatch([ ... {"text": ["Hello"]}, ... {"text": ["How", "are", "you?"]}]) >>> batch.text [['Hello'], ['How', 'are', 'you?']] """ def __init__( self, examples, padded_keys=None, device_prep_keys=None, padding_func=batch_pad_right, padding_kwargs={}, apply_default_convert=True, nonpadded_stack=True, ): self.__length = len(examples) self.__keys = list(examples[0].keys()) self.__padded_keys = [] self.__device_prep_keys = [] for key in self.__keys: values = [example[key] for example in examples] # Default convert usually does the right thing (numpy2torch etc.) if apply_default_convert: values = default_convert(values) if (padded_keys is not None and key in padded_keys) or ( padded_keys is None and isinstance(values[0], torch.Tensor) ): # Padding and PaddedData self.__padded_keys.append(key) padded = PaddedData(*padding_func(values, **padding_kwargs)) setattr(self, key, padded) else: # Default PyTorch collate usually does the right thing # (convert lists of equal sized tensors to batch tensors, etc.) if nonpadded_stack: values = mod_default_collate(values) setattr(self, key, values) if (device_prep_keys is not None and key in device_prep_keys) or ( device_prep_keys is None and isinstance(values[0], torch.Tensor) ): self.__device_prep_keys.append(key) def __len__(self): return self.__length def __getitem__(self, key): if key in self.__keys: return getattr(self, key) else: raise KeyError(f"Batch doesn't have key: {key}")
[docs] def __iter__(self): """Iterates over the different elements of the batch. Example ------- >>> batch = PaddedBatch([ ... {"id": "ex1", "val": torch.Tensor([1.])}, ... {"id": "ex2", "val": torch.Tensor([2., 1.])}]) >>> ids, vals = batch >>> ids ['ex1', 'ex2'] """ return iter((getattr(self, key) for key in self.__keys))
[docs] def pin_memory(self): """In-place, moves relevant elements to pinned memory.""" for key in self.__device_prep_keys: value = getattr(self, key) pinned = recursive_pin_memory(value) setattr(self, key, pinned) return self
[docs] def to(self, *args, **kwargs): """In-place move/cast relevant elements. Passes all arguments to torch.Tensor.to, see its documentation. """ for key in self.__device_prep_keys: value = getattr(self, key) moved = recursive_to(value, *args, **kwargs) setattr(self, key, moved) return self
[docs] def at_position(self, pos): """Gets the position.""" key = self.__keys[pos] return getattr(self, key)
@property def batchsize(self): """Returns the bach size""" return self.__length
[docs] class BatchsizeGuesser: """Try to figure out the batchsize, but never error out If this cannot figure out anything else, will fallback to guessing 1 Example ------- >>> guesser = BatchsizeGuesser() >>> # Works with simple tensors: >>> guesser(torch.randn((2,3))) 2 >>> # Works with sequences of tensors: >>> guesser((torch.randn((2,3)), torch.randint(high=5, size=(2,)))) 2 >>> # Works with PaddedBatch: >>> guesser(PaddedBatch([{"wav": [1.,2.,3.]}, {"wav": [4.,5.,6.]}])) 2 >>> guesser("Even weird non-batches have a fallback") 1 """ def __init__(self): self.method = None def __call__(self, batch): try: return self.method(batch) except: # noqa: E722 return self.find_suitable_method(batch)
[docs] def find_suitable_method(self, batch): """Try the different methods and note which worked""" try: bs = self.attr_based(batch) self.method = self.attr_based return bs except: # noqa: E722 pass try: bs = self.torch_tensor_bs(batch) self.method = self.torch_tensor_bs return bs except: # noqa: E722 pass try: bs = self.len_of_first(batch) self.method = self.len_of_first return bs except: # noqa: E722 pass try: bs = self.len_of_iter_first(batch) self.method = self.len_of_iter_first return bs except: # noqa: E722 pass # Last ditch fallback: bs = self.fallback(batch) self.method = self.fallback(batch) return bs
[docs] def attr_based(self, batch): """Implementation of attr_based.""" return batch.batchsize
[docs] def torch_tensor_bs(self, batch): """Implementation of torch_tensor_bs.""" return batch.shape[0]
[docs] def len_of_first(self, batch): """Implementation of len_of_first.""" return len(batch[0])
[docs] def len_of_iter_first(self, batch): """Implementation of len_of_iter_first.""" return len(next(iter(batch)))
[docs] def fallback(self, batch): """Implementation of fallback.""" return 1