speechbrain.utils.parameter_transfer module

Convenience functions for the simplest parameter transfer cases.

Use speechbrain.utils.checkpoints.Checkpointer to find a checkpoint and the path to the parameter file.

Authors
  • Aku Rouhe 2020

  • Andreas Nautsch 2023

  • Adel Moumen 2023

Summary

Classes:

Pretrainer

Orchestrates pretraining

Reference

class speechbrain.utils.parameter_transfer.Pretrainer(collect_in='./model_checkpoints', loadables=None, paths=None, custom_hooks=None, conditions=None)[source]

Bases: object

Orchestrates pretraining

First collects parameter file symlinks into the given directory. Then calls load hooks for each of those parameter files.

Parameters:
  • collect_in (str or Path) – Path to directory where the parameter file symlinks are collected.

  • loadables (mapping) – Mapping from loadable key to object. This connects the keys to the actual object instances.

  • paths (mapping) – Mapping from loadable key to filepath. The last part of the path is treated as file name, the rest of it is treated as a “source” which can be either a directory path or a magic source like Huggingface hub ID. e.g. sb/asr-crdnn-libri/lm.ckpt -> source=sb/asr-crdnn-libri, file=lm.ckpt Note that when collecting, you can specify a default source, which is used for all loadables that don’t have a path specified.

  • custom_hooks (mapping) – Mapping from loadable key to parameter transfer hook function. If you want to use a custom loading function, specify it here.

  • conditions (mapping) – An optional mapping from loadable keys to condition values, useful for loading certain elements only if a flag is turned on

set_collect_in(path)[source]

Change the collecting path

add_loadables(loadables)[source]

Update the loadables dict from the given mapping.

Parameters:

loadables (mapping) – Mapping from loadable key to object

add_paths(paths)[source]

Update the paths for different loadables.

When collecting parameters, paths here are preferred. Note that when collecting, you can specify a default source, which is used for all loadables that don’t have a path specified.

Parameters:

paths (mapping) – Mapping from loadable key to filepath. The last part of the path is treated as file name, the rest of it is treated as a “source” which can be either a directory path or a magic source like Huggingface hub ID. e.g. sb/asr-crdnn-libri/lm.ckpt -> source=sb/asr-crdnn-libri, file=lm.ckpt

add_custom_hooks(custom_hooks)[source]

Update the custom hooks.

When loading parameters, hooks here are preferred over class defaults.

Parameters:

custom_hooks (mapping) – Mapping from loadable key to parameter transfer hook function. If you want to use a custom loading function, specify it here.

add_conditions(conditions)[source]

Update the conditions.

Parameters:

conditions (mapping) – Mapping from loadable keys to condition values, useful for loading certain elements only if a flag is turned on

static split_path(path)[source]

Splits a path to source and filename

This also handles URLs and Huggingface hub paths, in addition to regular paths.

Parameters:

path (str) –

Returns:

  • str – Source

  • str – Filename

collect_files(default_source=None, internal_ddp_handling=False)[source]

Fetches parameters from known paths with fallback default_source

The actual parameter files may reside elsewhere, but this ensures a symlink in the self.collect_in directory. The symlink always uses the loadable key in the filename. This standardization makes it easier to orchestrate pretraining on e.g. distributed setups.

Use the default_source if you have everything organized neatly into one location, like a Huggingface hub repo.

Parameters:
  • default_source (str or Path or FetchSource) – This is used for each loadable which doesn’t have a path already specified. If the loadable has key “asr”, then the file to look for is default_source/asr.ckpt

  • internal_ddp_handling (bool) – Whether/not the function should handle DDP i.e. run_on_main. (Default: False)

Returns:

Mapping from loadable key to a local path from which loadable’s parameters can be loaded. This is not used in this class, but can possibly be helpful.

Return type:

dict

is_loadable(name)[source]

Returns True if no condition is defined or for the specified loadable or if the condition is true

Parameters:

name (str) – the name of the loadable

Returns:

is_loadable – whether the item should be loaded

Return type:

bool

load_collected(device=None)[source]

Loads the files that have been collected.

Parameters:

device (str) – Device on which to load, if you want to load to a specific device directly ( otherwise just leave it to None ).