to execute or view/download this notebook on
GitHub
The Brain Class
A fundamental aspect of deep learning involves iterating through a dataset multiple times and updating model parameters, commonly referred to as the “training loop.” To streamline and organize this process, SpeechBrain offers a versatile framework in the form of the “Brain” class, implemented in speechbrain/core.py
. In each recipe, this class is sub-classed, and its methods are overridden to tailor the implementation to the specific requirements of that recipe.
The core method of the Brain class is fit()
, responsible for iterating through the dataset, performing updates to the model, and managing the training loop. To leverage fit()
, at least two methods must be defined in the sub-class: compute_forward()
and compute_objectives()
. These methods handle the computation of the model for generating predictions and the calculation of loss terms required for gradient computation.
Let’s explore a minimal example to illustrate this:
%%capture
# Installing SpeechBrain via pip
BRANCH = 'develop'
!python -m pip install git+https://github.com/speechbrain/speechbrain.git@$BRANCH
# Clone SpeechBrain repository
!git clone https://github.com/speechbrain/speechbrain/
import torch
import speechbrain as sb
class SimpleBrain(sb.Brain):
def compute_forward(self, batch, stage):
return self.modules.model(batch["input"])
def compute_objectives(self, predictions, batch, stage):
return torch.nn.functional.l1_loss(predictions, batch["target"])
model = torch.nn.Linear(in_features=10, out_features=10)
brain = SimpleBrain({"model": model}, opt_class=lambda x: torch.optim.SGD(x, 0.1))
data = [{"input": torch.rand(10, 10), "target": torch.rand(10, 10)}]
brain.fit(range(10), data)
With just around 10 lines of code, we can successfully train a neural model. This efficiency is achieved because the Brain class handles intricate details of training, such as managing train()
and eval()
states or computing and applying gradients. Furthermore, the flexibility of the class allows every step of the process to be overridden by adding methods to the sub-class. This means that even intricate training procedures, such as those involved in Generative Adversarial Networks (GAN), can be seamlessly integrated into the Brain class.
In this tutorial, we’ll begin by elucidating the parameters of the Brain class. Subsequently, we’ll delve into the fit()
method, breaking it down step by step and highlighting the segments that can be overridden when necessary. These insights into the class’s parameters and the fit()
method form the foundation for understanding the functionality and versatility of the Brain class.
Arguments to Brain
class
The Brain class only takes 5 arguments, but each of these can be a little complex, so we explain them in detail here. The relevant code is just the __init__
definition:
def __init__(
self,
modules=None,
opt_class=None,
hparams=None,
run_opts=None,
checkpointer=None,
):
modules
argument
This first argument takes a dictionary of torch modules. The Brain class takes this dictionary and converts it to a Torch ModuleDict. This provides a convenient way to move all parameters to the correct device, call train()
and eval()
, and wrap the modules in the appropriate distributed wrapper if necessary.
opt_class
argument
The Brain class takes a function definition for a pytorch optimizer. The reason for choosing this as input rather than a pre-constructed pytorch optimizer is that the Brain class automatically handles wrapping the module parameters in distributed wrappers if requested. This needs to happen before the parameters get passed to the optimizer constructor.
To pass a pytorch optimizer constructor, a lambda can be used, as in the example at the beginning of this tutorial. More convenient, however, is the option used by most of the recipes in SpeechBrain: define the constructor with HyperPyYAML. The !name:
tag acts similarly to the lambda, creating a new constructor that can be used to make optimizers.
optimizer: !name:torch.optim.Adam
lr: 0.1
Of course sometimes zero or multiple optimizers are required. In the case of multiple optimizers, the init_optimizers
method can be overridden to initialize each individually.
hparams
argument
The Brain class algorithm may depend on a set of hyperparameters that should be easy to control externally, this argument accepts a dictionary that will be accessible to all the internal methods using “dot notation”. An example follows:
class SimpleBrain(sb.Brain):
def compute_forward(self, batch, stage):
return self.modules.model(batch["input"])
def compute_objectives(self, predictions, batch, stage):
term1 = torch.nn.functional.l1_loss(predictions, batch["target1"])
term2 = torch.nn.functional.mse_loss(predictions, batch["target2"])
return self.hparams.weight1 * term1 + self.hparams.weight2 * term2
hparams = {"weight1": 0.7, "weight2": 0.3}
model = torch.nn.Linear(in_features=10, out_features=10)
brain = SimpleBrain(
modules={"model": model},
opt_class=lambda x: torch.optim.SGD(x, 0.1),
hparams=hparams,
)
data = [{
"input": torch.rand(10, 10),
"target1": torch.rand(10, 10),
"target2": torch.rand(10, 10),
}]
brain.fit(range(10), data)
run_opts
argument
There are a large number of options for controlling the execution details for the fit()
method, that can all be passed via this argument. Some examples include enabling debug mode, the execution device, and the distributed execution options. For a full list, see [ADD LINK TO THE DOCUMENTATION].
checkpointer
argument
Finally, if you pass a SpeechBrain checkpointer to the Brain class, there are several operations that automatically get called:
The optimizer parameters are added to the checkpointer.
At the beginning of training, the most recent checkpoint is loaded and training is resumed from that point. If training is finished, this simply ends the training step and moves on to evaluation.
During training, checkpoints are saved every 15 minutes by default (this can be changed or disabled with an option in
run_opts
).At the beginning of evaluation, the “best” checkpoint is loaded, as determined by the lowest or highest score on a metric recorded in the checkpoints.
The fit()
method
This method does a lot, but only actually takes about ~100 lines of code, so it is understandable by reading the code itself. We break it down section-by-section and explain what each part is doing. First, let’s briefly go over the arguments:
def fit(
self,
epoch_counter,
train_set,
valid_set=None,
progressbar=None,
train_loader_kwargs={},
valid_loader_kwargs={},
):
The
epoch_counter
argument takes an iterator, so whenfit()
is called, the outer loop iterates this variable. This argument was co-designed with anEpochCounter
class enabling storage of the epoch loop state. With this argument, we can restart experiments from where they left off.The
train_set
andvalid_set
arguments take a Torch Dataset or DataLoader that will load the tensors needed for training. If a DataLoader is not passed, one will be constructed automatically (see next section).The
progressbar
argument controls whether atqdm
progressbar is displayed showing progress through the dataset for each epoch.The
train_loader_kwargs
andvalid_loader_kwargs
are passed to themake_dataloader
method for making the DataLoader (see next section).
Fit structure
With the arguments out of the way, we can start to look at the structure of this method. Here is a simple graphic to show all the override-able calls within fit()
. We’ll go over these one-by-one through the rest of the tutorial.
make_dataloader
The first step of the fit()
method is to ensure that the data is in an appropriate format for iteration. Both the train_set
and valid_set
are passed along with their respective keyword arguments. Here’s the actual code:
if not isinstance(train_set, DataLoader):
train_set = self.make_dataloader(
train_set, stage=sb.Stage.TRAIN, **train_loader_kwargs
)
if valid_set is not None and not isinstance(valid_set, DataLoader):
valid_set = self.make_dataloader(
valid_set,
stage=sb.Stage.VALID,
ckpt_prefix=None,
**valid_loader_kwargs,
)
By default, this method handles potential complications to DataLoader creation, such as creating a DistributedSampler for distributed execution. As with all the other methods in the fit()
call, this can be overridden by creating a make_dataloader
method in the Brain’s sub-class definition.
on_fit_start
Besides the dataloader, there’s some setup that needs to happen before the training can begin. Here’s the relevant code:
self.on_fit_start()
if progressbar is None:
progressbar = self.progressbar
The on_fit_start
method takes care of a few important things, which can most easily be explained by sharing the code:
def on_fit_start(self):
self._compile_jit()
self._wrap_distributed()
self.init_optimizers()
if self.checkpointer is not None:
self.checkpointer.recover_if_possible(
device=torch.device(self.device)
)
Basically, this method ensures that the torch modules are prepared appropriately, including jit compilation, distributed wrapping, and initializing the optimizer with all of the relevant parameters. The optimizer initialization also adds the optimizer parameters to the checkpointer if there is one. Finally, this method loads the latest checkpoint in order to resume training if it was interrupted.
on_stage_start
This next section starts the epoch iteration and prepares for iterating the train data. To adjust the preparation one can override the on_stage_start
method, which will allow for things like creating containers to store training statistics.
for epoch in epoch_counter:
self.on_stage_start(Stage.TRAIN, epoch)
self.modules.train()
self.nonfinite_count = 0
if self.train_sampler is not None and hasattr(
self.train_sampler, "set_epoch"
):
self.train_sampler.set_epoch(epoch)
last_ckpt_time = time.time()
Training loop
The longest blocks of code in this tutorial are devoted to training and validation data loops. However, they really only do three important things:
Call
fit_batch()
on each batch in the DataLoader.Track average loss and report it.
Optionally save a checkpoint periodically so training can be resumed.
Here’s the code:
enable = progressbar and sb.utils.distributed.if_main_process()
with tqdm(
train_set, initial=self.step, dynamic_ncols=True, disable=not enable,
) as t:
for batch in t:
self.step += 1
loss = self.fit_batch(batch)
self.avg_train_loss = self.update_average(
loss, self.avg_train_loss
)
t.set_postfix(train_loss=self.avg_train_loss)
if self.debug and self.step == self.debug_batches:
break
if (
self.checkpointer is not None
and self.ckpt_interval_minutes > 0
and time.time() - last_ckpt_time
>= self.ckpt_interval_minutes * 60.0
):
run_on_main(self._save_intra_epoch_ckpt)
last_ckpt_time = time.time()
Perhaps the most important step is the fit_batch(batch)
call, which we show a trimmed version of here:
def fit_batch(self, batch):
outputs = self.compute_forward(batch, Stage.TRAIN)
loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
loss.backward()
if self.check_gradients(loss):
self.optimizer.step()
self.optimizer.zero_grad()
return loss.detach().cpu()
This method calls the most important methods for fitting, compute_forward
and compute_objectives
that both must be overridden in order to use the Brain class. Then the loss is backpropagated and the gradients are checked for non-finite values and excessively large norms before the update is applied (large norms are automatically clipped by default).
on_stage_end
At the end of the training loop, the on_stage_end
method is called for potential cleanup operations, such as reporting training statistics.
self.on_stage_end(Stage.TRAIN, self.avg_train_loss, epoch)
self.avg_train_loss = 0.0
self.step = 0
Validation loop
Much like the training loop, the validation loop iterates the dataloader and handles one batch of data at a time. However, instead of calling fit_batch
this loop calls evaluate_batch
which does not backpropagate the gradient or apply any updates.
if valid_set is not None:
self.on_stage_start(Stage.VALID, epoch)
self.modules.eval()
avg_valid_loss = 0.0
with torch.no_grad():
for batch in tqdm(
valid_set, dynamic_ncols=True, disable=not enable
):
self.step += 1
loss = self.evaluate_batch(batch, stage=Stage.VALID)
avg_valid_loss = self.update_average(
loss, avg_valid_loss
)
if self.debug and self.step == self.debug_batches:
break
on_stage_end
This method is the same one as the method for the train stage, but this time is only executed on a single process, since often the process will involve writing to files. Common uses include: updating learning rate, saving a checkpoint, and recording statistics for an epoch.
self.step = 0
run_on_main(
self.on_stage_end,
args=[Stage.VALID, avg_valid_loss, epoch],
)
The very last thing is a simple check for debug mode, to run only a few epochs.
if self.debug and epoch == self.debug_epochs:
break
Congrats, you now know how the fit()
method works, and why it is a useful tool for running experiments. All of the parts of training a model are broken down and the annoying bits are taken care of, while full flexibility is still available by overriding any part of the Brain class.
The evaluate()
method
This method iterates the test data in much the same way as the validation data of the fit()
method, including calls to on_stage_start
and on_stage_end
. One additional method that is called is the on_evaluate_start()
method, which by default loads the best checkpoint for evaluation.
Conclusion
The Brain class and the fit()
method in particular were inspired by other popular Python libraries for statistics and machine learning, notably numpy, scipy, keras and PyTorch Lightning.
As we add tutorials about more advanced usage of Brain class, we will add links to them here. Some examples of planned tutorials:
Writing a GAN with the Brain class
Distributed training with the Brain class
Non-gradient-based usage of the Brain class
Citing SpeechBrain
If you use SpeechBrain in your research or business, please cite it using the following BibTeX entry:
@misc{speechbrainV1,
title={Open-Source Conversational AI with {SpeechBrain} 1.0},
author={Mirco Ravanelli and Titouan Parcollet and Adel Moumen and Sylvain de Langen and Cem Subakan and Peter Plantinga and Yingzhi Wang and Pooneh Mousavi and Luca Della Libera and Artem Ploujnikov and Francesco Paissan and Davide Borra and Salah Zaiem and Zeyu Zhao and Shucong Zhang and Georgios Karakasidis and Sung-Lin Yeh and Pierre Champion and Aku Rouhe and Rudolf Braun and Florian Mai and Juan Zuluaga-Gomez and Seyed Mahed Mousavi and Andreas Nautsch and Xuechen Liu and Sangeet Sagar and Jarod Duret and Salima Mdhaffar and Gaelle Laperriere and Mickael Rouvier and Renato De Mori and Yannick Esteve},
year={2024},
eprint={2407.00463},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.00463},
}
@misc{speechbrain,
title={{SpeechBrain}: A General-Purpose Speech Toolkit},
author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio},
year={2021},
eprint={2106.04624},
archivePrefix={arXiv},
primaryClass={eess.AS},
note={arXiv:2106.04624}
}