Source code for speechbrain.inference.encoders
""" Specifies the inference interfaces for speech and audio encoders.
Authors:
* Aku Rouhe 2021
* Peter Plantinga 2021
* Loren Lugosch 2020
* Mirco Ravanelli 2020
* Titouan Parcollet 2021
* Abdel Heba 2021
* Andreas Nautsch 2022, 2023
* Pooneh Mousavi 2023
* Sylvain de Langen 2023
* Adel Moumen 2023
* Pradnya Kandarkar 2023
"""
import torch
from speechbrain.inference.interfaces import Pretrained
[docs]
class WaveformEncoder(Pretrained):
"""A ready-to-use waveformEncoder model
It can be used to wrap different embedding models such as SSL ones (wav2vec2)
or speaker ones (Xvector) etc. Two functions are available: encode_batch and
encode_file. They can be used to obtain the embeddings directly from an audio
file or from a batch of audio tensors respectively.
The given YAML must contain the fields specified in the *_NEEDED[] lists.
Arguments
---------
See ``Pretrained``
Example
-------
>>> from speechbrain.inference.encoders import WaveformEncoder
>>> tmpdir = getfixture("tmpdir")
>>> ssl_model = WaveformEncoder.from_hparams(
... source="speechbrain/ssl-wav2vec2-base-libri",
... savedir=tmpdir,
... ) # doctest: +SKIP
>>> ssl_model.encode_file("samples/audio_samples/example_fr.wav") # doctest: +SKIP
"""
MODULES_NEEDED = ["encoder"]
[docs]
def encode_file(self, path, **kwargs):
"""Encode the given audiofile into a sequence of embeddings.
Arguments
---------
path : str
Path to audio file which to encode.
**kwargs : dict
Arguments forwarded to ``load_audio``
Returns
-------
torch.Tensor
The audiofile embeddings produced by this system.
"""
waveform = self.load_audio(path, **kwargs)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
results = self.encode_batch(batch, rel_length)
return results["embeddings"]
[docs]
def encode_batch(self, wavs, wav_lens):
"""Encodes the input audio into a sequence of hidden states
The waveforms should already be in the model's desired format.
Arguments
---------
wavs : torch.Tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model.
wav_lens : torch.Tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.Tensor
The encoded batch
"""
wavs = wavs.float()
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
encoder_out = self.mods.encoder(wavs, wav_lens)
return encoder_out
[docs]
def forward(self, wavs, wav_lens):
"""Runs the encoder"""
return self.encode_batch(wavs, wav_lens)
[docs]
class MelSpectrogramEncoder(Pretrained):
"""A MelSpectrogramEncoder class created for the Zero-Shot Multi-Speaker TTS models.
This is for speaker encoder models using the PyTorch MelSpectrogram transform for compatibility with the
current TTS pipeline.
This class can be used to encode a single waveform, a single mel-spectrogram, or a batch of mel-spectrograms.
Arguments
---------
See ``Pretrained``
Example
-------
>>> import torchaudio
>>> from speechbrain.inference.encoders import MelSpectrogramEncoder
>>> # Model is downloaded from the speechbrain HuggingFace repo
>>> tmpdir = getfixture("tmpdir")
>>> encoder = MelSpectrogramEncoder.from_hparams(
... source="speechbrain/tts-ecapa-voxceleb",
... savedir=tmpdir,
... ) # doctest: +SKIP
>>> # Compute embedding from a waveform (sample_rate must match the sample rate of the encoder)
>>> signal, fs = torchaudio.load("tests/samples/single-mic/example1.wav") # doctest: +SKIP
>>> spk_emb = encoder.encode_waveform(signal) # doctest: +SKIP
>>> # Compute embedding from a mel-spectrogram (sample_rate must match the sample rate of the ecoder)
>>> mel_spec = encoder.mel_spectogram(audio=signal) # doctest: +SKIP
>>> spk_emb = encoder.encode_mel_spectrogram(mel_spec) # doctest: +SKIP
>>> # Compute embeddings for a batch of mel-spectrograms
>>> spk_embs = encoder.encode_mel_spectrogram_batch(mel_spec) # doctest: +SKIP
"""
MODULES_NEEDED = ["normalizer", "embedding_model"]
[docs]
def dynamic_range_compression(self, x, C=1, clip_val=1e-5):
"""Dynamic range compression for audio signals"""
return torch.log(torch.clamp(x, min=clip_val) * C)
[docs]
def mel_spectogram(self, audio):
"""calculates MelSpectrogram for a raw audio signal
Arguments
---------
audio : torch.tensor
input audio signal
Returns
-------
mel : torch.Tensor
Mel-spectrogram
"""
from torchaudio import transforms
audio_to_mel = transforms.MelSpectrogram(
sample_rate=self.hparams.sample_rate,
hop_length=self.hparams.hop_length,
win_length=self.hparams.win_length,
n_fft=self.hparams.n_fft,
n_mels=self.hparams.n_mel_channels,
f_min=self.hparams.mel_fmin,
f_max=self.hparams.mel_fmax,
power=self.hparams.power,
normalized=self.hparams.mel_normalized,
norm=self.hparams.norm,
mel_scale=self.hparams.mel_scale,
).to(audio.device)
mel = audio_to_mel(audio)
if self.hparams.dynamic_range_compression:
mel = self.dynamic_range_compression(mel)
return mel
[docs]
def encode_waveform(self, wav):
"""
Encodes a single waveform
Arguments
---------
wav : torch.Tensor
waveform
Returns
-------
encoder_out : torch.Tensor
Speaker embedding for the input waveform
"""
# Moves tensor to the appropriate device
wav = wav.to(self.device)
# Computes mel-spectrogram
mel_spec = self.mel_spectogram(audio=wav)
# Calls encode_mel_spectrogram to compute the speaker embedding
return self.encode_mel_spectrogram(mel_spec)
[docs]
def encode_mel_spectrogram(self, mel_spec):
"""
Encodes a single mel-spectrograms
Arguments
---------
mel_spec : torch.Tensor
Mel-spectrograms
Returns
-------
encoder_out : torch.Tensor
Speaker embedding for the input mel-spectrogram
"""
# Fakes a batch
batch = mel_spec
if len(mel_spec.shape) == 2:
batch = mel_spec.unsqueeze(0)
rel_length = torch.tensor([1.0])
# Calls encode_mel_spectrogram_batch to compute speaker embeddings
results = self.encode_mel_spectrogram_batch(batch, rel_length)
return results
[docs]
def encode_mel_spectrogram_batch(self, mel_specs, lens=None):
"""
Encodes a batch of mel-spectrograms
Arguments
---------
mel_specs : torch.Tensor
Mel-spectrograms
lens : torch.Tensor
Relative lengths of the mel-spectrograms
Returns
-------
encoder_out : torch.Tensor
Speaker embedding for the input mel-spectrogram batch
"""
# Assigns full length if lens is not assigned
if lens is None:
lens = torch.ones(mel_specs.shape[0], device=self.device)
# Moves the tensors to the appropriate device
mel_specs, lens = mel_specs.to(self.device), lens.to(self.device)
# Computes speaker embeddings
mel_specs = torch.transpose(mel_specs, 1, 2)
feats = self.hparams.normalizer(mel_specs, lens)
encoder_out = self.hparams.embedding_model(feats)
return encoder_out
def __forward(self, mel_specs, lens):
"""Runs the encoder"""
return self.encode_batch(mel_specs, lens)