speechbrain.core module
Core SpeechBrain code for running experiments.
- Authors
Peter Plantinga 2020, 2023
Abdel Heba 2020
Mirco Ravanelli 2020
Aku Rouhe 2021
Andreas Nautsch 2022
Sylvain de Langen 2023
Adel Moumen 2023
Summary
Classes:
Brain class abstracts away the details of data loops. |
|
Simple enum to track stage of experiments. |
Functions:
Create the output folder and relevant experimental files. |
|
Parse command-line arguments to the experiment. |
Reference
- speechbrain.core.create_experiment_directory(experiment_directory, hyperparams_to_save=None, overrides={}, log_config='/home/docs/checkouts/readthedocs.org/user_builds/speechbrain/checkouts/latest/speechbrain/log-config.yaml', save_env_desc=True)[source]
Create the output folder and relevant experimental files.
- Parameters:
experiment_directory (str) – The place where the experiment directory should be created.
hyperparams_to_save (str) – A filename of a yaml file representing the parameters for this experiment. If passed, references are resolved, and the result is written to a file in the experiment directory called “hyperparams.yaml”.
overrides (dict) – A mapping of replacements made in the yaml file, to save in yaml.
log_config (str) – A yaml filename containing configuration options for the logger.
save_env_desc (bool) – If True, an environment state description is saved to the experiment directory, in a file called env.log in the experiment directory.
- speechbrain.core.parse_arguments(arg_list=None)[source]
Parse command-line arguments to the experiment.
- Parameters:
arg_list (list, None) – A list of arguments to parse. If not given, this is read from sys.argv[1:]
- Returns:
param_file (str) – The location of the parameters file.
run_opts (dict) – Run options, such as distributed, device, etc.
overrides (dict) – The overrides to pass to
load_hyperpyyaml
.
Example
>>> argv = ['hyperparams.yaml', '--device', 'cuda:1', '--seed', '10'] >>> filename, run_opts, overrides = parse_arguments(argv) >>> filename 'hyperparams.yaml' >>> run_opts["device"] 'cuda:1' >>> overrides 'seed: 10'
- class speechbrain.core.Stage(value)[source]
Bases:
Enum
Simple enum to track stage of experiments.
- TRAIN = 1
- VALID = 2
- TEST = 3
- class speechbrain.core.Brain(modules=None, opt_class=None, hparams=None, run_opts=None, checkpointer=None, profiler=None)[source]
Bases:
object
Brain class abstracts away the details of data loops.
The primary purpose of the Brain class is the implementation of the
fit()
method, which iterates epochs and datasets for the purpose of “fitting” a set of modules to a set of data.In order to use the
fit()
method, one should sub-class theBrain
class and override any methods for which the default behavior does not match the use case. For a simple use case (e.g., training a single model with a single dataset) the only methods that need to be overridden are:compute_forward()
compute_objectives()
The example below illustrates how overriding these two methods is done.
For more complicated use cases, such as multiple modules that need to be updated, the following methods can be overridden:
fit_batch()
evaluate_batch()
- Parameters:
modules (dict of str:torch.nn.Module pairs) – These modules are passed to the optimizer by default if they have trainable parameters, and will have
train()
/eval()
called on them.opt_class (torch.optim class) – A torch optimizer constructor that takes only the list of parameters (e.g. a lambda or partial function definition). By default, this will be passed all modules in
modules
at the beginning of thefit()
method. This behavior can be changed by overriding theconfigure_optimizers()
method.hparams (dict) – Each key:value pair should consist of a string key and a hyperparameter that is used within the overridden methods. These will be accessible via an
hparams
attribute, using “dot” notation: e.g., self.hparams.model(x).run_opts –
A set of options to change the runtime environment, including
- debug (bool)
If
True
, this will only iterate a few batches for all datasets, to ensure code runs without crashing.- debug_batches (int)
Number of batches to run in debug mode, Default
2
.- debug_epochs (int)
Number of epochs to run in debug mode, Default
2
. If a non-positive number is passed, all epochs are run.- debug_persistently (bool)
Keep data stored during debug mode (not using /tmp), Default
False
.- jit (bool)
Enable to compile all modules using jit, Default
False
.- jit_module_keys (list of str)
List of keys in
modules
that should be jit compiled.- compile (bool)
Enable to compile all modules using torch.compile, Default
False
.- compile_module_keys (list of str)
List of keys in
modules
that should be compiled usingtorch.compile
. Iftorch.compile
is unavailable, an error is raised.- compile_mode (str)
One of
default
,reduce-overhead
,max-autotune
, Defaultreduce-overhead
.- compile_using_fullgraph (bool)
Whether it is ok to break model into several subgraphs, Default
False
.- compile_using_dynamic_shape_tracing (bool)
Use dynamic shape tracing for compilation, Default
False
.- distributed_backend (str)
One of
nccl
,gloo
,mpi
.- device (str)
The location for performing computations.
- auto_mix_prec (bool)
If
True
, automatic mixed-precision is used. Activate it only with cuda.- max_grad_norm (float)
Default implementation of
fit_batch()
usesclip_grad_norm_
with this value. Default:5
.- nonfinite_patience (int)
Number of times to ignore non-finite losses before stopping. Default:
3
.- noprogressbar (bool)
Whether to turn off progressbar when training. Default:
False
.- ckpt_interval_minutes (float)
Amount of time between saving intra-epoch checkpoints, in minutes, default:
15.0
. If non-positive, these are not saved.- ckpt_interval_steps (int)
Number of steps between saving intra-epoch checkpoints. If non-positive, these are not saved. Default:
0
.
- checkpointerspeechbrain.Checkpointer
By default, this will be used to load checkpoints, and will have the optimizer added to continue training if interrupted.
- profilertorch.profiler.profile
Context manager for profiling and benchmarking of training/inference steps. Default:
None
(skip profiling).
Example
>>> from torch.optim import SGD >>> class SimpleBrain(Brain): ... def compute_forward(self, batch, stage): ... return self.modules.model(batch[0]) ... def compute_objectives(self, predictions, batch, stage): ... return torch.nn.functional.l1_loss(predictions, batch[0]) >>> model = torch.nn.Linear(in_features=10, out_features=10) >>> brain = SimpleBrain({"model": model}, opt_class=lambda x: SGD(x, 0.1)) >>> brain.fit(range(1), ([torch.rand(10, 10), torch.rand(10, 10)],))
- compute_forward(batch, stage)[source]
Forward pass, to be overridden by sub-classes.
- Parameters:
batch (torch.Tensor or tensors) – An element from the dataloader, including inputs for processing.
stage (Stage) – The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
- Returns:
The outputs after all processing is complete. Directly passed to
compute_objectives()
.- Return type:
torch.Tensor or Tensors
- compute_objectives(predictions, batch, stage)[source]
Compute loss, to be overridden by sub-classes.
- Parameters:
predictions (torch.Tensor or Tensors) – The output tensor or tensors to evaluate. Comes directly from
compute_forward()
.batch (torch.Tensor or tensors) – An element from the dataloader, including targets for comparison.
stage (Stage) – The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
- Returns:
loss – A tensor with the computed loss.
- Return type:
- on_stage_start(stage, epoch=None)[source]
Gets called when a stage starts.
Useful for defining class variables used during the stage.
- on_stage_end(stage, stage_loss, epoch=None)[source]
Gets called at the end of a stage.
Useful for computing stage statistics, saving checkpoints, etc.
- make_dataloader(dataset, stage, ckpt_prefix='dataloader-', **loader_kwargs)[source]
Creates DataLoaders for Datasets.
This is used by
fit()
andevaluate()
if they just receive Datasets.Alternatively, this can be called from outside the Brain subclass. In that case, the DataLoader should be passed to
fit()
in place of the dataset.The Stage.TRAIN DataLoader is handled specially. It has extra args for shuffle and drop_last. In DDP a DistributedSampler is created (unless the dataset is an IterableDataset).
Note
Some important DataLoader arguments are passed via **loader_kwargs, e.g., batch_size, num_workers, pin_memory.
Note
By default,
evaluate()
specifies ckpt_prefix=None to stop the test DataLoader being added to the checkpointer. If you need to add a recoverable after saving checkpoints (e.g., at test time, after checkpointing the training), and still be able to recover reasonably, you should probably specifyallow_partial_load=True
.- Parameters:
dataset (Dataset) – A set of data to use to create data loader. If the Dataset is a DynamicItemDataset, PaddedBatch is used as the default collate_fn, unless specified in loader_kwargs.
stage (Stage) – The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
ckpt_prefix (str, None) – Prefix to use for SaveableDataLoader Checkpoint name. The Stage name is added to this to create the full key. Set to None to not save the DataLoader.
**loader_kwargs (dict) – Additional keyword arguments to the DataLoader. E.g., batch_size, num_workers, pin_memory.
- on_fit_start()[source]
Gets called at the beginning of
fit()
, on multiple processes ifdistributed_count > 0
and backend is ddp.Default implementation compiles the jit modules, initializes optimizers, and loads the latest checkpoint to resume training.
- init_optimizers()[source]
Called during
on_fit_start()
, initialize optimizers after parameters are fully configured (e.g. DDP, jit).The default implementation of this method depends on an optimizer class being passed at initialization that takes only a list of parameters (e.g., a lambda or a partial function definition). This creates a single optimizer that optimizes all trainable params.
Override this class if there are multiple optimizers.
- zero_grad(set_to_none=False)[source]
Sets the gradients of all optimized
torch.Tensor``s to zero if ``set_to_none=False
(default) or to None otherwise.Setting gradients to None should save the memory, e.g. during
evaluate()
and thus larger batch might be used.
- on_evaluate_start(max_key=None, min_key=None)[source]
Gets called at the beginning of
evaluate()
Default implementation loads the best-performing checkpoint for evaluation, based on stored metrics.
- fit_batch(batch)[source]
Fit one batch, override to do multiple updates.
The default implementation depends on a few methods being defined with a particular behavior:
compute_forward()
compute_objectives()
Also depends on having optimizers passed at initialization.
- Parameters:
batch (list of torch.Tensors) – Batch of data to use for training. Default implementation assumes this batch has two elements: inputs and targets.
- Return type:
detached loss
- on_fit_batch_end(batch, outputs, loss, should_step)[source]
Called after
fit_batch()
, meant for calculating and logging metrics.- Parameters:
batch (list of torch.Tensors) – Batch of data to use for training. Default implementation assumes this batch has two elements: inputs and targets.
outputs (list or dictionary of torch.Tensors) – Returned value of compute_forward().
loss (torch.Tensor) – Returned value of compute_objectives().
should_step (boolean) – Whether optimizer.step() was called or not.
- check_gradients(loss)[source]
Check if gradients are finite and not too large. Automatically clips large gradients.
- Parameters:
loss (tensor) – The loss tensor after
backward()
has been called but before the optimizersstep()
.- Returns:
Whether or not the optimizer step should be carried out.
- Return type:
- evaluate_batch(batch, stage)[source]
Evaluate one batch, override for different procedure than train.
The default implementation depends on two methods being defined with a particular behavior:
compute_forward()
compute_objectives()
- fit(epoch_counter, train_set, valid_set=None, progressbar=None, train_loader_kwargs={}, valid_loader_kwargs={})[source]
Iterate epochs and datasets to improve objective.
Relies on the existence of multiple functions that can (or should) be overridden. The following methods are used and expected to have a certain behavior:
fit_batch()
evaluate_batch()
update_average()
If the initialization was done with distributed_count > 0 and the distributed_backend is ddp, this will generally handle multiprocess logic, like splitting the training data into subsets for each device and only saving a checkpoint on the main process.
- Parameters:
epoch_counter (iterable) – Each call should return an integer indicating the epoch count.
train_set (Dataset, DataLoader) – A set of data to use for training. If a Dataset is given, a DataLoader is automatically created. If a DataLoader is given, it is used directly.
valid_set (Dataset, DataLoader) – A set of data to use for validation. If a Dataset is given, a DataLoader is automatically created. If a DataLoader is given, it is used directly.
train_loader_kwargs (dict) – Kwargs passed to make_dataloader() for making the train_loader (if train_set is a Dataset, not DataLoader). E.G. batch_size, num_workers. DataLoader kwargs are all valid.
valid_loader_kwargs (dict) – Kwargs passed to make_dataloader() for making the valid_loader (if valid_set is a Dataset, not DataLoader). E.g., batch_size, num_workers. DataLoader kwargs are all valid.
progressbar (bool) – Whether to display the progress of each epoch in a progressbar.
- evaluate(test_set, max_key=None, min_key=None, progressbar=None, test_loader_kwargs={})[source]
Iterate test_set and evaluate brain performance. By default, loads the best-performing checkpoint (as recorded using the checkpointer).
- Parameters:
test_set (Dataset, DataLoader) – If a DataLoader is given, it is iterated directly. Otherwise passed to
self.make_dataloader()
.max_key (str) – Key to use for finding best checkpoint, passed to
on_evaluate_start()
.min_key (str) – Key to use for finding best checkpoint, passed to
on_evaluate_start()
.progressbar (bool) – Whether to display the progress in a progressbar.
test_loader_kwargs (dict) – Kwargs passed to
make_dataloader()
iftest_set
is not a DataLoader. NOTE:loader_kwargs["ckpt_prefix"]
gets automatically overwritten toNone
(so that the test DataLoader is not added to the checkpointer).
- Return type:
average test loss
- no_sync(use=True)[source]
Copies pytorch’s implementation for doing no_sync across all modules.
Explanation: nn.module.no_sync() is a context manager for when one does not want to sync gradients, which happens when using both DDP and gradient accumulation. Speechbrain brain’s class can contain multiple modules and calling no_sync on these individually would be very awkward, therefore this contextmanager exists.
- Parameters:
use (bool) – If set to False will still sync gradients, useful to make behaviour togglable.