speechbrain.utils.parameter_transfer module
Convenience functions for the simplest parameter transfer cases.
Use speechbrain.utils.checkpoints.Checkpointer to find a checkpoint and the path to the parameter file.
- Authors
Aku Rouhe 2020
Andreas Nautsch 2023
Summary
Classes:
Orchestrates pretraining |
Reference
- class speechbrain.utils.parameter_transfer.Pretrainer(collect_in='./model_checkpoints', loadables=None, paths=None, custom_hooks=None, conditions=None)[source]
Bases:
object
Orchestrates pretraining
First collects parameter file symlinks into the given directory. Then calls load hooks for each of those parameter files.
- Parameters:
collect_in (str or Path) – Path to directory where the parameter file symlinks are collected.
loadables (mapping) – Mapping from loadable key to object. This connects the keys to the actual object instances.
paths (mapping) – Mapping from loadable key to filepath. The last part of the path is treated as file name, the rest of it is treated as a “source” which can be either a directory path or a magic source like Huggingface hub ID. e.g. sb/asr-crdnn-libri/lm.ckpt -> source=sb/asr-crdnn-libri, file=lm.ckpt Note that when collecting, you can specify a default source, which is used for all loadables that don’t have a path specified.
custom_hooks (mapping) – Mapping from loadable key to parameter transfer hook function. If you want to use a custom loading function, specify it here.
conditions (mapping) – An optional mapping from loadable keys to condition values, useful for loading certain elements only if a flag is turned on
- add_loadables(loadables)[source]
Update the loadables dict from the given mapping.
- Parameters:
loadables (mapping) – Mapping from loadable key to object
- add_paths(paths)[source]
Update the paths for different loadables.
When collecting parameters, paths here are preferred. Note that when collecting, you can specify a default source, which is used for all loadables that don’t have a path specified.
- Parameters:
paths (mapping) – Mapping from loadable key to filepath. The last part of the path is treated as file name, the rest of it is treated as a “source” which can be either a directory path or a magic source like Huggingface hub ID. e.g. sb/asr-crdnn-libri/lm.ckpt -> source=sb/asr-crdnn-libri, file=lm.ckpt
- add_custom_hooks(custom_hooks)[source]
Update the custom hooks.
When loading parameters, hooks here are preferred over class defaults.
- Parameters:
custom_hooks (mapping) – Mapping from loadable key to parameter transfer hook function. If you want to use a custom loading function, specify it here.
- add_conditions(conditions)[source]
Update the conditions.
- Parameters:
conditions (mapping) – Mapping from loadable keys to condition values, useful for loading certain elements only if a flag is turned on
- static split_path(path)[source]
Splits a path to source and filename
This also handles URLs and Huggingface hub paths, in addition to regular paths.
- Parameters:
path (str) –
- Returns:
str – Source
str – Filename
- collect_files(default_source=None, internal_ddp_handling=False)[source]
Fetches parameters from known paths with fallback default_source
The actual parameter files may reside elsewhere, but this ensures a symlink in the self.collect_in directory. The symlink always uses the loadable key in the filename. This standardization makes it easier to orchestrate pretraining on e.g. distributed setups.
Use the default_source if you have everything organized neatly into one location, like a Huggingface hub repo.
- Parameters:
default_source (str or Path or FetchSource) – This is used for each loadable which doesn’t have a path already specified. If the loadable has key “asr”, then the file to look for is default_source/asr.ckpt
internal_ddp_handling (bool) – Whether/not the function should handle DDP i.e. run_on_main. (Default: False)
- Returns:
Mapping from loadable key to a local path from which loadable’s parameters can be loaded. This is not used in this class, but can possibly be helpful.
- Return type: