Source code for speechbrain.utils.pretrained

"""
Training utilities for pretrained models

Authors
* Artem Ploujnikov 2021
"""

import os
import shutil

from speechbrain.utils.logger import get_logger

logger = get_logger(__name__)


[docs] def save_for_pretrained( hparams, min_key=None, max_key=None, ckpt_predicate=None, pretrainer_key="pretrainer", checkpointer_key="checkpointer", ): """ Saves the necessary files for the pretrained model from the best checkpoint found. The goal of this function is to export the model for a Pretrainer Arguments --------- hparams: dict the hyperparameter file min_key: str Key to use for finding best checkpoint (lower is better). By default, passed to ``self.checkpointer.recover_if_possible()``. max_key: str Key to use for finding best checkpoint (higher is better). By default, passed to ``self.checkpointer.recover_if_possible()``. ckpt_predicate: callable a filter predicate to locate checkpoints pretrainer_key: str the key under which the pretrainer is stored checkpointer_key: str the key under which the checkpointer is stored Returns ------- saved: bool Whether the save was successful """ if any(key not in hparams for key in [pretrainer_key, checkpointer_key]): raise ValueError( f"Incompatible hparams: a checkpointer with key {checkpointer_key}" f"and a pretrainer with key {pretrainer_key} are required" ) pretrainer = hparams[pretrainer_key] checkpointer = hparams[checkpointer_key] checkpoint = checkpointer.find_checkpoint( min_key=min_key, max_key=max_key, ckpt_predicate=ckpt_predicate ) if checkpoint: logger.info( "Saving checkpoint '%s' a pretrained model", checkpoint.path ) pretrainer_keys = set(pretrainer.loadables.keys()) checkpointer_keys = set(checkpoint.paramfiles.keys()) keys_to_save = pretrainer_keys & checkpointer_keys for key in keys_to_save: source_path = checkpoint.paramfiles[key] if not os.path.exists(source_path): raise ValueError( f"File {source_path} does not exist in the checkpoint" ) target_path = pretrainer.paths[key] dirname = os.path.dirname(target_path) if not os.path.exists(dirname): os.makedirs(dirname) if os.path.exists(target_path): os.remove(target_path) shutil.copyfile(source_path, target_path) saved = True else: logger.info( "Unable to find a matching checkpoint for min_key = %s, max_key = %s", min_key, max_key, ) checkpoints = checkpointer.list_checkpoints() checkpoints_str = "\n".join( f"{checkpoint.path}: {checkpoint.meta}" for checkpoint in checkpoints ) logger.info("Available checkpoints: %s", checkpoints_str) saved = False return saved