speechbrain.lobes.models.huggingface_transformers.gpt module

This lobe enables the integration of huggingface pretrained GPT2LMHeadModel model.

Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html

Authors
  • Pooneh Mousavi 2023

  • Simone Alghisi 2023

Summary

Classes:

GPT

This lobe enables the integration of HuggingFace pretrained GPT model.

Reference

class speechbrain.lobes.models.huggingface_transformers.gpt.GPT(source, save_path, freeze=False, max_new_tokens=200, min_length=1, top_k=45, top_p=0.9, num_beams=8, eos_token_id=50258, early_stopping=True)[source]

Bases: HFTransformersInterface

This lobe enables the integration of HuggingFace pretrained GPT model.
Source paper whisper:

https://life-extension.github.io/2020/05/27/GPT%E6%8A%80%E6%9C%AF%E5%88%9D%E6%8E%A2/language-models.pdf

Transformer from HuggingFace needs to be installed:

https://huggingface.co/transformers/installation.html

The model can be finetuned. It will download automatically the model from HuggingFace or use a local path.

Parameters:
  • source (str) – HuggingFace hub name: e.g “gpt2”

  • save_path (str) – Path (dir) of the downloaded model.

  • freeze (bool (default: False)) – If True, the model is frozen. If False, the model will be trained alongside with the rest of the pipeline.

Example

>>> model_hub = "gpt2"
>>> save_path = "savedir"
>>> model = GPT(model_hub, save_path)
>>> tokens = torch.tensor([[1, 1]])
>>> tokens_type = torch.tensor([[1, 1]])
>>> attention_mask = torch.tensor([[1, 1]])
>>> outputs = model(tokens, tokens_type, attention_mask)
forward(input_ids: Tensor, token_type_ids: Tensor, attention_mask: Tensor)[source]

Takes an input a history of conversation and returns its corresponding reply.

Parameters:
  • input_ids (torch.Tensor ()) – A batch of input-id to transform to features.

  • token_type_ids (torch.Tensor) – Token Type(Speaker) for each token in input_ids.

  • attention_mask (torch.Tensor ()) – A batch of attention_mask.

generate(input_ids: Tensor, token_type_ids, attention_mask: Tensor, decoder_type='greedy')[source]

Takes an input a history of conversation and returns its corresponding reply.

Parameters:
  • input_ids (torch.Tensor ()) – A batch of input-id which are dialogue context tokens

  • decoder_type (Str) – It shows strategy for autoregressive decoding either beam seach or greedy.

  • attention_mask (torch.Tensor ()) – A batch of attention_mask.

training: bool