speechbrain.lobes.models.BESTRQ module

Few components to support BEST RQ training as described in the original paper: https://arxiv.org/pdf/2202.01855.

Authors * Ryan Whetten 2024 * Titouan Parcollet 2025

Summary

Functions:

brq_mask_collate_fn

This creates a batch from a list of samples and also creates the mask that will be used to mask the inputs of BEST-RQ.

compute_mask

This function generates the masks of BEST-RQ.

Reference

speechbrain.lobes.models.BESTRQ.compute_mask(shape, sample_lens, mask_prob, mask_length)[source]

This function generates the masks of BEST-RQ.

It generates a unique mask for the whole batch and based on the shorter utte rance. This is important as it may alter the training if the batch contains one small sentence and many large ones as only few frames will be masked.

In particular, out of the smaller length passed to sample_lens, we will generate N masks with N = mask_prob * smallest_len. Hence, mask_prob is the probability for a frame to start a mask, and not to be masked.

If a sentence length is 100 time steps, a mask_prob of 0.15 and a mask size of 4 would results in 100*0.15*4=60% of the frames being masked.

Parameters:
  • shape (tuple) – The shape of the input tensor to be masked. Usually (Batch, Time, Fea).

  • sample_lens (list) – List of int corresponding to the number of frames of each sample in the batch. E.g. (12,13,14,20)

  • mask_prob (float) – Probability for a frame to spawn a mask. Frames already masked cannot spawn new masks.

  • mask_length (int) – Number of frames covered by a mask.

Return type:

The computed mask

Example

>>> compute_mask((2,50,60), [40, 50], 0.15, 2).shape
torch.Size([12])
speechbrain.lobes.models.BESTRQ.brq_mask_collate_fn(samples_lst, get_out_len_fn, mask_prob, mask_length, n_mels)[source]

This creates a batch from a list of samples and also creates the mask that will be used to mask the inputs of BEST-RQ. To create the mask we need to know the output shape after the latent extractor, therefore the argument get_out_len_fn. One could also create masks per sample (when loading the audio file) and then collate them but at that time one doesn’t know the length of the shortest sample in the batch (which determines the number of masked frames) so it’s better this way.

Parameters:
  • samples_lst (list) – List of samples returned by the audio_pipeline.

  • get_out_len_fn (function) – Function that calculates length of sample after it passes through feature extractor.

  • mask_prob (float) – Probability for a frame to spawn a mask. Frames already masked cannot spawn new masks.

  • mask_length (int) – Number of contiguous frames that will be masked.

  • n_mels (int) – Number of Mels filterbanks in the last dimension of the input tensor.

Returns:

  • wavs_padded (torch.Tensor, shape (B, T)) – Audio arrays with right-sided padding.

  • wav_lens (torch.Tensor, shape (B,)) – For each sample the percentage of the array that is not padding.

  • mask (torch.Tensor, shape (T)) – Mask with the indices to be masked in the input tensor.