Source code for speechbrain.utils.epoch_loop

"""Implements a checkpointable epoch counter (loop).

Authors
 * Aku Rouhe 2020
"""
from .checkpoints import register_checkpoint_hooks
from .checkpoints import mark_as_saver
from .checkpoints import mark_as_loader
import logging

logger = logging.getLogger(__name__)


[docs]@register_checkpoint_hooks class EpochCounter: """An epoch counter which can save and recall its state. Use this as the iterator for epochs. Note that this iterator gives you the numbers from [1 ... limit] not [0 ... limit-1] as range(limit) would. Example ------- >>> from speechbrain.utils.checkpoints import Checkpointer >>> tmpdir = getfixture('tmpdir') >>> epoch_counter = EpochCounter(10) >>> recoverer = Checkpointer(tmpdir, {"epoch": epoch_counter}) >>> recoverer.recover_if_possible() >>> # Now after recovery, >>> # the epoch starts from where it left off! >>> for epoch in epoch_counter: ... # Run training... ... ckpt = recoverer.save_checkpoint() """ def __init__(self, limit): self.current = 0 self.limit = int(limit) def __iter__(self): return self def __next__(self): if self.current < self.limit: self.current += 1 logger.info(f"Going into epoch {self.current}") return self.current raise StopIteration @mark_as_saver def _save(self, path): with open(path, "w") as fo: fo.write(str(self.current)) @mark_as_loader def _recover(self, path, end_of_epoch=True, device=None): # NOTE: end_of_epoch = True by default so that when # loaded in parameter transfer, this starts a new epoch. # However, parameter transfer to EpochCounter should # probably never be used really. del device # Not used. with open(path) as fi: saved_value = int(fi.read()) if end_of_epoch: self.current = saved_value else: self.current = saved_value - 1