speechbrain.utils.parameter_transfer module
Convenience functions for the simplest parameter transfer cases.
Use speechbrain.utils.checkpoints.Checkpointer
to find a checkpoint
and the path to the parameter file.
- Authors
Aku Rouhe 2020
Andreas Nautsch 2023
Adel Moumen 2023
Summary
Classes:
Orchestrates pretraining |
Reference
- class speechbrain.utils.parameter_transfer.Pretrainer(collect_in=None, loadables=None, paths=None, custom_hooks=None, conditions=None)[source]
Bases:
object
Orchestrates pretraining
First optionally collects files from some source (local directory, HuggingFace repository, base URL), into the
collect_in
directory, if specified.Then, calls load hooks for each of those files.
- Parameters:
collect_in (str or Path, optional) – Path to directory where the files are to be collected. If
None
, then files will be referred to from cache or directly, if possible (URLs will fail). There will not be a centralized target directory with all the files.loadables (mapping) – Mapping from loadable key to object. This connects the keys to the actual object instances.
paths (mapping) – Mapping from loadable key to filepath. The last part of the path is treated as file name, the rest of it is treated as a “source” which can be either a directory path or a magic source like Huggingface hub ID. e.g. sb/asr-crdnn-libri/lm.ckpt -> source=sb/asr-crdnn-libri, file=lm.ckpt Note that when collecting, you can specify a default source, which is used for all loadables that don’t have a path specified.
custom_hooks (mapping) – Mapping from loadable key to parameter transfer hook function. If you want to use a custom loading function, specify it here.
conditions (mapping) – An optional mapping from loadable keys to condition values, useful for loading certain elements only if a flag is turned on
- add_loadables(loadables)[source]
Update the loadables dict from the given mapping.
- Parameters:
loadables (mapping) – Mapping from loadable key to object
- add_paths(paths)[source]
Update the paths for different loadables.
When collecting parameters, paths here are preferred. Note that when collecting, you can specify a default source, which is used for all loadables that don’t have a path specified.
- Parameters:
paths (mapping) – Mapping from loadable key to filepath. The last part of the path is treated as file name, the rest of it is treated as a “source” which can be either a directory path or a magic source like Huggingface hub ID. e.g. sb/asr-crdnn-libri/lm.ckpt -> source=sb/asr-crdnn-libri, file=lm.ckpt
- add_custom_hooks(custom_hooks)[source]
Update the custom hooks.
When loading parameters, hooks here are preferred over class defaults.
- Parameters:
custom_hooks (mapping) – Mapping from loadable key to parameter transfer hook function. If you want to use a custom loading function, specify it here.
- add_conditions(conditions)[source]
Update the conditions.
- Parameters:
conditions (mapping) – Mapping from loadable keys to condition values, useful for loading certain elements only if a flag is turned on
- static split_path(path)[source]
Splits a path to source and filename
This also handles URLs and Huggingface hub paths, in addition to regular paths.
- Parameters:
path (str)
- Returns:
str – Source
str – Filename
- collect_files(default_source=None, use_auth_token=False, local_strategy: LocalStrategy = LocalStrategy.SYMLINK)[source]
Fetches parameters from known paths with fallback default_source
The actual parameter files may reside elsewhere, but this ensures a symlink in the self.collect_in directory. The symlink always uses the loadable key in the filename. This standardization makes it easier to orchestrate pretraining on e.g. distributed setups.
Use the default_source if you have everything organized neatly into one location, like a Huggingface hub repo.
- Parameters:
default_source (str or Path or FetchSource) – This is used for each loadable which doesn’t have a path already specified. e.g. if the loadable has key
"asr"
, then the file to look for is<default_source>/asr.ckpt
use_auth_token (bool (default: False)) – If true Huggingface’s auth_token will be used to load private models from the HuggingFace Hub, default is False because the majority of models are public.
local_strategy (speechbrain.utils.fetching.LocalStrategy) – The fetching strategy to use, which controls the behavior of remote file fetching with regards to symlinking and copying. Ignored if a
collect_in
directory was not specified. Seespeechbrain.utils.fetching.fetch()
for further details.
- Returns:
Mapping from loadable key to a local path from which loadable’s parameters can be loaded. This is not used in this class, but can possibly be helpful.
- Return type: