Source code for speechbrain.lobes.models.huggingface_transformers.whisper

"""This lobe enables the integration of huggingface pretrained whisper model.

Transformer from HuggingFace needs to be installed:
https://huggingface.co/transformers/installation.html

Authors
 * Adel Moumen 2022
 * Titouan Parcollet 2022
 * Luca Della Libera 2022
 * Ha Nguyen 2023
"""

import torch
import logging
from torch import nn

from speechbrain.lobes.models.huggingface_transformers.huggingface import (
    HFTransformersInterface,
)

logger = logging.getLogger(__name__)


[docs] class Whisper(HFTransformersInterface): """This lobe enables the integration of HuggingFace pretrained Whisper model. Source paper whisper: https://cdn.openai.com/papers/whisper.pdf Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html Some part of the code also cis adapted from the official OpenAI repository: https://github.com/openai/whisper The model can be finetuned. It will download automatically the model from HuggingFace or use a local path. Arguments --------- source : str HuggingFace hub name: e.g "openai/whisper-tiny" save_path : str Path (dir) of the downloaded model. sampling_rate : int (default: 16000) Sampling rate of the audio signal. encoder_only : bool (default: False) If True, the forward function outputs the hidden states from the last transformer layer of the encoder. If False, one step of the decoder is performed and returned. freeze : bool (default: False) If True, the model is frozen. freeze_encoder : bool (default: False) If True, the encoder is frozen. output_attentions : bool (default: True) If True, the forward function outputs the attention weights. output_all_hiddens: bool (default: False) If True, the forward function outputs the hidden states from all transformer layers of the encoder. For example whisper-base has 6 transformer layers and the output is of shape (7, B, T, C), where the output of the CNN output is added to the beginning. If False, the forward function outputs the hidden states only from the last transformer layer of the encoder. Example ------- >>> model_hub = "openai/whisper-tiny" >>> save_path = "savedir" >>> sampling_rate = 16000 >>> model = Whisper(model_hub, save_path, sampling_rate) >>> tokens = torch.tensor([[1, 1]]) * model.model.config.decoder_start_token_id >>> inputs = torch.randn([1, 93680]) >>> outputs = model(inputs, tokens) """ def __init__( self, source, save_path, sampling_rate=16000, encoder_only=False, freeze=False, freeze_encoder=False, output_attentions=True, output_all_hiddens=False, ): super().__init__( source=source, save_path=save_path, freeze=freeze, sampling_rate=sampling_rate, ) self.sampling_rate = sampling_rate self.encoder_only = encoder_only self.freeze_encoder = freeze_encoder self.output_attentions = output_attentions self.output_all_hiddens = output_all_hiddens if encoder_only: self.tokenizer = None else: self.load_tokenizer(source) self.load_feature_extractor( source, save_path, sampling_rate=sampling_rate ) self._n_fft = self.feature_extractor.n_fft self._hop_length = self.feature_extractor.hop_length self._n_samples = self.feature_extractor.n_samples # The following breaking changes were introduced in transformers>=4.29: # 1) mel_filters.shape = (..., feature_extractor.feature_size) instead of (feature_extractor.feature_size, ...) # 2) mel_filters.dtype = float64 instead of float32 # The following code fixes the issue in a backward compatible way mel_filters = self.feature_extractor.mel_filters if mel_filters.shape[0] != self.feature_extractor.feature_size: mel_filters = mel_filters.T assert mel_filters.shape[0] == self.feature_extractor.feature_size self.register_buffer( "_mel_filters", torch.as_tensor(mel_filters, dtype=torch.float32) ) ################################################################# if not self.freeze and self.freeze_encoder: logger.warning( "speechbrain.lobes.models.huggingface_transformers.whisper - whisper encoder is frozen." ) for param in self.model.encoder.parameters(): param.requires_grad = False
[docs] def freeze_model(self, model): """ Freezes parameters of a model. Arguments --------- model : from AutoModel.from_config Valid HuggingFace transformers model object. """ logger.warning( "speechbrain.lobes.models.huggingface_transformers.whisper - whisper encoder-decoder is frozen." ) model.train() # we keep it to train to have dropout and LN computed adequaly for param in model.parameters(): param.requires_grad = False
[docs] def forward(self, wav, decoder_input_ids=None): """Perform mel transformation and one step of the whisper (encoder-decoder). Arguments --------- wav : torch.Tensor (signal) A batch of audio signals to transform to features. decoder_input_ids : torch.Tensor This is necessary if we want to use the decoder. A batch of decoder inputs tokens. The first tokens need to dictacte the behavior of the decoder. It needs to start with the bos_token, the language token, the task token, and finally the timestamp token. Please refer to the whisper paper for more details or go to the seq2seq2.py file in SpeechBrain to see how to generate the tokens with Greedy Search and/or Beam Search. """ if self.freeze: with torch.no_grad(): out_encoder = self.forward_encoder(wav) if self.encoder_only: return out_encoder if self.output_all_hiddens: logits, attn = self.forward_decoder( out_encoder[-1], decoder_input_ids ) else: logits, attn = self.forward_decoder( out_encoder, decoder_input_ids ) return out_encoder, logits, attn else: if self.encoder_only: return self.forward_encoder(wav) else: out_encoder = self.forward_encoder(wav) if self.output_all_hiddens: logits, attn = self.forward_decoder( out_encoder[-1], decoder_input_ids ) else: logits, attn = self.forward_decoder( out_encoder, decoder_input_ids ) return out_encoder, logits, attn
[docs] def forward_encoder(self, wav): """Perform one step of the whisper encoder with Mel FBANKs as Input. Arguments --------- wav : torch.Tensor (FBANKs) A batch of Mel FBANK from HF to transform to features. """ if self.freeze_encoder: with torch.no_grad(): return self._get_encoder_states(wav) else: return self._get_encoder_states(wav)
def _get_encoder_states(self, wav): """Takes an input waveform and return its corresponding encoder states. Returns the last hidden state of the encoder or all hidden states if output_all_hiddens is True. Arguments --------- wav : torch.Tensor (signal) A batch of audio signals to transform to features. """ mel = self._get_mel(wav) if self.output_all_hiddens: states = self.model.encoder(mel, output_hidden_states=True) return torch.stack(states.hidden_states) else: return self.model.encoder(mel).last_hidden_state def _get_mel(self, wav): """Takes an input waveform and return its corresponding mel spectrogram according to HuggingFace implementation. WARNING: it's slow! Better push this in the DataLoader. Arguments --------- wav : torch.Tensor (signal) A batch of audio signals to transform to features. """ mels = self._pad_or_trim(wav) mels = self._log_mel_spectrogram(mels) return mels def _log_mel_spectrogram(self, audio): """Compute the Mel spectrogram of a batch of input waveforms. Reference: adapted from https://github.com/openai/whisper/blob/eff383b27b783e280c089475852ba83f20f64998/whisper/audio.py#L92 Arguments --------- audio : torch.Tensor A batch of audio waveforms in 16 kHz. Returns ------- torch.Tensor A tensor that contains the batch of Mel spectrograms. """ window = torch.hann_window(self._n_fft, device=audio.device) stft = torch.stft( audio, self._n_fft, self._hop_length, window=window, return_complex=True, ) magnitudes = stft[..., :-1].abs() ** 2 filters = self._mel_filters mel_spec = filters @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() log_spec = torch.maximum( log_spec, (log_spec.flatten(start_dim=1).max(dim=-1)[0] - 8.0)[:, None, None], ) log_spec = (log_spec + 4.0) / 4.0 return log_spec def _pad_or_trim(self, array, axis=-1): """Pad or trim the Mel spectrograms as expected by the encoder. Reference: adapted from https://github.com/openai/whisper/blob/eff383b27b783e280c089475852ba83f20f64998/whisper/audio.py#L52 Arguments --------- array : torch.Tensor A tensor that contains the batch of Mel spectrograms. axis : int The axis along which to pad. Returns ------- torch.Tensor The padded tensor. """ if array.shape[axis] > self._n_samples: array = array.index_select( dim=axis, index=torch.arange(self._n_samples, device=array.device), ) if array.shape[axis] < self._n_samples: pad_widths = [(0, 0)] * array.ndim pad_widths[axis] = ( 0, self._n_samples - array.shape[axis], ) array = nn.functional.pad( array, [pad for sizes in pad_widths[::-1] for pad in sizes] ) return array
[docs] def forward_decoder(self, audio_features, decoder_input_ids): """Perform one step of the whisper decoder. Arguments --------- audio_features : torch.Tensor A batch of audio features (mel + whisper encoding). decoder_input_ids : torch.Tensor A batch of decoder inputs tokens. The first tokens need to dictacte the behavior of the decoder. It needs to start with the bos_token, the language token, the task token, and finally the timestamp token. Please refer to the whisper paper for more details or go to the seq2seq2.py file in SpeechBrain to see how to generate the tokens with Greedy Search and/or Beam Search. """ output_states = self.model.decoder( encoder_hidden_states=audio_features, input_ids=decoder_input_ids, output_attentions=self.output_attentions, ) attn = output_states.attentions[-1] attn = attn.view(attn.shape[0] * attn.shape[1], *attn.shape[2:]) output_states = output_states.last_hidden_state logits = ( output_states @ torch.transpose( self.model.decoder.embed_tokens.weight.to(output_states.dtype), 0, 1, ) ).to(audio_features.dtype) return logits, attn