Source code for speechbrain.processing.NMF

"""Non-negative matrix factorization

Authors
 * Cem Subakan
"""
import torch
from speechbrain.processing.features import spectral_magnitude
import speechbrain.processing.features as spf


[docs] def spectral_phase(stft, power=2, log=False): """Returns the phase of a complex spectrogram. Arguments --------- stft : torch.Tensor A tensor, output from the stft function. Example ------- >>> BS, nfft, T = 10, 20, 300 >>> X_stft = torch.randn(BS, nfft//2 + 1, T, 2) >>> phase_mix = spectral_phase(X_stft) """ phase = torch.atan2(stft[:, :, :, 1], stft[:, :, :, 0]) return phase
[docs] def NMF_separate_spectra(Whats, Xmix): """This function separates the mixture signals, given NMF template matrices. Arguments --------- Whats : list This list contains the list [W1, W2], where W1 W2 are respectively the NMF template matrices that correspond to source1 and source2. W1, W2 are of size [nfft/2 + 1, K], where nfft is the fft size for STFT, and K is the number of vectors (templates) in W. Xmix : torch.tensor This is the magnitude spectra for the mixtures. The size is [BS x T x nfft//2 + 1] where, BS = batch size, nfft = fft size, T = number of time steps in the spectra. Outputs ------- X1hat : Separated spectrum for source1 Size = [BS x (nfft/2 +1) x T] where, BS = batch size, nfft = fft size, T = number of time steps in the spectra. X2hat : Separated Spectrum for source2 The size definitions are the same as above. Example -------- >>> BS, nfft, T = 4, 20, 400 >>> K1, K2 = 10, 10 >>> W1hat = torch.randn(nfft//2 + 1, K1) >>> W2hat = torch.randn(nfft//2 + 1, K2) >>> Whats = [W1hat, W2hat] >>> Xmix = torch.randn(BS, T, nfft//2 + 1) >>> X1hat, X2hat = NMF_separate_spectra(Whats, Xmix) """ W1, W2 = Whats nmixtures = Xmix.shape[0] Xmix = Xmix.permute(0, 2, 1).reshape(-1, Xmix.size(-1)).t() n = Xmix.shape[1] eps = 1e-20 # Normalize input g = Xmix.sum(dim=0) + eps z = Xmix / g # initialize w = torch.cat([W1, W2], dim=1) K = w.size(1) K1 = W1.size(1) h = 0.1 * torch.rand(K, n) h /= torch.sum(h, dim=0) + eps for ep in range(1000): v = z / (torch.matmul(w, h) + eps) nh = h * torch.matmul(w.t(), v) h = nh / (torch.sum(nh, dim=0) + eps) h *= g Xhat1 = torch.matmul(w[:, :K1], h[:K1, :]) Xhat1 = torch.split(Xhat1.unsqueeze(0), Xhat1.size(1) // nmixtures, dim=2) Xhat1 = torch.cat(Xhat1, dim=0) Xhat2 = torch.matmul(w[:, K1:], h[K1:, :]) Xhat2 = torch.split(Xhat2.unsqueeze(0), Xhat2.size(1) // nmixtures, dim=2) Xhat2 = torch.cat(Xhat2, dim=0) return Xhat1, Xhat2
[docs] def reconstruct_results( X1hat, X2hat, X_stft, sample_rate, win_length, hop_length, ): """This function reconstructs the separated spectra into waveforms. Arguments --------- Xhat1 : torch.tensor The separated spectrum for source 1 of size [BS, nfft/2 + 1, T], where, BS = batch size, nfft = fft size, T = length of the spectra. Xhat2 : torch.tensor The separated spectrum for source 2 of size [BS, nfft/2 + 1, T]. The size definitions are the same as Xhat1. X_stft : torch.tensor This is the magnitude spectra for the mixtures. The size is [BS x nfft//2 + 1 x T x 2] where, BS = batch size, nfft = fft size, T = number of time steps in the spectra. The last dimension is to represent complex numbers. sample_rate : int The sampling rate (in Hz) in which we would like to save the results. win_length : int The length of stft windows (in ms). hop_length : int The length with which we shift the STFT windows (in ms). Returns ------- x1hats : list List of waveforms for source 1. x2hats : list List of waveforms for source 2. Example ------- >>> BS, nfft, T = 10, 512, 16000 >>> sample_rate, win_length, hop_length = 16000, 25, 10 >>> X1hat = torch.randn(BS, nfft//2 + 1, T) >>> X2hat = torch.randn(BS, nfft//2 + 1, T) >>> X_stft = torch.randn(BS, nfft//2 + 1, T, 2) >>> x1hats, x2hats = reconstruct_results(X1hat, X2hat, X_stft, sample_rate, win_length, hop_length) """ ISTFT = spf.ISTFT( sample_rate=sample_rate, win_length=win_length, hop_length=hop_length ) phase_mix = spectral_phase(X_stft) mag_mix = spectral_magnitude(X_stft, power=2) x1hats, x2hats = [], [] eps = 1e-25 for i in range(X1hat.shape[0]): X1hat_stft = ( (X1hat[i] / (eps + X1hat[i] + X2hat[i])).unsqueeze(-1) * mag_mix[i].unsqueeze(-1) * torch.cat( [ torch.cos(phase_mix[i].unsqueeze(-1)), torch.sin(phase_mix[i].unsqueeze(-1)), ], dim=-1, ) ) X2hat_stft = ( (X2hat[i] / (eps + X1hat[i] + X2hat[i])).unsqueeze(-1) * mag_mix[i].unsqueeze(-1) * torch.cat( [ torch.cos(phase_mix[i].unsqueeze(-1)), torch.sin(phase_mix[i].unsqueeze(-1)), ], dim=-1, ) ) X1hat_stft = X1hat_stft.unsqueeze(0).permute(0, 2, 1, 3) X2hat_stft = X2hat_stft.unsqueeze(0).permute(0, 2, 1, 3) shat1 = ISTFT(X1hat_stft) shat2 = ISTFT(X2hat_stft) div_factor = 10 x1 = shat1 / (div_factor * shat1.std()) x2 = shat2 / (div_factor * shat2.std()) x1hats.append(x1) x2hats.append(x2) return x1hats, x2hats