Source code for speechbrain.augment.augmenter

"""Classes for implementing data augmentation pipelines.

Authors
 * Mirco Ravanelli 2022
"""
import random
import logging
import torch
import torch.nn.functional as F
from speechbrain.utils.callchains import lengths_arg_exists

logger = logging.getLogger(__name__)


[docs] class Augmenter(torch.nn.Module): """Applies pipelines of data augmentation. Arguments --------- parallel_augment: bool If False, the augmentations are applied sequentially with the order specified in the pipeline argument. When True, all the N augmentations are concatenated in the output on the batch axis. parallel_augment_fixed_bs: bool If False, each augmenter (performed in parallel) generates a number of augmented examples equal to the batch size. Thus, overall, with this option N*batch size artificial data are generated, where N is the number of augmenters. When True, the number of total augmented examples is kept fixed at the batch size, thus, for each augmenter, fixed at batch size // N examples. This option is useful to keep controlled the number of synthetic examples with respect to the original data distribution, as it keep always 50% of original data, and 50% of augmented data. concat_original: bool if True, the original input is concatenated with the augmented outputs (on the batch axis). min_augmentations: int The number of augmentations applied to the input signal is randomly sampled between min_augmentations and max_augmentations. For instance, if the augmentation dict contains N=6 augmentations and we set select min_augmentations=1 and max_augmentations=4 we apply up to M=4 augmentations. The selected augmentations are applied in the order specified in the augmentations dict. If shuffle_augmentations = True, a random set of M augmentations is selected. max_augmentations: int Maximum number of augmentations to apply. See min_augmentations for more details. shuffle_augmentations: bool If True, it shuffles the entries of the augmentations dictionary. The effect is to randomply select the order of the augmentations to apply. repeat_augment: int Applies the augmentation algorithm N times. This can be used to perform more data augmentation. augment_start_index: int The index of the first element in the input batch from which data augmentation should begin. This argument allows you to specify the starting point for applying data augmentation. augment_end_index: int The index of the last element in the input batch at which data augmentation should stop. You can use this argument to define the endpoint for applying data augmentation within the batch. concat_start_index: int If `concat_original` is set to True, you can specify a subpart of the original batch to concatenate in the output. Use this argument to select the index of the first element from the original input batch to start copying from. concat_end_index: int If `concat_original` is set to True, you can specify a subpart of the original batch to concatenate in the output. Use this argument to select the index of the last element from the original input batch to end the copying process. augment_prob: float The probability (0.0 to 1.0) of applying data augmentation. When set to 0.0, the original signal is returned without any augmentation. When set to 1.0, augmentation is always applied. Values in between determine the likelihood of augmentation. augmentations: list List of augmentater objects to combine to perform data augmentation. enable_augmentations: list A list of booleans used to selectively enable or disable specific augmentation techniques within the 'augmentations' list. Each boolean corresponds to an augmentation object in the 'augmentations' list and should be of the same length and order. This feature is useful for performing ablations on augmentation techniques to tailor them for a specific task. Example ------- >>> from speechbrain.augment.time_domain import DropFreq, DropChunk >>> freq_dropper = DropFreq() >>> chunk_dropper = DropChunk(drop_start=100, drop_end=16000) >>> augment = Augmenter(parallel_augment=False, concat_original=False, augmentations=[freq_dropper, chunk_dropper]) >>> signal = torch.rand([4, 16000]) >>> output_signal, lenghts = augment(signal, lengths=torch.tensor([0.2,0.5,0.7,1.0])) """ def __init__( self, parallel_augment=False, parallel_augment_fixed_bs=False, concat_original=False, min_augmentations=None, max_augmentations=None, shuffle_augmentations=False, repeat_augment=1, augment_start_index=0, augment_end_index=None, concat_start_index=0, concat_end_index=None, augment_prob=1.0, augmentations=list(), enable_augmentations=None, ): super().__init__() self.parallel_augment = parallel_augment self.parallel_augment_fixed_bs = parallel_augment_fixed_bs self.concat_original = concat_original self.augmentations = augmentations self.min_augmentations = min_augmentations self.max_augmentations = max_augmentations self.shuffle_augmentations = shuffle_augmentations self.augment_start_index = augment_start_index self.augment_end_index = augment_end_index self.concat_start_index = concat_start_index self.concat_end_index = concat_end_index self.repeat_augment = repeat_augment self.augment_prob = augment_prob # Check min and max augmentations self.check_min_max_augmentations() # This variable represents the total number of augmentations to perform for each signal, # including the original signal in the count. self.num_augmentations = None self.do_augment = True # Check repeat augment arguments if not isinstance(self.repeat_augment, int): raise ValueError("repeat_augment must be an integer.") if self.repeat_augment < 0: raise ValueError("repeat_augment must be greater than 0.") if self.augment_end_index is not None: if self.augment_end_index < self.augment_start_index: raise ValueError( "augment_end_index must be smaller or equal to augment_start_index." ) if self.concat_end_index is not None: if self.concat_end_index < self.concat_start_index: raise ValueError( "concat_end_index must be smaller or equal to concat_start_index." ) # Managing enable augmentations if enable_augmentations is None: enable_augmentations = [True] * len(augmentations) elif not isinstance(enable_augmentations, list): raise ValueError("enable_augmentations must be a list.") elif len(enable_augmentations) != len(augmentations): raise ValueError( "enable_augmentations must have the same length as augmentations." ) else: augmentations = [ aug for aug, enabled in zip(augmentations, enable_augmentations) if enabled ] # Turn augmentations into a dictionary self.augmentations = { augmentation.__class__.__name__ + str(i): augmentation for i, augmentation in enumerate(augmentations) } if len(self.augmentations) == 0: logger.warning( "No augmentation is applied because the augmentation list is empty." ) # Check min and max augmentations if self.max_augmentations <= 0: logger.warning( "No augmentations applied because max_augmentations is non-positive." ) if self.min_augmentations < 0: self.min_augmentations = 0 logger.warning( "min_augmentations is negative. Modified to be non-negative." ) if self.min_augmentations > self.max_augmentations: logger.warning( "min_augmentations is greater than max_augmentations. min_augmentations set to max_augmentations." ) self.max_augmentations = self.min_augmentations # Check if augmentation modules need the length argument self.require_lengths = {} for aug_key, aug_fun in self.augmentations.items(): self.require_lengths[aug_key] = lengths_arg_exists(aug_fun.forward)
[docs] def augment(self, x, lengths, selected_augmentations): """Applies data augmentation on the seleted augmentations. Arguments --------- x : torch.Tensor (batch, time, channel) input to augment. lengths : torch.Tensor The length of each sequence in the batch. selected_augmentations: dict Dictionary containg the selected augmentation to apply. """ next_input = x next_lengths = lengths output = [] output_lengths = [] out_lengths = lengths for k, augment_name in enumerate(selected_augmentations): augment_fun = self.augmentations[augment_name] idx = torch.arange(x.shape[0]) if self.parallel_augment and self.parallel_augment_fixed_bs: idx_startstop = torch.linspace( 0, x.shape[0], len(selected_augmentations) + 1 ).to(torch.int) idx_start = idx_startstop[k] idx_stop = idx_startstop[k + 1] idx = idx[idx_start:idx_stop] # Check input arguments if self.require_lengths[augment_name]: out = augment_fun( next_input[idx, ...], lengths=next_lengths[idx] ) else: out = augment_fun(next_input[idx, ...]) # Check output arguments if isinstance(out, tuple): if len(out) == 2: out, out_lengths = out else: raise ValueError( "The function must return max two arguments (Tensor, Length[optional])" ) # Manage sequential or parallel augmentation if not self.parallel_augment: next_input = out next_lengths = out_lengths[idx] else: output.append(out) output_lengths.append(out_lengths) if self.parallel_augment: # Concatenate all the augmented data output, output_lengths = self.concatenate_outputs( output, output_lengths ) else: # Take the last agumented signal of the pipeline output = out output_lengths = out_lengths return output, output_lengths
[docs] def forward(self, x, lengths): """Applies data augmentation. Arguments --------- x : torch.Tensor (batch, time, channel) input to augment. lengths : torch.Tensor The length of each sequence in the batch. """ # Determine whether to apply data augmentation self.do_augment = True if random.random() > self.augment_prob: self.do_augment = False return x, lengths x_original = x len_original = lengths # Determine the ending index for augmentation, considering user-specified or default values. self.augment_end_index_batch = ( min(self.augment_end_index, x.shape[0]) if self.augment_end_index is not None else x.shape[0] ) # If the augmentation starting index is beyond the size of the data, return the original data. if self.augment_start_index >= x.shape[0]: self.do_augment = False logger.warning( "No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch." ) return x, lengths # Select the number of augmentations to apply self.N_augment = torch.randint( low=self.min_augmentations, high=self.max_augmentations + 1, size=(1,), device=x.device, ) # Get augmentations list augmentations_lst = list(self.augmentations.keys()) # No augmentation if ( self.repeat_augment == 0 or self.N_augment == 0 or len(augmentations_lst) == 0 ): self.do_augment = False return x, lengths # Shuffle augmentation if self.shuffle_augmentations: random.shuffle(augmentations_lst) # Select the augmentations to apply selected_augmentations = augmentations_lst[0 : self.N_augment] # Select the portion of the input to augment and update lengths accordingly. x = x[self.augment_start_index : self.augment_end_index_batch] lengths = lengths[ self.augment_start_index : self.augment_end_index_batch ] # Lists to collect the outputs output_lst = [] output_len_lst = [] # Concatenate the original signal if required self.skip_concat = not (self.concat_original) if self.concat_original: # Check start index if self.concat_start_index >= x.shape[0]: self.skip_concat = True pass else: self.skip_concat = False # Determine the ending index for concatenation, considering user-specified or default values. self.concat_end_index_batch = ( min(self.concat_end_index, x_original.shape[0]) if self.concat_end_index is not None else x_original.shape[0] ) output_lst.append( x_original[ self.concat_start_index : self.concat_end_index_batch ] ) output_len_lst.append( len_original[ self.concat_start_index : self.concat_end_index_batch ] ) # Perform augmentations for i in range(self.repeat_augment): output, output_lengths = self.augment( x, lengths, selected_augmentations ) output_lst.append(output) output_len_lst.append(output_lengths) # Concatenate the final outputs while handling scenarios where # different temporal dimensions may arise due to augmentations # like speed change. output, output_lengths = self.concatenate_outputs( output_lst, output_len_lst ) return output, output_lengths
[docs] def concatenate_outputs(self, augment_lst, augment_len_lst): """ Concatenate a list of augmented signals, accounting for varying temporal lengths. Padding is applied to ensure all signals can be concatenated. Arguments --------- augmentations : List of torch.Tensor List of augmented signals to be concatenated. augmentation_lengths : List of torch.Tensor List of lengths corresponding to the augmented signals. Returns ------- concatenated_signals : torch.Tensor A tensor containing the concatenated signals. concatenated_lengths : torch.Tensor A tensor containing the concatenated signal lengths. Notes ----- This function takes a list of augmented signals, which may have different temporal lengths due to variations such as speed changes. It pads the signals to match the maximum temporal dimension found among the input signals and rescales the lengths accordingly before concatenating them. """ # Find the maximum temporal dimension (batch length) among the sequences max_len = max(augment.shape[1] for augment in augment_lst) # Rescale the sequence lengths to adjust for augmented batches with different temporal dimensions. augment_len_lst = [ length * (output.shape[1] / max_len) for length, output in zip(augment_len_lst, augment_lst) ] # Pad sequences to match the maximum temporal dimension. # Note that some augmented batches, like those with speed changes, may have different temporal dimensions. augment_lst = [ F.pad(output, (0, max_len - output.shape[1])) for output in augment_lst ] # Concatenate the padded sequences and rescaled lengths output = torch.cat(augment_lst, dim=0) output_lengths = torch.cat(augment_len_lst, dim=0) return output, output_lengths
[docs] def replicate_multiple_labels(self, *args): """ Replicates the labels along the batch axis a number of times that corresponds to the number of augmentations. Indeed parallel and concatenation augmentations alter the time dimension. Arguments --------- args : torch.Tensor Input label tensors to be replicated. Can be a uniq or a list of Tensors. Returns ------- augmented_labels: torch.Tensor Labels corresponding to the augmented input. Returns as many Tensor as given in input. """ # Determine whether to apply data augmentation if not self.do_augment: return args list_of_augmented_labels = [] for labels in args: list_of_augmented_labels.append(self.replicate_labels(labels)) return list_of_augmented_labels
[docs] def replicate_labels(self, labels): """ Replicates the labels along the batch axis a number of times that corresponds to the number of augmentations. Indeed parallel and concatenation augmentations alter the time dimension. Arguments --------- labels : torch.Tensor Input label tensors to be replicated. Returns ------- augmented_labels: torch.Tensor Labels corresponding to the augmented input. Returns as many Tensor as given in input. """ # Determine whether to apply data augmentation if not self.do_augment: return labels augmented_labels = [] if self.concat_original and not (self.skip_concat): augmented_labels = [ labels[self.concat_start_index : self.concat_end_index_batch] ] selected_labels = labels[ self.augment_start_index : self.augment_end_index_batch ] if self.parallel_augment: selected_labels = torch.cat( [selected_labels] * self.N_augment, dim=0 ) augmented_labels = ( augmented_labels + [selected_labels] * self.repeat_augment ) augmented_labels = torch.cat(augmented_labels, dim=0) return augmented_labels
[docs] def check_min_max_augmentations(self): """Checks the min_augmentations and max_augmentations arguments. """ if self.min_augmentations is None: self.min_augmentations = 1 if self.max_augmentations is None: self.max_augmentations = len(self.augmentations) if self.max_augmentations > len(self.augmentations): self.max_augmentations = len(self.augmentations) if self.min_augmentations > len(self.augmentations): self.min_augmentations = len(self.augmentations)