""" Specifies the inference interfaces for Speech Translation (ST) modules.
Authors:
* Aku Rouhe 2021
* Peter Plantinga 2021
* Loren Lugosch 2020
* Mirco Ravanelli 2020
* Titouan Parcollet 2021
* Abdel Heba 2021
* Andreas Nautsch 2022, 2023
* Pooneh Mousavi 2023
* Sylvain de Langen 2023
* Adel Moumen 2023
* Pradnya Kandarkar 2023
"""
import torch
from speechbrain.inference.interfaces import Pretrained
[docs]
class EncoderDecoderS2UT(Pretrained):
"""A ready-to-use Encoder Decoder for speech-to-unit translation model
The class can be used to run the entire encoder-decoder S2UT model
(translate_file()) to translate speech. The given YAML must contains the fields
specified in the *_NEEDED[] lists.
Example
-------
>>> from speechbrain.inference.ST import EncoderDecoderS2UT
>>> tmpdir = getfixture("tmpdir")
>>> s2ut_model = EncoderDecoderS2UT.from_hparams(source="speechbrain/s2st-transformer-fr-en-hubert-l6-k100-cvss", savedir=tmpdir) # doctest: +SKIP
>>> s2ut_model.translate_file("speechbrain/s2st-transformer-fr-en-hubert-l6-k100-cvss/example-fr.wav") # doctest: +SKIP
"""
HPARAMS_NEEDED = ["sample_rate"]
MODULES_NEEDED = ["encoder", "decoder"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sample_rate = self.hparams.sample_rate
[docs]
def translate_file(self, path):
"""Translates the given audiofile into a sequence speech unit.
Arguments
---------
path : str
Path to audio file which to translate.
Returns
-------
int[]
The audiofile translation produced by this speech-to-unit translationmodel.
"""
audio = self.load_audio(path)
audio = audio.to(self.device)
# Fake a batch:
batch = audio.unsqueeze(0)
rel_length = torch.tensor([1.0])
predicted_tokens = self.translate_batch(batch, rel_length)
return predicted_tokens[0]
[docs]
def encode_batch(self, wavs, wav_lens):
"""Encodes the input audio into a sequence of hidden states
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderS2UT.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels].
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.tensor
The encoded batch
"""
wavs = wavs.float()
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
encoder_out = self.mods.encoder(wavs, wav_lens)
return encoder_out
[docs]
def translate_batch(self, wavs, wav_lens):
"""Translates the input audio into a sequence of words
The waveforms should already be in the model's desired format.
You can call:
``normalized = EncoderDecoderS2UT.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels].
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
list
Each waveform in the batch translated.
tensor
Each predicted token id.
"""
with torch.no_grad():
wav_lens = wav_lens.to(self.device)
encoder_out = self.encode_batch(wavs, wav_lens)
predicted_tokens, _, _, _ = self.mods.decoder(encoder_out, wav_lens)
return predicted_tokens
[docs]
def forward(self, wavs, wav_lens):
"""Runs full translation"""
return self.encode_batch(wavs, wav_lens)