speechbrain.utils.repro moduleο
Reproducibility tools
- Author:
Artem Ploujnikov 2025
Summaryο
Classes:
A wrapper that can be used to store the state of the random number generator in a checkpoint. |
Referenceο
- class speechbrain.utils.repro.SaveableGenerator(generators=None)[source]ο
Bases:
objectA wrapper that can be used to store the state of the random number generator in a checkpoint. It helps with reproducibility in long-running experiments.
Currently, this only supports CPU and Cuda devices natively. If you need training on other architectures, consider implementing a custom generator.
Running it on an unsupported device not using the Torch generator interface will simply fail to restore the state but will not cause an error.
Typical in hparams: ```yaml generator: !new:model.custom_model.SaveableGenerator # <β Include the wrapper
- checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
checkpoints_dir: !ref <save_folder> recoverables:
model: !ref <model> lr_scheduler: !ref <lr_annealing> counter: !ref <epoch_counter> generator: !ref <generator>
- Parameters:
generators (Mapping[str, Generator], optional) β A dictionary of named generator objects. If not provided, the default generators for CPU and Cuda will be used
Examples
>>> import torch >>> from speechbrain.utils.repro import SaveableGenerator >>> from speechbrain.utils.checkpoints import Checkpointer >>> gena, genb = [torch.Generator().manual_seed(x) for x in [42, 24]] >>> saveable_gen = SaveableGenerator( ... generators={"a": gena, "b": genb} ... ) >>> tempdir = getfixture('tmpdir') >>> checkpointer = Checkpointer( ... tempdir, ... recoverables={"generator": saveable_gen}) >>> torch.randint(0, 10, (1,), generator=gena).item() 2 >>> torch.randint(0, 10, (1,), generator=genb).item() 4 >>> _ = checkpointer.save_checkpoint() >>> torch.randint(0, 10, (1,), generator=gena).item() 7 >>> torch.randint(0, 10, (1,), generator=genb).item() 5 >>> _ = checkpointer.recover_if_possible() >>> torch.randint(0, 10, (1,), generator=gena).item() 7 >>> torch.randint(0, 10, (1,), generator=genb).item() 5