Source code for speechbrain.lobes.models.huggingface_transformers.whisper

"""This lobe enables the integration of huggingface pretrained whisper model.

Transformer from HuggingFace needs to be installed:
https://huggingface.co/transformers/installation.html

Authors
 * Adel Moumen 2022, 2024
 * Titouan Parcollet 2022
 * Luca Della Libera 2022
 * Ha Nguyen 2023
"""

from functools import cached_property

import numpy as np
import torch
import torch.nn as nn

from speechbrain.lobes.models.huggingface_transformers.huggingface import (
    HFTransformersInterface,
)
from speechbrain.utils.logger import get_logger

SAMPLE_RATE = 16000
N_FFT = 400
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE  # 480000 samples in a 30-second chunk

logger = get_logger(__name__)


[docs] class Whisper(HFTransformersInterface): """This lobe enables the integration of HuggingFace pretrained Whisper model. Source paper whisper: https://cdn.openai.com/papers/whisper.pdf Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html Some part of the code also cis adapted from the official OpenAI repository: https://github.com/openai/whisper The model can be finetuned. It will download automatically the model from HuggingFace or use a local path. Arguments --------- source : str HuggingFace hub name: e.g "openai/whisper-tiny" save_path : str Path (dir) of the downloaded model. sampling_rate : int (default: 16000) Sampling rate of the audio signal. encoder_only : bool (default: False) If True, the forward function outputs the hidden states from the last transformer layer of the encoder. If False, one step of the decoder is performed and returned. freeze : bool (default: False) If True, the model is frozen. freeze_encoder : bool (default: False) If True, the encoder is frozen. output_attentions : bool (default: False) If ``True``, the forward function outputs the attention weights. By default, it is ``False`` because flash attention requires having ``output_attentions=False``. In case ``output_attentions`` is ``True``, a from-scratch attention implementation is being used, which can make the code slower and can increase the VRAM memory usage. output_all_hiddens: bool (default: False) If True, the forward function outputs the hidden states from all transformer layers of the encoder. For example whisper-base has 6 transformer layers and the output is of shape (7, B, T, C), where the output of the CNN output is added to the beginning. If False, the forward function outputs the hidden states only from the last transformer layer of the encoder. language: str (default: "en") Language token to use for the decoder. task: str (default: "transcribe") Task token to use for the decoder. It must be one of the following: - "transcribe" - "translate" Example ------- >>> model_hub = "openai/whisper-tiny" >>> save_path = "savedir" >>> sampling_rate = 16000 >>> model = Whisper(model_hub, save_path, sampling_rate) >>> tokens = torch.tensor([[1, 1]]) * model.model.config.decoder_start_token_id >>> inputs = torch.randn([1, 93680]) >>> outputs = model(inputs, tokens) """ def __init__( self, source, save_path, sampling_rate=16000, encoder_only=False, freeze=False, freeze_encoder=False, output_attentions=False, output_all_hiddens=False, language=None, task="transcribe", ): super().__init__(source=source, save_path=save_path, freeze=freeze) self.sampling_rate = sampling_rate self.encoder_only = encoder_only self.freeze_encoder = freeze_encoder self.output_attentions = output_attentions self.output_all_hiddens = output_all_hiddens self.language = language self.task = task if encoder_only: self.tokenizer = None # We first move the decoder to the CPU self.model.decoder.cpu() # Then we delete the decoder del self.model.decoder self.model.decoder = None import gc gc.collect() torch.cuda.empty_cache() else: # when the model is not multilingual i.e. all Whisper # models ending in .en, you must not set the language # and task tokens. self.load_tokenizer( source, bos_token="<|startoftranscript|>", ) if self.is_multilingual: language = self.language or "en" self.tokenizer.set_prefix_tokens( language=language, task=self.task ) self.load_feature_extractor( source, save_path, sampling_rate=sampling_rate ) self._n_fft = self.feature_extractor.n_fft self._hop_length = self.feature_extractor.hop_length self._n_samples = self.feature_extractor.n_samples # The following breaking changes were introduced in transformers>=4.29: # 1) mel_filters.shape = (..., feature_extractor.feature_size) instead of (feature_extractor.feature_size, ...) # 2) mel_filters.dtype = float64 instead of float32 # The following code fixes the issue in a backward compatible way mel_filters = self.feature_extractor.mel_filters if mel_filters.shape[0] != self.feature_extractor.feature_size: mel_filters = mel_filters.T assert mel_filters.shape[0] == self.feature_extractor.feature_size self.register_buffer( "_mel_filters", torch.as_tensor(mel_filters, dtype=torch.float32) ) # freeze the model if not self.freeze and self.freeze_encoder: logger.warning( "speechbrain.lobes.models.huggingface_transformers.whisper - whisper encoder is frozen." ) for param in self.model.encoder.parameters(): param.requires_grad = False
[docs] def freeze_model(self, model): """ Freezes parameters of a model. Arguments --------- model : from AutoModel.from_config Valid HuggingFace transformers model object. """ logger.warning( "speechbrain.lobes.models.huggingface_transformers.whisper - whisper encoder-decoder is frozen." ) model.train() # we keep it to train to have dropout and LN computed adequately for param in model.parameters(): param.requires_grad = False
[docs] def forward(self, wav, decoder_input_ids=None): """Perform mel transformation and one step of the whisper (encoder-decoder). Arguments --------- wav : torch.Tensor A batch of audio signals to transform to features. decoder_input_ids : torch.Tensor Input tokens for the decoder. This can be language, task, etc. Please refer to the whisper paper for more details or go to the seq2seq2.py file in SpeechBrain to see how to generate the tokens with Greedy Search and/or Beam Search. Returns ------- out_encoder : torch.Tensor The output of the encoder model. decoder_logits : torch.Tensor The output of the decoder model. decoder_attn : torch.Tensor The attention values of the decoder model. """ def _forward(): """Forward pass of the model""" mel = self._get_mel(wav) out_encoder = self.forward_encoder(mel) if self.encoder_only: return out_encoder else: if self.output_all_hiddens: decoder_logits, decoder_attn, _ = self.forward_decoder( out_encoder[-1], decoder_input_ids ) else: decoder_logits, decoder_attn, _ = self.forward_decoder( out_encoder, decoder_input_ids ) return out_encoder, decoder_logits, decoder_attn if self.freeze: with torch.no_grad(): return _forward() else: return _forward()
def _get_mel(self, wav): """ Compute the mel spectrogram features from the input audio waveform. Arguments --------- wav : torch.Tensor A batch of audio signals to compute mel spectrogram features from. Returns ------- torch.Tensor Mel spectrogram features computed from the input audio waveform. """ mels = self.pad_or_trim(wav) mels = self.log_mel_spectrogram(mels) return mels
[docs] def log_mel_spectrogram( self, audio, padding: int = 0, ): """Compute the Mel spectrogram of a batch of input waveforms. Reference: adapted from https://github.com/openai/whisper/blob/eff383b27b783e280c089475852ba83f20f64998/whisper/audio.py#L92 Arguments --------- audio : torch.Tensor A batch of audio waveforms in 16 kHz. padding : int The number of samples to append to the end of the audio tensor. Returns ------- log_spec : torch.Tensor A tensor that contains the batch of Mel spectrograms. """ if padding > 0: audio = nn.functional.pad(audio, (0, padding)) window = torch.hann_window(self._n_fft, device=audio.device) stft = torch.stft( audio, self._n_fft, self._hop_length, window=window, return_complex=True, ) magnitudes = stft[..., :-1].abs() ** 2 filters = self._mel_filters mel_spec = filters @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 return log_spec
[docs] def pad_or_trim(self, array, length: int = N_SAMPLES, axis=-1): """Pad or trim the Mel spectrograms as expected by the encoder. Reference: adapted from https://github.com/openai/whisper/blob/eff383b27b783e280c089475852ba83f20f64998/whisper/audio.py#L52 Arguments --------- array : torch.Tensor A tensor that contains the batch of Mel spectrograms. length : int Input tensor will be coerced to `length` number of samples. axis : int The axis along which to pad. Returns ------- array : torch.Tensor The padded tensor. """ if array.shape[axis] > length: array = array.index_select( dim=axis, index=torch.arange(length, device=array.device), ) if array.shape[axis] < length: pad_widths = [(0, 0)] * array.ndim pad_widths[axis] = ( 0, length - array.shape[axis], ) array = nn.functional.pad( array, [pad for sizes in pad_widths[::-1] for pad in sizes] ) return array
[docs] def forward_encoder(self, mel): """Takes an input mel and return its corresponding encoder states. Returns the last hidden state of the encoder or all hidden states if output_all_hiddens is True. Arguments --------- mel : torch.Tensor (signal) A batch of audio mel to transform to features. Returns ------- torch.Tensor The last hidden state of the encoder or all hidden states if output_all_hiddens is True. """ encoder_states = self.model.encoder( mel, output_hidden_states=self.output_all_hiddens ) if self.output_all_hiddens: return torch.stack(encoder_states.hidden_states) else: return encoder_states.last_hidden_state
[docs] def forward_decoder( self, encoder_states, decoder_input_ids, use_cache=True, past_key_values=None, ): """Perform one step of the whisper decoder. Arguments --------- encoder_states : torch.Tensor A batch of encoder_states features (mel + whisper feature extractor). decoder_input_ids : torch.Tensor Input tokens for the decoder. This can be language, task, etc. Please refer to the whisper paper for more details or go to the seq2seq2.py file in SpeechBrain to see how to generate the tokens with Greedy Search and/or Beam Search. use_cache : bool If True, keys and values are returned as output for KV caching. past_key_values : torch.Tensor (default: None) If not None, the past key values are used for KV caching and avoid recomputing the attention weights. Returns ------- logits : torch.Tensor The logits of the decoder. attn : torch.Tensor | None If ``output_attentions`` is True, the attention weights are returned. Otherwise, ``None`` is returned. past_key_values : torch.Tensor The past key values of the decoder. """ if past_key_values is not None: # if KV cache we do not need to pass the whole past tokens but only t-1 decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1) output_states = self.model.decoder( encoder_hidden_states=encoder_states, input_ids=decoder_input_ids, past_key_values=past_key_values, output_attentions=self.output_attentions, use_cache=use_cache, ) if self.output_attentions: attn = output_states.attentions[-1] attn = attn.view(attn.shape[0] * attn.shape[1], *attn.shape[2:]) else: attn = None x = output_states.last_hidden_state logits = ( x @ torch.transpose( self.model.decoder.embed_tokens.weight.to(x.dtype), 0, 1 ) ).float() return logits, attn, output_states.past_key_values
@cached_property def all_language_tokens(self): """Returns the list of tokens corresponding to the language tokens.""" from transformers.models.whisper.tokenization_whisper import LANGUAGES langs = list(LANGUAGES.keys()) # Convert keys to a list bos_token_id = self.tokenizer.convert_tokens_to_ids( self.tokenizer.bos_token ) result = [] for lang in langs: result.append(bos_token_id + 1 + langs.index(lang)) return tuple(result) @cached_property def all_language_codes(self): """Returns the list of language codes corresponding to the language tokens.""" from transformers.models.whisper.tokenization_whisper import LANGUAGES langs = list(LANGUAGES.keys()) # Convert keys to a list return tuple(langs) @cached_property def non_speech_tokens(self): """ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. - ♪♪♪ - ( SPEAKING FOREIGN LANGUAGE ) - [DAVID] Hey there, keeping basic punctuations like commas, periods, question marks, exclamation points, etc. Taken from: openai/whisper GitHub """ symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() # symbols that may be a single token or multiple tokens depending on the tokenizer. # In case they're multiple tokens, suppress the first token, which is safe because: # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. miscellaneous = set("♩♪♫♬♭♮♯") assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word result = { self.tokenizer.encode(" -", add_special_tokens=False)[0], self.tokenizer.encode(" '", add_special_tokens=False)[0], } for symbol in symbols + list(miscellaneous): for tokens in [ self.tokenizer.encode(symbol, add_special_tokens=False), self.tokenizer.encode(" " + symbol, add_special_tokens=False), ]: if len(tokens) == 1 or symbol in miscellaneous: result.add(tokens[0]) return tuple(sorted(result)) @cached_property def transcribe(self) -> int: """Returns the token id corresponding to the value of the `transcribe` field""" return self.tokenizer.convert_tokens_to_ids("<|transcribe|>") @cached_property def translate(self) -> int: """Returns the token id corresponding to the value of the `translate` field""" return self.tokenizer.convert_tokens_to_ids("<|translate|>") @cached_property def bos(self) -> int: """Returns the token id corresponding to the value of the `bos` field""" return self.tokenizer.convert_tokens_to_ids("<|startoftranscript|>") @cached_property def eos(self) -> int: """Returns the token id corresponding to the value of the `eos` field""" return self.tokenizer.convert_tokens_to_ids("<|endoftext|>") @cached_property def bos_lm(self) -> int: """Returns the token id corresponding to the value of the `bos_lm` field""" return self.tokenizer.convert_tokens_to_ids("<|startoflm|>") @cached_property def bos_prev(self) -> int: """Returns the token id corresponding to the value of the `bos_prev` field""" return self.tokenizer.convert_tokens_to_ids("<|startofprev|>") @cached_property def no_timestamps(self) -> int: """Returns the token id corresponding to the value of the `no_timestamps` field""" return self.tokenizer.convert_tokens_to_ids("<|notimestamps|>") @cached_property def timestamp_begin(self) -> int: """Returns the token id corresponding to the value of the `timestamp_begin` field""" return self.tokenizer.convert_tokens_to_ids("<|0.00|>") @cached_property def no_speech(self) -> int: """Returns the token id corresponding to the value of the `no_speech` field""" return self.no_timestamps - 1 @cached_property def language_token(self) -> int: """Returns the token id corresponding to the value of the `language` field""" if self.language is None: raise ValueError( "This tokenizer does not have language token configured" ) return self.to_language_token(self.language)
[docs] def to_language_token(self, language): """Returns the token id corresponding to the given language. Arguments --------- language : str The language to convert to a token. Returns ------- token The token id corresponding to the given language. Raises ------ KeyError If the language is not found in the tokenizer. """ token = self.tokenizer.convert_tokens_to_ids.get( f"<|{language}|>", None ) if token: return token raise KeyError(f"Language {language} not found in tokenizer.")
[docs] def set_language_token(self, language): """Set the language token to the given language. Arguments --------- language : str The language to set the token to. """ self.language = language self.tokenizer.set_prefix_tokens(language=self.language)
[docs] def set_task(self, task): """Set the task token to the given task. Arguments --------- task : str The task to set the token to. """ self.task = task self.tokenizer.set_prefix_tokens(task=self.task)
@cached_property def is_multilingual(self): """Returns True if the model is multilingual, False otherwise.""" return self.config.vocab_size >= 51865 @cached_property def get_suppress_tokens(self): """Returns the list of tokens to suppress""" return tuple(sorted(self.config.suppress_tokens))
[docs] @torch.no_grad() def detect_language(self, mel): """Detect the language of the given mel spectrogram features. Arguments --------- mel : torch.Tensor Mel spectrogram features to detect the language of. Returns ------- language_tokens : torch.Tensor of shape (batch_size,) ids of the most probable language tokens, which appears after the startoftranscript token. language_probs : List[Dict[str, float]] list of dictionaries containing the probability distribution over all languages. Raises ------ ValueError If the model doesn't have language tokens. """ if self.tokenizer.language is None: raise ValueError( "This model doesn't have language tokens so it can't perform lang id" ) batch_size = mel.shape[0] enc_states = self.model.encoder(mel).last_hidden_state decoder_input_ids = torch.tensor([[self.bos]] * batch_size).to( mel.device ) logits = self.forward_decoder(enc_states, decoder_input_ids)[0][:, 0] mask = torch.ones(logits.shape[-1], dtype=torch.bool) mask[list(self.all_language_tokens)] = False logits[:, mask] = -np.inf language_tokens = logits.argmax(dim=-1) language_token_probs = logits.softmax(dim=-1).cpu() language_probs = [ { c: language_token_probs[i, j].item() for j, c in zip( self.all_language_tokens, self.all_language_codes ) } for i in range(batch_size) ] return language_tokens, language_probs