speechbrain.lobes.models.huggingface_transformers.gpt module
This lobe enables the integration of huggingface pretrained GPT2LMHeadModel model.
Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html
- Authors
Pooneh Mousavi 2023
Simone Alghisi 2023
Summary
Classes:
This lobe enables the integration of HuggingFace pretrained GPT model. |
Reference
- class speechbrain.lobes.models.huggingface_transformers.gpt.GPT(source, save_path, freeze=False, max_new_tokens=200, min_length=1, top_k=45, top_p=0.9, num_beams=8, eos_token_id=50258, early_stopping=True)[source]
Bases:
HFTransformersInterface
- This lobe enables the integration of HuggingFace pretrained GPT model.
- Transformer from HuggingFace needs to be installed:
The model can be finetuned. It will download automatically the model from HuggingFace or use a local path.
- Parameters:
Example
>>> model_hub = "gpt2" >>> save_path = "savedir" >>> model = GPT(model_hub, save_path) >>> tokens = torch.tensor([[1, 1]]) >>> tokens_type = torch.tensor([[1, 1]]) >>> attention_mask = torch.tensor([[1, 1]]) >>> outputs = model(tokens, tokens_type, attention_mask)
- forward(input_ids: Tensor, token_type_ids: Tensor, attention_mask: Tensor)[source]
Takes an input a history of conversation and returns its corresponding reply.
- Parameters:
input_ids (torch.Tensor ()) – A batch of input-id to transform to features.
token_type_ids (torch.Tensor) – Token Type(Speaker) for each token in input_ids.
attention_mask (torch.Tensor ()) – A batch of attention_mask.
- generate(input_ids: Tensor, token_type_ids, attention_mask: Tensor, decoder_type='greedy')[source]
Takes an input a history of conversation and returns its corresponding reply.
- Parameters:
input_ids (torch.Tensor ()) – A batch of input-id which are dialogue context tokens
decoder_type (Str) – It shows strategy for autoregressive decoding either beam seach or greedy.
attention_mask (torch.Tensor ()) – A batch of attention_mask.