Source code for speechbrain.utils.epoch_loop

"""Implements a checkpointable epoch counter (loop), optionally integrating early stopping.

Authors
 * Aku Rouhe 2020
 * Davide Borra 2021
"""
from .checkpoints import register_checkpoint_hooks
from .checkpoints import mark_as_saver
from .checkpoints import mark_as_loader
import logging

logger = logging.getLogger(__name__)


[docs]@register_checkpoint_hooks class EpochCounter: """An epoch counter which can save and recall its state. Use this as the iterator for epochs. Note that this iterator gives you the numbers from [1 ... limit] not [0 ... limit-1] as range(limit) would. Example ------- >>> from speechbrain.utils.checkpoints import Checkpointer >>> tmpdir = getfixture('tmpdir') >>> epoch_counter = EpochCounter(10) >>> recoverer = Checkpointer(tmpdir, {"epoch": epoch_counter}) >>> recoverer.recover_if_possible() >>> # Now after recovery, >>> # the epoch starts from where it left off! >>> for epoch in epoch_counter: ... # Run training... ... ckpt = recoverer.save_checkpoint() """ def __init__(self, limit): self.current = 0 self.limit = int(limit) def __iter__(self): return self def __next__(self): if self.current < self.limit: self.current += 1 logger.info(f"Going into epoch {self.current}") return self.current raise StopIteration @mark_as_saver def _save(self, path): with open(path, "w") as fo: fo.write(str(self.current)) @mark_as_loader def _recover(self, path, end_of_epoch=True, device=None): # NOTE: end_of_epoch = True by default so that when # loaded in parameter transfer, this starts a new epoch. # However, parameter transfer to EpochCounter should # probably never be used really. del device # Not used. with open(path) as fi: saved_value = int(fi.read()) if end_of_epoch: self.current = saved_value else: self.current = saved_value - 1
[docs]class EpochCounterWithStopper(EpochCounter): """An epoch counter which can save and recall its state, integrating an early stopper by tracking a target metric. Arguments --------- limit: int maximum number of epochs limit_to_stop : int maximum number of consecutive epochs without improvements in performance limit_warmup : int number of epochs to wait until start checking for early stopping direction : "max" or "min" direction to optimize the target metric Example ------- >>> limit = 10 >>> limit_to_stop = 5 >>> limit_warmup = 2 >>> direction = "min" >>> epoch_counter = EpochCounterWithStopper(limit, limit_to_stop, limit_warmup, direction) >>> for epoch in epoch_counter: ... # Run training... ... # Track a validation metric, ... current_valid_metric = 0 ... # get the current valid metric (get current_valid_metric) ... if epoch_counter.should_stop(current=epoch, ... current_metric=current_valid_metric,): ... epoch_counter.current = epoch_counter.limit # skipping unpromising epochs """ def __init__(self, limit, limit_to_stop, limit_warmup, direction): super().__init__(limit) self.limit_to_stop = limit_to_stop self.limit_warmup = limit_warmup self.direction = direction self.best_limit = 0 self.min_delta = 1e-6 if self.limit_to_stop < 0: raise ValueError("Stopper 'limit_to_stop' must be >= 0") if self.limit_warmup < 0: raise ValueError("Stopper 'limit_warmup' must be >= 0") if self.direction == "min": self.th, self.sign = float("inf"), 1 elif self.direction == "max": self.th, self.sign = -float("inf"), -1 else: raise ValueError("Stopper 'direction' must be 'min' or 'max'")
[docs] def should_stop(self, current, current_metric): """Returns True is training should stop (based on the performance metrics).""" should_stop = False if current > self.limit_warmup: if self.sign * current_metric < self.sign * ( (1 - self.min_delta) * self.th ): self.best_limit = current self.th = current_metric should_stop = (current - self.best_limit) >= self.limit_to_stop return should_stop