speechbrain.lobes.models.huggingface_transformers.whisper module

This lobe enables the integration of huggingface pretrained whisper model.

Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html

Authors
  • Adel Moumen 2022, 2024

  • Titouan Parcollet 2022

  • Luca Della Libera 2022

  • Ha Nguyen 2023

Summary

Classes:

Whisper

This lobe enables the integration of HuggingFace pretrained Whisper model.

Reference

class speechbrain.lobes.models.huggingface_transformers.whisper.Whisper(source, save_path, sampling_rate=16000, encoder_only=False, freeze=False, freeze_encoder=False, output_attentions=False, output_all_hiddens=False, language=None, task='transcribe')[source]

Bases: HFTransformersInterface

This lobe enables the integration of HuggingFace pretrained Whisper model.

Source paper whisper:

https://cdn.openai.com/papers/whisper.pdf

Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html

Some part of the code also cis adapted from the official OpenAI repository: https://github.com/openai/whisper

The model can be finetuned. It will download automatically the model from HuggingFace or use a local path.

Parameters:
  • source (str) – HuggingFace hub name: e.g β€œopenai/whisper-tiny”

  • save_path (str) – Path (dir) of the downloaded model.

  • sampling_rate (int (default: 16000)) – Sampling rate of the audio signal.

  • encoder_only (bool (default: False)) – If True, the forward function outputs the hidden states from the last transformer layer of the encoder. If False, one step of the decoder is performed and returned.

  • freeze (bool (default: False)) – If True, the model is frozen.

  • freeze_encoder (bool (default: False)) – If True, the encoder is frozen.

  • output_attentions (bool (default: False)) – If True, the forward function outputs the attention weights. By default, it is False because flash attention requires having output_attentions=False. In case output_attentions is True, a from-scratch attention implementation is being used, which can make the code slower and can increase the VRAM memory usage.

  • output_all_hiddens (bool (default: False)) – If True, the forward function outputs the hidden states from all transformer layers of the encoder. For example whisper-base has 6 transformer layers and the output is of shape (7, B, T, C), where the output of the CNN output is added to the beginning. If False, the forward function outputs the hidden states only from the last transformer layer of the encoder.

  • language (str (default: "en")) – Language token to use for the decoder.

  • task (str (default: "transcribe")) – Task token to use for the decoder. It must be one of the following: - β€œtranscribe” - β€œtranslate”

Example

>>> model_hub = "openai/whisper-tiny"
>>> save_path = "savedir"
>>> sampling_rate = 16000
>>> model = Whisper(model_hub, save_path, sampling_rate)
>>> tokens = torch.tensor([[1, 1]]) * model.model.config.decoder_start_token_id
>>> inputs = torch.randn([1, 93680])
>>> outputs = model(inputs, tokens)
freeze_model(model)[source]

Freezes parameters of a model.

Parameters:

model (from AutoModel.from_config) – Valid HuggingFace transformers model object.

forward(wav, decoder_input_ids=None)[source]

Perform mel transformation and one step of the whisper (encoder-decoder).

Parameters:
  • wav (torch.Tensor) – A batch of audio signals to transform to features.

  • decoder_input_ids (torch.Tensor) – Input tokens for the decoder. This can be language, task, etc. Please refer to the whisper paper for more details or go to the seq2seq2.py file in SpeechBrain to see how to generate the tokens with Greedy Search and/or Beam Search.

Returns:

  • out_encoder (torch.Tensor) – The output of the encoder model.

  • decoder_logits (torch.Tensor) – The output of the decoder model.

  • decoder_attn (torch.Tensor) – The attention values of the decoder model.

log_mel_spectrogram(audio, padding: int = 0)[source]

Compute the Mel spectrogram of a batch of input waveforms.

Reference: adapted from https://github.com/openai/whisper/blob/eff383b27b783e280c089475852ba83f20f64998/whisper/audio.py#L92

Parameters:
  • audio (torch.Tensor) – A batch of audio waveforms in 16 kHz.

  • padding (int) – The number of samples to append to the end of the audio tensor.

Returns:

log_spec – A tensor that contains the batch of Mel spectrograms.

Return type:

torch.Tensor

pad_or_trim(array, length: int = 480000, axis=-1)[source]

Pad or trim the Mel spectrograms as expected by the encoder.

Reference: adapted from https://github.com/openai/whisper/blob/eff383b27b783e280c089475852ba83f20f64998/whisper/audio.py#L52

Parameters:
  • array (torch.Tensor) – A tensor that contains the batch of Mel spectrograms.

  • length (int) – Input tensor will be coerced to length number of samples.

  • axis (int) – The axis along which to pad.

Returns:

array – The padded tensor.

Return type:

torch.Tensor

forward_encoder(mel)[source]

Takes an input mel and return its corresponding encoder states. Returns the last hidden state of the encoder or all hidden states if output_all_hiddens is True.

Parameters:

mel (torch.Tensor (signal)) – A batch of audio mel to transform to features.

Returns:

The last hidden state of the encoder or all hidden states if output_all_hiddens is True.

Return type:

torch.Tensor

forward_decoder(encoder_states, decoder_input_ids, use_cache=True, past_key_values=None)[source]

Perform one step of the whisper decoder.

Parameters:
  • encoder_states (torch.Tensor) – A batch of encoder_states features (mel + whisper feature extractor).

  • decoder_input_ids (torch.Tensor) – Input tokens for the decoder. This can be language, task, etc. Please refer to the whisper paper for more details or go to the seq2seq2.py file in SpeechBrain to see how to generate the tokens with Greedy Search and/or Beam Search.

  • use_cache (bool) – If True, keys and values are returned as output for KV caching.

  • past_key_values (torch.Tensor (default: None)) – If not None, the past key values are used for KV caching and avoid recomputing the attention weights.

Returns:

  • logits (torch.Tensor) – The logits of the decoder.

  • attn (torch.Tensor | None) – If output_attentions is True, the attention weights are returned. Otherwise, None is returned.

  • past_key_values (torch.Tensor) – The past key values of the decoder.

property all_language_tokens

Returns the list of tokens corresponding to the language tokens.

property all_language_codes

Returns the list of language codes corresponding to the language tokens.

property non_speech_tokens

Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.

  • β™ͺβ™ͺβ™ͺ

  • ( SPEAKING FOREIGN LANGUAGE )

  • [DAVID] Hey there,

keeping basic punctuations like commas, periods, question marks, exclamation points, etc.

Taken from: openai/whisper GitHub

property transcribe: int

Returns the token id corresponding to the value of the transcribe field

property translate: int

Returns the token id corresponding to the value of the translate field

property bos: int

Returns the token id corresponding to the value of the bos field

property eos: int

Returns the token id corresponding to the value of the eos field

property bos_lm: int

Returns the token id corresponding to the value of the bos_lm field

property bos_prev: int

Returns the token id corresponding to the value of the bos_prev field

property no_timestamps: int

Returns the token id corresponding to the value of the no_timestamps field

property timestamp_begin: int

Returns the token id corresponding to the value of the timestamp_begin field

property no_speech: int

Returns the token id corresponding to the value of the no_speech field

property language_token: int

Returns the token id corresponding to the value of the language field

to_language_token(language)[source]

Returns the token id corresponding to the given language.

Parameters:

language (str) – The language to convert to a token.

Returns:

The token id corresponding to the given language.

Return type:

token

Raises:

KeyError – If the language is not found in the tokenizer.

set_language_token(language)[source]

Set the language token to the given language.

Parameters:

language (str) – The language to set the token to.

set_task(task)[source]

Set the task token to the given task.

Parameters:

task (str) – The task to set the token to.

property is_multilingual

Returns True if the model is multilingual, False otherwise.

property get_suppress_tokens

Returns the list of tokens to suppress

detect_language(mel)[source]

Detect the language of the given mel spectrogram features.

Parameters:

mel (torch.Tensor) – Mel spectrogram features to detect the language of.

Returns:

  • language_tokens (torch.Tensor of shape (batch_size,)) – ids of the most probable language tokens, which appears after the startoftranscript token.

  • language_probs (List[Dict[str, float]]) – list of dictionaries containing the probability distribution over all languages.

Raises:

ValueError – If the model doesn’t have language tokens.