Source code for speechbrain.inference.metrics

""" Specifies the inference interfaces for metric estimation modules.

Authors:
 * Aku Rouhe 2021
 * Peter Plantinga 2021
 * Loren Lugosch 2020
 * Mirco Ravanelli 2020
 * Titouan Parcollet 2021
 * Abdel Heba 2021
 * Andreas Nautsch 2022, 2023
 * Pooneh Mousavi 2023
 * Sylvain de Langen 2023
 * Adel Moumen 2023
 * Pradnya Kandarkar 2023
"""
import torch
from speechbrain.inference.interfaces import Pretrained


[docs] class SNREstimator(Pretrained): """A "ready-to-use" SNR estimator.""" MODULES_NEEDED = ["encoder", "encoder_out"] HPARAMS_NEEDED = ["stat_pooling", "snrmax", "snrmin"]
[docs] def estimate_batch(self, mix, predictions): """Run SI-SNR estimation on the estimated sources, and mixture. Arguments --------- mix : torch.Tensor The mixture of sources of shape B X T predictions : torch.Tensor of size (B x T x C), where B is batch size T is number of time points C is number of sources Returns ------- tensor Estimate of SNR """ predictions = predictions.permute(0, 2, 1) predictions = predictions.reshape(-1, predictions.size(-1)) if hasattr(self.hparams, "separation_norm_type"): if self.hparams.separation_norm_type == "max": predictions = ( predictions / predictions.max(dim=1, keepdim=True)[0] ) mix = mix / mix.max(dim=1, keepdim=True)[0] elif self.hparams.separation_norm_type == "stnorm": predictions = ( predictions - predictions.mean(dim=1, keepdim=True) ) / predictions.std(dim=1, keepdim=True) mix = (mix - mix.mean(dim=1, keepdim=True)) / mix.std( dim=1, keepdim=True ) min_T = min(predictions.shape[1], mix.shape[1]) assert predictions.shape[1] == mix.shape[1], "lengths change" mix_repeat = mix.repeat(2, 1) inp_cat = torch.cat( [ predictions[:, :min_T].unsqueeze(1), mix_repeat[:, :min_T].unsqueeze(1), ], dim=1, ) enc = self.mods.encoder(inp_cat) enc = enc.permute(0, 2, 1) enc_stats = self.hparams.stat_pooling(enc) # this gets the SI-SNR estimate in the compressed range 0-1 snrhat = self.mods.encoder_out(enc_stats).squeeze() # get the SI-SNR estimate in the true range snrhat = self.gettrue_snrrange(snrhat) return snrhat
[docs] def forward(self, mix, predictions): """Just run the batch estimate""" return self.estimate_batch(mix, predictions)
[docs] def gettrue_snrrange(self, inp): """Convert from 0-1 range to true snr range""" rnge = self.hparams.snrmax - self.hparams.snrmin inp = inp * rnge inp = inp + self.hparams.snrmin return inp