speechbrain.lobes.models.huggingface_transformers.mbart moduleο
This lobe enables the integration of huggingface pretrained mBART models. Reference: https://arxiv.org/abs/2001.08210
Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html
- Authors
Ha Nguyen 2023
Summaryο
Classes:
This lobe enables the integration of HuggingFace and SpeechBrain pretrained mBART models. |
Referenceο
- class speechbrain.lobes.models.huggingface_transformers.mbart.mBART(source, save_path, freeze=True, target_lang='fr_XX', decoder_only=True, share_input_output_embed=True)[source]ο
Bases:
HFTransformersInterface
This lobe enables the integration of HuggingFace and SpeechBrain pretrained mBART models.
Source paper mBART: https://arxiv.org/abs/2001.08210 Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html
The model is normally used as a text decoder of seq2seq models. It will download automatically the model from HuggingFace or use a local path.
- Parameters:
source (str) β HuggingFace hub name: e.g βfacebook/mbart-large-50-many-to-many-mmtβ
save_path (str) β Path (dir) of the downloaded model.
freeze (bool (default: True)) β If True, the model is frozen. If False, the model will be trained alongside with the rest of the pipeline.
target_lang (str (default: fra_Latn (a.k.a French)) β The target language code according to NLLB model.
decoder_only (bool (default: True)) β If True, only take the decoder part (and/or the lm_head) of the model. This is useful in case one wants to couple a pre-trained speech encoder (e.g. wav2vec) with a text-based pre-trained decoder (e.g. mBART, NLLB).
share_input_output_embed (bool (default: True)) β If True, use the embedded layer as the lm_head.
Example
>>> src = torch.rand([10, 1, 1024]) >>> tgt = torch.LongTensor([[250008, 313, 25, 525, 773, 21525, 4004, 2]]) >>> model_hub = "facebook/mbart-large-50-many-to-many-mmt" >>> save_path = "savedir" >>> model = mBART(model_hub, save_path) >>> outputs = model(src, tgt)
- forward(src, tgt, pad_idx=0)[source]ο
This method implements a forward step for mt task using a wav2vec encoder (same than above, but without the encoder stack)
- Parameters:
src (tensor) β output features from the w2v2 encoder (transcription)
tgt (tensor) β The sequence to the decoder (translation) (required).
pad_idx (int) β The index for <pad> token (default=0).
- Returns:
dec_out β Decoder output.
- Return type:
torch.Tensor
- decode(tgt, encoder_out, enc_len=None)[source]ο
This method implements a decoding step for the transformer model.
- Parameters:
tgt (torch.Tensor) β The sequence to the decoder.
encoder_out (torch.Tensor) β Hidden output of the encoder.
enc_len (torch.LongTensor) β The actual length of encoder states.
- Returns:
output (torch.Tensor) β Output of transformer.
cross_attention (torch.Tensor) β Attention value.