speechbrain.utils.checkpoints module

This module implements a checkpoint saver and loader.

A checkpoint in an experiment usually needs to save the state of many different things: the model parameters, optimizer parameters, what epoch is this, etc. The save format for a checkpoint is a directory, where each of these separate saveable things gets its own file. Additionally, a special file holds meta information about the checkpoint (by default just time of creation, but you can specify anything else you may wish, e.g. validation loss).

The interface for the checkpoint system requires you to specify what things to save. This approach is flexible and agnostic of how your experiment is actually run.

The interface requires you to specify names for each thing to save. This name is used to give the right parameter file to the right object when recovering.

Default saving and loading methods are only added for torch.nn.Modules (and their subclasses), and torch.optim.Optimizers. If those methods do not work for your object, you can specify your own saving and/or loading methods, either for a particular instance or a for a class.

Example

>>> # Toy example Module:
>>> class Recoverable(torch.nn.Module):
...     def __init__(self, param):
...         super().__init__()
...         self.param = torch.nn.Parameter(torch.tensor([param]))
...     def forward(self, x):
...         return x * self.param
>>> model = Recoverable(1.)
>>> tempdir = getfixture('tmpdir')
>>> # In simple cases, the module aims to have a terse syntax,
>>> # consisting of three steps.
>>> # 1. Specifying where to save checkpoints and what is included in a
>>> # checkpoint:
>>> checkpointer = Checkpointer(tempdir, {"network": model})
>>> # 2. Recover from the latest checkpoint, if one is found:
>>> checkpointer.recover_if_possible()
>>> # Run your experiment:
>>> data = [(0.1, 0.9), (0.3, 0.8)]
>>> for example, target in data:
...     loss = (model(example) - target)**2
...     # 3. Save checkpoints, and keep by default just one, the newest:
...     ckpt = checkpointer.save_and_keep_only()
Authors
  • Aku Rouhe 2020

Summary

Classes:

Checkpoint

NamedTuple describing one saved checkpoint

Checkpointer

Saves checkpoints and recovers from them.

Functions:

average_checkpoints

Average parameters from multiple checkpoints.

average_state_dicts

Produces an average state_dict from an iterator over state_dicts.

ckpt_recency

Recency as Checkpoint importance metric.

get_default_hook

Finds the default save/load hook to use with the given object.

mark_as_loader

Method decorator which marks given method as checkpoint loading hook.

mark_as_saver

Method decorator which marks given method as the checkpoint saving hook.

mark_as_transfer

Method decorator which marks given method as a parameter transfer hook.

register_checkpoint_hooks

Class decorator which registers the load, save and transfer hooks.

torch_parameter_transfer

Non-strict Torch Module state_dict load.

torch_recovery

Loads a torch.nn.Module state_dict from the given path instantly.

torch_save

Saves the obj’s parameters to path.

Reference

speechbrain.utils.checkpoints.torch_recovery(obj, path, end_of_epoch, device=None)[source]

Loads a torch.nn.Module state_dict from the given path instantly.

This can be made the default for torch.nn.Modules with: >>> DEFAULT_LOAD_HOOKS[torch.nn.Module] = torch_recovery

Parameters
  • obj (torch.nn.Module) – Instance for which to load the parameters.

  • path (str, pathlib.Path) – Path where to load from.

  • end_of_epoch (bool) – Whether the recovery comes from an end of epoch checkpoint.

  • device (str) – Torch device, where to map the loaded parameters.

Returns

Given object is modified in place.

Return type

None

speechbrain.utils.checkpoints.torch_save(obj, path)[source]

Saves the obj’s parameters to path.

Default save hook for torch.nn.Modules For saving torch.nn.Module state_dicts.

Parameters
Returns

State dict is written to disk.

Return type

None

speechbrain.utils.checkpoints.torch_parameter_transfer(obj, path, device)[source]

Non-strict Torch Module state_dict load.

Loads a set of parameters from path to obj. If obj has layers for which parameters can’t be found, only a warning is logged. Same thing if the path has parameters for layers which don’t find a counterpart in obj.

Parameters
  • obj (torch.nn.Module) – Instance for which to load the parameters.

  • path (str) – Path where to load from.

Returns

The object is modified in place.

Return type

None

speechbrain.utils.checkpoints.mark_as_saver(method)[source]

Method decorator which marks given method as the checkpoint saving hook.

See register_checkpoint_hooks for example.

Parameters

method (callable) – Method of the class to decorate. Must be callable with signature (instance, path) using positional arguments. This is satisfied by for example: def saver(self, path):

Note

This will not add the hook (not possible via a method decorator), you must also decorate the class with @register_checkpoint_hooks Only one method can be added as the hook.

speechbrain.utils.checkpoints.mark_as_loader(method)[source]

Method decorator which marks given method as checkpoint loading hook.

Parameters

method (callable) – Method of the class to decorate. Must be callable with signature (instance, path, end_of_epoch, device) using positional arguments. This is satisfied by for example: def loader(self, path, end_of_epoch, device):

Note

This will not add the hook (not possible via a method decorator), you must also decorate the class with @register_checkpoint_hooks Only one method can be added as the hook.

speechbrain.utils.checkpoints.mark_as_transfer(method)[source]

Method decorator which marks given method as a parameter transfer hook.

Parameters

method (callable) – Method of the class to decorate. Must be callable with signature (instance, path, device) using positional arguments. This is satisfied by for example: def loader(self, path, device):

Note

This will not add the hook (not possible via a method decorator), you must also decorate the class with @register_checkpoint_hooks Only one method can be added as the hook.

Note

The transfer hook is prioritized over the loader hook by the Pretrainer However, if no transfer hook is registered, the Pretrainer will use the loader hook.

speechbrain.utils.checkpoints.register_checkpoint_hooks(cls)[source]

Class decorator which registers the load, save and transfer hooks.

The hooks must have been marked with mark_as_loader and mark_as_saver, and possibly mark_as_transfer.

Parameters

cls (class) – Class to decorate

Example

>>> @register_checkpoint_hooks
... class CustomRecoverable:
...     def __init__(self, param):
...         self.param = int(param)
...
...     @mark_as_saver
...     def save(self, path):
...         with open(path, "w") as fo:
...             fo.write(str(self.param))
...
...     @mark_as_loader
...     def load(self, path, end_of_epoch, device=None):
...         del end_of_epoch  # Unused here
...         with open(path) as fi:
...             self.param = int(fi.read())
speechbrain.utils.checkpoints.get_default_hook(obj, default_hooks)[source]

Finds the default save/load hook to use with the given object.

Follows the Method Resolution Order, i.e., if no hook is registered for the class of the object itself, also searches classes which the object inherits from.

Parameters
  • obj (instance) – Instance of a class.

  • default_hooks (dict) – Mapping from classes to (checkpointing hook) functions.

Returns

Return type

The correct method or None if no method is registered.

Example

>>> a = torch.nn.Module()
>>> get_default_hook(a, DEFAULT_SAVE_HOOKS) == torch_save
True
class speechbrain.utils.checkpoints.Checkpoint(path, meta, paramfiles)

Bases: tuple

NamedTuple describing one saved checkpoint

To select a checkpoint to load from many checkpoint, Checkpoints are first filtered and sorted based on this namedtuple. Checkpointers put pathlib.Path in path and a dict in meta. You can essentially add any info you want to meta when saving a checkpoint. The only default key in meta is “unixtime”. Checkpoint.paramfiles is a dict from recoverable name to parameter filepath.

meta

Alias for field number 1

paramfiles

Alias for field number 2

path

Alias for field number 0

speechbrain.utils.checkpoints.ckpt_recency(ckpt)[source]

Recency as Checkpoint importance metric.

This function can also act as an example of how to make checkpoint importance keyfuncs. This is a named function, but as you can see it could be easily implemented as a lambda in a pinch.

class speechbrain.utils.checkpoints.Checkpointer(checkpoints_dir, recoverables=None, custom_load_hooks=None, custom_save_hooks=None, allow_partial_load=False)[source]

Bases: object

Saves checkpoints and recovers from them.

Arguments:

checkpoints_dirstr, pathlib.Path

Path to directory where to save checkpoints.

recoverablesmapping, optional

Objects to to recover. They need a (unique) name: this is used to connect the parameters in a checkpoint to the correct recoverable. The name is also used in the filename of the savefile for the objects parameters. These can also be added with add_recoverable or add_recoverables or just modifying checkpointer.recoverables directly.

custom_load_hooksmapping, optional

A mapping from name [same as in recoverables] to function or method. Sets a custom loading hook for a particular object. The function/method must be callable with signature (instance, path) using positional arguments. This is satisfied by for example: def loader(self, path).

custom_save_hooksmapping, optional

Mapping from name [same as in recoverables] to function or method. Sets a custom saving hook for a particular object. The function/method must be callable with signature (instance, path) using positional arguments. This is satisfied by for example: def saver(self, path):

allow_partial_loadbool, optional

If True, allows loading a checkpoint where a savefile is not found for every registered recoverable. In that case, only the found savefiles are loaded. When False, loading such a save will raise RuntimeError. (default: False)

Example

>>> import torch
>>> #SETUP:
>>> tempdir = getfixture('tmpdir')
>>> class Recoverable(torch.nn.Module):
...     def __init__(self, param):
...         super().__init__()
...         self.param = torch.nn.Parameter(torch.tensor([param]))
...     def forward(self, x):
...         return x * self.param
>>> recoverable = Recoverable(1.)
>>> recoverables = {'recoverable': recoverable}
>>> # SETUP DONE.
>>> checkpointer = Checkpointer(tempdir, recoverables)
>>> first_ckpt = checkpointer.save_checkpoint()
>>> recoverable.param.data = torch.tensor([2.])
>>> loaded_ckpt = checkpointer.recover_if_possible()
>>> # Parameter has been loaded:
>>> assert recoverable.param.data == torch.tensor([1.])
>>> # With this call, by default, oldest checkpoints are deleted:
>>> checkpointer.save_and_keep_only()
>>> assert first_ckpt not in checkpointer.list_checkpoints()
add_recoverable(name, obj, custom_load_hook=None, custom_save_hook=None)[source]

Register a recoverable with possible custom hooks.

Parameters
  • name (str) – Unique name for recoverable. Used to map savefiles to objects.

  • obj (instance) – The object to recover.

  • custom_load_hook (callable) – Called to load the object’s savefile. The function/method must be callable with signature (instance, path) using positional arguments. This is satisfied by for example: def load(self, path):

  • custom_save_hook (callable) – Called to save the object’s parameters. The function/method must be callable with signature (instance, path) using positional arguments. This is satisfied by for example: def saver(self, path):

add_recoverables(recoverables)[source]

Update the recoverables dict from the given mapping.

Parameters

recoverables (mapping) – Objects to recover. They need a (unique) name: this is used to connect the parameters in a checkpoint to the correct recoverable. The name is also used in the filename of the savefile for the objects parameters.

save_checkpoint(meta={}, end_of_epoch=True, name=None, verbosity=20)[source]

Saves a checkpoint.

The whole checkpoint becomes a directory. Saves each registered object’s parameters in a separate file. Also a meta file is added. The meta file by default has just the unixtime (seconds since unix epoch), but you can add anything relevant yourself. The meta information is later used to pick the checkpoint to load.

The value of end_of_epoch is saved in the meta. This can affect how epoch counters and dataset iterators load their state.

Parameters
  • meta (mapping, optional) – A mapping which is added to the meta file in the checkpoint. The key “unixtime” is included by default.

  • end_of_epoch (bool, optional) – Whether the checkpoint is at the end of an epoch. True by default. May affect loading.

  • name (str, optional) – Specify a custom name for your checkpoint. The name will still have a prefix added. If no name is given, a name is created from a timestamp and a random unique id.

  • verbosity (logging level) – Set logging level this save.

Returns

namedtuple [see above], the saved checkpoint.

Return type

Checkpoint

save_and_keep_only(meta={}, end_of_epoch=True, name=None, num_to_keep=1, keep_recent=True, importance_keys=[], max_keys=[], min_keys=[], ckpt_predicate=None, verbosity=20)[source]

Saves a checkpoint, then deletes the least important checkpoints.

Essentially this combines save_checkpoint() and delete_checkpoints() in one call, providing short syntax.

Parameters
  • meta (mapping, optional) – A mapping which is added to the meta file in the checkpoint. The key “unixtime” is included by default.

  • end_of_epoch (bool, optional) – Whether the checkpoint is at the end of an epoch. True by default. May affect loading.

  • name (str, optional) – Specify a custom name for your checkpoint. The name will still have a prefix added. If no name is given, a name is created from a timestamp and a random unique id.

  • num_to_keep (int, optional) – Number of checkpoints to keep. Defaults to 1. This deletes all checkpoints remaining after filtering. Must be >=0.

  • keep_recent (bool, optional) – Whether to keep the most recent num_to_keep checkpoints.

  • importance_keys (list, optional) – A list of key functions used in sorting (see the sorted built-in). Each callable defines a sort order and num_to_keep checkpoints are kept for callable. The checkpoint with the highest keys are kept. The functions are passed Checkpoint namedtuples (see above).

  • max_keys (list, optional) – A list of keys for which the highest value will be kept.

  • min_keys (list, optional) – A list of keys for which the lowest value will be kept.

  • ckpt_predicate (callable, optional) – Use this to exclude some checkpoints from deletion. Before any sorting, the list of checkpoints is filtered with this predicate. Only the checkpoints for which ckpt_predicate is True can be deleted. The function is called with Checkpoint namedtuples (see above).

Returns

Unlike save_checkpoint, this does not return anything, since we cannot guarantee that the saved checkpoint actually survives deletion.

Return type

None

find_checkpoint(importance_key=None, max_key=None, min_key=None, ckpt_predicate=None)[source]

Picks a particular checkpoint from all available checkpoints.

If none of importance_key, max_key, and min_key is used, then most recent checkpoint will be returned. No more than one of them may be used.

Most functionality is actually implemented in find_checkpoints() but this is kept as a useful interface.

Parameters
  • importance_key (callable, optional) – The key function used in sorting. The checkpoint with the highest returned value is picked. The function is called with Checkpoint namedtuples.

  • max_key (str, optional) – The checkpoint with the highest value for this key will be returned. Only checkpoints with this key will be considered!

  • min_key (str, optional) – The checkpoint with the lowest value for this key will be returned. Only checkpoints with this key will be considered!

  • ckpt_predicate (callable, optional) – Before sorting, the list of checkpoints is filtered with this predicate. See the filter builtin. The function is called with Checkpoint namedtuples (see above). By default, all checkpoints are considered.

Returns

  • Checkpoint – If found.

  • None – If no Checkpoints exist/remain after filtering.

find_checkpoints(importance_key=None, max_key=None, min_key=None, ckpt_predicate=None, max_num_checkpoints=None)[source]

Picks multiple checkpoints.

If none of importance_key, max_key, and min_key is used, then the most recent checkpoints will be returned. No more than one of these may be used.

Parameters
  • importance_key (callable, optional) – The key function used in sorting. The checkpoint with the highest returned value is picked. The function is called with Checkpoint namedtuples.

  • max_key (str, optional) – The checkpoint with the highest value for this key will be returned. Only checkpoints with this key will be considered!

  • min_key (str, optional) – The checkpoint with the lowest value for this key will be returned. Only checkpoints with this key will be considered!

  • ckpt_predicate (callable, optional) – Before sorting, the list of checkpoints is filtered with this predicate. See the filter builtin. The function is called with Checkpoint namedtuples (see above). By default, all checkpoints are considered.

  • max_num_checkpoints (int, None) – The maximum number of checkpoints to return, or None to return all found checkpoints.

Returns

List containing at most the max specified number of Checkpoints.

Return type

list

recover_if_possible(importance_key=None, max_key=None, min_key=None, ckpt_predicate=None, device=None)[source]

Picks a checkpoint and recovers from that, if one is found.

If a checkpoint is not found, no recovery is run.

If none of importance_key, max_key, and min_key is used, then most recent checkpoint will be returned. No more than one of them may be used.

Parameters
  • importance_key (callable, optional) – The key function used in sorting. The checkpoint with the highest returned value is loaded. The function is called with Checkpoint namedtuples.

  • max_key (str, optional) – The checkpoint with the highest value for this key will be loaded. Only checkpoints with this key will be considered!

  • min_key (str, optional) – The checkpoint with the lowest value for this key will be loaded. Only checkpoints with this key will be considered!

  • ckpt_predicate (callable, optional) – Before sorting, the list of checkpoints is filtered with this predicate. See the filter builtin. The function is called with Checkpoint namedtuples (see above). By default, all checkpoints are considered.

  • device (torch.device) – Device to load models to.

Returns

  • Checkpoint – If found.

  • None – If no Checkpoints exist/remain after filtering.

load_checkpoint(checkpoint, device=None)[source]

Loads the specified checkpoint.

Parameters

checkpoint (Checkpoint) – Checkpoint to load.

list_checkpoints()[source]

List all checkpoints in the checkpoints directory.

Returns

List of Checkpoint namedtuple (see above).

Return type

list

delete_checkpoints(*, num_to_keep=1, min_keys=None, max_keys=None, importance_keys=[<function ckpt_recency>], ckpt_predicate=None, verbosity=20)[source]

Deletes least important checkpoints.

Since there can be many ways to define importance (e.g. lowest WER, lowest loss), the user should provide a list of sort key functions, each defining a particular importance order. In essence, each importance key function extracts one importance metric (higher is more important). For each of these orders, num_to_keep checkpoints are kept. However if there is overlap between each orders’ preserved checkpoints, the additional checkpoints are not preserved, so the total number of preserved checkpoints can be less than:

num_to_keep * len(importance_keys)
Parameters
  • num_to_keep (int, optional) – Number of checkpoints to keep. Defaults to 10. You choose to keep 0. This deletes all checkpoints remaining after filtering. Must be >=0

  • min_keys (list, optional) – List of strings representing keys in the meta. The lowest of these values will be kept, up to num_to_keep.

  • max_keys (list, optional) – List of strings representing keys in the meta. The highest of these values will be kept, up to num_to_keep.

  • importance_keys (list, optional) – A list of key functions used in sorting (see the sorted built-in). Each callable defines a sort order and num_to_keep checkpoints are kept for callable. To be clear, those with the highest key are kept. The functions are called with Checkpoint namedtuples (see above). See also the default (ckpt_recency, above). The default deletes all but the latest checkpoint.

  • ckpt_predicate (callable, optional) – Use this to exclude some checkpoints from deletion. Before any sorting, the list of checkpoints is filtered with this predicate. Only the checkpoints for which ckpt_predicate is True can be deleted. The function is called with Checkpoint namedtuples (see above).

  • verbosity (logging level) – Set logging level for this deletion.

Note

Must be called with keyword arguments, as a signoff that you know what you are doing. Deletion is permanent.

speechbrain.utils.checkpoints.average_state_dicts(state_dicts)[source]

Produces an average state_dict from an iterator over state_dicts.

Note that at one time, this keeps two of the state_dicts in memory, which is the minimum memory requirement.

Parameters

state_dicts (iterator, list) – The state_dicts to average.

Returns

The averaged state_dict.

Return type

state_dict

speechbrain.utils.checkpoints.average_checkpoints(checkpoint_list, recoverable_name, parameter_loader=<function load>, averager=<function average_state_dicts>, device=None)[source]

Average parameters from multiple checkpoints.

Use Checkpointer.find_checkpoints() to get the list of checkpoints to average over. Averaging parameters from some of the last checkpoints in training has been shown to sometimes improve performance.

The default loader and averager work for standard PyTorch modules.

Parameters
  • checkpoint_list (list) – List of checkpoints to average.

  • recoverable_name (str) – The name of the recoverable, the parameters of which are loaded and averaged.

  • parameter_loader (function) – A function which takes a single argument, the path to a parameter file, and loads the parameters from that file. By default, torch.load, which produces state_dict dictionaries.

  • averager (function) – A function which takes an iterator over the parameters from each checkpoint, as loaded by parameter_loader, and produces their average. Note that the function is called with an iterator, so the length is initially unknown; the implementation should simply count the number of different parameter sets as they are yielded. See average_state_dicts above for an example. It is the default averager, and averages state_dicts.

Returns

The output of the averager function.

Return type

Any

Example

>>> # Consider this toy Module again:
>>> class Recoverable(torch.nn.Module):
...     def __init__(self, param):
...         super().__init__()
...         self.param = torch.nn.Parameter(torch.tensor([param]))
...     def forward(self, x):
...         return x * self.param
>>> # Now let's make some checkpoints:
>>> model = Recoverable(1.)
>>> tempdir = getfixture('tmpdir')
>>> checkpointer = Checkpointer(tempdir, {"model": model})
>>> for new_param in range(10):
...     model.param.data = torch.tensor([float(new_param)])
...     _ = checkpointer.save_checkpoint()  # Suppress output with assignment
>>> # Let's average the 3 latest checkpoints
>>> # (parameter values 7, 8, 9 -> avg=8)
>>> ckpt_list = checkpointer.find_checkpoints(max_num_checkpoints = 3)
>>> averaged_state = average_checkpoints(ckpt_list, "model")
>>> # Now load that state in the normal way:
>>> _ = model.load_state_dict(averaged_state)  # Suppress output
>>> model.param.data
tensor([8.])