speechbrain.lobes.models.huggingface_whisper module

This lobe enables the integration of huggingface pretrained whisper model.

Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html

Authors
  • Adel Moumen 2022

  • Titouan Parcollet 2022

  • Luca Della Libera 2022

Summary

Classes:

HuggingFaceWhisper

This lobe enables the integration of HuggingFace pretrained Whisper model. Source paper whisper: https://cdn.openai.com/papers/whisper.pdf Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html.

Reference

class speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper(source, save_path, sampling_rate=16000, encoder_only=False, freeze=False, freeze_encoder=False, output_attentions=True, output_all_hiddens=False)[source]

Bases: Module

This lobe enables the integration of HuggingFace pretrained Whisper model. Source paper whisper:

Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html

Some part of the code also cis adapted from the official OpenAI repository: https://github.com/openai/whisper

The model can be finetuned. It will download automatically the model from HuggingFace or use a local path. :param source: HuggingFace hub name: e.g “openai/whisper-tiny” :type source: str :param save_path: Path (dir) of the downloaded model. :type save_path: str :param sampling_rate: Sampling rate of the audio signal. :type sampling_rate: int (default: 16000) :param encoder_only: If True, the forward function outputs the hidden states from the last transformer layer of the encoder.

If False, one step of the decoder is performed and returned.

Parameters:
  • freeze (bool (default: False)) – If True, the model is frozen.

  • freeze_encoder (bool (default: False)) – If True, the encoder is frozen.

  • output_attentions (bool (default: True)) – If True, the forward function outputs the attention weights.

  • output_all_hiddens (bool (default: False)) – If True, the forward function outputs the hidden states from all transformer layers of the encoder. For example whisper-base has 6 transformer layers and the output is of shape (7, B, T, C), where the output of the CNN output is added to the beginning. If False, the forward function outputs the hidden states only from the last transformer layer of the encoder.

Example

>>> model_hub = "openai/whisper-tiny"
>>> save_path = "savedir"
>>> sampling_rate = 16000
>>> model = HuggingFaceWhisper(model_hub, save_path, sampling_rate)
>>> tokens = torch.tensor([[1, 1]]) * model.model.config.decoder_start_token_id
>>> inputs = torch.randn([1, 93680])
>>> outputs = model(inputs, tokens)
forward(wav, decoder_input_ids=None)[source]

Perform mel transformation and one step of the whisper (encoder-decoder).

Parameters:
  • wav (torch.Tensor (signal)) – A batch of audio signals to transform to features.

  • decoder_input_ids (torch.Tensor) –

    This is necessary if we want to use the decoder.

    A batch of decoder inputs tokens. The first tokens need to dictacte the behavior of the decoder. It needs to start with the bos_token, the language token, the task token, and finally the timestamp token.

    Please refer to the whisper paper for more details or go to the seq2seq2.py file in SpeechBrain to see how to generate the tokens with Greedy Search and/or Beam Search.

forward_encoder(wav)[source]

Perform one step of the whisper encoder with Mel FBANKs as Input. :param wav: A batch of Mel FBANK from HF to transform to features. :type wav: torch.Tensor (FBANKs)

forward_decoder(audio_features, decoder_input_ids)[source]

Perform one step of the whisper decoder. :param audio_features: A batch of audio features (mel + whisper encoding). :type audio_features: torch.Tensor :param decoder_input_ids: A batch of decoder inputs tokens.

The first tokens need to dictacte the behavior of the decoder. It needs to start with the bos_token, the language token, the task token, and finally the timestamp token.

Please refer to the whisper paper for more details or go to the seq2seq2.py file in SpeechBrain to see how to generate the tokens with Greedy Search and/or Beam Search.

training: bool