Source code for speechbrain.lobes.models.huggingface_transformers.mbart

"""This lobe enables the integration of huggingface pretrained mBART models.
Reference: https://arxiv.org/abs/2001.08210

Transformer from HuggingFace needs to be installed:
https://huggingface.co/transformers/installation.html

Authors
 * Ha Nguyen 2023
"""

import torch
import logging

from speechbrain.lobes.models.huggingface_transformers.huggingface import (
    HFTransformersInterface,
)

logger = logging.getLogger(__name__)


[docs] class mBART(HFTransformersInterface): """This lobe enables the integration of HuggingFace and SpeechBrain pretrained mBART models. Source paper mBART: https://arxiv.org/abs/2001.08210 Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html The model is normally used as a text decoder of seq2seq models. It will download automatically the model from HuggingFace or use a local path. Arguments --------- source : str HuggingFace hub name: e.g "facebook/mbart-large-50-many-to-many-mmt" save_path : str Path (dir) of the downloaded model. freeze : bool (default: True) If True, the model is frozen. If False, the model will be trained alongside with the rest of the pipeline. target_lang: str (default: fra_Latn (a.k.a French) The target language code according to NLLB model. decoder_only : bool (default: True) If True, only take the decoder part (and/or the lm_head) of the model. This is useful in case one wants to couple a pre-trained speech encoder (e.g. wav2vec) with a text-based pre-trained decoder (e.g. mBART, NLLB). share_input_output_embed : bool (default: True) If True, use the embedded layer as the lm_head. Example ------- >>> src = torch.rand([10, 1, 1024]) >>> tgt = torch.LongTensor([[250008, 313, 25, 525, 773, 21525, 4004, 2]]) >>> model_hub = "facebook/mbart-large-50-many-to-many-mmt" >>> save_path = "savedir" >>> model = mBART(model_hub, save_path) # doctest: +SKIP >>> outputs = model(src, tgt) # doctest: +SKIP """ def __init__( self, source, save_path, freeze=True, target_lang="fr_XX", decoder_only=True, share_input_output_embed=True, ): super().__init__( source=source, save_path=save_path, freeze=freeze, seq2seqlm=True, ) self.target_lang = target_lang self.decoder_only = decoder_only self.share_input_output_embed = share_input_output_embed self.load_tokenizer(source=source, pad_token=None, tgt_lang=target_lang) if share_input_output_embed: self.model.lm_head.weight = ( self.model.model.decoder.embed_tokens.weight ) self.model.lm_head.requires_grad = False self.model.model.decoder.embed_tokens.requires_grad = False if decoder_only: # When we only want to use the decoder part del self.model.model.encoder for k, p in self.model.named_parameters(): # It is a common practice to only fine-tune the encoder_attn and layer_norm layers of this model. if "encoder_attn" in k or "layer_norm" in k: p.requires_grad = True else: p.requires_grad = False
[docs] def forward(self, src, tgt, pad_idx=0): """This method implements a forward step for mt task using a wav2vec encoder (same than above, but without the encoder stack) Arguments ---------- src (transcription): tensor output features from the w2v2 encoder tgt (translation): tensor The sequence to the decoder (required). pad_idx : int The index for <pad> token (default=0). """ # should we replace 0 elements by pax_idx as pad_idx of mbart model seems to be different from 0? tgt = self.custom_padding( tgt, 0, self.model.model.decoder.config.pad_token_id ) if self.freeze: with torch.no_grad(): if hasattr(self.model.model, "encoder"): src = self.model.model.encoder( inputs_embeds=src ).last_hidden_state.detach() dec_out = self.model.model.decoder( input_ids=tgt, encoder_hidden_states=src ).last_hidden_state.detach() dec_out = self.model.lm_head(dec_out).detach() return dec_out if hasattr(self.model.model, "encoder"): src = self.model.model.encoder(inputs_embeds=src).last_hidden_state dec_out = self.model.model.decoder( input_ids=tgt, encoder_hidden_states=src ).last_hidden_state dec_out = self.model.lm_head(dec_out) return dec_out
[docs] @torch.no_grad() def decode(self, tgt, encoder_out, enc_len=None): """This method implements a decoding step for the transformer model. Arguments --------- tgt : torch.Tensor The sequence to the decoder. encoder_out : torch.Tensor Hidden output of the encoder. enc_len : torch.LongTensor The actual length of encoder states. """ if tgt.dtype not in [torch.long, torch.int64]: tgt = tgt.long() tgt_mask = torch.ones(tgt.size(), device=tgt.device) output = self.model.model.decoder( input_ids=tgt, encoder_hidden_states=encoder_out, attention_mask=tgt_mask, output_attentions=True, ) return ( self.model.lm_head(output.last_hidden_state), output.cross_attentions[-1], )
[docs] def custom_padding(self, x, org_pad, custom_pad): """This method customizes the padding. Default pad_idx of SpeechBrain is 0. However, it happens that some text-based models like mBART reserves 0 for something else, and are trained with specific pad_idx. This method change org_pad to custom_pad Arguments --------- x : torch.Tensor Input tensor with original pad_idx org_pad : int Orginal pad_idx custom_pad : int Custom pad_idx """ out = x.clone() out[x == org_pad] = custom_pad return out
[docs] def override_config(self, config): """If the config needs to be overrided, here is the place. Arguments --------- config : MBartConfig The original config needs to be overrided. Returns ------- Overridded config """ config.decoder_layerdrop = 0.05 return config