speechbrain.utils.epoch_loop module

Implements a checkpointable epoch counter (loop), optionally integrating early stopping.

Authors
  • Aku Rouhe 2020

  • Davide Borra 2021

Summary

Classes:

EpochCounter

An epoch counter which can save and recall its state.

EpochCounterWithStopper

An epoch counter which can save and recall its state, integrating an early stopper by tracking a target metric.

Reference

class speechbrain.utils.epoch_loop.EpochCounter(limit)[source]

Bases: object

An epoch counter which can save and recall its state.

Use this as the iterator for epochs. Note that this iterator gives you the numbers from [1 … limit] not [0 … limit-1] as range(limit) would.

Parameters:

limit (int) – maximum number of epochs

Example

>>> from speechbrain.utils.checkpoints import Checkpointer
>>> tmpdir = getfixture('tmpdir')
>>> epoch_counter = EpochCounter(10)
>>> recoverer = Checkpointer(tmpdir, {"epoch": epoch_counter})
>>> recoverer.recover_if_possible()
>>> # Now after recovery,
>>> # the epoch starts from where it left off!
>>> for epoch in epoch_counter:
...     # Run training...
...     ckpt = recoverer.save_checkpoint()
class speechbrain.utils.epoch_loop.EpochCounterWithStopper(limit, limit_to_stop, limit_warmup, direction)[source]

Bases: EpochCounter

An epoch counter which can save and recall its state, integrating an early stopper by tracking a target metric.

Parameters:
  • limit (int) – maximum number of epochs

  • limit_to_stop (int) – maximum number of consecutive epochs without improvements in performance

  • limit_warmup (int) – number of epochs to wait until start checking for early stopping

  • direction ("max" or "min") – direction to optimize the target metric

Example

>>> limit = 10
>>> limit_to_stop = 5
>>> limit_warmup = 2
>>> direction = "min"
>>> epoch_counter = EpochCounterWithStopper(limit, limit_to_stop, limit_warmup, direction)
>>> for epoch in epoch_counter:
...     # Run training...
...     # Track a validation metric, (insert calculation here)
...     current_valid_metric = 0
...     # Update epoch counter so that we stop at the appropriate time
...     epoch_counter.update_metric(current_valid_metric)
...     print(epoch)
1
2
3
4
5
6
7
8
__next__()[source]

Stop iteration if we’ve reached the condition.

update_metric(current_metric)[source]

Update the state to reflect most recent value of the relevant metric.

NOTE: Should be called only once per validation loop.

Parameters:

current_metric (float) – The metric used to make a stopping decision.