speechbrain.lobes.models.huggingface_transformers.whisper moduleο
This lobe enables the integration of huggingface pretrained whisper model.
Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html
- Authors
Adel Moumen 2022, 2024
Titouan Parcollet 2022
Luca Della Libera 2022
Ha Nguyen 2023
Summaryο
Classes:
This lobe enables the integration of HuggingFace pretrained Whisper model. |
Referenceο
- class speechbrain.lobes.models.huggingface_transformers.whisper.Whisper(source, save_path, sampling_rate=16000, encoder_only=False, freeze=False, freeze_encoder=False, output_attentions=False, output_all_hiddens=False, language=None, task='transcribe')[source]ο
Bases:
HFTransformersInterface
This lobe enables the integration of HuggingFace pretrained Whisper model.
- Source paper whisper:
Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html
Some part of the code also cis adapted from the official OpenAI repository: https://github.com/openai/whisper
The model can be finetuned. It will download automatically the model from HuggingFace or use a local path.
- Parameters:
source (str) β HuggingFace hub name: e.g βopenai/whisper-tinyβ
save_path (str) β Path (dir) of the downloaded model.
sampling_rate (int (default: 16000)) β Sampling rate of the audio signal.
encoder_only (bool (default: False)) β If True, the forward function outputs the hidden states from the last transformer layer of the encoder. If False, one step of the decoder is performed and returned.
freeze (bool (default: False)) β If True, the model is frozen.
freeze_encoder (bool (default: False)) β If True, the encoder is frozen.
output_attentions (bool (default: False)) β If
True
, the forward function outputs the attention weights. By default, it isFalse
because flash attention requires havingoutput_attentions=False
. In caseoutput_attentions
isTrue
, a from-scratch attention implementation is being used, which can make the code slower and can increase the VRAM memory usage.output_all_hiddens (bool (default: False)) β If True, the forward function outputs the hidden states from all transformer layers of the encoder. For example whisper-base has 6 transformer layers and the output is of shape (7, B, T, C), where the output of the CNN output is added to the beginning. If False, the forward function outputs the hidden states only from the last transformer layer of the encoder.
language (str (default: "en")) β Language token to use for the decoder.
task (str (default: "transcribe")) β Task token to use for the decoder. It must be one of the following: - βtranscribeβ - βtranslateβ
Example
>>> model_hub = "openai/whisper-tiny" >>> save_path = "savedir" >>> sampling_rate = 16000 >>> model = Whisper(model_hub, save_path, sampling_rate) >>> tokens = torch.tensor([[1, 1]]) * model.model.config.decoder_start_token_id >>> inputs = torch.randn([1, 93680]) >>> outputs = model(inputs, tokens)
- freeze_model(model)[source]ο
Freezes parameters of a model.
- Parameters:
model (from AutoModel.from_config) β Valid HuggingFace transformers model object.
- forward(wav, decoder_input_ids=None)[source]ο
Perform mel transformation and one step of the whisper (encoder-decoder).
- Parameters:
wav (torch.Tensor) β A batch of audio signals to transform to features.
decoder_input_ids (torch.Tensor) β Input tokens for the decoder. This can be language, task, etc. Please refer to the whisper paper for more details or go to the seq2seq2.py file in SpeechBrain to see how to generate the tokens with Greedy Search and/or Beam Search.
- Returns:
out_encoder (torch.Tensor) β The output of the encoder model.
decoder_logits (torch.Tensor) β The output of the decoder model.
decoder_attn (torch.Tensor) β The attention values of the decoder model.
- log_mel_spectrogram(audio, padding: int = 0)[source]ο
Compute the Mel spectrogram of a batch of input waveforms.
Reference: adapted from https://github.com/openai/whisper/blob/eff383b27b783e280c089475852ba83f20f64998/whisper/audio.py#L92
- Parameters:
audio (torch.Tensor) β A batch of audio waveforms in 16 kHz.
padding (int) β The number of samples to append to the end of the audio tensor.
- Returns:
log_spec β A tensor that contains the batch of Mel spectrograms.
- Return type:
torch.Tensor
- pad_or_trim(array, length: int = 480000, axis=-1)[source]ο
Pad or trim the Mel spectrograms as expected by the encoder.
Reference: adapted from https://github.com/openai/whisper/blob/eff383b27b783e280c089475852ba83f20f64998/whisper/audio.py#L52
- forward_encoder(mel)[source]ο
Takes an input mel and return its corresponding encoder states. Returns the last hidden state of the encoder or all hidden states if output_all_hiddens is True.
- Parameters:
mel (torch.Tensor (signal)) β A batch of audio mel to transform to features.
- Returns:
The last hidden state of the encoder or all hidden states if output_all_hiddens is True.
- Return type:
torch.Tensor
- forward_decoder(encoder_states, decoder_input_ids, use_cache=True, past_key_values=None)[source]ο
Perform one step of the whisper decoder.
- Parameters:
encoder_states (torch.Tensor) β A batch of encoder_states features (mel + whisper feature extractor).
decoder_input_ids (torch.Tensor) β Input tokens for the decoder. This can be language, task, etc. Please refer to the whisper paper for more details or go to the seq2seq2.py file in SpeechBrain to see how to generate the tokens with Greedy Search and/or Beam Search.
use_cache (bool) β If True, keys and values are returned as output for KV caching.
past_key_values (torch.Tensor (default: None)) β If not None, the past key values are used for KV caching and avoid recomputing the attention weights.
- Returns:
logits (torch.Tensor) β The logits of the decoder.
attn (torch.Tensor | None) β If
output_attentions
is True, the attention weights are returned. Otherwise,None
is returned.past_key_values (torch.Tensor) β The past key values of the decoder.
- property all_language_tokensο
Returns the list of tokens corresponding to the language tokens.
- property all_language_codesο
Returns the list of language codes corresponding to the language tokens.
- property non_speech_tokensο
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
βͺβͺβͺ
( SPEAKING FOREIGN LANGUAGE )
[DAVID] Hey there,
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
Taken from: openai/whisper GitHub
- property no_timestamps: intο
Returns the token id corresponding to the value of the
no_timestamps
field
- property timestamp_begin: intο
Returns the token id corresponding to the value of the
timestamp_begin
field
- property language_token: intο
Returns the token id corresponding to the value of the
language
field
- set_language_token(language)[source]ο
Set the language token to the given language.
- Parameters:
language (str) β The language to set the token to.
- set_task(task)[source]ο
Set the task token to the given task.
- Parameters:
task (str) β The task to set the token to.
- property is_multilingualο
Returns True if the model is multilingual, False otherwise.
- property get_suppress_tokensο
Returns the list of tokens to suppress
- detect_language(mel)[source]ο
Detect the language of the given mel spectrogram features.
- Parameters:
mel (torch.Tensor) β Mel spectrogram features to detect the language of.
- Returns:
language_tokens (torch.Tensor of shape (batch_size,)) β ids of the most probable language tokens, which appears after the startoftranscript token.
language_probs (List[Dict[str, float]]) β list of dictionaries containing the probability distribution over all languages.
- Raises:
ValueError β If the model doesnβt have language tokens.