"""Few components to support BEST RQ training as described in the
original paper: https://arxiv.org/pdf/2202.01855.
Authors
* Ryan Whetten 2024
* Titouan Parcollet 2025
"""
import random
import torch
from speechbrain.utils.data_utils import batch_pad_right
[docs]
def compute_mask(shape, sample_lens, mask_prob, mask_length):
"""This function generates the masks of BEST-RQ.
It generates a unique mask for the whole batch and based on the shorter utte
rance. This is important as it may alter the training if the batch contains
one small sentence and many large ones as only few frames will be masked.
In particular, out of the smaller length passed to sample_lens, we will
generate N masks with N = mask_prob * smallest_len. Hence, mask_prob is
the probability for a frame to start a mask, and not to be masked.
If a sentence length is 100 time steps, a mask_prob of 0.15 and a mask size
of 4 would results in 100*0.15*4=60% of the frames being masked.
Arguments
---------
shape: tuple
The shape of the input tensor to be masked. Usually (Batch, Time, Fea).
sample_lens: list
List of int corresponding to the number of frames of each sample in the
batch. E.g. (12,13,14,20)
mask_prob: float
Probability for a frame to spawn a mask. Frames already masked cannot
spawn new masks.
mask_length: int
Number of frames covered by a mask.
Returns
-------
The computed mask
Example
-------
>>> compute_mask((2,50,60), [40, 50], 0.15, 2).shape
torch.Size([12])
"""
min_sample_len = min(sample_lens)
# int always floors the float number so adding + random.random()
# makes it 50% change of rounding up and 50% of rounding down
num_mask = int(mask_prob * min_sample_len + random.random())
# make sure there is at least 1 mask
if num_mask == 0:
num_mask = 1
permutation = torch.randperm(min_sample_len // mask_length) * mask_length
selected_indices = permutation[:num_mask]
selected_indices, _ = selected_indices.sort()
idx = []
for i in selected_indices:
idx.append(torch.arange(start=i, end=i + mask_length))
idx = torch.cat(idx)
return idx
[docs]
def brq_mask_collate_fn(
samples_lst, get_out_len_fn, mask_prob, mask_length, n_mels
):
"""This creates a batch from a list of samples and also creates
the mask that will be used to mask the inputs of BEST-RQ.
To create the mask we need to know the output shape after the
latent extractor, therefore the argument `get_out_len_fn`.
One could also create masks per sample (when loading the audio file) and
then collate them but at that time one doesn't know the length of the
shortest sample in the batch (which determines the number of masked frames)
so it's better this way.
Arguments
---------
samples_lst : list
List of samples returned by the audio_pipeline.
get_out_len_fn : function
Function that calculates length of sample after it passes through feature extractor.
mask_prob : float
Probability for a frame to spawn a mask. Frames already masked cannot
spawn new masks.
mask_length : int
Number of contiguous frames that will be masked.
n_mels : int
Number of Mels filterbanks in the last dimension of the input tensor.
Returns
-------
wavs_padded : torch.Tensor, shape (B, T)
Audio arrays with right-sided padding.
wav_lens : torch.Tensor, shape (B,)
For each sample the percentage of the array that is not padding.
mask : torch.Tensor, shape (T)
Mask with the indices to be masked in the input tensor.
"""
wav_lst, latent_length_lst = [], []
ids = []
for sample in samples_lst:
ids.append(sample["id"])
sig = sample["sig"]
wav_lst.append(sig)
latent_length = get_out_len_fn(torch.as_tensor(sig.size(-1)))
latent_length_lst.append(latent_length.item())
bs = len(wav_lst)
wavs_padded, wav_lens = batch_pad_right(wav_lst)
batch_time_len = max(latent_length_lst)
mask = compute_mask(
(bs, batch_time_len, n_mels), latent_length_lst, mask_prob, mask_length
)
return (
torch.as_tensor(wavs_padded),
torch.as_tensor(wav_lens),
torch.as_tensor(mask),
)