to execute or view/download this notebook on GitHub
Neural Network Adapters for faster low-memory fine-tuningο
This tutorial covers the SpeechBrain implementation of adapters such as LoRA. This includes how to integrate either SpeechBrain implemented adapters, custom adapters, and adapters from libraries such as PEFT into a pre-trained model.
Prerequisitesο
A CUDA-enabled device for running this tutorial
Introduction and Backgroundο
As pre-trained models become larger and more capable, there is growing interest in methods for adapting them for specific tasks in a memory-efficient way, within a reasonable time span. One such technique is freezing the original parameters and inserting a small number of additional parameters into the original model, which are called βadapters.β These adapters can often match the performance of full fine-tuning at a fraction of the parameter count, meaning faster and more memory-efficient fine-tuning [1]. One popular technique for doing this is known as Low-Rank Adaptation (LoRA) [2].
On the software side, HuggingFace has produced a popular library for adapters called PEFT [3]. Our implementation includes some of the features of this library, as well as including the ability to integrate PEFT adapters into a SpeechBrain model. Weβll start with a basic example YAML so you can just try it yourself if you learn better from experimentation.
Relevant bibliographyο
N. Houlsby, A. Giurgiu, S. Jastrzebski, B. Morrone, Q. De Laroussilhe, A. Gesmundo, M. Attariyan, and S. Gelly, βParameter-efficient transfer learning for NLP.β In International Conference on Machine Learning, 2019.
E.J. Hu, P. Wallis, Z. Allen-Zhu, Y. Li, S. Wang, L. Wang, and W. Chen, βLoRA: Low-rank adaptation of large language models.β In International Conference on Learning Representations, 2021.
S. Mangrulkar, S. Gugger, L. Debut, Y. Belkada, S. Paul, and B. Bossan, βPEFT: State-of-the-art parameter-efficient fine-tuning methods.β GitHub Repository, 2022.
Too Long ; Didnβt Readο
The TL;DR of this tutorial is that you should use a section like this in your HyperPyYAML file to create a model with adapters:
adapted_model: !new:speechbrain.nnet.adapters.AdaptedModel
model_to_adapt: !ref <model>
adapter_class: !name:speechbrain.nnet.adapters.LoRA
all_linear: True
unfrozen_layers: ["conv_1d_*"]
adapter_kwargs:
rank: 8
Adding this section to the YAML takes the already-defined key model
and adds a LoRA adapter to every linear layer with the keyword argument rank=8
.
The all_linear
and all_conv
arguments simply add adapters to all linear or all convolution layers respectively.
By default, this class freezes the parameters of all layers that are not adapted, but the unfrozen_layers
argument can be used to specify the names
of layers that should also be trained, at the cost of a higher parameter count. One can specify specific layers that should be adapted with
the target_layers
argument. These arguments both support unix-style wildcards through the use of pythonβs fnmatch
library.
If the TL;DR isnβt enough and you need to work through an example in more detail, continue to the next section.
Detailed Tutorialο
Weβll demonstrate how to use adapters on a template recipe, which includes everything necessary for full training. The first step is to pretrain a model so we can later add adapters.
!git clone --depth 1 --branch v1.0.2 https://github.com/speechbrain/speechbrain.git
!python -m pip install -e speechbrain
fatal: destination path 'speechbrain' already exists and is not an empty directory.
/home/pplantinga/Documents/Repositories/uvenv/bin/python: No module named pip
# In order to use speechbrain in this repo we have to add it to the path
import os, sys
sys.path.append(os.path.join(os.getcwd(), 'speechbrain'))
%cd speechbrain/templates/speech_recognition/ASR
/home/pplantinga/Documents/Repositories/speechbrain/docs/tutorials/nn/speechbrain/templates/speech_recognition/ASR
/home/pplantinga/Documents/Repositories/uvenv/lib/python3.12/site-packages/IPython/core/magics/osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.
self.shell.db['dhist'] = compress_dhist(dhist)[-100:]
!python train.py train.yaml --number_of_epochs=1 --batch_size=2 --test_scorer "!ref <valid_scorer>" --enable_add_reverb=False --enable_add_noise=False #To speed up
INFO:speechbrain.utils.seed:Setting seed to 2602
WARNING:speechbrain.utils.train_logger:torchvision is not available - cannot save figures
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/CRDNN_BPE_960h_LM/2602
mini_librispeech_prepare - Preparation completed in previous run, skipping.
../data/noise/data.zip exists. Skipping download
../data/rir/data.zip exists. Skipping download
speechbrain.utils.fetching - Fetch lm.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch lm.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch tokenizer.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch tokenizer.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch asr.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.fetching - Fetch asr.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
speechbrain.utils.parameter_transfer - Loading pretrained files for: lm, tokenizer, model
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(path, map_location=device)
speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/core.py:793: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
speechbrain.core - ASR Model Statistics:
* Total Number of Trainable Parameters: 173.0M
* Total Number of Parameters: 173.0M
* Trainable Parameters represent 100.0000% of the total size.
speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.
speechbrain.utils.epoch_loop - Going into epoch 1
speechbrain.augment.augmenter - No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch.
100%|ββββββββββββββββββββββββ| 760/760 [08:28<00:00, 1.49it/s, train_loss=1.35]
100%|βββββββββββββββββββββββββββββββββββββββββ| 545/545 [01:28<00:00, 6.18it/s]
speechbrain.utils.train_logger - epoch: 1, lr: 1.00e+00 - train loss: 1.35 - valid loss: 1.31, valid CER: 7.71, valid WER: 20.06
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/CRDNN_BPE_960h_LM/2602/save/CKPT+2024-10-08+11-08-06+00
speechbrain.utils.checkpoints - Loading a checkpoint from results/CRDNN_BPE_960h_LM/2602/save/CKPT+2024-10-08+11-08-06+00
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/nnet/schedulers.py:240: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
data = torch.load(path)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
stats = torch.load(path, map_location=device)
100%|βββββββββββββββββββββββββββββββββββββββ| 1310/1310 [09:25<00:00, 2.32it/s]
speechbrain.utils.train_logger - Epoch loaded: 1 - test loss: 1.30, test CER: 5.75, test WER: 17.57
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/CRDNN_BPE_960h_LM/2602/save/CKPT+latest
Inferenceο
To prove that this is working, letβs just perform inference on one file. This code taken from transcribe_file.py
from speechbrain.inference.ASR import EncoderDecoderASR
from speechbrain.utils.fetching import fetch, LocalStrategy
# Ensure all the needed files end up in the same place to load with the transcriber
save_dir = os.path.abspath("results/CRDNN_BPE_960h_LM/2602/save/CKPT+latest")
fetch("lm.ckpt", "speechbrain/asr-crdnn-rnnlm-librispeech", save_dir, local_strategy=LocalStrategy.SYMLINK)
fetch("tokenizer.ckpt", "speechbrain/asr-crdnn-rnnlm-librispeech", save_dir, local_strategy=LocalStrategy.SYMLINK)
fetch("inference.yaml", os.getcwd(), save_dir, local_strategy=LocalStrategy.SYMLINK)
transcriber = EncoderDecoderASR.from_hparams(source=save_dir, hparams_file="inference.yaml")
speech_file = "../data/LibriSpeech/dev-clean-2/1272/135031/1272-135031-0015.flac"
transcriber.transcribe_file(speech_file)
INFO:speechbrain.utils.fetching:Fetch lm.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
INFO:speechbrain.utils.fetching:Fetch tokenizer.ckpt: Fetching from HuggingFace Hub 'speechbrain/asr-crdnn-rnnlm-librispeech' if not cached
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
INFO:speechbrain.utils.parameter_transfer:Loading pretrained files for: lm, tokenizer, model, normalizer
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
stats = torch.load(path, map_location=device)
'THE METAL FOREST IS IN THE GREAT DOMED CAVERN THE LARGEST IN ALL OUR DOMINIONS REPLIED CALICO β '
Adding adaptersο
So now that weβve proved that the model is at least working, letβs go ahead and add adapters. We basically need to create a new yaml file adding adapters to the model and then train with this new yaml file. To do this weβll just load the old yaml file and then weβll change all the parts necessary to train the adapted model.
%%writefile train_lora.patch
--- train.yaml 2024-10-07 19:23:49.839501714 -0400
+++ train_lora.yaml 2024-10-07 19:25:40.340933091 -0400
@@ -30,7 +30,7 @@
NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
-output_folder: !ref results/CRDNN_BPE_960h_LM/<seed>
+output_folder: !ref results/crdnn_lora/<seed>
test_wer_file: !ref <output_folder>/wer_test.txt
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
@@ -41,7 +41,7 @@
# speechbrain HuggingFace repository. However, a local path pointing to a
# directory containing the lm.ckpt and tokenizer.ckpt may also be specified
# instead. E.g if you want to use your own LM / tokenizer.
-pretrained_path: speechbrain/asr-crdnn-rnnlm-librispeech
+pretrained_path: results/CRDNN_BPE_960h_LM/2602/save/CKPT+latest
# Path where data manifest files will be stored. The data manifest files are created by the
@@ -481,10 +481,9 @@
ctc_lin: !ref <ctc_lin>
seq_lin: !ref <seq_lin>
normalize: !ref <normalize>
- lm_model: !ref <lm_model>
# Gathering all the submodels in a single model object.
-model: !new:torch.nn.ModuleList
+model_pretrained: !new:torch.nn.ModuleList
- - !ref <encoder>
- !ref <embedding>
- !ref <decoder>
@@ -629,8 +628,31 @@
loadables:
lm: !ref <lm_model>
tokenizer: !ref <tokenizer>
- model: !ref <model>
+ model: !ref <model_pretrained>
paths:
lm: !ref <pretrained_path>/lm.ckpt
tokenizer: !ref <pretrained_path>/tokenizer.ckpt
- model: !ref <pretrained_path>/asr.ckpt
+ model: !ref <pretrained_path>/model.ckpt
+
+new_encoder: !new:speechbrain.nnet.adapters.AdaptedModel
+ model_to_adapt: !ref <encoder>
+ adapter_class: !name:speechbrain.nnet.adapters.LoRA
+ all_linear: True
+ manual_adapter_insertion: True
+ adapter_kwargs:
+ rank: 8
+
+new_decoder: !new:speechbrain.nnet.adapters.AdaptedModel
+ model_to_adapt: !ref <decoder>
+ adapter_class: !name:speechbrain.nnet.adapters.LoRA
+ all_linear: True
+ manual_adapter_insertion: True
+ adapter_kwargs:
+ rank: 8
+
+model: !new:torch.nn.ModuleList
+ - - !ref <new_encoder>
+ - !ref <embedding>
+ - !ref <new_decoder>
+ - !ref <ctc_lin>
+ - !ref <seq_lin>
Overwriting train_lora.patch
!patch train.yaml -i train_lora.patch -o train_lora.yaml
patching file train_lora.yaml (read from train.yaml)
Because we are loading the pretrained parameters using the pretrainer, we have to insert this code to insert the adapters after the pretrained parameters have been loaded.
This is the reason for the manual_adapter_insertion: True
in the yaml and the following brief change to the training code:
%%writefile train_lora_py.patch
--- train.py 2024-10-07 14:57:21.534381751 -0400
+++ train_lora.py 2024-10-07 19:33:12.839895913 -0400
@@ -473,6 +473,8 @@
# the path given in the YAML file). The tokenizer is loaded at the same time.
hparams["pretrainer"].collect_files()
hparams["pretrainer"].load_collected()
+ hparams["new_encoder"].insert_adapters()
+ hparams["new_decoder"].insert_adapters()
# Trainer initialization
asr_brain = ASR(
Overwriting train_lora_py.patch
!patch train.py -i train_lora_py.patch -o train_lora.py
patching file train_lora.py (read from train.py)
Training the adapted modelο
Training works identically to before, using the updated lora file. The adapted model is designed to work as an in-place replacement. Notice how the number of trainable parameters is reduced to close to 1% of the original parameters.
!python train_lora.py train_lora.yaml --number_of_epochs=1 --batch_size=2 --test_scorer "!ref <valid_scorer>" --enable_add_reverb=False --enable_add_noise=False #To speed up
INFO:speechbrain.utils.seed:Setting seed to 2602
WARNING:speechbrain.utils.train_logger:torchvision is not available - cannot save figures
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/crdnn_lora/2602
mini_librispeech_prepare - Preparation completed in previous run, skipping.
../data/noise/data.zip exists. Skipping download
../data/rir/data.zip exists. Skipping download
speechbrain.utils.parameter_transfer - Loading pretrained files for: lm, tokenizer, model
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(path, map_location=device)
speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/core.py:793: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
speechbrain.core - ASR Model Statistics:
* Total Number of Trainable Parameters: 1.8M
* Total Number of Parameters: 120.0M
* Trainable Parameters represent 1.4807% of the total size.
speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.
speechbrain.utils.epoch_loop - Going into epoch 1
speechbrain.augment.augmenter - No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch.
100%|βββββββββββββββββββββββββββ| 760/760 [04:09<00:00, 3.04it/s, train_loss=1]
100%|βββββββββββββββββββββββββββββββββββββββββ| 545/545 [01:40<00:00, 5.42it/s]
speechbrain.utils.train_logger - epoch: 1, lr: 1.00e+00 - train loss: 1.00 - valid loss: 1.26, valid CER: 7.29, valid WER: 19.16
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_lora/2602/save/CKPT+2024-10-08+11-23-53+00
speechbrain.utils.checkpoints - Loading a checkpoint from results/crdnn_lora/2602/save/CKPT+2024-10-08+11-23-53+00
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/nnet/schedulers.py:240: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
data = torch.load(path)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
stats = torch.load(path, map_location=device)
100%|βββββββββββββββββββββββββββββββββββββββ| 1310/1310 [12:55<00:00, 1.69it/s]
speechbrain.utils.train_logger - Epoch loaded: 1 - test loss: 1.26, test CER: 5.62, test WER: 17.05
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_lora/2602/save/CKPT+latest
Custom adapterο
We designed this so that you could replace the SpeechBrain adapter with a peft
adapter:
new_encoder: !new:speechbrain.nnet.adapters.AdaptedModel
model_to_adapt: !ref <encoder>
- adapter_class: !name:speechbrain.nnet.adapters.LoRA
+ adapter_class: !name:peft.tuners.lora.layer.Linear
manual_adapter_insertion: True
adapter_kwargs:
- rank: 16
+ r: 16
+ adapter_name: lora
But this trains exactly the same thing as before, so no need for us to go through the whole thing. Perhaps more interesting is designing a custom adapter:
%%file pool_lora.py
import torch
class PoolLoRA(torch.nn.Module):
def __init__(self, target_module, stride=2, rank=16, alpha=1.0):
super().__init__()
input_size = target_module.weight.data.shape[1]
output_size = target_module.weight.data.shape[0]
# Disable gradient for pretrained module
self.pretrained_module = target_module
for param in self.pretrained_module.parameters():
param.requires_grad = False
device = target_module.weight.device
self.adapter_down_scale = torch.nn.AvgPool1d(kernel_size=stride)
self.adapter_down_proj = torch.nn.Linear(
input_size // stride, rank, bias=False, device=device
)
self.adapter_up_proj = torch.nn.Linear(
rank, output_size, bias=False, device=device
)
self.adapter_up_proj.weight.data.fill_(0.0)
self.scaling = alpha / rank
def forward(self, x: torch.Tensor):
"""Applies the LoRA Adapter.
Arguments
---------
x: torch.Tensor
Input tensor to the adapter module.
Returns
-------
The linear outputs
"""
x_pretrained = self.pretrained_module(x)
x_downsample = self.adapter_down_proj(self.adapter_down_scale(x))
x_pool_lora = self.adapter_up_proj(x_downsample)
return x_pretrained + x_pool_lora * self.scaling
Overwriting pool_lora.py
%%writefile train_pool_lora.patch
--- train_lora.yaml 2024-10-07 22:44:02.767830301 -0400
+++ train_pool_lora.yaml 2024-10-07 22:41:30.602641301 -0400
@@ -30,7 +30,7 @@
NOISE_DATASET_URL: https://www.dropbox.com/scl/fi/a09pj97s5ifan81dqhi4n/noises.zip?rlkey=j8b0n9kdjdr32o1f06t0cw5b7&dl=1
RIR_DATASET_URL: https://www.dropbox.com/scl/fi/linhy77c36mu10965a836/RIRs.zip?rlkey=pg9cu8vrpn2u173vhiqyu743u&dl=1
-output_folder: !ref results/crdnn_lora/<seed>
+output_folder: !ref results/crdnn_pool_lora/<seed>
test_wer_file: !ref <output_folder>/wer_test.txt
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt
@@ -636,19 +636,21 @@
new_encoder: !new:speechbrain.nnet.adapters.AdaptedModel
model_to_adapt: !ref <encoder>
- adapter_class: !name:speechbrain.nnet.adapters.LoRA
+ adapter_class: !name:pool_lora.PoolLoRA
all_linear: True
manual_adapter_insertion: True
adapter_kwargs:
- rank: 8
+ stride: 2
+ rank: 16
new_decoder: !new:speechbrain.nnet.adapters.AdaptedModel
model_to_adapt: !ref <decoder>
- adapter_class: !name:speechbrain.nnet.adapters.LoRA
+ adapter_class: !name:pool_lora.PoolLoRA
all_linear: True
manual_adapter_insertion: True
adapter_kwargs:
- rank: 8
+ stride: 2
+ rank: 16
model: !new:torch.nn.ModuleList
- - !ref <new_encoder>
Overwriting train_pool_lora.patch
!patch train_lora.yaml -i train_pool_lora.patch -o train_pool_lora.yaml
patching file train_pool_lora.yaml (read from train_lora.yaml)
!python train_lora.py train_pool_lora.yaml --number_of_epochs=1 --batch_size=2 --test_scorer "!ref <valid_scorer>" --enable_add_reverb=False --enable_add_noise=False #To speed up
INFO:speechbrain.utils.seed:Setting seed to 2602
WARNING:speechbrain.utils.train_logger:torchvision is not available - cannot save figures
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/autocast.py:68: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
wrapped_fwd = torch.cuda.amp.custom_fwd(fwd, cast_inputs=cast_inputs)
speechbrain.core - Beginning experiment!
speechbrain.core - Experiment folder: results/crdnn_pool_lora/2602
mini_librispeech_prepare - Preparation completed in previous run, skipping.
../data/noise/data.zip exists. Skipping download
../data/rir/data.zip exists. Skipping download
speechbrain.utils.parameter_transfer - Loading pretrained files for: lm, tokenizer, model
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(path, map_location=device)
speechbrain.core - Info: ckpt_interval_minutes arg from hparam file is used
speechbrain.core - Gradscaler enabled: False. Using precision: fp32.
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/core.py:793: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
self.scaler = torch.cuda.amp.GradScaler(enabled=gradscaler_enabled)
speechbrain.core - ASR Model Statistics:
* Total Number of Trainable Parameters: 1.8M
* Total Number of Parameters: 120.0M
* Trainable Parameters represent 1.5210% of the total size.
speechbrain.utils.checkpoints - Would load a checkpoint here, but none found yet.
speechbrain.utils.epoch_loop - Going into epoch 1
speechbrain.augment.augmenter - No augmentation is applied because the augmentation start index is greater than or equal to the number of examples in the input batch.
100%|ββββββββββββββββββββββββ| 760/760 [04:19<00:00, 2.93it/s, train_loss=0.98]
100%|βββββββββββββββββββββββββββββββββββββββββ| 545/545 [01:44<00:00, 5.24it/s]
speechbrain.utils.train_logger - epoch: 1, lr: 1.00e+00 - train loss: 9.80e-01 - valid loss: 1.26, valid CER: 7.18, valid WER: 18.92
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_pool_lora/2602/save/CKPT+2024-10-08+11-43-00+00
speechbrain.utils.checkpoints - Loading a checkpoint from results/crdnn_pool_lora/2602/save/CKPT+2024-10-08+11-43-00+00
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/utils/checkpoints.py:199: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(path, map_location=device)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/nnet/schedulers.py:240: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
data = torch.load(path)
/home/pplantinga/Documents/Repositories/speechbrain/speechbrain/processing/features.py:1311: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
stats = torch.load(path, map_location=device)
100%|βββββββββββββββββββββββββββββββββββββββ| 1310/1310 [14:07<00:00, 1.55it/s]
speechbrain.utils.train_logger - Epoch loaded: 1 - test loss: 1.25, test CER: 5.61, test WER: 16.99
speechbrain.utils.checkpoints - Saved an end-of-epoch checkpoint in results/crdnn_pool_lora/2602/save/CKPT+latest
Conclusionο
Thatβs it, thanks for following along! Go forth and make cool adapters.
Citing SpeechBrainο
If you use SpeechBrain in your research or business, please cite it using the following BibTeX entry:
@misc{speechbrainV1,
title={Open-Source Conversational AI with {SpeechBrain} 1.0},
author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
year={2024},
eprint={2407.00463},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
title={{SpeechBrain}: A General-Purpose Speech Toolkit},
author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
year={2021},
eprint={2106.04624},
archivePrefix={arXiv},
primaryClass={eess.AS},
note={arXiv:2106.04624}
}