speechbrain.dataio.batch module
Batch collation
- Authors
Aku Rouhe 2020
Summary
Classes:
Try to figure out the batchsize, but never error out |
|
Collate_fn when examples are dicts and have variable-length sequences. |
|
PaddedData(data, lengths) |
Reference
- class speechbrain.dataio.batch.PaddedData(data, lengths)
Bases:
tuple- data
Alias for field number 0
- lengths
Alias for field number 1
- class speechbrain.dataio.batch.PaddedBatch(examples, padded_keys=None, device_prep_keys=None, padding_func=<function batch_pad_right>, padding_kwargs=None, per_key_padding_kwargs=None, apply_default_convert=True, nonpadded_stack=True)[source]
Bases:
objectCollate_fn when examples are dicts and have variable-length sequences.
Different elements in the examples get matched by key. All numpy tensors get converted to Torch (PyTorch default_convert) Then, by default, all torch.Tensor valued elements get padded and support collective pin_memory() and to() calls. Regular Python data types are just collected in a list.
- Parameters:
examples (list) – List of example dicts, as produced by Dataloader.
padded_keys (list, None) – (Optional) List of keys to pad on. If None, pad all torch.Tensors
device_prep_keys (list, None) – (Optional) Only these keys participate in collective memory pinning and moving with to(). If None, defaults to all items with torch.Tensor values.
padding_func (callable, optional) – Called with a list of tensors to be padded together. Needs to return two tensors: the padded data, and another tensor for the data lengths.
padding_kwargs (dict, None) – (Optional) Extra kwargs to pass to padding_func. E.G. mode, value This is used as the default padding configuration for all keys.
per_key_padding_kwargs (dict, None) – (Optional) Per-key padding configuration. Keys in this dict should match the keys in the examples. Each value should be a dict with padding parameters (e.g., {‘value’: -100, ‘mode’: ‘constant’}). If a key is not in this dict, the global padding_kwargs will be used.
apply_default_convert (bool) – Whether to apply PyTorch default_convert (numpy to torch recursively, etc.) on all data. Default:True, usually does the right thing.
nonpadded_stack (bool) – Whether to apply PyTorch-default_collate-like stacking on values that didn’t get padded. This stacks if it can, but doesn’t error out if it cannot. Default:True, usually does the right thing.
Example
>>> batch = PaddedBatch( ... [ ... {"id": "ex1", "foo": torch.Tensor([1.0])}, ... {"id": "ex2", "foo": torch.Tensor([2.0, 1.0])}, ... ] ... ) >>> # Attribute or key-based access: >>> batch.id ['ex1', 'ex2'] >>> batch["id"] ['ex1', 'ex2'] >>> # torch.Tensors get padded >>> type(batch.foo) <class 'speechbrain.dataio.batch.PaddedData'> >>> batch.foo.data tensor([[1., 0.], [2., 1.]]) >>> batch.foo.lengths tensor([0.5000, 1.0000]) >>> # Batch supports collective operations: >>> _ = batch.to(dtype=torch.half) >>> batch.foo.data tensor([[1., 0.], [2., 1.]], dtype=torch.float16) >>> batch.foo.lengths tensor([0.5000, 1.0000], dtype=torch.float16) >>> # Numpy tensors get converted to torch and padded as well: >>> import numpy as np >>> batch = PaddedBatch( ... [{"wav": np.asarray([1, 2, 3, 4])}, {"wav": np.asarray([1, 2, 3])}] ... ) >>> batch.wav # +ELLIPSIS PaddedData(data=tensor([[1, 2,... >>> # Basic stacking collation deals with non padded data: >>> batch = PaddedBatch( ... [ ... { ... "spk_id": torch.tensor([1]), ... "wav": torch.tensor([0.1, 0.0, 0.3]), ... }, ... { ... "spk_id": torch.tensor([2]), ... "wav": torch.tensor([0.2, 0.3, -0.1]), ... }, ... ], ... padded_keys=["wav"], ... ) >>> batch.spk_id tensor([[1], [2]]) >>> # And some data is left alone: >>> batch = PaddedBatch( ... [{"text": ["Hello"]}, {"text": ["How", "are", "you?"]}] ... ) >>> batch.text [['Hello'], ['How', 'are', 'you?']] >>> # Per-key padding configuration: >>> batch = PaddedBatch( ... [ ... { ... "wav": torch.tensor([1, 2, 3]), ... "labels": torch.tensor([1, 2]), ... }, ... {"wav": torch.tensor([4, 5]), "labels": torch.tensor([3])}, ... ], ... per_key_padding_kwargs={ ... "wav": {"value": 0}, ... "labels": {"value": -100}, ... }, ... ) >>> batch.wav.data tensor([[1, 2, 3], [4, 5, 0]]) >>> batch.labels.data tensor([[ 1, 2], [ 3, -100]])
- __iter__()[source]
Iterates over the different elements of the batch.
- Return type:
Iterator over the batch.
Example
>>> batch = PaddedBatch( ... [ ... {"id": "ex1", "val": torch.Tensor([1.0])}, ... {"id": "ex2", "val": torch.Tensor([2.0, 1.0])}, ... ] ... ) >>> ids, vals = batch >>> ids ['ex1', 'ex2']
- to(*args, **kwargs)[source]
In-place move/cast relevant elements.
Passes all arguments to torch.Tensor.to, see its documentation.
- property batchsize
Returns the bach size
- class speechbrain.dataio.batch.BatchsizeGuesser[source]
Bases:
objectTry to figure out the batchsize, but never error out
If this cannot figure out anything else, will fallback to guessing 1
Example
>>> guesser = BatchsizeGuesser() >>> # Works with simple tensors: >>> guesser(torch.randn((2, 3))) 2 >>> # Works with sequences of tensors: >>> guesser((torch.randn((2, 3)), torch.randint(high=5, size=(2,)))) 2 >>> # Works with PaddedBatch: >>> guesser( ... PaddedBatch([{"wav": [1.0, 2.0, 3.0]}, {"wav": [4.0, 5.0, 6.0]}]) ... ) 2 >>> guesser("Even weird non-batches have a fallback") 1