speechbrain.processing.features module

Low-level feature pipeline components

This library gathers functions that compute popular speech features over batches of data. All the classes are of type nn.Module. This gives the possibility to have end-to-end differentiability and to backpropagate the gradient through them. Our functions are a modified version the ones in torch audio toolkit (https://github.com/pytorch/audio).

Example

>>> import torch
>>> from speechbrain.dataio.dataio import read_audio
>>> signal =read_audio('tests/samples/single-mic/example1.wav')
>>> signal = signal.unsqueeze(0)
>>> compute_STFT = STFT(
...     sample_rate=16000, win_length=25, hop_length=10, n_fft=400
... )
>>> features = compute_STFT(signal)
>>> features = spectral_magnitude(features)
>>> compute_fbanks = Filterbank(n_mels=40)
>>> features = compute_fbanks(features)
>>> compute_mfccs = DCT(input_size=40, n_out=20)
>>> features = compute_mfccs(features)
>>> compute_deltas = Deltas(input_size=20)
>>> delta1 = compute_deltas(features)
>>> delta2 = compute_deltas(delta1)
>>> features = torch.cat([features, delta1, delta2], dim=2)
>>> compute_cw = ContextWindow(left_frames=5, right_frames=5)
>>> features  = compute_cw(features)
>>> norm = InputNormalization()
>>> features = norm(features, torch.tensor([1]).float())
Authors
  • Mirco Ravanelli 2020

  • Peter Plantinga 2025

  • Rogier van Dalen 2025

Summary

Classes:

ContextWindow

Computes the context window.

DCT

Computes the discrete cosine transform.

Deltas

Computes delta coefficients (time derivatives).

DynamicRangeCompression

Dynamic range compression for audio signals - clipped log scale with an optional multiplier

Filterbank

computes filter bank (FBANK) features given spectral magnitudes.

GlobalNorm

A global normalization module - computes a single mean and standard deviation for the entire batch across unmasked positions and uses it to normalize the inputs to the desired mean and standard deviation.

ISTFT

Computes the Inverse Short-Term Fourier Transform (ISTFT)

InputNormalization

Performs mean and variance normalization over the time and possibly the (global) batch dimension of the input.

MinLevelNorm

A commonly used normalization for the decibel scale

STFT

computes the Short-Term Fourier Transform (STFT).

Functions:

combine_gaussian_statistics

Combine the first- and second-order moments from two pieces of data.

combine_gaussian_statistics_distributed

Combine the first- and second-order moments from multiple pieces of data using torch.distributed.

gaussian_statistics

Compute first- and second-order moments of data, and return them as the count, mean, and variance of a vector over one or more dimensions.

make_padding_mask

Create a mask from relative lengths along a given dimension.

mean_std_update

Update the mean and variance statistics run_mean and run_std that have been computed on run_count samples to integrate the new samples x.

spectral_magnitude

Returns the magnitude of a complex spectrogram.

Reference

class speechbrain.processing.features.STFT(sample_rate, win_length=25, hop_length=10, n_fft=400, window_fn=<built-in method hamming_window of type object>, normalized_stft=False, center=True, pad_mode='constant', onesided=True)[source]

Bases: Module

computes the Short-Term Fourier Transform (STFT).

This class computes the Short-Term Fourier Transform of an audio signal. It supports multi-channel audio inputs (batch, time, channels).

Parameters:
  • sample_rate (int) – Sample rate of the input audio signal (e.g 16000).

  • win_length (float) – Length (in ms) of the sliding window used to compute the STFT.

  • hop_length (float) – Length (in ms) of the hope of the sliding window used to compute the STFT.

  • n_fft (int) – Number of fft point of the STFT. It defines the frequency resolution (n_fft should be <= than win_len).

  • window_fn (function) – A function that takes an integer (number of samples) and outputs a tensor to be multiplied with each window before fft.

  • normalized_stft (bool) – If True, the function returns the normalized STFT results, i.e., multiplied by win_length^-0.5 (default is False).

  • center (bool) – If True (default), the input will be padded on both sides so that the t-th frame is centered at time tΓ—hop_length. Otherwise, the t-th frame begins at time tΓ—hop_length.

  • pad_mode (str) – It can be β€˜constant’,’reflect’,’replicate’, β€˜circular’, β€˜reflect’ (default). β€˜constant’ pads the input tensor boundaries with a constant value. β€˜reflect’ pads the input tensor using the reflection of the input boundary. β€˜replicate’ pads the input tensor using replication of the input boundary. β€˜circular’ pads using circular replication.

  • onesided (True) – If True (default) only returns nfft/2 values. Note that the other samples are redundant due to the Fourier transform conjugate symmetry.

Example

>>> import torch
>>> compute_STFT = STFT(
...     sample_rate=16000, win_length=25, hop_length=10, n_fft=400
... )
>>> inputs = torch.randn([10, 16000])
>>> features = compute_STFT(inputs)
>>> features.shape
torch.Size([10, 101, 201, 2])
forward(x)[source]

Returns the STFT generated from the input waveforms.

Parameters:

x (torch.Tensor) – A batch of audio signals to transform.

Returns:

stft

Return type:

torch.Tensor

get_filter_properties() FilterProperties[source]
class speechbrain.processing.features.ISTFT(sample_rate, n_fft=None, win_length=25, hop_length=10, window_fn=<built-in method hamming_window of type object>, normalized_stft=False, center=True, onesided=True, epsilon=1e-12)[source]

Bases: Module

Computes the Inverse Short-Term Fourier Transform (ISTFT)

This class computes the Inverse Short-Term Fourier Transform of an audio signal. It supports multi-channel audio inputs (batch, time_step, n_fft, 2, n_channels [optional]).

Parameters:
  • sample_rate (int) – Sample rate of the input audio signal (e.g. 16000).

  • n_fft (int) – Number of points in FFT.

  • win_length (float) – Length (in ms) of the sliding window used when computing the STFT.

  • hop_length (float) – Length (in ms) of the hope of the sliding window used when computing the STFT.

  • window_fn (function) – A function that takes an integer (number of samples) and outputs a tensor to be used as a window for ifft.

  • normalized_stft (bool) – If True, the function assumes that it’s working with the normalized STFT results. (default is False)

  • center (bool) – If True (default), the function assumes that the STFT result was padded on both sides.

  • onesided (True) – If True (default), the function assumes that there are n_fft/2 values for each time frame of the STFT.

  • epsilon (float) – A small value to avoid division by 0 when normalizing by the sum of the squared window. Playing with it can fix some abnormalities at the beginning and at the end of the reconstructed signal. The default value of epsilon is 1e-12.

Example

>>> import torch
>>> compute_STFT = STFT(
...     sample_rate=16000, win_length=25, hop_length=10, n_fft=400
... )
>>> compute_ISTFT = ISTFT(
...     sample_rate=16000, win_length=25, hop_length=10
... )
>>> inputs = torch.randn([10, 16000])
>>> outputs = compute_ISTFT(compute_STFT(inputs))
>>> outputs.shape
torch.Size([10, 16000])
forward(x, sig_length=None)[source]

Returns the ISTFT generated from the input signal.

Parameters:
  • x (torch.Tensor) – A batch of audio signals in the frequency domain to transform.

  • sig_length (int) – The length of the output signal in number of samples. If not specified will be equal to: (time_step - 1) * hop_length + n_fft

Returns:

istft

Return type:

torch.Tensor

speechbrain.processing.features.spectral_magnitude(stft, power: float = 1, log: bool = False, eps: float = 1e-14)[source]

Returns the magnitude of a complex spectrogram.

Parameters:
  • stft (torch.Tensor) – A tensor, output from the stft function.

  • power (int) – What power to use in computing the magnitude. Use power=1 for the power spectrogram. Use power=0.5 for the magnitude spectrogram.

  • log (bool) – Whether to apply log to the spectral features.

  • eps (float) – A small value to prevent square root of zero.

Returns:

spectr

Return type:

torch.Tensor

Example

>>> a = torch.Tensor([[3, 4]])
>>> spectral_magnitude(a, power=0.5)
tensor([5.])
class speechbrain.processing.features.Filterbank(n_mels=40, log_mel=True, filter_shape='triangular', f_min=0, f_max=8000, n_fft=400, sample_rate=16000, power_spectrogram=2, amin=1e-10, ref_value=1.0, top_db=80.0, param_change_factor=1.0, param_rand_factor=0.0, freeze=True)[source]

Bases: Module

computes filter bank (FBANK) features given spectral magnitudes.

Parameters:
  • n_mels (float) – Number of Mel filters used to average the spectrogram.

  • log_mel (bool) – If True, it computes the log of the FBANKs.

  • filter_shape (str) – Shape of the filters (β€˜triangular’, β€˜rectangular’, β€˜gaussian’).

  • f_min (int) – Lowest frequency for the Mel filters.

  • f_max (int) – Highest frequency for the Mel filters.

  • n_fft (int) – Number of fft points of the STFT. It defines the frequency resolution (n_fft should be<= than win_len).

  • sample_rate (int) – Sample rate of the input audio signal (e.g, 16000)

  • power_spectrogram (float) – Exponent used for spectrogram computation.

  • amin (float) – Minimum amplitude (used for numerical stability).

  • ref_value (float) – Reference value used for the dB scale.

  • top_db (float) – Minimum negative cut-off in decibels.

  • param_change_factor (bool) – If freeze=False, this parameter affects the speed at which the filter parameters (i.e., central_freqs and bands) can be changed. When high (e.g., param_change_factor=1) the filters change a lot during training. When low (e.g. param_change_factor=0.1) the filter parameters are more stable during training

  • param_rand_factor (float) – This parameter can be used to randomly change the filter parameters (i.e, central frequencies and bands) during training. It is thus a sort of regularization. param_rand_factor=0 does not affect, while param_rand_factor=0.15 allows random variations within +-15% of the standard values of the filter parameters (e.g., if the central freq is 100 Hz, we can randomly change it from 85 Hz to 115 Hz).

  • freeze (bool) – If False, it the central frequency and the band of each filter are added into nn.parameters. If True, the standard frozen features are computed.

Example

>>> import torch
>>> compute_fbanks = Filterbank()
>>> inputs = torch.randn([10, 101, 201])
>>> features = compute_fbanks(inputs)
>>> features.shape
torch.Size([10, 101, 40])
forward(spectrogram)[source]

Returns the FBANks.

Parameters:

spectrogram (torch.Tensor) – A batch of spectrogram tensors.

Returns:

fbanks

Return type:

torch.Tensor

class speechbrain.processing.features.DCT(input_size, n_out=20, ortho_norm=True)[source]

Bases: Module

Computes the discrete cosine transform.

This class is primarily used to compute MFCC features of an audio signal given a set of FBANK features as input.

Parameters:
  • input_size (int) – Expected size of the last dimension in the input.

  • n_out (int) – Number of output coefficients.

  • ortho_norm (bool) – Whether to use orthogonal norm.

Example

>>> import torch
>>> inputs = torch.randn([10, 101, 40])
>>> compute_mfccs = DCT(input_size=inputs.size(-1))
>>> features = compute_mfccs(inputs)
>>> features.shape
torch.Size([10, 101, 20])
forward(x)[source]

Returns the DCT of the input tensor.

Parameters:

x (torch.Tensor) – A batch of tensors to transform, usually fbank features.

Returns:

dct

Return type:

torch.Tensor

class speechbrain.processing.features.Deltas(input_size, window_length=5)[source]

Bases: Module

Computes delta coefficients (time derivatives).

Parameters:
  • input_size (int) – The expected size of the inputs for parameter initialization.

  • window_length (int) – Length of the window used to compute the time derivatives.

Example

>>> inputs = torch.randn([10, 101, 20])
>>> compute_deltas = Deltas(input_size=inputs.size(-1))
>>> features = compute_deltas(inputs)
>>> features.shape
torch.Size([10, 101, 20])
forward(x)[source]

Returns the delta coefficients.

Parameters:

x (torch.Tensor) – A batch of tensors.

Returns:

delta_coeff

Return type:

torch.Tensor

class speechbrain.processing.features.ContextWindow(left_frames=0, right_frames=0)[source]

Bases: Module

Computes the context window.

This class applies a context window by gathering multiple time steps in a single feature vector. The operation is performed with a convolutional layer based on a fixed kernel designed for that.

Parameters:
  • left_frames (int) – Number of left frames (i.e, past frames) to collect.

  • right_frames (int) – Number of right frames (i.e, future frames) to collect.

Example

>>> import torch
>>> compute_cw = ContextWindow(left_frames=5, right_frames=5)
>>> inputs = torch.randn([10, 101, 20])
>>> features = compute_cw(inputs)
>>> features.shape
torch.Size([10, 101, 220])
forward(x)[source]

Returns the tensor with the surrounding context.

Parameters:

x (torch.Tensor) – A batch of tensors.

Returns:

cw_x – The context-enriched tensor

Return type:

torch.Tensor

speechbrain.processing.features.gaussian_statistics(x: Tensor, mask: Tensor | None = None, dim: int | tuple | None = None)[source]

Compute first- and second-order moments of data, and return them as the count, mean, and variance of a vector over one or more dimensions.

Parameters:
  • x (torch.Tensor) – The tensor to compute the statistics over.

  • mask (torch.Tensor) – Padding mask to exclude padding from the statistics computation. For dimensions in dim, the mask size should exactly match x. All dimensions other than dim should be ones (e.g. [B, T, 1, …]) Ones / trues are valid positions, and zeros / falses are padding positions.

  • dim (int | tuple | None) – The dimension or dimensions that the statistics should be computed over. The other dimensions are retained in the output. If None, then scalar-valued statistics will be returned.

Returns:

  • count (int) – The number of values in the statistics computation, without padding this is just the product of the lengths of the dimensions in dim.

  • mean (torch.Tensor) – The mean of the non-padding values over the dimensions in dim.

  • variance (torch.Tensor) – The (biased) variance of the non-padding values over dim.

Example

>>> x = torch.tensor([[1., 3., 0.]])
>>> mask = torch.tensor([[True, True, False]])
>>> dim = (0, 1)
>>> count, mean, variance = gaussian_statistics(x, mask, dim)
>>> count
2
>>> mean
tensor(2.)
>>> variance
tensor(1.)
speechbrain.processing.features.combine_gaussian_statistics(left_statistics: Tuple[int, Tensor, Tensor | None], right_statistics: Tuple[int, Tensor, Tensor | None])[source]

Combine the first- and second-order moments from two pieces of data. The data and the result is in the form (count, mean, variance). The result is the mean and variance as if they have been computed on the concatenation of the data for left_statistics and the data for right_statistics.

Parameters:
  • left_statistics (Tuple[int, torch.Tensor, Optional[torch.Tensor]]) – One set of gaussian stats: count, mean, variance

  • right_statistics (Tuple[int, torch.Tensor, Optional[torch.Tensor]]) – Another set of gaussian stats: count, mean, variance

Returns:

  • count – The total number of elements in the data.

  • mean – The combined mean.

  • variance – The combined variance, relative to the new mean. Returns None if either statistics set has variance of None

speechbrain.processing.features.combine_gaussian_statistics_distributed(statistics: Tuple[int, Tensor, Tensor])[source]

Combine the first- and second-order moments from multiple pieces of data using torch.distributed. The data and the result is in the form (count, mean, variance). The result is the mean and variance as if they have been computed on the concatenation of the data for statistics for all parallel processes.

Parameters:

statistics (Tuple[int, torch.Tensor, torch.Tensor]) – A set of gaussian statistics to reduce across all processes. The three elements of the tuple represent the count, mean, and variance.

Returns:

  • count – The total number of elements in the data across processes.

  • mean – The combined mean.

  • variance – The combined variance, relative to the new mean.

speechbrain.processing.features.mean_std_update(x: Tensor, mask: Tensor | None, dim: int | tuple | None, run_count: int, run_mean: Tensor, run_std: Tensor)[source]

Update the mean and variance statistics run_mean and run_std that have been computed on run_count samples to integrate the new samples x.

WARNING: Must be called in sync across processes.

Parameters:
  • x (torch.Tensor) – The new values to add to the running stats.

  • mask (torch.Tensor) – Padding mask to exclude padding from the statistics computation. All dimensions other than batch and time should be ones (e.g. [B, T, 1, …]) Ones / trues are valid positions, and zeros / falses are padding positions.

  • dim (tuple or int) – The dimension or dimensions to reduce (e.g. 1 for length).

  • run_count (float or torch.Tensor) – The running number of samples seen so far.

  • run_mean (float or torch.Tensor) – The running mean of samples seen so far.

  • run_std (float or torch.Tensor) – The running standard deviations from the mean.

Returns:

  • new_run_count (torch.Tensor) – Updated count all samples, now including x.

  • new_run_mean (torch.Tensor) – Updated running mean of all samples, now including x.

  • new_run_std (torch.Tensor) – Updated running standard deviations of all samples, now including x.

Example

>>> input_tensor = torch.tensor([[-1.0, 0.0, 1.0, 0.0]])
>>> input_length = torch.tensor([0.75])
>>> input_length_dim = 1
>>> input_mask = make_padding_mask(input_tensor, input_length, input_length_dim)
>>> dim = (0, input_length_dim)
>>> run_count, run_mean, run_std = 0, torch.tensor(0.0), torch.tensor(1.0)
>>> run_count, run_mean, run_std = mean_std_update(
...     input_tensor, input_mask, dim, run_count, run_mean, run_std
... )
>>> run_count
3
>>> run_mean
tensor(0.)
>>> run_std
tensor(0.8165)
class speechbrain.processing.features.InputNormalization(mean_norm=True, std_norm=True, norm_type='global', avg_factor=None, length_dim=1, update_until_epoch=2, avoid_padding_norm=False, epsilon=1e-10, device='cpu')[source]

Bases: Module

Performs mean and variance normalization over the time and possibly the (global) batch dimension of the input.

When the default norm_type of β€œglobal” is used, running mean and variance statistics are computed and stored incorporating all the samples seen.

WARNING: at first, the running statistics do not represent the β€œtrue” mean and variance, but are estimates based on the data seen so far. Once enough data has been seen, the stats should closely approximate the β€œtrue” values.

WARNING: Using global normalization, the first call of forward() will throw an error if no updates have been performed (including the current batch), i.e. on first call the epoch >= update_until_epoch or the module is first called in .eval() mode.

Parameters:
  • mean_norm (bool, default True) – If True, the mean will be normalized. Passing False is deprecated.

  • std_norm (bool, default True) – If True, the variance will be normalized.

  • norm_type (str, default "global") –

    String parameter whose value defines how the statistics are computed:
    • ’sentence’ computes norms per utterance (no running stats)

    • ’batch’ computes norms per input tensor (no running stats)

    • ’global’ computes norms over all inputs (single mean, variance)

    • ’speaker’ - DEPRECATED

  • avg_factor (float, optional) – Passing avg_factor is DEPRECATED as this exactly matches the behavior of BatchNorm. To maintain this behavior, use speechbrain.nnet.normalization.BatchNorm1d(momentum=avg_factor).

  • length_dim (int, default 1) – The dimension for which to mask out the padding positions.

  • update_until_epoch (int, default 2) – The epoch for which updates to the norm stats should stop. By default, stops after one epoch of updates, as when epoch == update_until_epoch then the updates stop immediately.

  • avoid_padding_norm (bool, default False) – Regardless of the value passed here, padding is ignored for statistics computation. However, if False is passed for avoid_padding_norm, padding will get normalized along with the rest of the input tensor. If True, the padding will not be affected by this normalization operation.

  • epsilon (float, default 1e-10) – A small value to improve the numerical stability of the variance.

  • device (str or torch.device) – The device on which to create the global statistics. Can be changed later with .to(device).

Example

>>> import torch
>>> inputs = torch.arange(9).view(3, 3).float()
>>> inputs
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
>>> input_lens = torch.ones(3)
>>> norm = InputNormalization(norm_type="sentence")
>>> features = norm(inputs, input_lens)
>>> features
tensor([[-1.2247,  0.0000,  1.2247],
        [-1.2247,  0.0000,  1.2247],
        [-1.2247,  0.0000,  1.2247]])
>>> norm = InputNormalization(norm_type="batch")
>>> features = norm(inputs, input_lens)
>>> features
tensor([[-1.5492, -1.1619, -0.7746],
        [-0.3873,  0.0000,  0.3873],
        [ 0.7746,  1.1619,  1.5492]])
>>> norm = InputNormalization(norm_type="global")
>>> features = norm(inputs, input_lens)
>>> features.mean() < 1e-7
tensor(True)
>>> features = norm(inputs + 1, input_lens)
>>> features.mean()
tensor(0.1901)
>>> features = norm(inputs, input_lens)
>>> features.mean()
tensor(-0.1270)
>>> features = norm(inputs - 1, input_lens)
>>> features.mean()
tensor(-0.3735)
>>> features = norm(inputs, input_lens)
>>> features.mean() < 1e-7
tensor(True)
Dict

alias of Dict

spk_dict_mean: Dict[int, Tensor]
spk_dict_std: Dict[int, Tensor]
spk_dict_count: Dict[int, int]
NORM_TYPES = ('global', 'batch', 'sentence')
forward(x, lengths=None, epoch=None)[source]

Normalizes the input tensor, x, according to the norm_type.

Excludes the padded portion of the tensor by using the passed relative lengths. Automatically updates running mean, variance if β€œglobal” or β€œspeaker” norm is used.

Parameters:
  • x (torch.Tensor) – The input tensor to normalize.

  • lengths (torch.Tensor, optional) – The relative length of each sentence (e.g, [0.7, 0.9, 1.0]), used to avoid computing stats on the padding part of the tensor.

  • epoch (int, optional) – The current epoch count, used to stop updates to global stats after enough samples have been seen (e.g. one epoch).

Returns:

x – The normalized tensor.

Return type:

torch.Tensor

to(device)[source]

Puts the needed tensors in the right device.

speechbrain.processing.features.make_padding_mask(x, lengths=None, length_dim=1, eps=1e-06)[source]

Create a mask from relative lengths along a given dimension.

Parameters:
  • x (torch.Tensor) – The input tensor demonstrating the size of the target mask.

  • lengths (torch.Tensor, optional) – The relative lengths of an input batch of utterances. If None, all positions are considered valid (i.e. mask is all True).

  • length_dim (int, default 1) – The dimension for which the lengths indicate padded positions.

  • eps (float, default 1e-8) – A small constant to avoid floating point errors in computation of the padding mask.

Returns:

padding_mask – A boolean tensor with True for valid positions and False for padding positions. The padding_mask can be multiplied with x via broadcasting, as all dimensions other than length and batch are singleton dimensions.

Return type:

torch.Tensor

Example

>>> input_tensor = torch.arange(3 * 4 * 2).view(3, 4, 2)
>>> lengths = torch.tensor([1.0, 0.75, 0.5])
>>> mask = make_padding_mask(input_tensor, lengths)
>>> mask.shape
torch.Size([3, 4, 1])
>>> input_tensor * mask
tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5],
         [ 6,  7]],

        [[ 8,  9],
         [10, 11],
         [12, 13],
         [ 0,  0]],

        [[16, 17],
         [18, 19],
         [ 0,  0],
         [ 0,  0]]])
class speechbrain.processing.features.GlobalNorm(norm_mean=0.0, norm_std=1.0, update_steps=None, length_dim=2, mask_value=0.0)[source]

Bases: Module

A global normalization module - computes a single mean and standard deviation for the entire batch across unmasked positions and uses it to normalize the inputs to the desired mean and standard deviation.

This normalization is reversible - it is possible to use the .denormalize() method to recover the original values.

Parameters:
  • norm_mean (float, default 0.0) – the desired normalized mean

  • norm_std (float, default 1.0) – the desired normalized standard deviation

  • update_steps (float, optional) – the number of steps over which statistics will be collected

  • length_dim (int, default 2) – the dimension used to represent the length

  • mask_value (float, default 0.0) – the value with which to fill masked positions without a mask_value, the masked positions would be normalized, which might not be desired

Example

>>> import torch
>>> from speechbrain.processing.features import GlobalNorm
>>> global_norm = GlobalNorm(
...     norm_mean=0.5,
...     norm_std=0.2,
...     update_steps=3,
...     length_dim=1
... )
>>> x = torch.tensor([[1., 2., 3.]])
>>> x_norm = global_norm(x)
>>> x_norm
tensor([[0.2551, 0.5000, 0.7449]])
>>> x = torch.tensor([[5., 10., -4.]])
>>> x_norm = global_norm(x)
>>> x_norm
tensor([[0.6027, 0.8397, 0.1761]])
>>> x_denorm = global_norm.denormalize(x_norm)
>>> x_denorm
tensor([[ 5.0000, 10.0000, -4.0000]])
>>> x = torch.tensor([[100., -100., -50.]])
>>> global_norm.freeze()
>>> global_norm(x)
tensor([[ 5.1054, -4.3740, -2.0041]])
>>> global_norm.denormalize(x_norm)
tensor([[ 5.0000, 10.0000, -4.0000]])
>>> global_norm.unfreeze()
>>> global_norm(x)
tensor([[ 5.1054, -4.3740, -2.0041]])
>>> global_norm.denormalize(x_norm)
tensor([[ 5.0000, 10.0000, -4.0000]])
forward(x, lengths=None, mask_value=None, skip_update=False)[source]

Normalizes the tensor provided

Parameters:
  • x (torch.Tensor) – the tensor to normalize

  • lengths (torch.Tensor, optional) – a tensor of relative lengths (padding will not count towards normalization)

  • mask_value (float, optional) – the value to use for masked positions

  • skip_update (bool, default False) – whether to skip updates to the norm

Returns:

result – the normalized tensor

Return type:

torch.Tensor

should_update()[source]

Whether to perform an update.

normalize(x)[source]

Performs the normalization operation against the running mean and standard deviation

Parameters:

x (torch.Tensor) – the tensor to normalize

Returns:

result – the normalized tensor

Return type:

torch.Tensor

denormalize(x)[source]

Reverses the normalization process

Parameters:

x (torch.Tensor) – a normalized tensor

Returns:

result – a denormalized version of x

Return type:

torch.Tensor

freeze()[source]

Stops updates to the running mean/std

unfreeze()[source]

Resumes updates to the running mean/std

class speechbrain.processing.features.MinLevelNorm(min_level_db)[source]

Bases: Module

A commonly used normalization for the decibel scale

The scheme is as follows

x_norm = (x - min_level_db)/-min_level_db * 2 - 1

The rationale behind the scheme is as follows:

The top of the scale is assumed to be 0db. x_rel = (x - min) / (max - min) gives the relative position on the scale between the minimum and the maximum where the minimum is 0. and the maximum is 1.

The subsequent rescaling (x_rel * 2 - 1) puts it on a scale from -1. to 1. with the middle of the range centered at zero.

Parameters:

min_level_db (float) – the minimum level

Example

>>> norm = MinLevelNorm(min_level_db=-100.)
>>> x = torch.tensor([-50., -20., -80.])
>>> x_norm = norm(x)
>>> x_norm
tensor([ 0.0000,  0.6000, -0.6000])
forward(x)[source]

Normalizes audio features in decibels (usually spectrograms)

Parameters:

x (torch.Tensor) – input features

Returns:

normalized_features – the normalized features

Return type:

torch.Tensor

denormalize(x)[source]

Reverses the min level normalization process

Parameters:

x (torch.Tensor) – the normalized tensor

Returns:

result – the denormalized tensor

Return type:

torch.Tensor

class speechbrain.processing.features.DynamicRangeCompression(multiplier=1, clip_val=1e-05)[source]

Bases: Module

Dynamic range compression for audio signals - clipped log scale with an optional multiplier

Parameters:
  • multiplier (float) – the multiplier constant

  • clip_val (float) – the minimum accepted value (values below this minimum will be clipped)

Example

>>> drc = DynamicRangeCompression()
>>> x = torch.tensor([10., 20., 0., 30.])
>>> drc(x)
tensor([  2.3026,   2.9957, -11.5129,   3.4012])
>>> drc = DynamicRangeCompression(2.)
>>> x = torch.tensor([10., 20., 0., 30.])
>>> drc(x)
tensor([  2.9957,   3.6889, -10.8198,   4.0943])
forward(x)[source]

Performs the forward pass

Parameters:

x (torch.Tensor) – the source signal

Returns:

result – the result

Return type:

torch.Tensor