speechbrain.lobes.models.huggingface_transformers.gpt moduleο
This lobe enables the integration of huggingface pretrained GPT2LMHeadModel model.
Transformer from HuggingFace needs to be installed: https://huggingface.co/transformers/installation.html
- Authors
Pooneh Mousavi 2023
Simone Alghisi 2023
Summaryο
Classes:
This lobe enables the integration of HuggingFace pretrained GPT model. |
Referenceο
- class speechbrain.lobes.models.huggingface_transformers.gpt.GPT(source, save_path, freeze=False, max_new_tokens=200, min_length=1, top_k=45, top_p=0.9, num_beams=8, eos_token_id=50258, early_stopping=True)[source]ο
Bases:
HFTransformersInterface
- This lobe enables the integration of HuggingFace pretrained GPT model.
- Transformer from HuggingFace needs to be installed:
The model can be finetuned. It will download automatically the model from HuggingFace or use a local path.
- Parameters:
source (str) β HuggingFace hub name: e.g βgpt2β
save_path (str) β Path (dir) of the downloaded model.
freeze (bool (default: False)) β If True, the model is frozen. If False, the model will be trained alongside with the rest of the pipeline.
max_new_tokens (int) β Maximum count of new tokens allowed.
min_length (int) β Minimum count of input tokens
top_k (int) β Top results count to keep
top_p (float) β Proportion of top results to keep
num_beams (int) β Number of decoder beams
eos_token_id (int) β Index of end-of-sentence token.
early_stopping (int) β Whether to stop training early.
Example
>>> model_hub = "gpt2" >>> save_path = "savedir" >>> model = GPT(model_hub, save_path) >>> tokens = torch.tensor([[1, 1]]) >>> tokens_type = torch.tensor([[1, 1]]) >>> attention_mask = torch.tensor([[1, 1]]) >>> outputs = model(tokens, tokens_type, attention_mask)
- forward(input_ids: Tensor, token_type_ids: Tensor, attention_mask: Tensor)[source]ο
Takes an input a history of conversation and returns its corresponding reply.
- Parameters:
input_ids (torch.Tensor) β A batch of input-id to transform to features.
token_type_ids (torch.Tensor) β Token Type(Speaker) for each token in input_ids.
attention_mask (torch.Tensor) β A batch of attention_mask.
- Returns:
output β Reply to conversation
- Return type:
torch.Tensor
- generate(input_ids: Tensor, token_type_ids, attention_mask: Tensor, decoder_type='greedy')[source]ο
Takes an input a history of conversation and returns its corresponding reply.
- Parameters:
input_ids (torch.Tensor) β A batch of input-id which are dialogue context tokens
token_type_ids (torch.Tensor)
attention_mask (torch.Tensor) β A batch of attention_mask.
decoder_type (str) β It shows strategy for autoregressive decoding either beam search or greedy.
- Returns:
hyp β Conversation reply.
- Return type:
torch.Tensor