"""Convenience functions for the simplest parameter transfer cases.
Use `speechbrain.utils.checkpoints.Checkpointer` to find a checkpoint
and the path to the parameter file.
Authors
* Aku Rouhe 2020
* Andreas Nautsch 2023
* Adel Moumen 2023
"""
import pathlib
import platform
import warnings
from speechbrain.utils.checkpoints import (
DEFAULT_LOAD_HOOKS,
DEFAULT_TRANSFER_HOOKS,
PARAMFILE_EXT,
get_default_hook,
)
from speechbrain.utils.distributed import run_on_main
from speechbrain.utils.fetching import FetchSource, LocalStrategy, fetch
from speechbrain.utils.logger import get_logger
logger = get_logger(__name__)
[docs]
class Pretrainer:
"""Orchestrates pretraining
First optionally collects files from some source (local directory,
HuggingFace repository, base URL), into the `collect_in` directory, if
specified.
Then, calls load hooks for each of those files.
Arguments
---------
collect_in : str or Path, optional
Path to directory where the files are to be collected.
If `None`, then files will be referred to from cache or directly, if
possible (URLs will fail). There will not be a centralized target
directory with all the files.
loadables : mapping
Mapping from loadable key to object. This connects the keys to
the actual object instances.
paths : mapping
Mapping from loadable key to filepath. The last part
of the path is treated as file name, the rest of it
is treated as a "source" which can be either a directory
path or a magic source like Huggingface hub ID.
e.g. sb/asr-crdnn-libri/lm.ckpt
-> source=sb/asr-crdnn-libri, file=lm.ckpt
Note that when collecting, you can specify a default source,
which is used for all loadables that don't have a path specified.
custom_hooks : mapping
Mapping from loadable key to parameter transfer hook function. If you
want to use a custom loading function, specify it here.
conditions: mapping
An optional mapping from loadable keys to condition values,
useful for loading certain elements only if a flag is turned on
"""
def __init__(
self,
collect_in=None,
loadables=None,
paths=None,
custom_hooks=None,
conditions=None,
):
self.loadables = {}
self.set_collect_in(collect_in)
if loadables is not None:
self.add_loadables(loadables)
self.paths = {}
if paths is not None:
self.add_paths(paths)
self.custom_hooks = {}
if custom_hooks is not None:
self.add_custom_hooks(custom_hooks)
self.conditions = {}
if conditions is not None:
self.add_conditions(conditions)
self.is_local = []
[docs]
def set_collect_in(self, path):
"""Change the collecting path"""
self.collect_in = pathlib.Path(path) if path is not None else None
[docs]
def add_loadables(self, loadables):
"""Update the loadables dict from the given mapping.
Arguments
---------
loadables : mapping
Mapping from loadable key to object
"""
self.loadables.update(loadables)
[docs]
def add_paths(self, paths):
"""Update the paths for different loadables.
When collecting parameters, paths here are preferred. Note that when
collecting, you can specify a default source, which is used for all
loadables that don't have a path specified.
Arguments
---------
paths : mapping
Mapping from loadable key to filepath. The last part
of the path is treated as file name, the rest of it
is treated as a "source" which can be either a directory
path or a magic source like Huggingface hub ID.
e.g. sb/asr-crdnn-libri/lm.ckpt
-> source=sb/asr-crdnn-libri, file=lm.ckpt
"""
self.paths.update(paths)
[docs]
def add_custom_hooks(self, custom_hooks):
"""Update the custom hooks.
When loading parameters, hooks here are preferred over class defaults.
Arguments
---------
custom_hooks : mapping
Mapping from loadable key to parameter transfer hook function. If
you want to use a custom loading function, specify it here.
"""
self.custom_hooks.update(custom_hooks)
[docs]
def add_conditions(self, conditions):
"""Update the conditions.
Arguments
---------
conditions: mapping
Mapping from loadable keys to condition values,
useful for loading certain elements only if a flag is turned on
"""
self.conditions.update(conditions)
[docs]
@staticmethod
def split_path(path):
"""Splits a path to source and filename
This also handles URLs and Huggingface hub paths, in addition to
regular paths.
Arguments
---------
path : str
Returns
-------
str
Source
str
Filename
"""
def split(src):
"""Core function to split path."""
if "/" in src:
return src.rsplit("/", maxsplit=1)
else:
# Interpret as path to file in current directory.
return "./", src
if isinstance(path, FetchSource):
fetch_from, fetch_path = path
source, filename = split(fetch_path)
return FetchSource(fetch_from, source), filename
else:
return split(path)
[docs]
def collect_files(
self,
default_source=None,
use_auth_token=False,
local_strategy: LocalStrategy = LocalStrategy.SYMLINK,
):
"""Fetches parameters from known paths with fallback default_source
The actual parameter files may reside elsewhere, but this ensures a
symlink in the self.collect_in directory. The symlink always uses the
loadable key in the filename. This standardization makes it easier to
orchestrate pretraining on e.g. distributed setups.
Use the default_source if you have everything organized neatly into one
location, like a Huggingface hub repo.
Arguments
---------
default_source : str or Path or FetchSource
This is used for each loadable which doesn't have a path already
specified.
e.g. if the loadable has key `"asr"`, then the file to look for is
`<default_source>/asr.ckpt`
use_auth_token : bool (default: False)
If true Huggingface's auth_token will be used to load private models from the HuggingFace Hub,
default is False because the majority of models are public.
local_strategy : speechbrain.utils.fetching.LocalStrategy
The fetching strategy to use, which controls the behavior of remote file
fetching with regards to symlinking and copying.
Ignored if a `collect_in` directory was not specified.
See :func:`speechbrain.utils.fetching.fetch` for further details.
Returns
-------
dict
Mapping from loadable key to a local path from which loadable's
parameters can be loaded. This is not used in this class, but
can possibly be helpful.
"""
if self.collect_in is not None:
logger.debug(
f"Collecting files (or symlinks) for pretraining in {self.collect_in}."
)
self.collect_in.mkdir(exist_ok=True)
if (
platform.system() == "Windows"
and local_strategy == LocalStrategy.SYMLINK
):
warnings.warn(
"Requested Pretrainer collection using symlinks on Windows. This might not work; see `LocalStrategy` documentation. Consider unsetting `collect_in` in Pretrainer to avoid symlinking altogether."
)
else:
logger.debug(
"Fetching files for pretraining (no collection directory set)"
)
loadable_paths = {}
for name in self.loadables:
if not self.is_loadable(name):
continue
save_filename = name + PARAMFILE_EXT
if name in self.paths:
source, filename = self.split_path(self.paths[name])
elif default_source is not None:
filename = save_filename
source = default_source
else:
raise ValueError(
f"Path not specified for '{name}', "
"and no default_source given!"
)
fetch_kwargs = {
"filename": filename,
"source": source,
"savedir": self.collect_in,
"overwrite": False,
"save_filename": save_filename,
"use_auth_token": use_auth_token,
"revision": None,
"local_strategy": local_strategy,
}
path = None
def run_fetch(**kwargs):
"""Very basic local wrapper to fetch to store the path in a
local of collect_files
Arguments
---------
**kwargs : dict
Arguments to forward to fetch"""
nonlocal path
path = fetch(**kwargs)
# run fetch() on the main process, potentially performing downloading
# which we do NOT want to happen concurrently.
#
# then, if there are any non-main processes, run fetch() on them to
# resolve the path.
#
# path needs to be available only if it is a local source w/o symlink
run_on_main(
run_fetch,
kwargs=fetch_kwargs,
post_func=run_fetch,
post_kwargs=fetch_kwargs,
)
loadable_paths[name] = path
if isinstance(source, FetchSource):
_fetch_from, source = source
logger.debug(f'Set local path in self.paths["{name}"] = {path}')
self.paths[name] = str(path)
self.is_local.append(name)
return loadable_paths
[docs]
def is_loadable(self, name):
"""Returns True if no condition is defined or for the specified
loadable or if the condition is true
Arguments
---------
name: str
the name of the loadable
Returns
-------
is_loadable: bool
whether the item should be loaded
"""
if name not in self.conditions:
return True
condition = self.conditions[name]
if callable(condition):
return condition()
else:
return bool(condition)
[docs]
def load_collected(self):
"""Loads the files that have been collected."""
logger.info(
f"Loading pretrained files for: {', '.join(self.loadables)}"
)
paramfiles = {}
for name in self.loadables:
if not self.is_loadable(name):
continue
filename = name + PARAMFILE_EXT
if name in self.is_local:
logger.debug(
f"Redirecting (loading from local path): {name} -> {self.paths[name]}"
)
paramfiles[name] = self.paths[name]
elif self.collect_in is not None:
paramfiles[name] = self.collect_in / filename
else:
raise ValueError(
f'Pretrainer has never collected `{name}`, did you forget a call to `collect_files`? Could not fall back to `collect_in`, as it was not specified (default is no longer "model_checkpoints").'
)
self._call_load_hooks(paramfiles)
def _call_load_hooks(self, paramfiles):
# This internal function finds the correct hook to call for every
# recoverable, and calls it.
for name, obj in self.loadables.items():
if not self.is_loadable(name):
continue
loadpath = paramfiles[name]
# First see if object has custom load hook:
if name in self.custom_hooks:
self.custom_hooks[name](obj, loadpath)
continue
# Try the default transfer hook:
default_hook = get_default_hook(obj, DEFAULT_TRANSFER_HOOKS)
if default_hook is not None:
default_hook(obj, loadpath)
continue
# Otherwise find the default loader for that type:
default_hook = get_default_hook(obj, DEFAULT_LOAD_HOOKS)
if default_hook is not None:
# Need to fake end-of-epoch:
end_of_epoch = False
default_hook(obj, loadpath, end_of_epoch)
continue
# If we got here, no custom hook or registered default hook exists
MSG = f"Don't know how to load {type(obj)}. Register default hook \
or add custom hook for this object."
raise RuntimeError(MSG)