"""Loggers for experiment monitoring.
Authors
* Peter Plantinga 2020
"""
import logging
logger = logging.getLogger(__name__)
[docs]class TrainLogger:
"""Abstract class defining an interface for training loggers."""
[docs] def log_stats(
self,
stats_meta,
train_stats=None,
valid_stats=None,
test_stats=None,
verbose=False,
):
"""Log the stats for one epoch.
Arguments
---------
stats_meta : dict of str:scalar pairs
Meta information about the stats (e.g., epoch, learning-rate, etc.).
train_stats : dict of str:list pairs
Each loss type is represented with a str : list pair including
all the values for the training pass.
valid_stats : dict of str:list pairs
Each loss type is represented with a str : list pair including
all the values for the validation pass.
test_stats : dict of str:list pairs
Each loss type is represented with a str : list pair including
all the values for the test pass.
verbose : bool
Whether to also put logging information to the standard logger.
"""
raise NotImplementedError
[docs]class FileTrainLogger(TrainLogger):
"""Text logger of training information.
Arguments
---------
save_file : str
The file to use for logging train information.
precision : int
Number of decimal places to display. Default 2, example: 1.35e-5.
summary_fns : dict of str:function pairs
Each summary function should take a list produced as output
from a training/validation pass and summarize it to a single scalar.
"""
def __init__(self, save_file, precision=2):
self.save_file = save_file
self.precision = precision
def _item_to_string(self, key, value, dataset=None):
"""Convert one item to string, handling floats"""
if isinstance(value, float) and 1.0 < value < 100.0:
value = f"{value:.{self.precision}f}"
elif isinstance(value, float):
value = f"{value:.{self.precision}e}"
if dataset is not None:
key = f"{dataset} {key}"
return f"{key}: {value}"
def _stats_to_string(self, stats, dataset=None):
"""Convert all stats to a single string summary"""
return ", ".join(
[self._item_to_string(k, v, dataset) for k, v in stats.items()]
)
[docs] def log_stats(
self,
stats_meta,
train_stats=None,
valid_stats=None,
test_stats=None,
verbose=True,
):
"""See TrainLogger.log_stats()"""
string_summary = self._stats_to_string(stats_meta)
for dataset, stats in [
("train", train_stats),
("valid", valid_stats),
("test", test_stats),
]:
if stats is not None:
string_summary += " - " + self._stats_to_string(stats, dataset)
with open(self.save_file, "a") as fout:
print(string_summary, file=fout)
if verbose:
logger.info(string_summary)
[docs]class TensorboardLogger(TrainLogger):
"""Logs training information in the format required by Tensorboard.
Arguments
---------
save_dir : str
A directory for storing all the relevant logs.
Raises
------
ImportError if Tensorboard is not installed.
"""
def __init__(self, save_dir):
self.save_dir = save_dir
# Raises ImportError if TensorBoard is not installed
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter(self.save_dir)
self.global_step = {"train": {}, "valid": {}, "test": {}, "meta": 0}
[docs] def log_stats(
self,
stats_meta,
train_stats=None,
valid_stats=None,
test_stats=None,
verbose=False,
):
"""See TrainLogger.log_stats()"""
self.global_step["meta"] += 1
for name, value in stats_meta.items():
self.writer.add_scalar(name, value, self.global_step["meta"])
for dataset, stats in [
("train", train_stats),
("valid", valid_stats),
("test", test_stats),
]:
if stats is None:
continue
for stat, value_list in stats.items():
if stat not in self.global_step[dataset]:
self.global_step[dataset][stat] = 0
tag = f"{stat}/{dataset}"
# Both single value (per Epoch) and list (Per batch) logging is supported
if isinstance(value_list, list):
for value in value_list:
new_global_step = self.global_step[dataset][stat] + 1
self.writer.add_scalar(tag, value, new_global_step)
self.global_step[dataset][stat] = new_global_step
else:
value = value_list
new_global_step = self.global_step[dataset][stat] + 1
self.writer.add_scalar(tag, value, new_global_step)
self.global_step[dataset][stat] = new_global_step