speechbrain.utils.epoch_loop module

Implements a checkpointable epoch counter (loop), optionally integrating early stopping.

Authors
  • Aku Rouhe 2020

  • Davide Borra 2021

Summary

Classes:

EpochCounter

An epoch counter which can save and recall its state.

EpochCounterWithStopper

An epoch counter which can save and recall its state, integrating an early stopper by tracking a target metric.

Reference

class speechbrain.utils.epoch_loop.EpochCounter(limit)[source]

Bases: object

An epoch counter which can save and recall its state.

Use this as the iterator for epochs. Note that this iterator gives you the numbers from [1 … limit] not [0 … limit-1] as range(limit) would.

Example

>>> from speechbrain.utils.checkpoints import Checkpointer
>>> tmpdir = getfixture('tmpdir')
>>> epoch_counter = EpochCounter(10)
>>> recoverer = Checkpointer(tmpdir, {"epoch": epoch_counter})
>>> recoverer.recover_if_possible()
>>> # Now after recovery,
>>> # the epoch starts from where it left off!
>>> for epoch in epoch_counter:
...     # Run training...
...     ckpt = recoverer.save_checkpoint()
class speechbrain.utils.epoch_loop.EpochCounterWithStopper(limit, limit_to_stop, limit_warmup, direction)[source]

Bases: EpochCounter

An epoch counter which can save and recall its state, integrating an early stopper by tracking a target metric.

Parameters
  • limit (int) – maximum number of epochs

  • limit_to_stop (int) – maximum number of consecutive epochs without improvements in performance

  • limit_warmup (int) – number of epochs to wait until start checking for early stopping

  • direction ("max" or "min") – direction to optimize the target metric

Example

>>> limit = 10
>>> limit_to_stop = 5
>>> limit_warmup = 2
>>> direction = "min"
>>> epoch_counter = EpochCounterWithStopper(limit, limit_to_stop, limit_warmup, direction)
>>> for epoch in epoch_counter:
...     # Run training...
...     # Track a validation metric,
...     current_valid_metric = 0
...     # get the current valid metric (get current_valid_metric)
...     if epoch_counter.should_stop(current=epoch,
...                                  current_metric=current_valid_metric,):
...         epoch_counter.current = epoch_counter.limit  # skipping unpromising epochs
should_stop(current, current_metric)[source]

Returns True is training should stop (based on the performance metrics).