speechbrain.core moduleο
Core SpeechBrain code for running experiments.
- Authors
Peter Plantinga 2020, 2023
Abdel Heba 2020
Mirco Ravanelli 2020
Aku Rouhe 2021
Andreas Nautsch 2022
Sylvain de Langen 2023
Adel Moumen 2023, 2024
Summaryο
Classes:
Brain class abstracts away the details of data loops. |
|
Simple enum to track stage of experiments. |
Functions:
Create the output folder and relevant experimental files. |
Referenceο
- speechbrain.core.create_experiment_directory(experiment_directory, hyperparams_to_save=None, overrides={}, log_config='/home/docs/checkouts/readthedocs.org/user_builds/speechbrain/checkouts/latest/speechbrain/log-config.yaml', save_env_desc=True)[source]ο
Create the output folder and relevant experimental files.
- Parameters:
experiment_directory (str) β The place where the experiment directory should be created.
hyperparams_to_save (str) β A filename of a yaml file representing the parameters for this experiment. If passed, references are resolved, and the result is written to a file in the experiment directory called βhyperparams.yamlβ.
overrides (dict) β A mapping of replacements made in the yaml file, to save in yaml.
log_config (str) β A yaml filename containing configuration options for the logger.
save_env_desc (bool) β If True, an environment state description is saved to the experiment directory, in a file called env.log in the experiment directory.
- class speechbrain.core.Stage(*values)[source]ο
Bases:
EnumSimple enum to track stage of experiments.
- TRAIN = 1ο
- VALID = 2ο
- TEST = 3ο
- class speechbrain.core.Brain(modules=None, opt_class=None, hparams=None, run_opts=None, checkpointer=None)[source]ο
Bases:
objectBrain class abstracts away the details of data loops.
The primary purpose of the
Brainclass is the implementation of thefit()method, which iterates epochs and datasets for the purpose of βfittingβ a set of modules to a set of data.In order to use the
fit()method, one should sub-class theBrainclass and override any methods for which the default behavior does not match the use case. For a simple use case (e.g., training a single model with a single dataset) the only methods that need to be overridden are:compute_forward()compute_objectives()
The example below illustrates how overriding these two methods is done.
For more complicated use cases, such as multiple modules that need to be updated, the following methods can be overridden:
fit_batch()evaluate_batch()
- Parameters:
modules (dict[str, torch.nn.Module]) β These modules are passed to the optimizer by default if they have trainable parameters, and will have
train()/eval()called on them.opt_class (Optional[Type[torch.optim]]) β A torch optimizer constructor that takes only the list of parameters (e.g. a lambda or partial function definition). By default, this will be passed all modules in
modulesat the beginning of thefit()method. This behavior can be changed by overriding theconfigure_optimizers()method.hparams (Optional[dict]) β Each key:value pair should consist of a string key and a hyperparameter that is used within the overridden methods. These will be accessible via an
hparamsattribute, using βdotβ notation: e.g., self.hparams.model(x).run_opts (Optional[Union[RunOptions, dict]]) β A set of options to change the runtime environment, see
RunOptionsfor a list. Typically in a script this comes fromspeechbrain.parse_args, an alias forRunOptions.from_command_line_args. If an option is not defined here (keep in mind thatparse_argswill inject some options by default), then the option is also searched for in hparams (by key).checkpointer (Optional[speechbrain.utils.checkpoints.Checkpointer]) β By default, this will be used to load checkpoints, and will have the optimizer added to continue training if interrupted.
Example
>>> from torch.optim import SGD >>> class SimpleBrain(Brain): ... def compute_forward(self, batch, stage): ... return self.modules.model(batch[0] * self.hparams.scalar) ... ... def compute_objectives(self, predictions, batch, stage): ... return torch.nn.functional.l1_loss(predictions, batch[0]) >>> model = torch.nn.Linear(in_features=10, out_features=10) >>> brain = SimpleBrain( ... modules={"model": model}, ... opt_class=lambda x: SGD(x, lr=0.1), ... hparams={"scalar": 5}, ... run_opts={"device": "cpu"}, ... ) >>> brain.fit(range(1), ([torch.rand(10, 10), torch.rand(10, 10)],))
- compute_forward(batch, stage)[source]ο
Forward pass, to be overridden by sub-classes.
- Parameters:
batch (torch.Tensor or tensors) β An element from the dataloader, including inputs for processing.
stage (Stage) β The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
- Returns:
The outputs after all processing is complete. Directly passed to
compute_objectives().- Return type:
torch.Tensor or torch.Tensors
- compute_objectives(predictions, batch, stage)[source]ο
Compute loss, to be overridden by sub-classes.
- Parameters:
predictions (torch.Tensor or torch.Tensors) β The output tensor or tensors to evaluate. Comes directly from
compute_forward().batch (torch.Tensor or tensors) β An element from the dataloader, including targets for comparison.
stage (Stage) β The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
- Returns:
loss β A tensor with the computed loss.
- Return type:
- on_stage_start(stage, epoch=None)[source]ο
Gets called when a stage starts.
Useful for defining class variables used during the stage.
- on_stage_end(stage, stage_loss, epoch=None)[source]ο
Gets called at the end of a stage.
Useful for computing stage statistics, saving checkpoints, etc.
- make_dataloader(dataset, stage, ckpt_prefix='dataloader-', **loader_kwargs)[source]ο
Creates DataLoaders for Datasets.
This is used by
fit()andevaluate()if they just receive Datasets.Alternatively, this can be called from outside the Brain subclass. In that case, the DataLoader should be passed to
fit()in place of the dataset.The Stage.TRAIN DataLoader is handled specially. It has extra args for shuffle and drop_last. In DDP a DistributedSampler is created (unless the dataset is an IterableDataset).
Note
Some important DataLoader arguments are passed via **loader_kwargs, e.g., batch_size, num_workers, pin_memory.
Note
By default,
evaluate()specifies ckpt_prefix=None to stop the test DataLoader being added to the checkpointer. If you need to add a recoverable after saving checkpoints (e.g., at test time, after checkpointing the training), and still be able to recover reasonably, you should probably specifyallow_partial_load=True.- Parameters:
dataset (Dataset) β A set of data to use to create data loader. If the Dataset is a DynamicItemDataset, PaddedBatch is used as the default collate_fn, unless specified in loader_kwargs.
stage (Stage) β The stage of the experiment: Stage.TRAIN, Stage.VALID, Stage.TEST
ckpt_prefix (str, None) β Prefix to use for SaveableDataLoader Checkpoint name. The Stage name is added to this to create the full key. Set to None to not save the DataLoader.
**loader_kwargs (dict) β Additional keyword arguments to the DataLoader. E.g., batch_size, num_workers, pin_memory.
- Return type:
DataLoader for the input dataset
- on_fit_start()[source]ο
Gets called at the beginning of
fit(), on multiple processes ifdistributed_count > 0and backend is ddp.Default implementation compiles the jit modules, initializes optimizers, and loads the latest checkpoint to resume training.
- init_optimizers()[source]ο
Called during
on_fit_start(), initialize optimizers after parameters are fully configured (e.g. DDP, jit).The default implementation of this method depends on an optimizer class being passed at initialization that takes only a list of parameters (e.g., a lambda or a partial function definition). This creates a single optimizer that optimizes all trainable params.
Override this class if there are multiple optimizers.
- zero_grad(set_to_none=False)[source]ο
Sets the gradients of all optimized
torch.Tensor``s to zero if ``set_to_none=False(default) or to None otherwise.Setting gradients to None should save the memory, e.g. during
evaluate()and thus larger batch might be used.
- on_evaluate_start(max_key=None, min_key=None)[source]ο
Gets called at the beginning of
evaluate()Default implementation loads the best-performing checkpoint for evaluation, based on stored metrics.
- fit_batch(batch)[source]ο
Fit one batch, override to do multiple updates.
The default implementation depends on a few methods being defined with a particular behavior:
compute_forward()compute_objectives()optimizers_step()
Also depends on having optimizers passed at initialization.
- Parameters:
batch (list of torch.Tensors) β Batch of data to use for training. Default implementation assumes this batch has two elements: inputs and targets.
- Return type:
detached loss
- check_loss_isfinite(loss)[source]ο
Check if the loss is finite.
If the loss is not finite, log a helpful message and increment the
nonfinite_count. If thenonfinite_countexceeds the--nonfinite_patiencethreshold, stop the training and raise an error.This check is particularly useful when the loss becomes NaN or inf, while the parameters and gradients remain finite. It helps prevent getting stuck in an infinite loop during training.
- Parameters:
loss (tensor) β The loss tensor after
backward()has been called but before the optimizersstep().
- check_gradients()[source]ο
Checks if the gradients are finite. If not, it will emit a warning and set them to zero.
- freeze_optimizers(optimizers)[source]ο
By default, this method returns the passed optimizers. Override this method if you want to freeze some optimizers during training. To do so, return a of active optimizers.
- optimizers_step()[source]ο
Performs a step of gradient descent on the optimizers. This method is called every
grad_accumulation_factorsteps.
- on_fit_batch_start(batch, should_step)[source]ο
Called at the beginning of
fit_batch().This method is not called under the AMP context manager. Do not assume automatic casting of the input batch to a lower precision (e.g. fp16).
- Parameters:
batch (list of torch.Tensors) β Batch of data to use for training. Default implementation assumes this batch has two elements: inputs and targets.
should_step (boolean) β Whether optimizer.step() was called or not.
- on_fit_batch_end(batch, outputs, loss, should_step)[source]ο
Called after
fit_batch().- Parameters:
batch (list of torch.Tensors) β Batch of data to use for training. Default implementation assumes this batch has two elements: inputs and targets.
outputs (list or dictionary of torch.Tensors) β Returned value of compute_forward().
loss (torch.Tensor) β Returned value of compute_objectives().
should_step (boolean) β Whether optimizer.step() was called or not.
- evaluate_batch(batch, stage)[source]ο
Evaluate one batch, override for different procedure than train.
The default implementation depends on two methods being defined with a particular behavior:
compute_forward()compute_objectives()
- fit(epoch_counter, train_set, valid_set=None, progressbar=None, train_loader_kwargs={}, valid_loader_kwargs={})[source]ο
Iterate epochs and datasets to improve objective.
Relies on the existence of multiple functions that can (or should) be overridden. The following methods are used and expected to have a certain behavior:
fit_batch()evaluate_batch()update_average()
If the initialization was done with distributed_count > 0 and the distributed_backend is ddp, this will generally handle multiprocess logic, like splitting the training data into subsets for each device and only saving a checkpoint on the main process.
- Parameters:
epoch_counter (iterable) β Each call should return an integer indicating the epoch count.
train_set (Dataset, DataLoader) β A set of data to use for training. If a Dataset is given, a DataLoader is automatically created. If a DataLoader is given, it is used directly.
valid_set (Dataset, DataLoader) β A set of data to use for validation. If a Dataset is given, a DataLoader is automatically created. If a DataLoader is given, it is used directly.
progressbar (bool) β Whether to display the progress of each epoch in a progressbar.
train_loader_kwargs (dict) β Kwargs passed to
make_dataloader()for making the train_loader (if train_set is a Dataset, not DataLoader). E.G. batch_size, num_workers. DataLoader kwargs are all valid.valid_loader_kwargs (dict) β Kwargs passed to
make_dataloader()for making the valid_loader (if valid_set is a Dataset, not DataLoader). E.g., batch_size, num_workers. DataLoader kwargs are all valid.
- Return type:
None
- evaluate(test_set, max_key=None, min_key=None, progressbar=None, test_loader_kwargs={})[source]ο
Iterate test_set and evaluate brain performance. By default, loads the best-performing checkpoint (as recorded using the checkpointer).
- Parameters:
test_set (Dataset, DataLoader) β If a DataLoader is given, it is iterated directly. Otherwise passed to
self.make_dataloader().max_key (str) β Key to use for finding best checkpoint, passed to
on_evaluate_start().min_key (str) β Key to use for finding best checkpoint, passed to
on_evaluate_start().progressbar (bool) β Whether to display the progress in a progressbar.
test_loader_kwargs (dict) β Kwargs passed to
make_dataloader()iftest_setis not a DataLoader. NOTE:loader_kwargs["ckpt_prefix"]gets automatically overwritten toNone(so that the test DataLoader is not added to the checkpointer).
- Return type:
average test loss
- no_sync(use=True)[source]ο
Copies pytorchβs implementation for doing no_sync across all modules.
Explanation: nn.module.no_sync() is a context manager for when one does not want to sync gradients, which happens when using both DDP and gradient accumulation. Speechbrain brainβs class can contain multiple modules and calling no_sync on these individually would be very awkward, therefore this contextmanager exists.
- Parameters:
use (bool) β If set to
Falsewill still sync gradients, useful to make behavior toggleable.- Yields:
None