Source code for speechbrain.pretrained.training

"""
Training utilities for pretrained models

Authors
* Artem Ploujnikov 2021
"""
import os
import shutil


[docs]def save_for_pretrained( hparams, min_key=None, max_key=None, pretrainer_key="pretrainer", checkpointer_key="checkpointer", ): """ Saves the necessary files for the pretrained model from the best checkpoint found. The goal of this function is to export the model for a Pretrainer Arguments --------- hparams: dict the hyperparameter file max_key : str Key to use for finding best checkpoint (higher is better). By default, passed to ``self.checkpointer.recover_if_possible()``. min_key : str Key to use for finding best checkpoint (lower is better). By def ault, passed to ``self.checkpointer.recover_if_possible()``. checkpointer_key: str the key under which the checkpointer is stored pretrained_key: str the key under which the pretrainer is stored """ if any(key not in hparams for key in [pretrainer_key, checkpointer_key]): raise ValueError( f"Incompatible hparams: a checkpointer with key {checkpointer_key}" f"and a pretrainer with key {pretrainer_key} are required" ) pretrainer = hparams[pretrainer_key] checkpointer = hparams[checkpointer_key] checkpoint = checkpointer.find_checkpoint(min_key=min_key, max_key=max_key) pretrainer_keys = set(pretrainer.loadables.keys()) checkpointer_keys = set(checkpoint.paramfiles.keys()) keys_to_save = pretrainer_keys & checkpointer_keys for key in keys_to_save: source_path = checkpoint.paramfiles[key] if not os.path.exists(source_path): raise ValueError( f"File {source_path} does not exist in the checkpoint" ) target_path = pretrainer.paths[key] dirname = os.path.dirname(target_path) if not os.path.exists(dirname): os.makedirs(dirname) if os.path.exists(target_path): os.remove(target_path) shutil.copyfile(source_path, target_path)