"""Frequency-Domain Sequential Data Augmentation Classes
This module comprises classes tailored for augmenting sequential data in the
frequency domain, such as spectrograms and mel spectrograms.
Its primary purpose is to enhance the resilience of neural models during the training process.
Authors:
- Peter Plantinga (2020)
- Mirco Ravanelli (2023)
"""
import random
import torch
[docs]
class SpectrogramDrop(torch.nn.Module):
"""This class drops slices of the input spectrogram.
Using `SpectrogramDrop` as an augmentation strategy helps a models learn to rely
on all parts of the signal, since it can't expect a given part to be
present.
Reference:
https://arxiv.org/abs/1904.08779
Arguments
---------
drop_length_low : int
The low end of lengths for which to drop the
spectrogram, in samples.
drop_length_high : int
The high end of lengths for which to drop the
signal, in samples.
drop_count_low : int
The low end of number of times that the signal
can be dropped.
drop_count_high : int
The high end of number of times that the signal
can be dropped.
replace: str
- 'zeros': Masked values are replaced with zeros.
- 'mean': Masked values are replaced with the mean value of the spectrogram.
- 'rand': Masked values are replaced with random numbers ranging between
the maximum and minimum values of the spectrogram.
- 'cutcat': Masked values are replaced with chunks from other signals in the batch.
- 'swap': Masked values are replaced with other chunks from the same sentence.
- 'random_selection': A random selection among the approaches above.
dim : int
Corresponding dimension to mask. If dim=1, we apply time masking.
If dim=2, we apply frequency masking.
Example
-------
>>> # time-masking
>>> drop = SpectrogramDrop(dim=1)
>>> spectrogram = torch.rand(4, 150, 40)
>>> print(spectrogram.shape)
torch.Size([4, 150, 40])
>>> out = drop(spectrogram)
>>> print(out.shape)
torch.Size([4, 150, 40])
>>> # frequency-masking
>>> drop = SpectrogramDrop(dim=2)
>>> spectrogram = torch.rand(4, 150, 40)
>>> print(spectrogram.shape)
torch.Size([4, 150, 40])
>>> out = drop(spectrogram)
>>> print(out.shape)
torch.Size([4, 150, 40])
"""
def __init__(
self,
drop_length_low=5,
drop_length_high=15,
drop_count_low=1,
drop_count_high=3,
replace="zeros",
dim=1,
):
super().__init__()
self.drop_length_low = drop_length_low
self.drop_length_high = drop_length_high
self.drop_count_low = drop_count_low
self.drop_count_high = drop_count_high
self.replace = replace
self.dim = dim
# Validate low < high
if drop_length_low > drop_length_high:
raise ValueError("Low limit must not be more than high limit")
if drop_count_low > drop_count_high:
raise ValueError("Low limit must not be more than high limit")
self.replace_opts = [
"zeros",
"mean",
"rand",
"cutcat",
"swap",
"random_selection",
]
if self.replace not in self.replace_opts:
raise ValueError(
f"Invalid 'replace' option. Select one of {', '.join(self.replace_opts)}"
)
[docs]
def forward(self, spectrogram):
"""
Apply the DropChunk augmentation to the input spectrogram.
This method randomly drops chunks of the input spectrogram to augment the data.
Arguments
---------
spectrogram : torch.Tensor
Input spectrogram of shape `[batch, time, fea]`.
Returns
-------
torch.Tensor
Augmented spectrogram of shape `[batch, time, fea]`.
"""
# Manage 4D tensors
if spectrogram.dim() == 4:
spectrogram = spectrogram.view(
-1, spectrogram.shape[2], spectrogram.shape[3]
)
# Get the batch size
batch_size, time_duration, fea_size = spectrogram.shape
# Managing masking dimensions
if self.dim == 1:
D = time_duration
else:
D = fea_size
# Randomly select the number of chunks to drop (same for all samples in the batch)
n_masks = torch.randint(
low=self.drop_count_low,
high=self.drop_count_high + 1,
size=(1,),
device=spectrogram.device,
)
# If the number of chunks to drop is 0, return the spectrogram unchanged
if n_masks == 0:
return spectrogram
# Randomly sample the lengths of the chunks to drop
mask_len = torch.randint(
low=self.drop_length_low,
high=self.drop_length_high,
size=(batch_size, n_masks),
device=spectrogram.device,
).unsqueeze(2)
# Randomly sample the positions of the chunks to drop
mask_pos = torch.randint(
0,
max(1, D, -mask_len.max()),
(batch_size, n_masks),
device=spectrogram.device,
).unsqueeze(2)
# Compute the mask for the selected chunk positions
arange = torch.arange(D, device=spectrogram.device).view(1, 1, -1)
mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len))
mask = mask.any(dim=1)
mask = mask.unsqueeze(2) if self.dim == 1 else mask.unsqueeze(1)
# Determine the value to replace the masked chunks (zero or mean of the spectrogram)
if self.replace == "random_selection":
self.replace = random.choice(self.replace_opts[:-1])
if self.replace == "zeros":
spectrogram = spectrogram.masked_fill_(mask, 0.0)
elif self.replace == "mean":
mean = spectrogram.mean().detach()
spectrogram = spectrogram.masked_fill_(mask, mean)
elif self.replace == "rand":
max_spectrogram = spectrogram.max().detach()
min_spectrogram = spectrogram.min().detach()
rand_spectrogram = torch.rand_like(spectrogram)
rand_spectrogram = (
rand_spectrogram * (max_spectrogram - min_spectrogram)
+ min_spectrogram
)
mask = mask.float()
spectrogram = (1 - mask) * spectrogram + mask * rand_spectrogram
elif self.replace == "cutcat":
rolled_spectrogram = torch.roll(spectrogram, shifts=1, dims=0)
mask = mask.float()
spectrogram = (1 - mask) * spectrogram + mask * rolled_spectrogram
elif self.replace == "swap":
shift = torch.randint(
low=1,
high=spectrogram.shape[1],
size=(1,),
device=spectrogram.device,
)
rolled_spectrogram = torch.roll(
spectrogram, shifts=shift.item(), dims=1
)
mask = mask.float()
spectrogram = (1 - mask) * spectrogram + mask * rolled_spectrogram
return spectrogram.view(*spectrogram.shape)
[docs]
class Warping(torch.nn.Module):
"""
Apply time or frequency warping to a spectrogram.
If `dim=1`, time warping is applied; if `dim=2`, frequency warping is applied.
This implementation selects a center and a window length to perform warping.
It ensures that the temporal dimension remains unchanged by upsampling or
downsampling the affected regions accordingly.
Reference:
https://arxiv.org/abs/1904.08779
Arguments
---------
warp_window : int, optional
The width of the warping window. Default is 5.
warp_mode : str, optional
The interpolation mode for time warping. Default is "bicubic."
dim : int, optional
Dimension along which to apply warping (1 for time, 2 for frequency).
Default is 1.
Example
-------
>>> # Time-warping
>>> warp = Warping()
>>> spectrogram = torch.rand(4, 150, 40)
>>> print(spectrogram.shape)
torch.Size([4, 150, 40])
>>> out = warp(spectrogram)
>>> print(out.shape)
torch.Size([4, 150, 40])
>>> # Frequency-warping
>>> warp = Warping(dim=2)
>>> spectrogram = torch.rand(4, 150, 40)
>>> print(spectrogram.shape)
torch.Size([4, 150, 40])
>>> out = warp(spectrogram)
>>> print(out.shape)
torch.Size([4, 150, 40])
"""
def __init__(self, warp_window=5, warp_mode="bicubic", dim=1):
super().__init__()
self.warp_window = warp_window
self.warp_mode = warp_mode
self.dim = dim
[docs]
def forward(self, spectrogram):
"""
Apply warping to the input spectrogram.
Arguments
---------
spectrogram : torch.Tensor
Input spectrogram with shape `[batch, time, fea]`.
Returns
-------
torch.Tensor
Augmented spectrogram with shape `[batch, time, fea]`.
"""
# Set warping dimension
if self.dim == 2:
spectrogram = spectrogram.transpose(1, 2)
original_size = spectrogram.shape
window = self.warp_window
# 2d interpolation requires 4D or higher dimension tensors
# x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
if spectrogram.dim() == 3:
spectrogram = spectrogram.unsqueeze(1)
len_original = spectrogram.shape[2]
if len_original - window <= window:
return spectrogram.view(*original_size)
# Compute center and corresponding window
c = torch.randint(window, len_original - window, (1,))[0]
w = torch.randint(c - window, c + window, (1,))[0] + 1
# Update the left part of the spectrogram
left = torch.nn.functional.interpolate(
spectrogram[:, :, :c],
(w, spectrogram.shape[3]),
mode=self.warp_mode,
align_corners=True,
)
# Update the right part of the spectrogram.
# When the left part is expanded, the right part is compressed by the
# same factor, and vice versa.
right = torch.nn.functional.interpolate(
spectrogram[:, :, c:],
(len_original - w, spectrogram.shape[3]),
mode=self.warp_mode,
align_corners=True,
)
# Injecting the warped left and right parts.
spectrogram[:, :, :w] = left
spectrogram[:, :, w:] = right
spectrogram = spectrogram.view(*original_size)
# Transpose if freq warping is applied.
if self.dim == 2:
spectrogram = spectrogram.transpose(1, 2)
return spectrogram
[docs]
class RandomShift(torch.nn.Module):
"""Shifts the input tensor by a random amount, allowing for either a time
or frequency (or channel) shift depending on the specified axis.
It is crucial to calibrate the minimum and maximum shifts according to the
requirements of your specific task.
We recommend using small shifts to preserve information integrity.
Using large shifts may result in the loss of significant data and could
potentially lead to misalignments with corresponding labels.
Arguments
---------
min_shift : int
The minimum channel shift.
max_shift : int
The maximum channel shift.
dim: int
The dimension to shift.
Example
-------
>>> # time shift
>>> signal = torch.zeros(4, 100, 80)
>>> signal[0,50,:] = 1
>>> rand_shift = RandomShift(dim=1, min_shift=-10, max_shift=10)
>>> lengths = torch.tensor([0.2, 0.8, 0.9,1.0])
>>> output_signal, lengths = rand_shift(signal,lengths)
>>> # frequency shift
>>> signal = torch.zeros(4, 100, 80)
>>> signal[0,:,40] = 1
>>> rand_shift = RandomShift(dim=2, min_shift=-10, max_shift=10)
>>> lengths = torch.tensor([0.2, 0.8, 0.9,1.0])
>>> output_signal, lengths = rand_shift(signal,lengths)
"""
def __init__(self, min_shift=0, max_shift=0, dim=1):
super().__init__()
self.min_shift = min_shift
self.max_shift = max_shift
self.dim = dim
# Check arguments
if self.max_shift < self.min_shift:
raise ValueError("max_shift must be >= min_shift")
[docs]
def forward(self, waveforms, lengths):
"""
Arguments
---------
waveforms : tensor
Shape should be `[batch, time]` or `[batch, time, channels]`.
lengths : tensor
Shape should be a single dimension, `[batch]`.
Returns
-------
Tensor of shape `[batch, time]` or `[batch, time, channels]`
"""
# Pick a frequency to drop
N_shifts = torch.randint(
low=self.min_shift,
high=self.max_shift + 1,
size=(1,),
device=waveforms.device,
)
waveforms = torch.roll(waveforms, shifts=N_shifts.item(), dims=self.dim)
# Update lengths in the case of temporal shift.
if self.dim == 1:
lengths = lengths + N_shifts / waveforms.shape[self.dim]
lengths = torch.clamp(lengths, min=0.0, max=1.0)
return waveforms, lengths