speechbrain.processing.vocal_features moduleο
Functions for analyzing vocal characteristics: jitter, shimmer, HNR, and GNE.
These are typically used for analysis of dysarthric voices using more traditional approaches (i.e. not deep learning). Often useful as a baseline for e.g. pathology detection. Inspired by PRAAT.
- Authors
Peter Plantinga, 2024
Summaryο
Functions:
Generate autocorrelation scores using circular convolution. |
|
Compute features based on autocorrelation |
|
Computes the correlation between two sets of frames. |
|
An algorithm for GNE computation from the original paper: |
|
Compute the hilbert envelope of the signal in a specific frequency band using FFT. |
|
Function to compute periodic features: jitter, shimmer |
|
Compute statistical measures on spectral frames such as flux, skew, spread, flatness. |
|
Perform inverse filtering on frames to estimate glottal pulse train. |
|
Normalize the given value by the spectrum. |
Referenceο
- speechbrain.processing.vocal_features.compute_autocorr_features(frames, min_lag, max_lag, neighbors=5)[source]ο
Compute features based on autocorrelation
- Parameters:
frames (torch.Tensor) β The audio frames to be evaluated for autocorrelation, shape [batch, frame, sample]
min_lag (int) β The minimum number of samples to consider for potential period length.
max_lag (int) β The maximum number of samples to consider for potential period length.
neighbors (int) β The number of neighbors to use for rolling median β to avoid octave errors.
- Returns:
harmonicity (torch.Tensor) β The highest autocorrelation score relative to the 0-lag score. Used to compute HNR
best_lags (torch.Tensor) β The lag corresponding to the highest autocorrelation score, an estimate of period length.
Example
>>> audio = torch.rand(1, 16000) >>> frames = audio.unfold(-1, 800, 200) >>> frames.shape torch.Size([1, 77, 800]) >>> harmonicity, best_lags = compute_autocorr_features(frames, 100, 200) >>> harmonicity.shape torch.Size([1, 77]) >>> best_lags.shape torch.Size([1, 77])
- speechbrain.processing.vocal_features.autocorrelate(frames)[source]ο
Generate autocorrelation scores using circular convolution.
- Parameters:
frames (torch.Tensor) β The audio frames to be evaluated for autocorrelation, shape [batch, frame, sample]
- Returns:
autocorrelation β The ratio of the best candidate lagβs autocorrelation score against the theoretical maximum autocorrelation score at lag 0. Normalized by the autocorrelation_score of the window.
- Return type:
torch.Tensor
Example
>>> audio = torch.rand(1, 16000) >>> frames = audio.unfold(-1, 800, 200) >>> frames.shape torch.Size([1, 77, 800]) >>> autocorrelation = autocorrelate(frames) >>> autocorrelation.shape torch.Size([1, 77, 401])
- speechbrain.processing.vocal_features.compute_periodic_features(frames, best_lags, neighbors=4)[source]ο
Function to compute periodic features: jitter, shimmer
- Parameters:
frames (torch.Tensor) β The framed audio to use for feature computation, dims [batch, frame, sample].
best_lags (torch.Tensor) β The estimated period length for each frame, dims [batch, frame].
neighbors (int) β Number of neighbors to use in comparison.
- Returns:
jitter (torch.Tensor) β The average absolute deviation in period over the frame.
shimmer (torch.Tensor) β The average absolute deviation in amplitude over the frame.
Example
>>> audio = torch.rand(1, 16000) >>> frames = audio.unfold(-1, 800, 200) >>> frames.shape torch.Size([1, 77, 800]) >>> harmonicity, best_lags = compute_autocorr_features(frames, 100, 200) >>> jitter, shimmer = compute_periodic_features(frames, best_lags) >>> jitter.shape torch.Size([1, 77]) >>> shimmer.shape torch.Size([1, 77])
- speechbrain.processing.vocal_features.compute_spectral_features(spectrum, eps=1e-10)[source]ο
Compute statistical measures on spectral frames such as flux, skew, spread, flatness.
Reference page for computing values: https://www.mathworks.com/help/audio/ug/spectral-descriptors.html
- Parameters:
spectrum (torch.Tensor) β The spectrum to use for feature computation, dims [batch, frame, freq].
eps (float) β A small value to avoid division by 0.
- Returns:
features β
- A [batch, frame, 8] tensor of spectral features for each frame:
centroid: The mean of the spectrum.
spread: The stdev of the spectrum.
skew: The spectral balance.
kurtosis: The spectral tailedness.
entropy: The peakiness of the spectrum.
flatness: The ratio of geometric mean to arithmetic mean.
crest: The ratio of spectral maximum to arithmetic mean.
flux: The average delta-squared between one spectral value and itβs successor.
- Return type:
torch.Tensor
Example
>>> audio = torch.rand(1, 16000) >>> window_size = 800 >>> frames = audio.unfold(-1, window_size, 200) >>> frames.shape torch.Size([1, 77, 800]) >>> hann = torch.hann_window(window_size).view(1, 1, -1) >>> windowed_frames = frames * hann >>> spectrum = torch.abs(torch.fft.rfft(windowed_frames)) >>> spectral_features = compute_spectral_features(spectrum) >>> spectral_features.shape torch.Size([1, 77, 8])
- speechbrain.processing.vocal_features.spec_norm(value, spectrum, eps=1e-10)[source]ο
Normalize the given value by the spectrum.
- speechbrain.processing.vocal_features.compute_gne(audio, sample_rate=16000, bandwidth=1000, fshift=300, frame_len=0.03, hop_len=0.01)[source]ο
An algorithm for GNE computation from the original paper:
βGlottal-to-Noise Excitation Ratio - a New Measure for Describing Pathological Voicesβ by D. Michaelis, T. Oramss, and H. W. Strube.
This algorithm divides the signal into frequency bands, and compares the correlation between the bands. High correlation indicates a relatively low amount of noise in the signal, whereas lower correlation could be a sign of pathology in the vocal signal.
Godino-Llorente et al. in βThe Effectiveness of the Glottal to Noise Excitation Ratio for the Screening of Voice Disorders.β explore the goodness of the bandwidth and frequency shift parameters, the defaults here are the ones recommended in that work.
- Parameters:
audio (torch.Tensor) β The batched audio signal to use for GNE computation, [batch, sample]
sample_rate (float) β The sample rate of the input audio.
bandwidth (float) β The width of the frequency bands used for computing correlation.
fshift (float) β The shift between frequency bands used for computing correlation.
frame_len (float) β Length of each analysis frame, in seconds.
hop_len (float) β Length of time between the start of each analysis frame, in seconds.
- Returns:
gne β The glottal-to-noise-excitation ratio for each frame of the audio signal.
- Return type:
torch.Tensor
Example
>>> sample_rate = 16000 >>> audio = torch.rand(1, sample_rate) # 1s of audio >>> gne = compute_gne(audio, sample_rate=sample_rate) >>> gne.shape torch.Size([1, 98])
- speechbrain.processing.vocal_features.inverse_filter(frames, lpc_order=13)[source]ο
Perform inverse filtering on frames to estimate glottal pulse train.
Uses autocorrelation method and Linear Predictive Coding (LPC). Algorithm from https://course.ece.cmu.edu/~ece792/handouts/RS_Chap_LPC.pdf
- Parameters:
frames (torch.Tensor) β The audio frames to filter using inverse filter.
lpc_order (int) β The size of the filter to compute and use on the frames.
- Returns:
filtered_frames β The frames after the inverse filter is applied
- Return type:
torch.Tensor
Example
>>> audio = torch.rand(1, 10000) >>> frames = audio.unfold(-1, 300, 100) >>> frames.shape torch.Size([1, 98, 300]) >>> filtered_frames = inverse_filter(frames) >>> filtered_frames.shape torch.Size([1, 98, 300])
- speechbrain.processing.vocal_features.compute_hilbert_envelopes(frames, center_freq, bandwidth=1000, sample_rate=10000)[source]ο
Compute the hilbert envelope of the signal in a specific frequency band using FFT.
- Parameters:
- Returns:
envelopes β The computed envelopes.
- Return type:
torch.Tensor
Example
>>> audio = torch.rand(1, 10000) >>> frames = audio.unfold(-1, 300, 100) >>> frames.shape torch.Size([1, 98, 300]) >>> envelope = compute_hilbert_envelopes(frames, 1000) >>> envelope.shape torch.Size([1, 98, 300])
- speechbrain.processing.vocal_features.compute_cross_correlation(frames_a, frames_b, width=None)[source]ο
Computes the correlation between two sets of frames.
- Parameters:
frames_a (torch.Tensor)
frames_b (torch.Tensor) β The two sets of frames to compare using cross-correlation, shape [batch, frame, sample]
width (int, default is None) β The number of samples before and after 0 lag. A width of 3 returns 7 results. If None, 0 lag is put at the front, and the result is 1/2 the original length + 1, a nice default for autocorrelation as there are no repeated values.
- Return type:
The cross-correlation between frames_a and frames_b.
Example
>>> frames = torch.arange(10).view(1, 1, -1).float() >>> compute_cross_correlation(frames, frames, width=3) tensor([[[0.6316, 0.7193, 0.8421, 1.0000, 0.8421, 0.7193, 0.6316]]]) >>> compute_cross_correlation(frames, frames) tensor([[[1.0000, 0.8421, 0.7193, 0.6316, 0.5789, 0.5614]]])