"""Library for computing STOI computation.
Reference: "End-to-End Waveform Utterance Enhancement for Direct Evaluation
Metrics Optimization by Fully Convolutional Neural Networks", TASLP, 2018
Authors:
Szu-Wei, Fu 2020
"""
import numpy as np
import torch
import torchaudio
from speechbrain.utils.torch_audio_backend import check_torchaudio_backend
check_torchaudio_backend()
smallVal = np.finfo("float").eps # To avoid divide by zero
[docs]
def thirdoct(fs, nfft, num_bands, min_freq):
"""Returns the 1/3 octave band matrix.
Arguments
---------
fs : int
Sampling rate.
nfft : int
FFT size.
num_bands : int
Number of 1/3 octave bands.
min_freq : int
Center frequency of the lowest 1/3 octave band.
Returns
-------
obm : tensor
Octave Band Matrix.
"""
f = torch.linspace(0, fs, nfft + 1)
f = f[: int(nfft / 2) + 1]
k = torch.from_numpy(np.array(range(num_bands)).astype(float))
cf = torch.pow(2.0 ** (1.0 / 3), k) * min_freq
freq_low = min_freq * torch.pow(2.0, (2 * k - 1) / 6)
freq_high = min_freq * torch.pow(2.0, (2 * k + 1) / 6)
obm = torch.zeros(num_bands, len(f)) # a verifier
for i in range(len(cf)):
# Match 1/3 oct band freq with fft frequency bin
f_bin = torch.argmin(torch.square(f - freq_low[i]))
freq_low[i] = f[f_bin]
fl_ii = f_bin
f_bin = torch.argmin(torch.square(f - freq_high[i]))
freq_high[i] = f[f_bin]
fh_ii = f_bin
# Assign to the octave band matrix
obm[i, fl_ii:fh_ii] = 1
return obm
[docs]
def removeSilentFrames(x, y, dyn_range=40, N=256, K=128):
"""Removes silent frames from the STOI computation.
This function can be used as a loss function for training
with SGD-based updates.
Arguments
---------
x: torch.Tensor
The clean (reference) waveforms.
y: torch.Tensor
The degraded (enhanced) waveforms.
dyn_range: int
Dynamic range used for mask computation.
N: int
Window length.
K: int
Step size.
Returns
-------
list with 2 elements, x and y with silence removed.
"""
w = torch.unsqueeze(torch.from_numpy(np.hanning(N)), 0).to(torch.float)
X1 = x[0 : int(x.shape[0]) // N * N].reshape(int(x.shape[0]) // N, N).T
X2 = (
x[K : (int(x.shape[0]) - K) // N * N + K]
.reshape((int(x.shape[0]) - K) // N, N)
.T
)
X = torch.zeros(N, X1.shape[1] + X2.shape[1])
X[:, 0::2] = X1
X[:, 1::2] = X2
energy = 20 * torch.log10(
torch.sqrt(torch.matmul(w**2, X**2)) / 16.0 + smallVal
)
Max_energy = torch.max(energy)
msk = torch.squeeze((energy - Max_energy + dyn_range > 0))
Y1 = y[0 : int(y.shape[0]) // N * N].reshape(int(y.shape[0]) // N, N).T
Y2 = (
y[K : (int(y.shape[0]) - K) // N * N + K]
.reshape((int(y.shape[0]) - K) // N, N)
.T
)
Y = torch.zeros(N, Y1.shape[1] + Y2.shape[1])
Y[:, 0::2] = Y1
Y[:, 1::2] = Y2
x_sil = w.T.repeat(1, X[:, msk].shape[-1]) * X[:, msk]
y_sil = w.T.repeat(1, X[:, msk].shape[-1]) * Y[:, msk]
x_sil = torch.cat(
(
x_sil[0:K, 0],
(x_sil[0:K, 1:] + x_sil[K:, 0:-1]).T.flatten(),
x_sil[K:N, -1],
),
axis=0,
)
y_sil = torch.cat(
(
y_sil[0:K, 0],
(y_sil[0:K, 1:] + y_sil[K:, 0:-1]).T.flatten(),
y_sil[K:N, -1],
),
axis=0,
)
return [x_sil, y_sil]
[docs]
def stoi_loss(y_pred_batch, y_true_batch, lens, reduction="mean"):
"""Compute the STOI score and return -1 * that score.
This function can be used as a loss function for training
with SGD-based updates.
Arguments
---------
y_pred_batch : torch.Tensor
The degraded (enhanced) waveforms.
y_true_batch : torch.Tensor
The clean (reference) waveforms.
lens : torch.Tensor
The relative lengths of the waveforms within the batch.
reduction : str
The type of reduction ("mean" or "batch") to use.
Returns
-------
The computed STOI loss.
Example
-------
>>> a = torch.sin(torch.arange(16000, dtype=torch.float32)).unsqueeze(0)
>>> b = a + 0.001
>>> -stoi_loss(b, a, torch.ones(1))
tensor(0.7...)
"""
y_pred_batch = torch.squeeze(y_pred_batch, dim=-1)
y_true_batch = torch.squeeze(y_true_batch, dim=-1)
batch_size = y_pred_batch.shape[0]
fs = 16000 # Sampling rate
N = 30 # length of temporal envelope vectors
J = 15.0 # Number of one-third octave bands
octave_band = thirdoct(fs=10000, nfft=512, num_bands=15, min_freq=150)
c = 5.62341325 # 10^(-Beta/20) with Beta = -15
D = torch.zeros(batch_size)
resampler = torchaudio.transforms.Resample(fs, 10000).to(
y_pred_batch.device
)
for i in range(0, batch_size): # Run over mini-batches
y_true = y_true_batch[i, 0 : int(lens[i] * y_pred_batch.shape[1])]
y_pred = y_pred_batch[i, 0 : int(lens[i] * y_pred_batch.shape[1])]
y_true, y_pred = resampler(y_true), resampler(y_pred)
[y_sil_true, y_sil_pred] = removeSilentFrames(y_true, y_pred)
stft_true = torchaudio.transforms.Spectrogram(
n_fft=512, win_length=256, hop_length=128, power=2
)(y_sil_true)
stft_pred = torchaudio.transforms.Spectrogram(
n_fft=512, win_length=256, hop_length=128, power=2
)(y_sil_pred)
OCT_true = torch.sqrt(torch.matmul(octave_band, stft_true) + 1e-14)
OCT_pred = torch.sqrt(torch.matmul(octave_band, stft_pred) + 1e-14)
M = int(
stft_pred.shape[-1] - (N - 1)
) # number of temporal envelope vectors
X = torch.zeros(15 * M, 30)
Y = torch.zeros(15 * M, 30)
for m in range(0, M): # Run over temporal envelope vectors
X[m * 15 : (m + 1) * 15, :] = OCT_true[:, m : m + N]
Y[m * 15 : (m + 1) * 15, :] = OCT_pred[:, m : m + N]
alpha = torch.norm(X, dim=-1, keepdim=True) / (
torch.norm(Y, dim=-1, keepdim=True) + smallVal
)
ay = Y * alpha
y = torch.min(ay, X + X * c)
xn = X - torch.mean(X, dim=-1, keepdim=True)
xn = xn / (torch.norm(xn, dim=-1, keepdim=True) + smallVal)
yn = y - torch.mean(y, dim=-1, keepdim=True)
yn = yn / (torch.norm(yn, dim=-1, keepdim=True) + smallVal)
d = torch.sum(xn * yn)
D[i] = d / (J * M)
if reduction == "mean":
return -D.mean()
return -D