speechbrain.lobes.augment module

Combinations of processing algorithms to implement common augmentations.

Examples:
  • SpecAugment

  • Environmental corruption (noise, reverberation)

Authors
  • Peter Plantinga 2020

  • Jianyuan Zhong 2020

Summary

Classes:

EnvCorrupt

Environmental Corruptions for speech signals: noise, reverb, babble.

SpecAugment

An implementation of SpecAugment algorithm.

TimeDomainSpecAugment

A time-domain approximation of the SpecAugment algorithm.

Reference

class speechbrain.lobes.augment.SpecAugment(time_warp=True, time_warp_window=5, time_warp_mode='bicubic', freq_mask=True, freq_mask_width=(0, 20), n_freq_mask=2, time_mask=True, time_mask_width=(0, 100), n_time_mask=2, replace_with_zero=True)[source]

Bases: torch.nn.modules.module.Module

An implementation of SpecAugment algorithm.

Reference:

https://arxiv.org/abs/1904.08779

Parameters
  • time_warp (bool) – Whether applying time warping.

  • time_warp_window (int) – Time warp window.

  • time_warp_mode (str) – Interpolation mode for time warping (default “bicubic”).

  • freq_mask (bool1) – Whether applying freq mask.

  • freq_mask_width (int or tuple) – Freq mask width range.

  • n_freq_mask (int) – Number of freq mask.

  • time_mask (int) – Whether applying time mask.

  • time_mask_width (int or tuple) – Time mask width range.

  • n_time_mask (int) – Number of time mask.

  • replace_with_zero (bool) – If True, replace masked value with 0, else replace masked value with mean of the input tensor.

Example

>>> aug = SpecAugment()
>>> a = torch.rand([8, 120, 80])
>>> a = aug(a)
>>> print(a.shape)
torch.Size([8, 120, 80])
forward(x)[source]
time_warp(x)[source]

Time warping with torch.nn.functional.interpolate

mask_along_axis(x, dim)[source]

Mask along time or frequency axis.

Parameters
  • x (tensor) – Input tensor.

  • dim (int) – Corresponding dimension to mask.

training: bool
class speechbrain.lobes.augment.TimeDomainSpecAugment(perturb_prob=1.0, drop_freq_prob=1.0, drop_chunk_prob=1.0, speeds=[95, 100, 105], sample_rate=16000, drop_freq_count_low=0, drop_freq_count_high=3, drop_chunk_count_low=0, drop_chunk_count_high=5, drop_chunk_length_low=1000, drop_chunk_length_high=2000, drop_chunk_noise_factor=0)[source]

Bases: torch.nn.modules.module.Module

A time-domain approximation of the SpecAugment algorithm.

This augmentation module implements three augmentations in the time-domain.

  1. Drop chunks of the audio (zero amplitude or white noise)

  2. Drop frequency bands (with band-drop filters)

  3. Speed peturbation (via resampling to slightly different rate)

Parameters
  • perturb_prob (float from 0 to 1) – The probability that a batch will have speed perturbation applied.

  • drop_freq_prob (float from 0 to 1) – The probability that a batch will have frequencies dropped.

  • drop_chunk_prob (float from 0 to 1) – The probability that a batch will have chunks dropped.

  • speeds (list of ints) – A set of different speeds to use to perturb each batch. See speechbrain.processing.speech_augmentation.SpeedPerturb

  • sample_rate (int) – Sampling rate of the input waveforms.

  • drop_freq_count_low (int) – Lowest number of frequencies that could be dropped.

  • drop_freq_count_high (int) – Highest number of frequencies that could be dropped.

  • drop_chunk_count_low (int) – Lowest number of chunks that could be dropped.

  • drop_chunk_count_high (int) – Highest number of chunks that could be dropped.

  • drop_chunk_length_low (int) – Lowest length of chunks that could be dropped.

  • drop_chunk_length_high (int) – Highest length of chunks that could be dropped.

  • drop_chunk_noise_factor (float) – The noise factor used to scale the white noise inserted, relative to the average amplitude of the utterance. Default 0 (no noise inserted).

Example

>>> inputs = torch.randn([10, 16000])
>>> feature_maker = TimeDomainSpecAugment(speeds=[80])
>>> feats = feature_maker(inputs, torch.ones(10))
>>> feats.shape
torch.Size([10, 12800])
forward(waveforms, lengths)[source]

Returns the distorted waveforms.

Parameters

waveforms (torch.Tensor) – The waveforms to distort

training: bool
class speechbrain.lobes.augment.EnvCorrupt(reverb_prob=1.0, babble_prob=1.0, noise_prob=1.0, openrir_folder=None, openrir_max_noise_len=None, reverb_csv=None, noise_csv=None, noise_num_workers=0, babble_speaker_count=0, babble_snr_low=0, babble_snr_high=0, noise_snr_low=0, noise_snr_high=0, rir_scale_factor=1.0)[source]

Bases: torch.nn.modules.module.Module

Environmental Corruptions for speech signals: noise, reverb, babble.

Parameters
  • reverb_prob (float from 0 to 1) – The probability that each batch will have reverberation applied.

  • babble_prob (float from 0 to 1) – The probability that each batch will have babble added.

  • noise_prob (float from 0 to 1) – The probability that each batch will have noise added.

  • openrir_folder (str) – If provided, download and prepare openrir to this location. The reverberation csv and noise csv will come from here unless overridden by the reverb_csv or noise_csv arguments.

  • openrir_max_noise_len (float) – The maximum length in seconds for a noise segment from openrir. Only takes effect if openrir_folder is used for noises. Cuts longer noises into segments equal to or less than this length.

  • reverb_csv (str) – A prepared csv file for loading room impulse responses.

  • noise_csv (str) – A prepared csv file for loading noise data.

  • noise_num_workers (int) – Number of workers to use for loading noises.

  • babble_speaker_count (int) – Number of speakers to use for babble. Must be less than batch size.

  • babble_snr_low (int) – Lowest generated SNR of reverbed signal to babble.

  • babble_snr_high (int) – Highest generated SNR of reverbed signal to babble.

  • noise_snr_low (int) – Lowest generated SNR of babbled signal to noise.

  • noise_snr_high (int) – Highest generated SNR of babbled signal to noise.

  • rir_scale_factor (float) – It compresses or dilates the given impulse response. If 0 < rir_scale_factor < 1, the impulse response is compressed (less reverb), while if rir_scale_factor > 1 it is dilated (more reverb).

Example

>>> inputs = torch.randn([10, 16000])
>>> corrupter = EnvCorrupt(babble_speaker_count=9)
>>> feats = corrupter(inputs, torch.ones(10))
forward(waveforms, lengths)[source]

Returns the distorted waveforms.

Parameters

waveforms (torch.Tensor) – The waveforms to distort.

training: bool