speechbrain.core module

Core SpeechBrain code for running experiments.

Authors
  • Peter Plantinga 2020

  • Abdel Heba 2020

  • Mirco Ravanelli 2020

  • Aku Rouhe 2021

  • Andreas Nautsch 2022

Summary

Classes:

Brain

Brain class abstracts away the details of data loops.

Stage

Simple enum to track stage of experiments.

Functions:

create_experiment_directory

Create the output folder and relevant experimental files.

parse_arguments

Parse command-line arguments to the experiment.

Reference

speechbrain.core.create_experiment_directory(experiment_directory, hyperparams_to_save=None, overrides={}, log_config='/home/docs/checkouts/readthedocs.org/user_builds/speechbrain/checkouts/v0.5.12/speechbrain/log-config.yaml', save_env_desc=True)[source]

Create the output folder and relevant experimental files.

Parameters
  • experiment_directory (str) – The place where the experiment directory should be created.

  • hyperparams_to_save (str) – A filename of a yaml file representing the parameters for this experiment. If passed, references are resolved, and the result is written to a file in the experiment directory called “hyperparams.yaml”.

  • overrides (dict) – A mapping of replacements made in the yaml file, to save in yaml.

  • log_config (str) – A yaml filename containing configuration options for the logger.

  • save_env_desc (bool) – If True, an environment state description is saved to the experiment directory, in a file called env.log in the experiment directory.

speechbrain.core.parse_arguments(arg_list=None)[source]

Parse command-line arguments to the experiment.

Parameters

arg_list (list, None) – A list of arguments to parse. If not given, this is read from sys.argv[1:]

Returns

  • param_file (str) – The location of the parameters file.

  • run_opts (dict) – Run options, such as distributed, device, etc.

  • overrides (dict) – The overrides to pass to load_hyperpyyaml.

Example

>>> argv = ['hyperparams.yaml', '--device', 'cuda:1', '--seed', '10']
>>> filename, run_opts, overrides = parse_arguments(argv)
>>> filename
'hyperparams.yaml'
>>> run_opts["device"]
'cuda:1'
>>> overrides
'seed: 10'
class speechbrain.core.Stage(value)[source]

Bases: Enum

Simple enum to track stage of experiments.

TRAIN = 1
VALID = 2
TEST = 3
class speechbrain.core.Brain(modules=None, opt_class=None, hparams=None, run_opts=None, checkpointer=None, profiler=None)[source]

Bases: object

Brain class abstracts away the details of data loops.

The primary purpose of the Brain class is the implementation of the fit() method, which iterates epochs and datasets for the purpose of “fitting” a set of modules to a set of data.

In order to use the fit() method, one should sub-class the Brain class and override any methods for which the default behavior does not match the use case. For a simple use case (e.g., training a single model with a single dataset) the only methods that need to be overridden are:

  • compute_forward()

  • compute_objectives()

The example below illustrates how overriding these two methods is done.

For more complicated use cases, such as multiple modules that need to be updated, the following methods can be overridden:

  • fit_batch()

  • evaluate_batch()

Parameters
  • modules (dict of str:torch.nn.Module pairs) – These modules are passed to the optimizer by default if they have trainable parameters, and will have train()/eval() called on them.

  • opt_class (torch.optim class) – A torch optimizer constructor that has takes only the list of parameters (e.g. a lambda or partial function definition). By default, this will be passed all modules in modules at the beginning of the fit() method. This behavior can be changed by overriding the configure_optimizers() method.

  • hparams (dict) – Each key:value pair should consist of a string key and a hyperparameter that is used within the overridden methods. These will be accessible via an hparams attribute, using “dot” notation: e.g., self.hparams.model(x).

  • run_opts (dict) –

    A set of options to change the runtime environment, including

    debug (bool)

    If True, this will only iterate a few batches for all datasets, to ensure code runs without crashing.

    debug_batches (int)

    Number of batches to run in debug mode, Default 2.

    debug_epochs (int)

    Number of epochs to run in debug mode, Default 2. If a non-positive number is passed, all epochs are run.

    jit_module_keys (list of str)

    List of keys in modules that should be jit compiled.

    distributed_backend (str)

    One of nccl, gloo, mpi.

    device (str)

    The location for performing computations.

    auto_mix_prec (bool)

    If True, automatic mixed-precision is used. Activate it only with cuda.

    max_grad_norm (float)

    Default implementation of fit_batch() uses clip_grad_norm_ with this value. Default: 5.

    nonfinite_patience (int)

    Number of times to ignore non-finite losses before stopping. Default: 3.

    noprogressbar (bool)

    Whether to turn off progressbar when training. Default: False.

    ckpt_interval_minutes (float)

    Amount of time between saving intra-epoch checkpoints, in minutes, default: 15.0. If non-positive, these are not saved.

    Typically in a script this comes from speechbrain.parse_args, which has different defaults than Brain. If an option is not defined here (keep in mind that parse_args will inject some options by default), then the option is also searched for in hparams (by key).

  • checkpointer (speechbrain.Checkpointer) – By default, this will be used to load checkpoints, and will have the optimizer added to continue training if interrupted.

  • profiler (torch.profiler.profile) – Context manager for profiling and benchmarking of training/inference steps. Default: None (skip profiling).

Example

>>> from torch.optim import SGD
>>> class SimpleBrain(Brain):
...     def compute_forward(self, batch, stage):
...         return self.modules.model(batch[0])
...     def compute_objectives(self, predictions, batch, stage):
...         return torch.nn.functional.l1_loss(predictions, batch[0])
>>> model = torch.nn.Linear(in_features=10, out_features=10)
>>> brain = SimpleBrain({"model": model}, opt_class=lambda x: SGD(x, 0.1))
>>> brain.fit(range(1), ([torch.rand(10, 10), torch.rand(10, 10)],))
compute_forward(batch, stage)[source]

Forward pass, to be overridden by sub-classes.

Parameters
  • batch (torch.Tensor or tensors) – An element from the dataloader, including inputs for processing.

  • stage (Stage) – The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST

Returns

The outputs after all processing is complete. Directly passed to compute_objectives().

Return type

torch.Tensor or Tensors

compute_objectives(predictions, batch, stage)[source]

Compute loss, to be overridden by sub-classes.

Parameters
  • predictions (torch.Tensor or Tensors) – The output tensor or tensors to evaluate. Comes directly from compute_forward().

  • batch (torch.Tensor or tensors) – An element from the dataloader, including targets for comparison.

  • stage (Stage) – The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST

Returns

loss – A tensor with the computed loss.

Return type

torch.Tensor

on_stage_start(stage, epoch=None)[source]

Gets called when a stage starts.

Useful for defining class variables used during the stage.

Parameters
  • stage (Stage) – The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST

  • epoch (int) – The current epoch count.

on_stage_end(stage, stage_loss, epoch=None)[source]

Gets called at the end of a stage.

Useful for computing stage statistics, saving checkpoints, etc.

Parameters
  • stage (Stage) – The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST

  • stage_loss (float) – The average loss over the completed stage.

  • epoch (int) – The current epoch count.

make_dataloader(dataset, stage, ckpt_prefix='dataloader-', **loader_kwargs)[source]

Creates DataLoaders for Datasets.

This is used by fit() and evaluate() if they just receive Datasets.

Alternatively, this can be called from outside the Brain subclass. In that case, the DataLoader should be passed to fit() in place of the dataset.

The Stage.TRAIN DataLoader is handled specially. It has extra args for shuffle and drop_last. In DDP a DistributedSampler is created (unless the dataset is an IterableDataset).

Note

Some important DataLoader arguments are passed via **loader_kwargs, e.g., batch_size, num_workers, pin_memory.

Note

By default, evaluate() specifies ckpt_prefix=None to stop the test DataLoader being added to the checkpointer. If you need to add a recoverable after saving checkpoints (e.g., at test time, after checkpointing the training), and still be able to recover reasonably, you should probably specify allow_partial_load=True.

Parameters
  • dataset (Dataset) – A set of data to use to create data loader. If the Dataset is a DynamicItemDataset, PaddedBatch is used as the default collate_fn, unless specified in loader_kwargs.

  • stage (Stage) – The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST

  • ckpt_prefix (str, None) – Prefix to use for SaveableDataLoader Checkpoint name. The Stage name is added to this to create the full key. Set to None to not save the DataLoader.

  • **loader_kwargs (dict) – Additional keyword arguments to the DataLoader. E.g., batch_size, num_workers, pin_memory.

on_fit_start()[source]

Gets called at the beginning of fit(), on multiple processes if distributed_count > 0 and backend is ddp.

Default implementation compiles the jit modules, initializes optimizers, and loads the latest checkpoint to resume training.

init_optimizers()[source]

Called during on_fit_start(), initialize optimizers after parameters are fully configured (e.g. DDP, jit).

The default implementation of this method depends on an optimizer class being passed at initialization that takes only a list of parameters (e.g., a lambda or a partial function definition). This creates a single optimizer that optimizes all trainable params.

Override this class if there are multiple optimizers.

on_evaluate_start(max_key=None, min_key=None)[source]

Gets called at the beginning of evaluate()

Default implementation loads the best-performing checkpoint for evaluation, based on stored metrics.

Parameters
  • max_key (str) – Key to use for finding best checkpoint (higher is better). By default, passed to self.checkpointer.recover_if_possible().

  • min_key (str) – Key to use for finding best checkpoint (lower is better). By default, passed to self.checkpointer.recover_if_possible().

fit_batch(batch)[source]

Fit one batch, override to do multiple updates.

The default implementation depends on a few methods being defined with a particular behavior:

  • compute_forward()

  • compute_objectives()

Also depends on having optimizers passed at initialization.

Parameters

batch (list of torch.Tensors) – Batch of data to use for training. Default implementation assumes this batch has two elements: inputs and targets.

Return type

detached loss

check_gradients(loss)[source]

Check if gradients are finite and not too large.

Automatically clips large gradients.

Parameters

loss (tensor) – The loss tensor after backward() has been called but before the optimizers step().

Returns

Whether or not the optimizer step should be carried out.

Return type

bool

evaluate_batch(batch, stage)[source]

Evaluate one batch, override for different procedure than train.

The default implementation depends on two methods being defined with a particular behavior:

  • compute_forward()

  • compute_objectives()

Parameters
  • batch (list of torch.Tensors) – Batch of data to use for evaluation. Default implementation assumes this batch has two elements: inputs and targets.

  • stage (Stage) – The stage of the experiment: Stage.VALID, Stage.TEST

Return type

detached loss

fit(epoch_counter, train_set, valid_set=None, progressbar=None, train_loader_kwargs={}, valid_loader_kwargs={})[source]

Iterate epochs and datasets to improve objective.

Relies on the existence of multiple functions that can (or should) be overridden. The following methods are used and expected to have a certain behavior:

  • fit_batch()

  • evaluate_batch()

  • update_average()

If the initialization was done with distributed_count > 0 and the distributed_backend is ddp, this will generally handle multiprocess logic, like splitting the training data into subsets for each device and only saving a checkpoint on the main process.

Parameters
  • epoch_counter (iterable) – Each call should return an integer indicating the epoch count.

  • train_set (Dataset, DataLoader) – A set of data to use for training. If a Dataset is given, a DataLoader is automatically created. If a DataLoader is given, it is used directly.

  • valid_set (Dataset, DataLoader) – A set of data to use for validation. If a Dataset is given, a DataLoader is automatically created. If a DataLoader is given, it is used directly.

  • train_loader_kwargs (dict) – Kwargs passed to make_dataloader() for making the train_loader (if train_set is a Dataset, not DataLoader). E.G. batch_size, num_workers. DataLoader kwargs are all valid.

  • valid_loader_kwargs (dict) – Kwargs passed to make_dataloader() for making the valid_loader (if valid_set is a Dataset, not DataLoader). E.g., batch_size, num_workers. DataLoader kwargs are all valid.

  • progressbar (bool) – Whether to display the progress of each epoch in a progressbar.

evaluate(test_set, max_key=None, min_key=None, progressbar=None, test_loader_kwargs={})[source]

Iterate test_set and evaluate brain performance. By default, loads the best-performing checkpoint (as recorded using the checkpointer).

Parameters
  • test_set (Dataset, DataLoader) – If a DataLoader is given, it is iterated directly. Otherwise passed to self.make_dataloader().

  • max_key (str) – Key to use for finding best checkpoint, passed to on_evaluate_start().

  • min_key (str) – Key to use for finding best checkpoint, passed to on_evaluate_start().

  • progressbar (bool) – Whether to display the progress in a progressbar.

  • test_loader_kwargs (dict) – Kwargs passed to make_dataloader() if test_set is not a DataLoader. NOTE: loader_kwargs["ckpt_prefix"] gets automatically overwritten to None (so that the test DataLoader is not added to the checkpointer).

Return type

average test loss

update_average(loss, avg_loss)[source]

Update running average of the loss.

Parameters
  • loss (torch.tensor) – detached loss, a single float value.

  • avg_loss (float) – current running average.

Returns

avg_loss – The average loss.

Return type

float