speechbrain.utils.parameter_transfer module¶
Convenience functions for the simplest parameter transfer cases.
Use speechbrain.utils.checkpoints.Checkpointer to find a checkpoint and the path to the parameter file.
- Authors
Aku Rouhe 2020
Reference¶
- class speechbrain.utils.parameter_transfer.Pretrainer(collect_in='./model_checkpoints', loadables=None, paths=None, custom_hooks=None)[source]¶
Bases:
object
Orchestrates pretraining
First collects parameter file symlinks into the given directory. Then calls load hooks for each of those parameter files.
- Parameters
collect_in (str or Path) – Path to directory where the parameter file symlinks are collected.
loadables (mapping) – Mapping from loadable key to object. This connects the keys to the actual object instances.
paths (mapping) – Mapping from loadable key to filepath. The last part of the path is treated as file name, the rest of it is treated as a “source” which can be either a directory path or a magic source like Huggingface hub ID. e.g. sb/asr-crdnn-libri/lm.ckpt -> source=sb/asr-crdnn-libri, file=lm.ckpt Note that when collecting, you can specify a default source, which is used for all loadables that don’t have a path specified.
custom_hooks (mapping) – Mapping from loadable key to parameter transfer hook function. If you want to use a custom loading function, specify it here.
- add_loadables(loadables)[source]¶
Update the loadables dict from the given mapping.
- Parameters
loadables (mapping) – Mapping from loadable key to object
- add_paths(paths)[source]¶
Update the paths for different loadables.
When collecting parameters, paths here are preferred. Note that when collecting, you can specify a default source, which is used for all loadables that don’t have a path specified.
- Parameters
paths (mapping) – Mapping from loadable key to filepath. The last part of the path is treated as file name, the rest of it is treated as a “source” which can be either a directory path or a magic source like Huggingface hub ID. e.g. sb/asr-crdnn-libri/lm.ckpt -> source=sb/asr-crdnn-libri, file=lm.ckpt
- add_custom_hooks(custom_hooks)[source]¶
Update the custom hooks.
When loading parameters, hooks here are preferred over class defaults.
- Parameters
custom_hooks (mapping) – Mapping from loadable key to parameter transfer hook function. If you want to use a custom loading function, specify it here.
- static split_path(path)[source]¶
Splits a path to source and filename
This also handles URLs and Huggingface hub paths, in addition to regular paths.
- Parameters
path (str) –
- Returns
str – Source
str – Filename
- collect_files(default_source=None)[source]¶
Fetches parameters from known paths with fallback default_source
The actual parameter files may reside elsewhere, but this ensures a symlink in the self.collect_in directory. The symlink always uses the loadable key in the filename. This standardization makes it easier to orchestrate pretraining on e.g. distributed setups.
Use the default_source if you have everything organized neatly into one location, like a Huggingface hub repo.
- Parameters
default_source (str or Path) – This is used for each loadable which doesn’t have a path already specified. If the loadable has key “asr”, then the file to look for is default_source/asr.ckpt
- Returns
Mapping from loadable key to a local path from which loadable’s parameters can be loaded. This is not used in this class, but can possibly be helpful.
- Return type