speechbrain.nnet.adapters module

The SpeechBrain implementation of various pre-trained model adapters e.g. LoRA, Houlsby

Authors
  • Titouan Parcollet 2024

  • Peter Plantinga 2024

Summary

Classes:

AdaptedModel

Given any torch model, e.g. asr_brain.modules.Transformer, and an adapter class, e.g. HoulsbyAdapter, this class will replace the target layers with this new adapter class (while preserving the parameters).

HoulsbyAdapterLinear

This class implements the Houlsby Adapter as described in: 'Parameter-Efficient Transfer Learning for NLP' https://arxiv.org/abs/1902.00751

LoRA

This class implements the LoRA Adapter as described in: 'LoRA: Low-Rank Adaptation of Large Language Models' https://arxiv.org/abs/2106.09685

Functions:

is_layer_adaptable

Check if layer is among list of layers to be adapted.

replace_module

Replace layer with a new module based on a parent assignation.

Reference

class speechbrain.nnet.adapters.AdaptedModel(model_to_adapt: Module, adapter_class: Module, all_linear: bool = False, all_conv: bool = False, target_layers: list = [], unfrozen_layers: list = [], adapter_kwargs: dict = {}, manual_adapter_insertion: bool = False)[source]

Bases: Module

Given any torch model, e.g. asr_brain.modules.Transformer, and an adapter class, e.g. HoulsbyAdapter, this class will replace the target layers with this new adapter class (while preserving the parameters).

Parameters:
  • model_to_adapt (nn.Module) – The base PyTorch model to add adapters to.

  • adapter_class (class) – An (uninitialized) adapter of this SpeechBrain library.

  • all_linear (bool) – Whether to add the adapter to all linear layers (default: False)

  • all_conv (bool) – Whether to add the adapter to all conv layers (default: False)

  • target_layers (list of str) – A list of module names in the given model that should be replaced. Supports Unix shell-style wildcards (*, ?, [seq], [!seq]) with fnmatch.

  • unfrozen_layers (list of str) – List of layers to be unfrozen during training. Supports Unix shell-style wildcards (*, ?, [seq], [!seq]) with fnmatch.

  • adapter_kwargs (dict) – Ensemble of parameters that should be given to the adapter.

  • manual_adapter_insertion (bool) – The default value (False) leads to the adapters being inserted at the time of initialization. However, in some cases, it is preferable to wait to insert the adapters, e.g. when pretrained parameters need to be loaded. In this case, one can set this to True and call insert_adapters manually after the parameters have been loaded.

Example

>>> from collections import OrderedDict
>>> model = torch.nn.Sequential(
...   OrderedDict([
...     ("layer1", torch.nn.Linear(10, 20)),
...     ("layer2", torch.nn.Linear(20, 20)),
...     ("layer3", torch.nn.Linear(20, 10)),
...   ])
... )
>>> lora_model = AdaptedModel(
...   model_to_adapt=model,
...   adapter_class=LoRA,
...   target_layers=["layer[13]"],
...   unfrozen_layers=["layer2"],
...   adapter_kwargs={"rank": 2},
... )
>>> lora_model
AdaptedModel(
  (adapted_model): Sequential(
    (layer1): LoRA(
      (pretrained_module): Linear(in_features=10, out_features=20, bias=True)
      (adapter_down_proj): Linear(in_features=10, out_features=2, bias=False)
      (adapter_up_proj): Linear(in_features=2, out_features=20, bias=False)
    )
    (layer2): Linear(in_features=20, out_features=20, bias=True)
    (layer3): LoRA(
      (pretrained_module): Linear(in_features=20, out_features=10, bias=True)
      (adapter_down_proj): Linear(in_features=20, out_features=2, bias=False)
      (adapter_up_proj): Linear(in_features=2, out_features=10, bias=False)
    )
  )
)
insert_adapters()[source]

If this is in __init__ it conflicts with Pretrainer. Ensure this function is called exactly once before training. See __init__.manual_adapter_insertion

forward(*args, **kwargs)[source]

Pass arguments to adapted model.

saver(path)[source]

Saves only the trainable parameters.

loader(path, end_of_epoch)[source]

Loads the base model plus trained params.

parameter_transfer(path)[source]

Avoids warnings due to only loading trained params.

__getattr__(item)[source]

Override getattr to pass item accesses to pre-adapted model.

speechbrain.nnet.adapters.is_layer_adaptable(name, module, all_linear, all_conv, target_layers)[source]

Check if layer is among list of layers to be adapted.

Parameters:
  • name (str) – The name of the module to check.

  • module (torch.nn.Module) – The module to check.

  • all_linear (bool) – Whether all linear layers should be adapted.

  • all_conv (bool) – Whether all conv layers should be adapted.

  • target_layers (str or list of str) – See add_adapters_to_model

Returns:

Whether the layer is to be adapted or not.

Return type:

bool

speechbrain.nnet.adapters.replace_module(model: Module, name: str, new_module: Module)[source]

Replace layer with a new module based on a parent assignation. This is used to replace layers with an Adapter layer wrapped around the original layer. Hence, old parameters are preserved and new ones are added.

Parameters:
  • model (nn.Module) – Model containing the module to be replaced.

  • name (str) – Name of the target module to replace.

  • new_module (nn.Module) – New module made of the old plus the new parameters.

class speechbrain.nnet.adapters.HoulsbyAdapterLinear(target_linear, projection_size, activation=<class 'speechbrain.nnet.activations.Swish'>, bias=True)[source]

Bases: Module

This class implements the Houlsby Adapter as described in: β€˜Parameter-Efficient Transfer Learning for NLP’ https://arxiv.org/abs/1902.00751

Parameters:
  • target_linear (nn.Module) – Module corresponding to the pretrained Linear that will be wrapped with this adapter.

  • projection_size (int) – Size of the projection layer (usually smaller).

  • activation (nn.Module) – The activation function. Default is Swish.

  • bias (bool) – Whether to use biases in the linear projections.

Example

>>> import torch
>>> x = torch.rand((8, 60, 64))
>>> base_linear = nn.Linear(64, 64)
>>> adapt = HoulsbyAdapterLinear(base_linear, 8)
>>> output = adapt(x)
>>> output.shape
torch.Size([8, 60, 64])
forward(x: Tensor)[source]

Applies the HoulsbyAdapter to an input tensor x.

Parameters:

x (torch.Tensor) – Input tensor to the adapter module. Shape: [B, Time, X]

Return type:

The linear outputs

class speechbrain.nnet.adapters.LoRA(target_module, rank=16, alpha=1.0)[source]

Bases: Module

This class implements the LoRA Adapter as described in: β€˜LoRA: Low-Rank Adaptation of Large Language Models’ https://arxiv.org/abs/2106.09685

Parameters:
  • target_module (nn.Module) – Module corresponding to the pretrained layer that will be wrapped with this adapter. Works with nn.Linear and nn.Conv

  • rank (int) – Size of the projection layer or rank (usually smaller).

  • alpha (float) – Value used to control the scaling in LoRA. Default is one.

Example

>>> import torch
>>> x = torch.rand((8, 60, 64))
>>> base_linear = nn.Linear(64, 64)
>>> adapt = LoRA(base_linear, 64, 4)
>>> output = adapt(x)
>>> output.shape
torch.Size([8, 60, 64])
forward(x: Tensor)[source]

Applies the LoRA Adapter.

Parameters:

x (torch.Tensor) – Input tensor to the adapter module.

Return type:

The linear outputs