Source code for speechbrain.inference.encoders

""" Specifies the inference interfaces for speech and audio encoders.

Authors:
 * Aku Rouhe 2021
 * Peter Plantinga 2021
 * Loren Lugosch 2020
 * Mirco Ravanelli 2020
 * Titouan Parcollet 2021
 * Abdel Heba 2021
 * Andreas Nautsch 2022, 2023
 * Pooneh Mousavi 2023
 * Sylvain de Langen 2023
 * Adel Moumen 2023
 * Pradnya Kandarkar 2023
"""

import torch

from speechbrain.inference.interfaces import Pretrained


[docs] class WaveformEncoder(Pretrained): """A ready-to-use waveformEncoder model It can be used to wrap different embedding models such as SSL ones (wav2vec2) or speaker ones (Xvector) etc. Two functions are available: encode_batch and encode_file. They can be used to obtain the embeddings directly from an audio file or from a batch of audio tensors respectively. The given YAML must contain the fields specified in the *_NEEDED[] lists. Arguments --------- See ``Pretrained`` Example ------- >>> from speechbrain.inference.encoders import WaveformEncoder >>> tmpdir = getfixture("tmpdir") >>> ssl_model = WaveformEncoder.from_hparams( ... source="speechbrain/ssl-wav2vec2-base-libri", ... savedir=tmpdir, ... ) # doctest: +SKIP >>> ssl_model.encode_file("samples/audio_samples/example_fr.wav") # doctest: +SKIP """ MODULES_NEEDED = ["encoder"]
[docs] def encode_file(self, path, **kwargs): """Encode the given audiofile into a sequence of embeddings. Arguments --------- path : str Path to audio file which to encode. **kwargs : dict Arguments forwarded to ``load_audio`` Returns ------- torch.Tensor The audiofile embeddings produced by this system. """ waveform = self.load_audio(path, **kwargs) # Fake a batch: batch = waveform.unsqueeze(0) rel_length = torch.tensor([1.0]) results = self.encode_batch(batch, rel_length) return results["embeddings"]
[docs] def encode_batch(self, wavs, wav_lens): """Encodes the input audio into a sequence of hidden states The waveforms should already be in the model's desired format. Arguments --------- wavs : torch.Tensor Batch of waveforms [batch, time, channels] or [batch, time] depending on the model. wav_lens : torch.Tensor Lengths of the waveforms relative to the longest one in the batch, tensor of shape [batch]. The longest one should have relative length 1.0 and others len(waveform) / max_length. Used for ignoring padding. Returns ------- torch.Tensor The encoded batch """ wavs = wavs.float() wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) encoder_out = self.mods.encoder(wavs, wav_lens) return encoder_out
[docs] def forward(self, wavs, wav_lens): """Runs the encoder""" return self.encode_batch(wavs, wav_lens)
[docs] class MelSpectrogramEncoder(Pretrained): """A MelSpectrogramEncoder class created for the Zero-Shot Multi-Speaker TTS models. This is for speaker encoder models using the PyTorch MelSpectrogram transform for compatibility with the current TTS pipeline. This class can be used to encode a single waveform, a single mel-spectrogram, or a batch of mel-spectrograms. Arguments --------- See ``Pretrained`` Example ------- >>> import torchaudio >>> from speechbrain.inference.encoders import MelSpectrogramEncoder >>> # Model is downloaded from the speechbrain HuggingFace repo >>> tmpdir = getfixture("tmpdir") >>> encoder = MelSpectrogramEncoder.from_hparams( ... source="speechbrain/tts-ecapa-voxceleb", ... savedir=tmpdir, ... ) # doctest: +SKIP >>> # Compute embedding from a waveform (sample_rate must match the sample rate of the encoder) >>> signal, fs = torchaudio.load("tests/samples/single-mic/example1.wav") # doctest: +SKIP >>> spk_emb = encoder.encode_waveform(signal) # doctest: +SKIP >>> # Compute embedding from a mel-spectrogram (sample_rate must match the sample rate of the ecoder) >>> mel_spec = encoder.mel_spectogram(audio=signal) # doctest: +SKIP >>> spk_emb = encoder.encode_mel_spectrogram(mel_spec) # doctest: +SKIP >>> # Compute embeddings for a batch of mel-spectrograms >>> spk_embs = encoder.encode_mel_spectrogram_batch(mel_spec) # doctest: +SKIP """ MODULES_NEEDED = ["normalizer", "embedding_model"]
[docs] def dynamic_range_compression(self, x, C=1, clip_val=1e-5): """Dynamic range compression for audio signals""" return torch.log(torch.clamp(x, min=clip_val) * C)
[docs] def mel_spectogram(self, audio): """calculates MelSpectrogram for a raw audio signal Arguments --------- audio : torch.tensor input audio signal Returns ------- mel : torch.Tensor Mel-spectrogram """ from torchaudio import transforms audio_to_mel = transforms.MelSpectrogram( sample_rate=self.hparams.sample_rate, hop_length=self.hparams.hop_length, win_length=self.hparams.win_length, n_fft=self.hparams.n_fft, n_mels=self.hparams.n_mel_channels, f_min=self.hparams.mel_fmin, f_max=self.hparams.mel_fmax, power=self.hparams.power, normalized=self.hparams.mel_normalized, norm=self.hparams.norm, mel_scale=self.hparams.mel_scale, ).to(audio.device) mel = audio_to_mel(audio) if self.hparams.dynamic_range_compression: mel = self.dynamic_range_compression(mel) return mel
[docs] def encode_waveform(self, wav): """ Encodes a single waveform Arguments --------- wav : torch.Tensor waveform Returns ------- encoder_out : torch.Tensor Speaker embedding for the input waveform """ # Moves tensor to the appropriate device wav = wav.to(self.device) # Computes mel-spectrogram mel_spec = self.mel_spectogram(audio=wav) # Calls encode_mel_spectrogram to compute the speaker embedding return self.encode_mel_spectrogram(mel_spec)
[docs] def encode_mel_spectrogram(self, mel_spec): """ Encodes a single mel-spectrograms Arguments --------- mel_spec : torch.Tensor Mel-spectrograms Returns ------- encoder_out : torch.Tensor Speaker embedding for the input mel-spectrogram """ # Fakes a batch batch = mel_spec if len(mel_spec.shape) == 2: batch = mel_spec.unsqueeze(0) rel_length = torch.tensor([1.0]) # Calls encode_mel_spectrogram_batch to compute speaker embeddings results = self.encode_mel_spectrogram_batch(batch, rel_length) return results
[docs] def encode_mel_spectrogram_batch(self, mel_specs, lens=None): """ Encodes a batch of mel-spectrograms Arguments --------- mel_specs : torch.Tensor Mel-spectrograms lens : torch.Tensor Relative lengths of the mel-spectrograms Returns ------- encoder_out : torch.Tensor Speaker embedding for the input mel-spectrogram batch """ # Assigns full length if lens is not assigned if lens is None: lens = torch.ones(mel_specs.shape[0], device=self.device) # Moves the tensors to the appropriate device mel_specs, lens = mel_specs.to(self.device), lens.to(self.device) # Computes speaker embeddings mel_specs = torch.transpose(mel_specs, 1, 2) feats = self.hparams.normalizer(mel_specs, lens) encoder_out = self.hparams.embedding_model(feats) return encoder_out
def __forward(self, mel_specs, lens): """Runs the encoder""" return self.encode_batch(mel_specs, lens)