speechbrain.utils.parameter_transfer module

Convenience functions for the simplest parameter transfer cases.

Use speechbrain.utils.checkpoints.Checkpointer to find a checkpoint and the path to the parameter file.

Authors
  • Aku Rouhe 2020

  • Andreas Nautsch 2023

  • Adel Moumen 2023

Summary

Classes:

Pretrainer

Orchestrates pretraining

Reference

class speechbrain.utils.parameter_transfer.Pretrainer(collect_in=None, loadables=None, paths=None, custom_hooks=None, conditions=None)[source]

Bases: object

Orchestrates pretraining

First optionally collects files from some source (local directory, HuggingFace repository, base URL), into the collect_in directory, if specified.

Then, calls load hooks for each of those files.

Parameters:
  • collect_in (str or Path, optional) – Path to directory where the files are to be collected. If None, then files will be referred to from cache or directly, if possible (URLs will fail). There will not be a centralized target directory with all the files.

  • loadables (mapping) – Mapping from loadable key to object. This connects the keys to the actual object instances.

  • paths (mapping) – Mapping from loadable key to filepath. The last part of the path is treated as file name, the rest of it is treated as a “source” which can be either a directory path or a magic source like Huggingface hub ID. e.g. sb/asr-crdnn-libri/lm.ckpt -> source=sb/asr-crdnn-libri, file=lm.ckpt Note that when collecting, you can specify a default source, which is used for all loadables that don’t have a path specified.

  • custom_hooks (mapping) – Mapping from loadable key to parameter transfer hook function. If you want to use a custom loading function, specify it here.

  • conditions (mapping) – An optional mapping from loadable keys to condition values, useful for loading certain elements only if a flag is turned on

set_collect_in(path)[source]

Change the collecting path

add_loadables(loadables)[source]

Update the loadables dict from the given mapping.

Parameters:

loadables (mapping) – Mapping from loadable key to object

add_paths(paths)[source]

Update the paths for different loadables.

When collecting parameters, paths here are preferred. Note that when collecting, you can specify a default source, which is used for all loadables that don’t have a path specified.

Parameters:

paths (mapping) – Mapping from loadable key to filepath. The last part of the path is treated as file name, the rest of it is treated as a “source” which can be either a directory path or a magic source like Huggingface hub ID. e.g. sb/asr-crdnn-libri/lm.ckpt -> source=sb/asr-crdnn-libri, file=lm.ckpt

add_custom_hooks(custom_hooks)[source]

Update the custom hooks.

When loading parameters, hooks here are preferred over class defaults.

Parameters:

custom_hooks (mapping) – Mapping from loadable key to parameter transfer hook function. If you want to use a custom loading function, specify it here.

add_conditions(conditions)[source]

Update the conditions.

Parameters:

conditions (mapping) – Mapping from loadable keys to condition values, useful for loading certain elements only if a flag is turned on

static split_path(path)[source]

Splits a path to source and filename

This also handles URLs and Huggingface hub paths, in addition to regular paths.

Parameters:

path (str)

Returns:

  • str – Source

  • str – Filename

collect_files(default_source=None, use_auth_token=False, local_strategy: LocalStrategy = LocalStrategy.SYMLINK)[source]

Fetches parameters from known paths with fallback default_source

The actual parameter files may reside elsewhere, but this ensures a symlink in the self.collect_in directory. The symlink always uses the loadable key in the filename. This standardization makes it easier to orchestrate pretraining on e.g. distributed setups.

Use the default_source if you have everything organized neatly into one location, like a Huggingface hub repo.

Parameters:
  • default_source (str or Path or FetchSource) – This is used for each loadable which doesn’t have a path already specified. e.g. if the loadable has key "asr", then the file to look for is <default_source>/asr.ckpt

  • use_auth_token (bool (default: False)) – If true Huggingface’s auth_token will be used to load private models from the HuggingFace Hub, default is False because the majority of models are public.

  • local_strategy (speechbrain.utils.fetching.LocalStrategy) – The fetching strategy to use, which controls the behavior of remote file fetching with regards to symlinking and copying. Ignored if a collect_in directory was not specified. See speechbrain.utils.fetching.fetch() for further details.

Returns:

Mapping from loadable key to a local path from which loadable’s parameters can be loaded. This is not used in this class, but can possibly be helpful.

Return type:

dict

is_loadable(name)[source]

Returns True if no condition is defined or for the specified loadable or if the condition is true

Parameters:

name (str) – the name of the loadable

Returns:

is_loadable – whether the item should be loaded

Return type:

bool

load_collected()[source]

Loads the files that have been collected.