Source code for speechbrain.dataio.iterators

"""Webdataset compatible iterators

Authors:
 * Aku Rouhe 2021
"""
import bisect
import random
from dataclasses import dataclass, field
from functools import partial
from typing import Any
from speechbrain.dataio.batch import PaddedBatch


[docs] @dataclass(order=True) class LengthItem: """ Data class for lenghts""" length: int data: Any = field(compare=False)
[docs] def total_length_with_padding(lengths): """ Determines how long would batch be (with padding)""" return len(lengths) * max(lengths)
[docs] def padding_ratio(lengths): """ Determines how much of batch is padding.""" return 1.0 - sum(lengths) / total_length_with_padding(lengths)
[docs] @dataclass(order=True) class RatioIndex: "Data class for Ratio." ratio: float index: int
[docs] def indices_around_random_pivot( databuffer, target_batch_numel, max_batch_size=None, max_batch_numel=None, max_padding_ratio=0.2, randint_generator=random.randint, ): """Random pivot sampler_fn for dynamic_bucketed_batch Create a batch around a random pivot index in the sorted buffer This works on the databuffer which is assumed to be in sorted order. An index is chosen at random. This starts the window of indices: at first, only the randomly chosen pivot index is included. The window of indices is grown one-index-at-a-time, picking either the index to the right of the window, or the index to the left, picking the index that would increase the padding ratio the least, and making sure the batch wouldn't exceed the maximum batch length nor the maximum padding ratio. Arguments --------- databuffer : list Sorted list of LengthItems target_batch_numel : int Target of total batch length including padding, which is simply computed as batch size * length of longest example. This function aims to return the batch as soon as the gathered length exceeds this. If some limits are encountered first, this may not be satisifed. max_batch_size : None, int Maximum number of examples to include in the batch, or None to not limit by number of examples. max_batch_numel : None, int Maximum of total batch length including padding, which is simply computed as batch size * length of longest example. """ bufferlen = len(databuffer) if max_batch_size is None: max_batch_size = bufferlen # Choose pivot: min_index = max_index = randint_generator(0, bufferlen - 1) lengths = [databuffer[min_index].length] # Define index filtering function: def possibly_consider(index, to_consider): """Adds an index to the to_consider list, f the index passes all requirements.""" if index < 0 or index >= len(databuffer): return consideree = databuffer[index] updated_lengths = [consideree.length] + lengths if max_batch_numel is not None: updated_total = total_length_with_padding(updated_lengths) if updated_total > max_batch_numel: return updated_ratio = padding_ratio(updated_lengths) if max_padding_ratio is not None and updated_ratio > max_padding_ratio: return to_consider.append(RatioIndex(updated_ratio, index)) # Loop till the target length is exceeded or max batch size is hit: while ( max_index + 1 - min_index < max_batch_size and total_length_with_padding(lengths) < target_batch_numel ): # Consider indices to the left and to the right, if they # pass the requirements: to_consider = [] possibly_consider(min_index - 1, to_consider) possibly_consider(max_index + 1, to_consider) # If neither pass the requirements, then we must return the batch # as it is now (there can be no better addition): if not to_consider: break # Pick the index that minimizes the padding ratio increase: to_add = min(to_consider) min_index = min(min_index, to_add.index) max_index = max(max_index, to_add.index) lengths.append(databuffer[to_add.index].length) return list(range(min_index, max_index + 1))
[docs] def dynamic_bucketed_batch( data, len_key=None, len_fn=len, min_sample_len=None, max_sample_len=None, buffersize=1024, collate_fn=PaddedBatch, sampler_fn=indices_around_random_pivot, sampler_kwargs={}, drop_end=False, ): """Produce batches from a sorted buffer This function keeps a sorted buffer of the incoming samples. The samples can be filtered for min/max length. An external sampler is used to choose samples for each batch, which allows different dynamic batching algorithms to be used. Arguments --------- data : iterable An iterable source of samples, such as an IterableDataset. len_key : str, None The key in the sample dict to use to fetch the length of the sample, or None if no key should be used. len_fn : callable Called with sample[len_key] if len_key is not None, else sample. Needs to return the sample length as an integer. min_sample_len : int, None Discard samples with length lower than this. If None, no minimum is applied. max_sample_len : int, None Discard samples with length larger than this. If None, no maximum is applied. buffersize : int The size of the internal sorted buffer. The buffer is always filled up before yielding a batch of samples. collate_fn : callable Called with a list of samples. This should return a batch. By default, using the SpeechBrain PaddedBatch class, which works for dict-like samples, and pads any tensors. sampler_fn : callable Called with the sorted data buffer. Needs to return a list of indices, which make up the next batch. By default using ``indices_around_random_pivot`` sampler_kwargs : dict Keyword arguments, passed to sampler_fn. drop_end : bool After the data stream is exhausted, should batches be made until the data buffer is exhausted, or should the rest of the buffer be discarded. Without new samples, the last batches might not be efficient to process. Note: you can use ``.repeat`` on `webdataset` IterableDatasets to never run out of new samples, and then use `speechbrain.dataio.dataloader.LoopedLoader` to set a nominal epoch length. """ databuffer = [] if sampler_kwargs: sampler_fn = partial(sampler_fn, **sampler_kwargs) for sample in data: # Length fetching interface has multiple valid call signatures: if len_key is not None and len_fn is not None: length = len_fn(sample[len_key]) elif len_key is not None: length = sample[len_key] elif len_fn is not None: length = len_fn(sample) else: raise ValueError("Must specify at least one of len_key or len_fn") # Possibly filter by length: if (min_sample_len is not None and length < min_sample_len) or ( max_sample_len is not None and length > max_sample_len ): # Drop sample continue item = LengthItem(length, sample) # bisect.insort inserts in sorted order. # This should be a good way to maintain a sorted list, # but perhaps simply filling up the buffer and calling .sort() # could be good as well (Python's sort leverages already sorted segments) bisect.insort(databuffer, item) if len(databuffer) == buffersize: indices = sampler_fn(databuffer) batch_list = [] # popping from highest to lowest is safe for i in sorted(indices, reverse=True): item = databuffer.pop(i) batch_list.append(item.data) yield collate_fn(batch_list) # Data stream was exhausted. Data buffer is relatively full at first, # but cannot be replenished, so batches might not be efficiently produced. # Either stop, or exhaust buffer. if drop_end: return while databuffer: indices = sampler_fn(databuffer) batch_list = [] for i in sorted(indices, reverse=True): item = databuffer.pop(i) batch_list.append(item.data) yield collate_fn(batch_list)