Source code for speechbrain.utils.repro

"""Reproducibility tools

Author:
    * Artem Ploujnikov 2025
"""

import re

import torch

import speechbrain as sb
from speechbrain.utils.logger import get_logger

logger = get_logger(__name__)


[docs] @sb.utils.checkpoints.register_checkpoint_hooks class SaveableGenerator: """A wrapper that can be used to store the state of the random number generator in a checkpoint. It helps with reproducibility in long-running experiments. Currently, this only supports CPU and Cuda devices natively. If you need training on other architectures, consider implementing a custom generator. Running it on an unsupported device not using the Torch generator interface will simply fail to restore the state but will not cause an error. Typical in hparams: ```yaml generator: !new:model.custom_model.SaveableGenerator # <-- Include the wrapper checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer checkpoints_dir: !ref <save_folder> recoverables: model: !ref <model> lr_scheduler: !ref <lr_annealing> counter: !ref <epoch_counter> generator: !ref <generator> ``` Arguments --------- generators : Mapping[str, Generator], optional A dictionary of named generator objects. If not provided, the default generators for CPU and Cuda will be used Examples -------- >>> import torch >>> from speechbrain.utils.repro import SaveableGenerator >>> from speechbrain.utils.checkpoints import Checkpointer >>> gena, genb = [torch.Generator().manual_seed(x) for x in [42, 24]] >>> saveable_gen = SaveableGenerator( ... generators={"a": gena, "b": genb} ... ) >>> tempdir = getfixture('tmpdir') >>> checkpointer = Checkpointer( ... tempdir, ... recoverables={"generator": saveable_gen}) >>> torch.randint(0, 10, (1,), generator=gena).item() 2 >>> torch.randint(0, 10, (1,), generator=genb).item() 4 >>> _ = checkpointer.save_checkpoint() >>> torch.randint(0, 10, (1,), generator=gena).item() 7 >>> torch.randint(0, 10, (1,), generator=genb).item() 5 >>> _ = checkpointer.recover_if_possible() >>> torch.randint(0, 10, (1,), generator=gena).item() 7 >>> torch.randint(0, 10, (1,), generator=genb).item() 5 """ def __init__(self, generators=None): if generators is None: generators = {"default": torch.default_generator} if torch.cuda.is_available(): for idx in range(torch.cuda.device_count()): generators[f"cuda:{idx}"] = _CudaDefaultGeneratorWrapper( idx ) self.generators = generators
[docs] @sb.utils.checkpoints.mark_as_saver def save(self, path): """Save the generator state for later recovery Arguments --------- path : str, Path Where to save. Will overwrite. """ save_dict = { key: generator.get_state() for key, generator in self.generators.items() } torch.save(save_dict, path)
[docs] @sb.utils.checkpoints.mark_as_loader def load(self, path, end_of_epoch): """ Loads the generator state if the corresponding devices are present Arguments --------- path : str, Path Where to load from. end_of_epoch : bool Whether the checkpoint was end-of-epoch or not. """ del end_of_epoch save_dict = torch.load(path) for key, state in save_dict.items(): if key == "default": torch.default_generator.set_state(state) continue match = re.match(r"cuda:(\d+)", key) if match: if not torch.cuda.is_available(): logger.warning( "Unable to restore RNG for %s, CUDA unavailable", key ) continue idx = int(match.group(1)) if idx > torch.cuda.device_count() - 1: logger.warning( "Unable to restore RNG for %s, device not found", key ) continue self.generators[key].set_state(state)
class _CudaDefaultGeneratorWrapper: """A generator wrapper for default generators - because torch no longer exposes default_generators This class should not be used outside of SaveableGenerator Arguments --------- device : int|str The device index or identifier""" def __init__(self, device): self.device = device def get_state(self): """Returns the generator state Returns ------- result : torch.Tensor The generator state """ return torch.cuda.get_rng_state(self.device) def set_state(self, new_state): """ "Sets the generator state Arguments --------- new_state : dict The new state """ torch.cuda.set_rng_state(new_state, self.device)