Source code for speechbrain.inference.enhancement

""" Specifies the inference interfaces for speech enhancement modules.

Authors:
 * Aku Rouhe 2021
 * Peter Plantinga 2021
 * Loren Lugosch 2020
 * Mirco Ravanelli 2020
 * Titouan Parcollet 2021
 * Abdel Heba 2021
 * Andreas Nautsch 2022, 2023
 * Pooneh Mousavi 2023
 * Sylvain de Langen 2023
 * Adel Moumen 2023
 * Pradnya Kandarkar 2023
"""
import torch
import torchaudio
from speechbrain.inference.interfaces import Pretrained
from speechbrain.utils.callchains import lengths_arg_exists


[docs] class SpectralMaskEnhancement(Pretrained): """A ready-to-use model for speech enhancement. Arguments --------- See ``Pretrained``. Example ------- >>> import torch >>> from speechbrain.inference.enhancement import SpectralMaskEnhancement >>> # Model is downloaded from the speechbrain HuggingFace repo >>> tmpdir = getfixture("tmpdir") >>> enhancer = SpectralMaskEnhancement.from_hparams( ... source="speechbrain/metricgan-plus-voicebank", ... savedir=tmpdir, ... ) >>> enhanced = enhancer.enhance_file( ... "speechbrain/metricgan-plus-voicebank/example.wav" ... ) """ HPARAMS_NEEDED = ["compute_stft", "spectral_magnitude", "resynth"] MODULES_NEEDED = ["enhance_model"]
[docs] def compute_features(self, wavs): """Compute the log spectral magnitude features for masking. Arguments --------- wavs : torch.Tensor A batch of waveforms to convert to log spectral mags. """ feats = self.hparams.compute_stft(wavs) feats = self.hparams.spectral_magnitude(feats) return torch.log1p(feats)
[docs] def enhance_batch(self, noisy, lengths=None): """Enhance a batch of noisy waveforms. Arguments --------- noisy : torch.Tensor A batch of waveforms to perform enhancement on. lengths : torch.Tensor The lengths of the waveforms if the enhancement model handles them. Returns ------- torch.Tensor A batch of enhanced waveforms of the same shape as input. """ noisy = noisy.to(self.device) noisy_features = self.compute_features(noisy) # Perform masking-based enhancement, multiplying output with input. if lengths is not None: mask = self.mods.enhance_model(noisy_features, lengths=lengths) else: mask = self.mods.enhance_model(noisy_features) enhanced = torch.mul(mask, noisy_features) # Return resynthesized waveforms return self.hparams.resynth(torch.expm1(enhanced), noisy)
[docs] def enhance_file(self, filename, output_filename=None, **kwargs): """Enhance a wav file. Arguments --------- filename : str Location on disk to load file for enhancement. output_filename : str If provided, writes enhanced data to this file. """ noisy = self.load_audio(filename, **kwargs) noisy = noisy.to(self.device) # Fake a batch: batch = noisy.unsqueeze(0) if lengths_arg_exists(self.enhance_batch): enhanced = self.enhance_batch(batch, lengths=torch.tensor([1.0])) else: enhanced = self.enhance_batch(batch) if output_filename is not None: torchaudio.save( uri=output_filename, src=enhanced, sample_rate=self.hparams.compute_stft.sample_rate, ) return enhanced.squeeze(0)
[docs] class WaveformEnhancement(Pretrained): """A ready-to-use model for speech enhancement. Arguments --------- See ``Pretrained``. Example ------- >>> from speechbrain.inference.enhancement import WaveformEnhancement >>> # Model is downloaded from the speechbrain HuggingFace repo >>> tmpdir = getfixture("tmpdir") >>> enhancer = WaveformEnhancement.from_hparams( ... source="speechbrain/mtl-mimic-voicebank", ... savedir=tmpdir, ... ) >>> enhanced = enhancer.enhance_file( ... "speechbrain/mtl-mimic-voicebank/example.wav" ... ) """ MODULES_NEEDED = ["enhance_model"]
[docs] def enhance_batch(self, noisy, lengths=None): """Enhance a batch of noisy waveforms. Arguments --------- noisy : torch.Tensor A batch of waveforms to perform enhancement on. lengths : torch.Tensor The lengths of the waveforms if the enhancement model handles them. Returns ------- torch.Tensor A batch of enhanced waveforms of the same shape as input. """ noisy = noisy.to(self.device) enhanced_wav, _ = self.mods.enhance_model(noisy) return enhanced_wav
[docs] def enhance_file(self, filename, output_filename=None, **kwargs): """Enhance a wav file. Arguments --------- filename : str Location on disk to load file for enhancement. output_filename : str If provided, writes enhanced data to this file. """ noisy = self.load_audio(filename, **kwargs) # Fake a batch: batch = noisy.unsqueeze(0) enhanced = self.enhance_batch(batch) if output_filename is not None: torchaudio.save( uri=output_filename, src=enhanced, sample_rate=self.audio_normalizer.sample_rate, ) return enhanced.squeeze(0)
[docs] def forward(self, noisy, lengths=None): """Runs enhancement on the noisy input""" return self.enhance_batch(noisy, lengths)