"""The SpeechBrain implementation of various pre-trained model adapters e.g.
LoRA, Houlsby
Authors
* Titouan Parcollet 2024
* Peter Plantinga 2024
"""
import warnings
from fnmatch import fnmatch
import torch
import torch.nn as nn
from speechbrain.nnet.activations import Swish
from speechbrain.utils import checkpoints
MHA_WARNING = """
Torch's native multi-head attention is not adaptable since it accesses layer
weights directly to pass to highly optimized fused kernels. We are excluding
all native Torch MHA layers from the list of layers to adapt.
"""
[docs]
@checkpoints.register_checkpoint_hooks
class AdaptedModel(nn.Module):
"""Given any torch model, e.g. asr_brain.modules.Transformer, and an adapter
class, e.g. HoulsbyAdapter, this class will replace the target layers
with this new adapter class (while preserving the parameters).
Arguments
---------
model_to_adapt: nn.Module
The base PyTorch model to add adapters to.
adapter_class: class
An (uninitialized) adapter of this SpeechBrain library.
all_linear: bool
Whether to add the adapter to all linear layers (default: False)
all_conv: bool
Whether to add the adapter to all conv layers (default: False)
target_layers: list of str
A list of module names in the given model that should be replaced.
Supports Unix shell-style wildcards `(*, ?, [seq], [!seq])` with `fnmatch`.
unfrozen_layers: list of str
List of layers to be unfrozen during training.
Supports Unix shell-style wildcards `(*, ?, [seq], [!seq])` with `fnmatch`.
adapter_kwargs: dict
Ensemble of parameters that should be given to the adapter.
manual_adapter_insertion: bool
The default value (`False`) leads to the adapters being inserted at
the time of initialization. However, in some cases, it is preferable
to wait to insert the adapters, e.g. when pretrained parameters need to
be loaded. In this case, one can set this to `True` and call
`insert_adapters` manually after the parameters have been loaded.
Example
-------
>>> from collections import OrderedDict
>>> model = torch.nn.Sequential(
... OrderedDict([
... ("layer1", torch.nn.Linear(10, 20)),
... ("layer2", torch.nn.Linear(20, 20)),
... ("layer3", torch.nn.Linear(20, 10)),
... ])
... )
>>> lora_model = AdaptedModel(
... model_to_adapt=model,
... adapter_class=LoRA,
... target_layers=["layer[13]"],
... unfrozen_layers=["layer2"],
... adapter_kwargs={"rank": 2},
... )
>>> lora_model
AdaptedModel(
(adapted_model): Sequential(
(layer1): LoRA(
(pretrained_module): Linear(in_features=10, out_features=20, bias=True)
(adapter_down_proj): Linear(in_features=10, out_features=2, bias=False)
(adapter_up_proj): Linear(in_features=2, out_features=20, bias=False)
)
(layer2): Linear(in_features=20, out_features=20, bias=True)
(layer3): LoRA(
(pretrained_module): Linear(in_features=20, out_features=10, bias=True)
(adapter_down_proj): Linear(in_features=20, out_features=2, bias=False)
(adapter_up_proj): Linear(in_features=2, out_features=10, bias=False)
)
)
)
"""
def __init__(
self,
model_to_adapt: nn.Module,
adapter_class: nn.Module,
all_linear: bool = False,
all_conv: bool = False,
target_layers: list = [],
unfrozen_layers: list = [],
adapter_kwargs: dict = {},
manual_adapter_insertion: bool = False,
):
super().__init__()
# Collect and freeze layers
self.adapted_model = model_to_adapt
self.adapter_class = adapter_class
self.adapter_kwargs = adapter_kwargs
for param in model_to_adapt.parameters():
param.requires_grad = False
# Iterate modules to create list of layers to adapt
self.replace_layers = []
for name, module in model_to_adapt.named_modules():
if is_layer_adaptable(
name, module, all_linear, all_conv, target_layers
):
# Torch's MultiheadAttention is not adaptable due to an
# optimized fused kernel, warn if we find this.
parent_name = ".".join(name.split(".")[:-1])
parent = model_to_adapt.get_submodule(parent_name)
if isinstance(parent, torch.nn.MultiheadAttention):
warnings.warn(MHA_WARNING)
else:
self.replace_layers.append(name)
elif any(fnmatch(name, layer) for layer in unfrozen_layers):
for param in module.parameters():
param.requires_grad = True
# Some cases require a delay in adapter insertion, e.g. using Pretrainer
if not manual_adapter_insertion:
self.insert_adapters()
[docs]
def insert_adapters(self):
"""If this is in `__init__` it conflicts with `Pretrainer`.
Ensure this function is called exactly once before training.
See ``__init__.manual_adapter_insertion``
"""
for name in self.replace_layers:
module = self.adapted_model.get_submodule(name)
new_module = self.adapter_class(module, **self.adapter_kwargs)
replace_module(self.adapted_model, name, new_module)
[docs]
def forward(self, *args, **kwargs):
"""Pass arguments to adapted model."""
return self.adapted_model(*args, **kwargs)
[docs]
@checkpoints.mark_as_saver
def saver(self, path):
"""Saves only the trainable parameters."""
# NOTE: In order to preserve the gradient info, we have to prevent `state_dict` from detaching
# all the parameters and buffers. The `keep_vars=True` does this, then we detach manually
state_dict = {
name: param.detach()
for name, param in self.state_dict(keep_vars=True).items()
if param.requires_grad
}
torch.save(state_dict, path)
[docs]
@checkpoints.mark_as_loader
def loader(self, path, end_of_epoch):
"""Loads the base model plus trained params."""
del end_of_epoch
state_dict = torch.load(path, map_location="cpu")
self.load_state_dict(state_dict, strict=False)
[docs]
@checkpoints.mark_as_transfer
def parameter_transfer(self, path):
"""Avoids warnings due to only loading trained params."""
self.loader(path, True)
[docs]
def __getattr__(self, item):
"""Override getattr to pass item accesses to pre-adapted model."""
# Have to use super to get adapted model to avoid recursion
model = super().__getattr__("adapted_model")
if hasattr(model, item):
return getattr(model, item)
# Normal access
return super().__getattr__(item)
[docs]
def is_layer_adaptable(name, module, all_linear, all_conv, target_layers):
"""Check if layer is among list of layers to be adapted.
Arguments
---------
name: str
The name of the module to check.
module: torch.nn.Module
The module to check.
all_linear: bool
Whether all linear layers should be adapted.
all_conv: bool
Whether all conv layers should be adapted.
target_layers: str or list of str
See `add_adapters_to_model`
Returns
-------
bool
Whether the layer is to be adapted or not.
"""
return (
all_linear
and isinstance(module, nn.Linear)
or all_conv
and isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d))
or name
and any(fnmatch(name, layer) for layer in target_layers)
)
[docs]
def replace_module(model: nn.Module, name: str, new_module: nn.Module):
"""Replace layer with a new module based on a parent assignation.
This is used to replace layers with an Adapter layer wrapped around
the original layer. Hence, old parameters are preserved and new ones are
added.
Arguments
---------
model: nn.Module
Model containing the module to be replaced.
name: str
Name of the target module to replace.
new_module: nn.Module
New module made of the old plus the new parameters.
"""
# If the model is only one level deep, just use the model
try:
parent_name, target_name = name.rsplit(".", 1)
parent_module = model.get_submodule(parent_name)
except ValueError:
parent_module = model
target_name = name
setattr(parent_module, target_name, new_module)
[docs]
class HoulsbyAdapterLinear(nn.Module):
"""This class implements the Houlsby Adapter as described in:
'Parameter-Efficient Transfer Learning for NLP'
https://arxiv.org/abs/1902.00751
Arguments
---------
target_linear: nn.Module
Module corresponding to the pretrained Linear that will be wrapped with
this adapter.
projection_size: int
Size of the projection layer (usually smaller).
activation: nn.Module
The activation function. Default is Swish.
bias: bool
Whether to use biases in the linear projections.
Example
-------
>>> import torch
>>> x = torch.rand((8, 60, 64))
>>> base_linear = nn.Linear(64, 64)
>>> adapt = HoulsbyAdapterLinear(base_linear, 8)
>>> output = adapt(x)
>>> output.shape
torch.Size([8, 60, 64])
"""
def __init__(
self,
target_linear,
projection_size,
activation=Swish,
bias=True,
):
super().__init__()
if not isinstance(target_linear, nn.Linear):
raise ValueError(
"HoulsbyLinear currently only supports linear layers, "
f"but instead got {type(target_linear)}."
)
output_size = target_linear.weight.data.shape[0]
device = target_linear.weight.device
self.pretrained_linear = target_linear
self.pretrained_linear.requires_grad = False
self.adapter_down_proj = nn.Linear(
output_size, projection_size, bias=bias, device=device
)
self.adapter_up_proj = nn.Linear(
projection_size, output_size, bias=bias, device=device
)
self.activation = activation()
if bias:
self.adapter_down_proj.bias.data.fill_(0.0)
self.adapter_up_proj.bias.data.fill_(0.0)
[docs]
def forward(self, x: torch.Tensor):
"""Applies the HoulsbyAdapter to an input tensor `x`.
Arguments
---------
x: torch.Tensor
Input tensor to the adapter module. Shape: [B, Time, X]
Returns
-------
The linear outputs
"""
x_pretrained = self.pretrained_linear(x)
return (
self.adapter_up_proj(
self.activation(self.adapter_down_proj(x_pretrained))
)
+ x_pretrained
)
[docs]
class LoRA(nn.Module):
"""This class implements the LoRA Adapter as described in:
'LoRA: Low-Rank Adaptation of Large Language Models'
https://arxiv.org/abs/2106.09685
Arguments
---------
target_module: nn.Module
Module corresponding to the pretrained layer that will be wrapped with
this adapter. Works with nn.Linear and nn.Conv
rank: int
Size of the projection layer or rank (usually smaller).
alpha : float
Value used to control the scaling in LoRA. Default is one.
Example
-------
>>> import torch
>>> x = torch.rand((8, 60, 64))
>>> base_linear = nn.Linear(64, 64)
>>> adapt = LoRA(base_linear, 64, 4)
>>> output = adapt(x)
>>> output.shape
torch.Size([8, 60, 64])
"""
def __init__(self, target_module, rank=16, alpha=1.0):
super().__init__()
input_size = target_module.weight.data.shape[1]
output_size = target_module.weight.data.shape[0]
# Disable gradient for pretrained module
self.pretrained_module = target_module
for param in self.pretrained_module.parameters():
param.requires_grad = False
device = target_module.weight.device
self.adapter_down_proj = nn.Linear(
input_size, rank, bias=False, device=device
)
self.adapter_up_proj = nn.Linear(
rank, output_size, bias=False, device=device
)
self.adapter_up_proj.weight.data.fill_(0.0)
self.scaling = alpha / rank
[docs]
def forward(self, x: torch.Tensor):
"""Applies the LoRA Adapter.
Arguments
---------
x: torch.Tensor
Input tensor to the adapter module.
Returns
-------
The linear outputs
"""
x_pretrained = self.pretrained_module(x)
x_lora = self.adapter_up_proj(self.adapter_down_proj(x)) * self.scaling
return x_pretrained + x_lora