speechbrain.lobes.models.huggingface_transformers.mbart module

This lobe enables the integration of huggingface pretrained mBART models. Reference: https://arxiv.org/abs/2001.08210

Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html

Authors
  • Ha Nguyen 2023

Summary

Classes:

mBART

This lobe enables the integration of HuggingFace and SpeechBrain pretrained mBART models.

Reference

class speechbrain.lobes.models.huggingface_transformers.mbart.mBART(source, save_path, freeze=True, target_lang='fr_XX', decoder_only=True, share_input_output_embed=True)[source]

Bases: HFTransformersInterface

This lobe enables the integration of HuggingFace and SpeechBrain pretrained mBART models.

Source paper mBART: https://arxiv.org/abs/2001.08210 Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html

The model is normally used as a text decoder of seq2seq models. It will download automatically the model from HuggingFace or use a local path.

Parameters:
  • source (str) – HuggingFace hub name: e.g “facebook/mbart-large-50-many-to-many-mmt”

  • save_path (str) – Path (dir) of the downloaded model.

  • freeze (bool (default: True)) – If True, the model is frozen. If False, the model will be trained alongside with the rest of the pipeline.

  • target_lang (str (default: fra_Latn (a.k.a French)) – The target language code according to NLLB model.

  • decoder_only (bool (default: True)) – If True, only take the decoder part (and/or the lm_head) of the model. This is useful in case one wants to couple a pre-trained speech encoder (e.g. wav2vec) with a text-based pre-trained decoder (e.g. mBART, NLLB).

  • share_input_output_embed (bool (default: True)) – If True, use the embedded layer as the lm_head.

Example

>>> src = torch.rand([10, 1, 1024])
>>> tgt = torch.LongTensor([[250008,    313,     25,    525,    773,  21525,   4004,      2]])
>>> model_hub = "facebook/mbart-large-50-many-to-many-mmt"
>>> save_path = "savedir"
>>> model = mBART(model_hub, save_path) 
>>> outputs = model(src, tgt) 
forward(src, tgt, pad_idx=0)[source]

This method implements a forward step for mt task using a wav2vec encoder (same than above, but without the encoder stack)

Parameters:
  • (transcription) (src) – output features from the w2v2 encoder

  • (translation) (tgt) – The sequence to the decoder (required).

  • pad_idx (int) – The index for <pad> token (default=0).

decode(tgt, encoder_out, enc_len=None)[source]

This method implements a decoding step for the transformer model.

Parameters:
  • tgt (torch.Tensor) – The sequence to the decoder.

  • encoder_out (torch.Tensor) – Hidden output of the encoder.

  • enc_len (torch.LongTensor) – The actual length of encoder states.

custom_padding(x, org_pad, custom_pad)[source]

This method customizes the padding. Default pad_idx of SpeechBrain is 0. However, it happens that some text-based models like mBART reserves 0 for something else, and are trained with specific pad_idx. This method change org_pad to custom_pad

Parameters:
  • x (torch.Tensor) – Input tensor with original pad_idx

  • org_pad (int) – Orginal pad_idx

  • custom_pad (int) – Custom pad_idx

override_config(config)[source]

If the config needs to be overrided, here is the place.

Parameters:

config (MBartConfig) – The original config needs to be overrided.

Return type:

Overridded config

training: bool