Source code for speechbrain.utils.train_logger

"""Loggers for experiment monitoring.

Authors
 * Peter Plantinga 2020
"""
import logging

logger = logging.getLogger(__name__)


[docs]class TrainLogger: """Abstract class defining an interface for training loggers."""
[docs] def log_stats( self, stats_meta, train_stats=None, valid_stats=None, test_stats=None, verbose=False, ): """Log the stats for one epoch. Arguments --------- stats_meta : dict of str:scalar pairs Meta information about the stats (e.g., epoch, learning-rate, etc.). train_stats : dict of str:list pairs Each loss type is represented with a str : list pair including all the values for the training pass. valid_stats : dict of str:list pairs Each loss type is represented with a str : list pair including all the values for the validation pass. test_stats : dict of str:list pairs Each loss type is represented with a str : list pair including all the values for the test pass. verbose : bool Whether to also put logging information to the standard logger. """ raise NotImplementedError
[docs]class FileTrainLogger(TrainLogger): """Text logger of training information. Arguments --------- save_file : str The file to use for logging train information. precision : int Number of decimal places to display. Default 2, example: 1.35e-5. summary_fns : dict of str:function pairs Each summary function should take a list produced as output from a training/validation pass and summarize it to a single scalar. """ def __init__(self, save_file, precision=2): self.save_file = save_file self.precision = precision def _item_to_string(self, key, value, dataset=None): """Convert one item to string, handling floats""" if isinstance(value, float) and 1.0 < value < 100.0: value = f"{value:.{self.precision}f}" elif isinstance(value, float): value = f"{value:.{self.precision}e}" if dataset is not None: key = f"{dataset} {key}" return f"{key}: {value}" def _stats_to_string(self, stats, dataset=None): """Convert all stats to a single string summary""" return ", ".join( [self._item_to_string(k, v, dataset) for k, v in stats.items()] )
[docs] def log_stats( self, stats_meta, train_stats=None, valid_stats=None, test_stats=None, verbose=True, ): """See TrainLogger.log_stats()""" string_summary = self._stats_to_string(stats_meta) for dataset, stats in [ ("train", train_stats), ("valid", valid_stats), ("test", test_stats), ]: if stats is not None: string_summary += " - " + self._stats_to_string(stats, dataset) with open(self.save_file, "a") as fout: print(string_summary, file=fout) if verbose: logger.info(string_summary)
[docs]class TensorboardLogger(TrainLogger): """Logs training information in the format required by Tensorboard. Arguments --------- save_dir : str A directory for storing all the relevant logs. Raises ------ ImportError if Tensorboard is not installed. """ def __init__(self, save_dir): self.save_dir = save_dir # Raises ImportError if TensorBoard is not installed from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(self.save_dir) self.global_step = {"train": {}, "valid": {}, "test": {}, "meta": 0}
[docs] def log_stats( self, stats_meta, train_stats=None, valid_stats=None, test_stats=None, verbose=False, ): """See TrainLogger.log_stats()""" self.global_step["meta"] += 1 for name, value in stats_meta.items(): self.writer.add_scalar(name, value, self.global_step["meta"]) for dataset, stats in [ ("train", train_stats), ("valid", valid_stats), ("test", test_stats), ]: if stats is None: continue for stat, value_list in stats.items(): if stat not in self.global_step[dataset]: self.global_step[dataset][stat] = 0 tag = f"{stat}/{dataset}" # Both single value (per Epoch) and list (Per batch) logging is supported if isinstance(value_list, list): for value in value_list: new_global_step = self.global_step[dataset][stat] + 1 self.writer.add_scalar(tag, value, new_global_step) self.global_step[dataset][stat] = new_global_step else: value = value_list new_global_step = self.global_step[dataset][stat] + 1 self.writer.add_scalar(tag, value, new_global_step) self.global_step[dataset][stat] = new_global_step