Open In Colab to execute or view/download this notebook on GitHub

Neural Network Adapters for faster low-memory fine-tuning

This tutorial covers the SpeechBrain implementation of adapters such as LoRA. This includes how to integrate either SpeechBrain implemented adapters, custom adapters, and adapters from libraries such as PEFT into a pre-trained model.

Prerequisites

Introduction and Background

As pre-trained models become larger and more capable, there is growing interest in methods for adapting them for specific tasks in a memory-efficient way, within a reasonable time span. One such technique is freezing the original parameters and inserting a small number of additional parameters into the original model, which are called β€œadapters.” These adapters can often match the performance of full fine-tuning at a fraction of the parameter count, meaning faster and more memory-efficient fine-tuning [1]. One popular technique for doing this is known as Low-Rank Adaptation (LoRA) [2].

On the software side, HuggingFace has produced a popular library for adapters called PEFT [3]. Our implementation includes some of the features of this library, as well as including the ability to integrate PEFT adapters into a SpeechBrain model. We’ll start with a basic example YAML so you can just try it yourself if you learn better from experimentation.

Relevant bibliography

  1. N. Houlsby, A. Giurgiu, S. Jastrzebski, B. Morrone, Q. De Laroussilhe, A. Gesmundo, M. Attariyan, and S. Gelly, β€œParameter-efficient transfer learning for NLP.” In International Conference on Machine Learning, 2019.

  2. E.J. Hu, P. Wallis, Z. Allen-Zhu, Y. Li, S. Wang, L. Wang, and W. Chen, β€œLoRA: Low-rank adaptation of large language models.” In International Conference on Learning Representations, 2021.

  3. S. Mangrulkar, S. Gugger, L. Debut, Y. Belkada, S. Paul, and B. Bossan, β€œPEFT: State-of-the-art parameter-efficient fine-tuning methods.” GitHub Repository, 2022.

Too Long ; Didn’t Read

The TL;DR of this tutorial is that you should use a section like this in your HyperPyYAML file to create a model with adapters:

adapted_model: !new:speechbrain.nnet.adapters.AdaptedModel
    model_to_adapt: !ref <model>
    adapter_class: !name:speechbrain.nnet.adapters.LoRA
    all_linear: True
    unfrozen_layers: ["conv_1d_*"]
    adapter_kwargs:
        rank: 8

Adding this section to the YAML takes the already-defined key model and adds a LoRA adapter to every linear layer with the keyword argument rank=8. The all_linear and all_conv arguments simply add adapters to all linear or all convolution layers respectively. By default, this class freezes the parameters of all layers that are not adapted, but the unfrozen_layers argument can be used to specify the names of layers that should also be trained, at the cost of a higher parameter count. One can specify specific layers that should be adapted with the target_layers argument. These arguments both support unix-style wildcards through the use of python’s fnmatch library.

If the TL;DR isn’t enough and you need to work through an example in more detail, continue to the next section.

Detailed Tutorial

We’ll demonstrate how to use adapters on a template recipe, which includes everything necessary for full training. The first step is to pretrain a model so we can later add adapters.

!git clone --depth 1 --branch v1.0.2 https://github.com/speechbrain/speechbrain.git
!python -m pip install -e speechbrain
fatal: destination path 'speechbrain' already exists and is not an empty directory.
/home/pplantinga/Documents/Repositories/uvenv/bin/python: No module named pip
# In order to use speechbrain in this repo we have to add it to the path
import os, sys

sys.path.append(os.path.join(os.getcwd(), 'speechbrain'))
%cd speechbrain/templates/speech_recognition/ASR
/home/pplantinga/Documents/Repositories/speechbrain/docs/tutorials/nn/speechbrain/templates/speech_recognition/ASR
/home/pplantinga/Documents/Repositories/uvenv/lib/python3.12/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]
!python train.py train.yaml --number_of_epochs=1 --batch_size=2 --test_scorer "!ref <valid_scorer>" --enable_add_reverb=False --enable_add_noise=False #To speed up
INFO:speechbrain.utils.seed:Setting seed to 2602
WARNING:speechbrain.utils.train_logger:torchvision is not available - cannot save figures
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/CRDNN_BPE_960h_LM/2602
mini_librispeech_prepare - Preparation completed in previous run, skipping.
../data/noise/data.zip exists. Skipping download
../data/rir/data.zip exists. Skipping download
speechbrain.utils.fetching - Fetch lm.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch lm.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch tokenizer.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch tokenizer.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch asr.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch asr.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.parameter_transfer - Loading pretrained files for: lm, tokenizer, model
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/core.py:793: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
speechbrain.core - ASR Model Statistics:
* Total Number of Trainable Parameters: 173.0M
* Total Number of Parameters: 173.0M
* Trainable Parameters represent 100.0000% of the total size.
speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.
speechbrain.utils.epoch_loop - Going into epoch 1
speechbrain.augment.augmenter - No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch.
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 760/760 [08:28<00:00,  1.49it/s, train_loss=1.35]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 545/545 [01:28<00:00,  6.18it/s]
speechbrain.utils.train_logger - epoch: 1, lr: 1.00e+00 - train loss: 1.35 - valid loss: 1.31, valid CER: 7.71, valid WER: 20.06
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/CRDNN_BPE_960h_LM/2602/save/CKPT+2024-10-08+11-08-06+00
speechbrain.utils.checkpoints - Loading a checkpoint from results/CRDNN_BPE_960h_LM/2602/save/CKPT+2024-10-08+11-08-06+00
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/nnet/schedulers.py:240: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  data = torch.load(path)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  stats = torch.load(path, map_location=device)
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1310/1310 [09:25<00:00,  2.32it/s]
speechbrain.utils.train_logger - Epoch loaded: 1 - test loss: 1.30, test CER: 5.75, test WER: 17.57
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/CRDNN_BPE_960h_LM/2602/save/CKPT+latest

Inference

To prove that this is working, let’s just perform inference on one file. This code taken from transcribe_file.py

from speechbrain.inference.ASR import EncoderDecoderASR
from speechbrain.utils.fetching import fetch, LocalStrategy

# Ensure all the needed files end up in the same place to load with the transcriber
save_dir = os.path.abspath("results/CRDNN_BPE_960h_LM/2602/save/CKPT+latest")
fetch("lm.ckpt", "speechbrain/asr-crdnn-rnnlm-librispeech", save_dir, local_strategy=LocalStrategy.SYMLINK)
fetch("tokenizer.ckpt", "speechbrain/asr-crdnn-rnnlm-librispeech", save_dir, local_strategy=LocalStrategy.SYMLINK)
fetch("inference.yaml", os.getcwd(), save_dir, local_strategy=LocalStrategy.SYMLINK)

transcriber = EncoderDecoderASR.from_hparams(source=save_dir, hparams_file="inference.yaml")
speech_file = "../data/LibriSpeech/dev-clean-2/1272/135031/1272-135031-0015.flac"
transcriber.transcribe_file(speech_file)
INFO:speechbrain.utils.fetching:Fetch lm.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
INFO:speechbrain.utils.fetching:Fetch tokenizer.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
INFO:speechbrain.utils.parameter_transfer:Loading pretrained files for: lm, tokenizer, model, normalizer
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  stats = torch.load(path, map_location=device)
'THE METAL FOREST IS IN THE GREAT DOMED CAVERN THE LARGEST IN ALL OUR DOMINIONS REPLIED CALICO ⁇ '

Adding adapters

So now that we’ve proved that the model is at least working, let’s go ahead and add adapters. We basically need to create a new yaml file adding adapters to the model and then train with this new yaml file. To do this we’ll just load the old yaml file and then we’ll change all the parts necessary to train the adapted model.

%%writefile train_lora.patch
--- train.yaml	2024-10-07 19:23:49.839501714 -0400
+++ train_lora.yaml	2024-10-07 19:25:40.340933091 -0400
@@ -30,7 +30,7 @@
 NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
 RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
 
-output_folder: !ref results/CRDNN_BPE_960h_LM/<seed>
+output_folder: !ref results/crdnn_lora/<seed>
 test_wer_file: !ref <output_folder>/wer_test.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
@@ -41,7 +41,7 @@
 # speechbrain HuggingFace repository. However, a local path pointing to a
 # directory containing the lm.ckpt and tokenizer.ckpt may also be specified
 # instead. E.g if you want to use your own LM / tokenizer.
-pretrained_path: speechbrain/asr-crdnn-rnnlm-librispeech
+pretrained_path: results/CRDNN_BPE_960h_LM/2602/save/CKPT+latest
 
 
 # Path where data manifest files will be stored. The data manifest files are created by the
@@ -481,10 +481,9 @@
     ctc_lin: !ref <ctc_lin>
     seq_lin: !ref <seq_lin>
     normalize: !ref <normalize>
-    lm_model: !ref <lm_model>
 
 # Gathering all the submodels in a single model object.
-model: !new:torch.nn.ModuleList
+model_pretrained: !new:torch.nn.ModuleList
     - - !ref <encoder>
       - !ref <embedding>
       - !ref <decoder>
@@ -629,8 +628,31 @@
     loadables:
         lm: !ref <lm_model>
         tokenizer: !ref <tokenizer>
-        model: !ref <model>
+        model: !ref <model_pretrained>
     paths:
         lm: !ref <pretrained_path>/lm.ckpt
         tokenizer: !ref <pretrained_path>/tokenizer.ckpt
-        model: !ref <pretrained_path>/asr.ckpt
+        model: !ref <pretrained_path>/model.ckpt
+
+new_encoder: !new:speechbrain.nnet.adapters.AdaptedModel
+    model_to_adapt: !ref <encoder>
+    adapter_class: !name:speechbrain.nnet.adapters.LoRA
+    all_linear: True
+    manual_adapter_insertion: True
+    adapter_kwargs:
+        rank: 8
+
+new_decoder: !new:speechbrain.nnet.adapters.AdaptedModel
+    model_to_adapt: !ref <decoder>
+    adapter_class: !name:speechbrain.nnet.adapters.LoRA
+    all_linear: True
+    manual_adapter_insertion: True
+    adapter_kwargs:
+        rank: 8
+
+model: !new:torch.nn.ModuleList
+    - - !ref <new_encoder>
+      - !ref <embedding>
+      - !ref <new_decoder>
+      - !ref <ctc_lin>
+      - !ref <seq_lin>
Overwriting train_lora.patch
!patch train.yaml -i train_lora.patch -o train_lora.yaml
patching file train_lora.yaml (read from train.yaml)

Because we are loading the pretrained parameters using the pretrainer, we have to insert this code to insert the adapters after the pretrained parameters have been loaded.

This is the reason for the manual_adapter_insertion: True in the yaml and the following brief change to the training code:

%%writefile train_lora_py.patch
--- train.py	2024-10-07 14:57:21.534381751 -0400
+++ train_lora.py	2024-10-07 19:33:12.839895913 -0400
@@ -473,6 +473,8 @@
     # the path given in the YAML file). The tokenizer is loaded at the same time.
     hparams["pretrainer"].collect_files()
     hparams["pretrainer"].load_collected()
+    hparams["new_encoder"].insert_adapters()
+    hparams["new_decoder"].insert_adapters()
 
     # Trainer initialization
     asr_brain = ASR(
Overwriting train_lora_py.patch
!patch train.py -i train_lora_py.patch -o train_lora.py
patching file train_lora.py (read from train.py)

Training the adapted model

Training works identically to before, using the updated lora file. The adapted model is designed to work as an in-place replacement. Notice how the number of trainable parameters is reduced to close to 1% of the original parameters.

!python train_lora.py train_lora.yaml --number_of_epochs=1 --batch_size=2 --test_scorer "!ref <valid_scorer>" --enable_add_reverb=False --enable_add_noise=False #To speed up
INFO:speechbrain.utils.seed:Setting seed to 2602
WARNING:speechbrain.utils.train_logger:torchvision is not available - cannot save figures
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/crdnn_lora/2602
mini_librispeech_prepare - Preparation completed in previous run, skipping.
../data/noise/data.zip exists. Skipping download
../data/rir/data.zip exists. Skipping download
speechbrain.utils.parameter_transfer - Loading pretrained files for: lm, tokenizer, model
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/core.py:793: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
speechbrain.core - ASR Model Statistics:
* Total Number of Trainable Parameters: 1.8M
* Total Number of Parameters: 120.0M
* Trainable Parameters represent 1.4807% of the total size.
speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.
speechbrain.utils.epoch_loop - Going into epoch 1
speechbrain.augment.augmenter - No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch.
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 760/760 [04:09<00:00,  3.04it/s, train_loss=1]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 545/545 [01:40<00:00,  5.42it/s]
speechbrain.utils.train_logger - epoch: 1, lr: 1.00e+00 - train loss: 1.00 - valid loss: 1.26, valid CER: 7.29, valid WER: 19.16
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_lora/2602/save/CKPT+2024-10-08+11-23-53+00
speechbrain.utils.checkpoints - Loading a checkpoint from results/crdnn_lora/2602/save/CKPT+2024-10-08+11-23-53+00
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/nnet/schedulers.py:240: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  data = torch.load(path)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  stats = torch.load(path, map_location=device)
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1310/1310 [12:55<00:00,  1.69it/s]
speechbrain.utils.train_logger - Epoch loaded: 1 - test loss: 1.26, test CER: 5.62, test WER: 17.05
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_lora/2602/save/CKPT+latest

Custom adapter

We designed this so that you could replace the SpeechBrain adapter with a peft adapter:

new_encoder: !new:speechbrain.nnet.adapters.AdaptedModel
    model_to_adapt: !ref <encoder>
-   adapter_class: !name:speechbrain.nnet.adapters.LoRA
+   adapter_class: !name:peft.tuners.lora.layer.Linear
    manual_adapter_insertion: True
    adapter_kwargs:
-       rank: 16
+       r: 16
+       adapter_name: lora

But this trains exactly the same thing as before, so no need for us to go through the whole thing. Perhaps more interesting is designing a custom adapter:

%%file pool_lora.py

import torch

class PoolLoRA(torch.nn.Module):
    def __init__(self, target_module, stride=2, rank=16, alpha=1.0):
        super().__init__()

        input_size = target_module.weight.data.shape[1]
        output_size = target_module.weight.data.shape[0]
        
        # Disable gradient for pretrained module
        self.pretrained_module = target_module
        for param in self.pretrained_module.parameters():
            param.requires_grad = False
        device = target_module.weight.device

        self.adapter_down_scale = torch.nn.AvgPool1d(kernel_size=stride)
        self.adapter_down_proj = torch.nn.Linear(
            input_size // stride, rank, bias=False, device=device
        )   
        self.adapter_up_proj = torch.nn.Linear(
            rank, output_size, bias=False, device=device
        )   
        self.adapter_up_proj.weight.data.fill_(0.0)

        self.scaling = alpha / rank

    def forward(self, x: torch.Tensor):
        """Applies the LoRA Adapter.

        Arguments
        ---------
        x: torch.Tensor
            Input tensor to the adapter module.

        Returns
        -------
        The linear outputs
        """
        x_pretrained = self.pretrained_module(x)

        x_downsample = self.adapter_down_proj(self.adapter_down_scale(x))
        x_pool_lora = self.adapter_up_proj(x_downsample)
        
        return x_pretrained + x_pool_lora * self.scaling
Overwriting pool_lora.py
%%writefile train_pool_lora.patch
--- train_lora.yaml	2024-10-07 22:44:02.767830301 -0400
+++ train_pool_lora.yaml	2024-10-07 22:41:30.602641301 -0400
@@ -30,7 +30,7 @@
 NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
 RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
 
-output_folder: !ref results/crdnn_lora/<seed>
+output_folder: !ref results/crdnn_pool_lora/<seed>
 test_wer_file: !ref <output_folder>/wer_test.txt
 save_folder: !ref <output_folder>/save
 train_log: !ref <output_folder>/train_log.txt
@@ -636,19 +636,21 @@
 
 new_encoder: !new:speechbrain.nnet.adapters.AdaptedModel
     model_to_adapt: !ref <encoder>
-    adapter_class: !name:speechbrain.nnet.adapters.LoRA
+    adapter_class: !name:pool_lora.PoolLoRA
     all_linear: True
     manual_adapter_insertion: True
     adapter_kwargs:
-        rank: 8
+        stride: 2
+        rank: 16
 
 new_decoder: !new:speechbrain.nnet.adapters.AdaptedModel
     model_to_adapt: !ref <decoder>
-    adapter_class: !name:speechbrain.nnet.adapters.LoRA
+    adapter_class: !name:pool_lora.PoolLoRA
     all_linear: True
     manual_adapter_insertion: True
     adapter_kwargs:
-        rank: 8
+        stride: 2
+        rank: 16
 
 model: !new:torch.nn.ModuleList
     - - !ref <new_encoder>
Overwriting train_pool_lora.patch
!patch train_lora.yaml -i train_pool_lora.patch -o train_pool_lora.yaml
patching file train_pool_lora.yaml (read from train_lora.yaml)
!python train_lora.py train_pool_lora.yaml --number_of_epochs=1 --batch_size=2 --test_scorer "!ref <valid_scorer>" --enable_add_reverb=False --enable_add_noise=False #To speed up
INFO:speechbrain.utils.seed:Setting seed to 2602
WARNING:speechbrain.utils.train_logger:torchvision is not available - cannot save figures
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/crdnn_pool_lora/2602
mini_librispeech_prepare - Preparation completed in previous run, skipping.
../data/noise/data.zip exists. Skipping download
../data/rir/data.zip exists. Skipping download
speechbrain.utils.parameter_transfer - Loading pretrained files for: lm, tokenizer, model
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/core.py:793: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
speechbrain.core - ASR Model Statistics:
* Total Number of Trainable Parameters: 1.8M
* Total Number of Parameters: 120.0M
* Trainable Parameters represent 1.5210% of the total size.
speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.
speechbrain.utils.epoch_loop - Going into epoch 1
speechbrain.augment.augmenter - No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch.
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 760/760 [04:19<00:00,  2.93it/s, train_loss=0.98]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 545/545 [01:44<00:00,  5.24it/s]
speechbrain.utils.train_logger - epoch: 1, lr: 1.00e+00 - train loss: 9.80e-01 - valid loss: 1.26, valid CER: 7.18, valid WER: 18.92
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_pool_lora/2602/save/CKPT+2024-10-08+11-43-00+00
speechbrain.utils.checkpoints - Loading a checkpoint from results/crdnn_pool_lora/2602/save/CKPT+2024-10-08+11-43-00+00
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/nnet/schedulers.py:240: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  data = torch.load(path)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  stats = torch.load(path, map_location=device)
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1310/1310 [14:07<00:00,  1.55it/s]
speechbrain.utils.train_logger - Epoch loaded: 1 - test loss: 1.25, test CER: 5.61, test WER: 16.99
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_pool_lora/2602/save/CKPT+latest

Conclusion

That’s it, thanks for following along! Go forth and make cool adapters.

Citing SpeechBrain

If you use SpeechBrain in your research or business, please cite it using the following BibTeX entry:

@misc{speechbrainV1,
  title={Open-Source Conversational AI with {SpeechBrain} 1.0},
  author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
  year={2024},
  eprint={2407.00463},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
  title={{SpeechBrain}: A General-Purpose Speech Toolkit},
  author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
  year={2021},
  eprint={2106.04624},
  archivePrefix={arXiv},
  primaryClass={eess.AS},
  note={arXiv:2106.04624}
}