speechbrain.utils.repro module

Reproducibility tools

Author:
  • Artem Ploujnikov 2025

Summary

Classes:

SaveableGenerator

A wrapper that can be used to store the state of the random number generator in a checkpoint.

Reference

class speechbrain.utils.repro.SaveableGenerator(generators=None)[source]

Bases: object

A wrapper that can be used to store the state of the random number generator in a checkpoint. It helps with reproducibility in long-running experiments.

Currently, this only supports CPU and Cuda devices natively. If you need training on other architectures, consider implementing a custom generator.

Running it on an unsupported device not using the Torch generator interface will simply fail to restore the state but will not cause an error.

Typical in hparams: ```yaml generator: !new:model.custom_model.SaveableGenerator # <– Include the wrapper

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer

checkpoints_dir: !ref <save_folder> recoverables:

model: !ref <model> lr_scheduler: !ref <lr_annealing> counter: !ref <epoch_counter> generator: !ref <generator>

```

Parameters:

generators (Mapping[str, Generator], optional) – A dictionary of named generator objects. If not provided, the default generators for CPU and Cuda will be used

Examples

>>> import torch
>>> from speechbrain.utils.repro import SaveableGenerator
>>> from speechbrain.utils.checkpoints import Checkpointer
>>> gena, genb = [torch.Generator().manual_seed(x) for x in [42, 24]]
>>> saveable_gen = SaveableGenerator(
...     generators={"a": gena, "b": genb}
... )
>>> tempdir = getfixture('tmpdir')
>>> checkpointer = Checkpointer(
...     tempdir,
...     recoverables={"generator": saveable_gen})
>>> torch.randint(0, 10, (1,), generator=gena).item()
2
>>> torch.randint(0, 10, (1,), generator=genb).item()
4
>>> _ = checkpointer.save_checkpoint()
>>> torch.randint(0, 10, (1,), generator=gena).item()
7
>>> torch.randint(0, 10, (1,), generator=genb).item()
5
>>> _ = checkpointer.recover_if_possible()
>>> torch.randint(0, 10, (1,), generator=gena).item()
7
>>> torch.randint(0, 10, (1,), generator=genb).item()
5
save(path)[source]

Save the generator state for later recovery

Parameters:

path (str, Path) – Where to save. Will overwrite.

load(path, end_of_epoch)[source]

Loads the generator state if the corresponding devices are present

Parameters:
  • path (str, Path) – Where to load from.

  • end_of_epoch (bool) – Whether the checkpoint was end-of-epoch or not.