Source code for speechbrain.inference.separation

""" Specifies the inference interfaces for speech separation modules.

Authors:
 * Aku Rouhe 2021
 * Peter Plantinga 2021
 * Loren Lugosch 2020
 * Mirco Ravanelli 2020
 * Titouan Parcollet 2021
 * Abdel Heba 2021
 * Andreas Nautsch 2022, 2023
 * Pooneh Mousavi 2023
 * Sylvain de Langen 2023
 * Adel Moumen 2023
 * Pradnya Kandarkar 2023
"""
import torch
import torchaudio
import torch.nn.functional as F
from speechbrain.utils.fetching import fetch
from speechbrain.utils.data_utils import split_path
from speechbrain.inference.interfaces import Pretrained


[docs] class SepformerSeparation(Pretrained): """A "ready-to-use" speech separation model. Uses Sepformer architecture. Example ------- >>> tmpdir = getfixture("tmpdir") >>> model = SepformerSeparation.from_hparams( ... source="speechbrain/sepformer-wsj02mix", ... savedir=tmpdir) >>> mix = torch.randn(1, 400) >>> est_sources = model.separate_batch(mix) >>> print(est_sources.shape) torch.Size([1, 400, 2]) """ MODULES_NEEDED = ["encoder", "masknet", "decoder"]
[docs] def separate_batch(self, mix): """Run source separation on batch of audio. Arguments --------- mix : torch.Tensor The mixture of sources. Returns ------- tensor Separated sources """ # Separation mix = mix.to(self.device) mix_w = self.mods.encoder(mix) est_mask = self.mods.masknet(mix_w) mix_w = torch.stack([mix_w] * self.hparams.num_spks) sep_h = mix_w * est_mask # Decoding est_source = torch.cat( [ self.mods.decoder(sep_h[i]).unsqueeze(-1) for i in range(self.hparams.num_spks) ], dim=-1, ) # T changed after conv1d in encoder, fix it here T_origin = mix.size(1) T_est = est_source.size(1) if T_origin > T_est: est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est)) else: est_source = est_source[:, :T_origin, :] return est_source
[docs] def separate_file(self, path, savedir="audio_cache"): """Separate sources from file. Arguments --------- path : str Path to file which has a mixture of sources. It can be a local path, a web url, or a huggingface repo. savedir : path Path where to store the wav signals (when downloaded from the web). Returns ------- tensor Separated sources """ source, fl = split_path(path) path = fetch(fl, source=source, savedir=savedir) batch, fs_file = torchaudio.load(path) batch = batch.to(self.device) fs_model = self.hparams.sample_rate # resample the data if needed if fs_file != fs_model: print( "Resampling the audio from {} Hz to {} Hz".format( fs_file, fs_model ) ) tf = torchaudio.transforms.Resample( orig_freq=fs_file, new_freq=fs_model ).to(self.device) batch = batch.mean(dim=0, keepdim=True) batch = tf(batch) est_sources = self.separate_batch(batch) est_sources = ( est_sources / est_sources.abs().max(dim=1, keepdim=True)[0] ) return est_sources
[docs] def forward(self, mix): """Runs separation on the input mix""" return self.separate_batch(mix)