""" Specifies the inference interfaces for speech and audio encoders.
Authors:
* Aku Rouhe 2021
* Peter Plantinga 2021
* Loren Lugosch 2020
* Mirco Ravanelli 2020
* Titouan Parcollet 2021
* Abdel Heba 2021
* Andreas Nautsch 2022, 2023
* Pooneh Mousavi 2023
* Sylvain de Langen 2023
* Adel Moumen 2023
* Pradnya Kandarkar 2023
"""
import torch
from speechbrain.inference.interfaces import Pretrained
[docs]
class MelSpectrogramEncoder(Pretrained):
"""A MelSpectrogramEncoder class created for the Zero-Shot Multi-Speaker TTS models.
This is for speaker encoder models using the PyTorch MelSpectrogram transform for compatibility with the
current TTS pipeline.
This class can be used to encode a single waveform, a single mel-spectrogram, or a batch of mel-spectrograms.
```
Example
-------
>>> import torchaudio
>>> from speechbrain.inference.encoders import MelSpectrogramEncoder
>>> # Model is downloaded from the speechbrain HuggingFace repo
>>> tmpdir = getfixture("tmpdir")
>>> encoder = MelSpectrogramEncoder.from_hparams(
... source="speechbrain/tts-ecapa-voxceleb",
... savedir=tmpdir,
... ) # doctest: +SKIP
>>> # Compute embedding from a waveform (sample_rate must match the sample rate of the encoder)
>>> signal, fs = torchaudio.load("tests/samples/single-mic/example1.wav") # doctest: +SKIP
>>> spk_emb = encoder.encode_waveform(signal) # doctest: +SKIP
>>> # Compute embedding from a mel-spectrogram (sample_rate must match the sample rate of the ecoder)
>>> mel_spec = encoder.mel_spectogram(audio=signal) # doctest: +SKIP
>>> spk_emb = encoder.encode_mel_spectrogram(mel_spec) # doctest: +SKIP
>>> # Compute embeddings for a batch of mel-spectrograms
>>> spk_embs = encoder.encode_mel_spectrogram_batch(mel_spec) # doctest: +SKIP
"""
MODULES_NEEDED = ["normalizer", "embedding_model"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
[docs]
def dynamic_range_compression(self, x, C=1, clip_val=1e-5):
"""Dynamic range compression for audio signals
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
[docs]
def mel_spectogram(self, audio):
"""calculates MelSpectrogram for a raw audio signal
Arguments
---------
audio : torch.tensor
input audio signal
Returns
-------
mel : torch.Tensor
Mel-spectrogram
"""
from torchaudio import transforms
audio_to_mel = transforms.MelSpectrogram(
sample_rate=self.hparams.sample_rate,
hop_length=self.hparams.hop_length,
win_length=self.hparams.win_length,
n_fft=self.hparams.n_fft,
n_mels=self.hparams.n_mel_channels,
f_min=self.hparams.mel_fmin,
f_max=self.hparams.mel_fmax,
power=self.hparams.power,
normalized=self.hparams.mel_normalized,
norm=self.hparams.norm,
mel_scale=self.hparams.mel_scale,
).to(audio.device)
mel = audio_to_mel(audio)
if self.hparams.dynamic_range_compression:
mel = self.dynamic_range_compression(mel)
return mel
[docs]
def encode_mel_spectrogram(self, mel_spec):
"""
Encodes a single mel-spectrograms
Arguments
---------
mel_spec : torch.Tensor
Mel-spectrograms
Returns
-------
encoder_out : torch.Tensor
Speaker embedding for the input mel-spectrogram
"""
# Fakes a batch
batch = mel_spec
if len(mel_spec.shape) == 2:
batch = mel_spec.unsqueeze(0)
rel_length = torch.tensor([1.0])
# Calls encode_mel_spectrogram_batch to compute speaker embeddings
results = self.encode_mel_spectrogram_batch(batch, rel_length)
return results
[docs]
def encode_mel_spectrogram_batch(self, mel_specs, lens=None):
"""
Encodes a batch of mel-spectrograms
Arguments
---------
mel_specs : torch.Tensor
Mel-spectrograms
lens : torch.Tensor
Relative lengths of the mel-spectrograms
Returns
-------
encoder_out : torch.Tensor
Speaker embedding for the input mel-spectrogram batch
"""
# Assigns full length if lens is not assigned
if lens is None:
lens = torch.ones(mel_specs.shape[0], device=self.device)
# Moves the tensors to the appropriate device
mel_specs, lens = mel_specs.to(self.device), lens.to(self.device)
# Computes speaker embeddings
mel_specs = torch.transpose(mel_specs, 1, 2)
feats = self.hparams.normalizer(mel_specs, lens)
encoder_out = self.hparams.embedding_model(feats)
return encoder_out
def __forward(self, mel_specs, lens):
"""Runs the encoder"""
return self.encode_batch(mel_specs, lens)