speechbrain.nnet.schedulers module

Schedulers for updating hyperparameters (such as learning rate).

Authors
  • Mirco Ravanelli 2020

  • Peter Plantinga 2020

  • Loren Lugosch 2020

Summary

Classes:

CyclicCosineScheduler

The is an implementation of the Cyclic-Cosine learning rate scheduler with warmup.

CyclicLRScheduler

This implements a cyclical learning rate policy (CLR).

IntervalScheduler

A simple scheduler implementation that sets the learning rate to specific values after a specific number of steps has been reached. :param intervals: a list of dictionaries: {"steps": <number of steps>, "lr": the learning rate} 'steps' indicates the global step count at which a given rate will apply :type intervals: list.

InverseSquareRootScheduler

The Inverse Square Root Scheduler, as defined in the T5 paper https://arxiv.org/pdf/1910.10683.pdf :param warmup_steps: The number of steps over which the learning rate will be constant :type warmup_steps: int

LinearScheduler

Scheduler with linear annealing technique.

LinearWarmupScheduler

Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.

NewBobScheduler

Scheduler with new-bob technique, used for LR annealing.

NoamIntervalScheduler

A combination of Noam Scheduler and Interval Scheduler.

NoamScheduler

The is an implementation of the transformer's learning rate scheduler with warmup.

ReduceLROnPlateau

Learning rate scheduler which decreases the learning rate if the loss function of interest gets stuck on a plateau, or starts to increase.

ScheduledLoss

A convenience class for switching to a different loss function on a schedule

StepScheduler

Learning rate scheduler with step annealing technique.

TriStageLRSchedule

Warms up linearly, very slowly decays and cools down linearly again at the end of training.

WarmAndExpDecayLRSchedule

Warms up linearly, and then decay exponentially to ('lr' / 'decay_factor') in 'total_steps' steps.

WarmCoolDecayLRSchedule

Warms up linearly, very slowly decays and cools down linearly again at the end of training.

Functions:

update_learning_rate

Change the learning rate value within an optimizer.

Reference

speechbrain.nnet.schedulers.update_learning_rate(optimizer, new_lr, param_group=None)[source]

Change the learning rate value within an optimizer.

Parameters:
  • optimizer (torch.optim object) – Updates the learning rate for this optimizer.

  • new_lr (float) – The new value to use for the learning rate.

  • param_group (list of int) – The param group indices to update. If not provided, all groups updated.

Example

>>> from torch.optim import SGD
>>> from speechbrain.nnet.linear import Linear
>>> model = Linear(n_neurons=10, input_size=10)
>>> optimizer = SGD(model.parameters(), lr=0.1)
>>> update_learning_rate(optimizer, 0.2)
>>> optimizer.param_groups[0]["lr"]
0.2
class speechbrain.nnet.schedulers.WarmAndExpDecayLRSchedule(lr, n_warmup_steps, total_steps, decay_factor=0.1)[source]

Bases: object

Warms up linearly, and then decay exponentially to (‘lr’ / ‘decay_factor’) in ‘total_steps’ steps.

Parameters:
  • lr (float) – The max learning rate to reach after warmup.

  • warmup (int) – Number of warmup steps (following a linear increase).

  • total_steps (int) – Total number of steps (used to decay).

  • decay_factor (float) – Decay factor applied every decay_every steps. (default: 0.01)

Example

>>> from speechbrain.nnet.linear import Linear
>>> inp_tensor = torch.rand([1,660,3])
>>> model = Linear(input_size=3, n_neurons=4)
>>> optim = torch.optim.Adam(model.parameters(), lr=1)
>>> output = model(inp_tensor)
>>> scheduler = WarmAndExpDecayLRSchedule(lr=1, n_warmup_steps=2, decay_factor=0.01, total_steps=6)
>>> scheduler(optim)
>>> optim.param_groups[0]["lr"]
0.0
>>> scheduler(optim)
>>> optim.param_groups[0]["lr"]
0.5
>>> scheduler(optim)
>>> optim.param_groups[0]["lr"]
1
>>> scheduler(optim)
>>> optim.param_groups[0]["lr"]
0.31622776601683794
save(path)[source]

Saves the current metrics on the specified path.

load(path, end_of_epoch=False, device=None)[source]

Loads the needed information.

class speechbrain.nnet.schedulers.NewBobScheduler(initial_value, annealing_factor=0.5, improvement_threshold=0.0025, patient=0)[source]

Bases: object

Scheduler with new-bob technique, used for LR annealing.

The learning rate is annealed based on the validation performance. In particular: if (past_loss-current_loss)/past_loss< impr_threshold: lr=lr * annealing_factor.

Parameters:
  • initial_value (float) – The initial hyperparameter value.

  • annealing_factor (float) – It is annealing factor used in new_bob strategy.

  • improvement_threshold (float) – It is the improvement rate between losses used to perform learning annealing in new_bob strategy.

  • patient (int) – When the annealing condition is violated patient times, the learning rate is finally reduced.

Example

>>> scheduler = NewBobScheduler(initial_value=1.0)
>>> scheduler(metric_value=10.0)
(1.0, 1.0)
>>> scheduler(metric_value=2.0)
(1.0, 1.0)
>>> scheduler(metric_value=2.5)
(1.0, 0.5)
__call__(metric_value)[source]

Returns the current and new value for the hyperparameter.

Parameters:

metric_value (int) – A number for determining whether to change the hyperparameter value.

save(path)[source]

Saves the current metrics on the specified path.

load(path, end_of_epoch=False)[source]

Loads the needed information.

class speechbrain.nnet.schedulers.LinearScheduler(initial_value, final_value, epoch_count)[source]

Bases: object

Scheduler with linear annealing technique.

The learning rate linearly decays over the specified number of epochs.

Parameters:
  • initial_value (float) – The value upon initialization.

  • final_value (float) – The value used when the epoch count reaches epoch_count - 1.

  • epoch_count (int) – Number of epochs.

Example

>>> scheduler = LinearScheduler(1.0, 0.0, 4)
>>> scheduler(current_epoch=1)
(1.0, 0.666...)
>>> scheduler(current_epoch=2)
(0.666..., 0.333...)
>>> scheduler(current_epoch=3)
(0.333..., 0.0)
>>> scheduler(current_epoch=4)
(0.0, 0.0)
__call__(current_epoch)[source]

Returns the current and new value for the hyperparameter.

Parameters:

current_epoch (int) – Number of times the dataset has been iterated.

class speechbrain.nnet.schedulers.LinearWarmupScheduler(initial_value, num_warmup_steps, num_training_steps)[source]

Bases: object

Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. * Ge Li 2022

Parameters:
  • initial_value (float) – The value upon initialization (lr0).

  • num_warmup_steps (int) – Number of warmup steps. The learning rate reaches lr0 at num_warmup_steps + 1 step.

  • num_training_steps (int) – The total number of training steps.

Example

>>> scheduler = LinearWarmupScheduler(1.0, 2, 4)
>>> scheduler.get_next_value()
0.0
>>> scheduler.get_next_value()
0.5
>>> scheduler.get_next_value()
1.0
>>> scheduler.get_next_value()
0.5
>>> scheduler.get_next_value()
0.0
calculate_lr(current_step)[source]

Returns the current and new value for the hyperparameter.

Parameters:

current_step (int) – Number of steps the model has been updated.

get_next_value()[source]

Returns the next learning rate value for the hyperparameter.

save(path)[source]

Saves the current metrics on the specified path.

load(path, end_of_epoch=False)[source]

Loads the needed information.

class speechbrain.nnet.schedulers.StepScheduler(initial_value, decay_factor=None, decay_drop=None, half_life=None)[source]

Bases: object

Learning rate scheduler with step annealing technique.

The hyperparameter’s value decays over the epochs with the selected epoch_decay factor.

value = init_value * decay_factor ^ floor((1 + epoch) / decay_drop)

Parameters:
  • initial_value (float) – Initial value for the hyperparameter being updated.

  • decay_factor (float) – Factor multiplied with the initial_value

  • decay_drop (float) – Annealing factor (the decay of the hyperparameter value is faster with higher decay_drop values).

  • half_life (int) – A convenience parameter to set decay_factor such that the parameter will drop to half its value at the specified epoch. May not be used together with decay_factor or decay_drop

Example

>>> scheduler = StepScheduler(initial_value=1.0)
>>> scheduler(current_epoch=1)
(1.0, 0.5)
>>> scheduler(current_epoch=2)
(0.5, 0.5)
>>> scheduler(current_epoch=3)
(0.5, 0.25)
DEFAULT_DECAY_FACTOR = 0.5
DEFAULT_DECAY_DROP = 2
__call__(current_epoch)[source]

Returns current and new hyperparameter value.

Parameters:

current_epoch (int) – Number of times the dataset has been iterated.

class speechbrain.nnet.schedulers.NoamScheduler(lr_initial, n_warmup_steps, model_size=None)[source]

Bases: object

The is an implementation of the transformer’s learning rate scheduler with warmup. Reference: https://arxiv.org/abs/1706.03762

Note: this scheduler anneals the lr at each update of the model’s weight, and n_steps must be saved for restarting.

Parameters:
  • lr_initial (float) – Initial learning rate (i.e. the lr used at epoch 0).

  • n_warmup_steps (int) – numer of warm-up steps

  • model_size (int) – size of transformer embed_dim. It is used to scale the maximum learning rate value reached by the scheduler. It is divided by model_size ** (0.5). If not specified the maximum learning rate value is instead multiplied by warmup_steps ** (0.5).

Example

>>> from speechbrain.nnet.linear import Linear
>>> inp_tensor = torch.rand([1,660,3])
>>> model = Linear(input_size=3, n_neurons=4)
>>> optim = torch.optim.Adam(model.parameters(), lr=1)
>>> output = model(inp_tensor)
>>> scheduler =NoamScheduler(optim.param_groups[0]["lr"], 3)
>>> curr_lr,next_lr=scheduler(optim)
>>> optim.param_groups[0]["lr"]
0.3333333333333333
>>> curr_lr,next_lr=scheduler(optim)
>>> optim.param_groups[0]["lr"]
0.6666666666666666
>>> curr_lr,next_lr=scheduler(optim)
>>> optim.param_groups[0]["lr"]
0.9999999999999999
__call__(opt)[source]
Parameters:

opt (optimizer) – The optimizer to update using this scheduler.

Returns:

  • current_lr (float) – The learning rate before the update.

  • lr (float) – The learning rate after the update.

save(path)[source]

Saves the current metrics on the specified path.

load(path, end_of_epoch=False)[source]

Loads the needed information.

class speechbrain.nnet.schedulers.NoamIntervalScheduler(lr_initial, n_warmup_steps, anneal_steps, anneal_rates, model_size=None)[source]

Bases: object

A combination of Noam Scheduler and Interval Scheduler. The scheduler behaves as a Noam Scheduler, and anneals the learning rate at disigned steps with designed decays.

Note: this scheduler anneals the lr at each update of the model’s weight, and n_steps must be saved for restarting.

Parameters:
  • lr_initial (float) – Initial learning rate (i.e. the lr used at epoch 0).

  • n_warmup_steps (int) – numer of warm-up steps.

  • anneal_steps (list) – Pre-designed steps where the learning rate is to be annealed.

  • anneal_rates (list) – Pre-designed decay rate for each anneal step.

  • model_size (int) – size of transformer embed_dim. It is used to scale the maximum learning rate value reached by the scheduler. It is divided by model_size ** (0.5). If not specified the maximum learning rate value is instead multiplied by warmup_steps ** (0.5).

Example

>>> from speechbrain.nnet.linear import Linear
>>> inp_tensor = torch.rand([1,660,3])
>>> model = Linear(input_size=3, n_neurons=4)
>>> optim = torch.optim.Adam(model.parameters(), lr=1)
>>> output = model(inp_tensor)
>>> scheduler = NoamIntervalScheduler(
...    lr_initial=optim.param_groups[0]["lr"],
...    n_warmup_steps=3,
...    anneal_steps=[6, 9],
...    anneal_rates=[0.5, 0.1],
... )
>>> for _ in range(10):
...     curr_lr,next_lr=scheduler(optim)
...     print(optim.param_groups[0]["lr"])
0.3333333333333333
0.6666666666666666
0.9999999999999999
0.8660254037844386
0.7745966692414833
0.7071067811865475
0.3273268353539886
0.3061862178478973
0.28867513459481287
0.027386127875258306
__call__(opt)[source]
Parameters:

opt (optimizer) – The optimizer to update using this scheduler.

Returns:

  • current_lr (float) – The learning rate before the update.

  • lr (float) – The learning rate after the update.

save(path)[source]

Saves the current metrics on the specified path.

load(path, end_of_epoch=False, device=None)[source]

Loads the needed information.

class speechbrain.nnet.schedulers.CyclicCosineScheduler(n_warmup_steps, lr_initial=None, total_steps=100000)[source]

Bases: object

The is an implementation of the Cyclic-Cosine learning rate scheduler with warmup.

Reference: https://openreview.net/pdf?id=BJYwwY9ll

Note: this scheduler anneals the lr at each update of the model’s weight, and n_steps must be saved for restarting.

Parameters:
  • lr_initial (float) – Initial learning rate (i.e. the lr used at epoch 0).

  • n_warmup_steps (int) – Number of warm up steps.

  • total_steps (int) – Total number of updating steps.

Example

>>> from speechbrain.nnet.linear import Linear
>>> inp_tensor = torch.rand([1,660,3])
>>> model = Linear(input_size=3, n_neurons=4)
>>> optim = torch.optim.Adam(model.parameters(), lr=1)
>>> output = model(inp_tensor)
>>> scheduler =CyclicCosineScheduler(3, optim.param_groups[0]["lr"])
>>> curr_lr,next_lr=scheduler(optim)
>>> optim.param_groups[0]["lr"]
0.9999999990130395
>>> curr_lr,next_lr=scheduler(optim)
>>> optim.param_groups[0]["lr"]
0.9999999997532598
>>> curr_lr,next_lr=scheduler(optim)
>>> optim.param_groups[0]["lr"]
1.0
__call__(opt)[source]
Parameters:
  • opt (list of optimizers) – The optimizers to update using this scheduler.

  • current_epoch (int) – Number of times the dataset has been iterated.

  • current_loss (int) – A number for determining whether to change the learning rate.

Returns:

  • current_lr (float) – The learning rate before the update.

  • lr (float) – The learning rate after the update.

save(path)[source]

Saves the curent metrics on the specified path.

load(path, end_of_epoch=False)[source]

Loads the needed information.

class speechbrain.nnet.schedulers.ReduceLROnPlateau(lr_min=1e-08, factor=0.5, patience=2, dont_halve_until_epoch=65)[source]

Bases: object

Learning rate scheduler which decreases the learning rate if the loss function of interest gets stuck on a plateau, or starts to increase. The difference from NewBobLRScheduler is that, this one keeps a memory of the last step where do not observe improvement, and compares against that particular loss value as opposed to the most recent loss.

Parameters:
  • lr_min (float) – The minimum allowable learning rate.

  • factor (float) – Factor with which to reduce the learning rate.

  • patience (int) – How many epochs to wait before reducing the learning rate.

Example

>>> from torch.optim import Adam
>>> from speechbrain.nnet.linear import Linear
>>> inp_tensor = torch.rand([1,660,3])
>>> model = Linear(n_neurons=10, input_size=3)
>>> optim = Adam(lr=1.0, params=model.parameters())
>>> output = model(inp_tensor)
>>> scheduler = ReduceLROnPlateau(0.25, 0.5, 2, 1)
>>> curr_lr,next_lr=scheduler([optim],current_epoch=1, current_loss=10.0)
>>> curr_lr,next_lr=scheduler([optim],current_epoch=2, current_loss=11.0)
>>> curr_lr,next_lr=scheduler([optim],current_epoch=3, current_loss=13.0)
>>> curr_lr,next_lr=scheduler([optim],current_epoch=4, current_loss=14.0)
>>> next_lr
0.5
__call__(optim_list, current_epoch, current_loss)[source]
Parameters:
  • optim_list (list of optimizers) – The optimizers to update using this scheduler.

  • current_epoch (int) – Number of times the dataset has been iterated.

  • current_loss (int) – A number for determining whether to change the learning rate.

Returns:

  • current_lr (float) – The learning rate before the update.

  • next_lr (float) – The learning rate after the update.

save(path)[source]

Saves the curent metrics on the specified path.

load(path, end_of_epoch=False)[source]

Loads the needed information.

class speechbrain.nnet.schedulers.CyclicLRScheduler(base_lr=0.001, max_lr=0.006, step_size=2000.0, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle')[source]

Bases: object

This implements a cyclical learning rate policy (CLR). The method cycles the learning rate between two boundaries with some constant frequency, as detailed in this paper (https://arxiv.org/abs/1506.01186). The amplitude of the cycle can be scaled on a per-iteration or per-cycle basis.

This class has three built-in policies, as put forth in the paper. “triangular”:

A basic triangular cycle w/ no amplitude scaling.

“triangular2”:

A basic triangular cycle that scales initial amplitude by half each cycle.

“exp_range”:

A cycle that scales initial amplitude by gamma**(cycle iterations) at each cycle iteration.

For more detail, please see the reference paper.

Parameters:
  • base_lr (float) – initial learning rate which is the lower boundary in the cycle.

  • max_lr (float) – upper boundary in the cycle. Functionally, it defines the cycle amplitude (max_lr - base_lr). The lr at any cycle is the sum of base_lr and some scaling of the amplitude; therefore max_lr may not actually be reached depending on scaling function.

  • step_size (int) – number of training iterations per half cycle. The authors suggest setting step_size 2-8 x training iterations in epoch.

  • mode (str) – one of {triangular, triangular2, exp_range}. Default ‘triangular’. Values correspond to policies detailed above. If scale_fn is not None, this argument is ignored.

  • gamma (float) – constant in ‘exp_range’ scaling function: gamma**(cycle iterations)

  • scale_fn (lambda function) – Custom scaling policy defined by a single argument lambda function, where 0 <= scale_fn(x) <= 1 for all x >= 0. mode parameter is ignored

  • scale_mode (str) – {‘cycle’, ‘iterations’}. Defines whether scale_fn is evaluated on cycle number or cycle iterations (training iterations since start of cycle). Default is ‘cycle’.

Example

>>> from speechbrain.nnet.linear import Linear
>>> inp_tensor = torch.rand([1,660,3])
>>> model = Linear(input_size=3, n_neurons=4)
>>> optim = torch.optim.Adam(model.parameters(), lr=1)
>>> output = model(inp_tensor)
>>> scheduler = CyclicLRScheduler(base_lr=0.1, max_lr=0.3, step_size=2)
>>> scheduler.on_batch_end(optim)
>>> optim.param_groups[0]["lr"]
0.2
>>> scheduler.on_batch_end(optim)
>>> optim.param_groups[0]["lr"]
0.3
>>> scheduler.on_batch_end(optim)
>>> optim.param_groups[0]["lr"]
0.2
clr(clr_iterations)[source]

Clears interations.

on_batch_end(opt)[source]
Parameters:

opt (optimizers) – The optimizers to update using this scheduler.

save(path)[source]

Saves the current metrics on the specified path.

load(path, end_of_epoch=False)[source]

Loads the needed information.

class speechbrain.nnet.schedulers.IntervalScheduler(intervals)[source]

Bases: object

A simple scheduler implementation that sets the learning rate to specific values after a specific number of steps has been reached. :param intervals: a list of dictionaries: {“steps”: <number of steps>, “lr”: the learning rate}

‘steps’ indicates the global step count at which a given rate will apply

Example

>>> import torch
>>> from speechbrain.nnet.schedulers import IntervalScheduler
>>> from speechbrain.nnet.linear import Linear
>>> model = Linear(input_size=3, n_neurons=4)
>>> optim = torch.optim.Adam(model.parameters(), lr=1)
>>> scheduler = IntervalScheduler(
...    intervals=[
...        {"steps": 2, "lr": 0.01},
...        {"steps": 5, "lr": 0.005},
...        {"steps": 9, "lr": 0.001}
...    ]
... )
>>> optim.param_groups[0]["lr"]
1
>>> for _ in range(10):
...     pre, post = scheduler(optim)
...     print(f"{pre} -> {post}")
1 -> 1
1 -> 0.01
0.01 -> 0.01
0.01 -> 0.01
0.01 -> 0.005
0.005 -> 0.005
0.005 -> 0.005
0.005 -> 0.005
0.005 -> 0.001
0.001 -> 0.001
__call__(opt)[source]
Parameters:

opt (optimizer) – The optimizer to update using this scheduler.

Returns:

  • current_lr (float) – The learning rate before the update.

  • lr (float) – The learning rate after the update.

save(path)[source]

Saves the current metrics on the specified path.

load(path, end_of_epoch=False)[source]

Loads the needed information.

class speechbrain.nnet.schedulers.InverseSquareRootScheduler(warmup_steps)[source]

Bases: object

The Inverse Square Root Scheduler, as defined in the T5 paper https://arxiv.org/pdf/1910.10683.pdf :param warmup_steps: The number of steps over which the learning rate will be constant :type warmup_steps: int

__call__(opt)[source]

Returns current and new hyperparameter value. :param current_epoch: Number of times the dataset has been iterated. :type current_epoch: int

save(path)[source]

Saves the current metrics on the specified path.

class speechbrain.nnet.schedulers.WarmCoolDecayLRSchedule(lr, warmup, cooldown, total_steps, decay_factor=0.75, decay_every=100000)[source]

Bases: object

Warms up linearly, very slowly decays and cools down linearly again at the end of training. This is a three steps scheduler.

Reference

Scaling Vision Transformers arxiv.org/abs/2106.04560

param lr:

The max learning rate to reach after warmup.

type lr:

float

param warmup:

Number of warmup steps (following a linear increase).

type warmup:

int

param cooldown:

Number of cooldown steps (following a linear decrease).

type cooldown:

int

param total_steps:

Total number of steps (used to decay).

type total_steps:

int

param decay_factor:

Decay factor applied every decay_every steps.

type decay_factor:

float

param decay_every:

Apply the decay factor to the learning rate every decay_every steps.

type decay_every:

int

Example

>>> from speechbrain.nnet.linear import Linear
>>> inp_tensor = torch.rand([1,660,3])
>>> model = Linear(input_size=3, n_neurons=4)
>>> optim = torch.optim.Adam(model.parameters(), lr=1)
>>> output = model(inp_tensor)
>>> scheduler = WarmCoolDecayLRSchedule(lr=1, warmup=2, total_steps=6, decay_factor=0.5, decay_every=1, cooldown=1)
>>> optim.param_groups[0]["lr"]
1
>>> scheduler(optim, 1)
>>> optim.param_groups[0]["lr"]
0.5
>>> scheduler(optim, 2)
>>> optim.param_groups[0]["lr"]
1.0
>>> scheduler(optim, 3)
>>> optim.param_groups[0]["lr"]
0.5
>>> scheduler(optim, 4)
>>> optim.param_groups[0]["lr"]
0.25
>>> scheduler(optim, 5)
>>> optim.param_groups[0]["lr"]
0.12500000000000003
>>> scheduler(optim, 6)
>>> optim.param_groups[0]["lr"]
0.0
save(path)[source]

Saves the current metrics on the specified path.

load(path, end_of_epoch=False)[source]

Loads the needed information.

class speechbrain.nnet.schedulers.ScheduledLoss(schedule)[source]

Bases: Module

A convenience class for switching to a different loss function on a schedule

Parameters:

schedule (list) –

a list of dictionaries with the following keys

loss_fn: the loss function to use steps: the number of steps to apply before switching

to the next one

Example

>>> loss_fn = ScheduledLoss(
...     schedule=[
...         {"steps": 3, "loss_fn": nn.MSELoss()},
...         {"steps": 2, "loss_fn": nn.L1Loss()},
...         {"loss_fn": nn.SmoothL1Loss()}
...     ]
... )
>>> x = torch.tensor([1., 2.])
>>> y = torch.tensor([1.5, 2.5])
>>> for idx in range(10):
...     loss = loss_fn(x, y)
...     print(loss.item())
0.25
0.25
0.25
0.5
0.5
0.125
0.125
0.125
0.125
0.125
forward(*args, **kwargs)[source]

Computes the loss at the specified step number. Any arguments passed to this will be passed on to the specified loss_fn

Returns:

result – the loss value

Return type:

torch.Tensor

save(path)[source]

Saves the current state on the specified path.

load(path, end_of_epoch=False, device=None)[source]

Loads the needed information.

find_next_switch()[source]

Finds the threshold at which the next switch will occur based on the schedule

training: bool
class speechbrain.nnet.schedulers.TriStageLRSchedule(lr, warmup_steps, hold_steps, decay_steps, total_steps, init_lr_scale=0.01, final_lr_scale=0.05)[source]

Bases: object

Warms up linearly, very slowly decays and cools down linearly again at the end of training. This is a three steps scheduler. Reference https://arxiv.org/pdf/1904.08779.pdf

Parameters:
  • lr (float) – The max learning rate to reach after warmup.

  • warmup_steps (int) – Number of warmup steps (following a linear increase).

  • hold_steps (int) – Number of holding steps (lr remains unchanged).

  • total_steps (int) – Total number of steps (used to decay).

  • init_lr_scale (float) – The initial learning rate scale during warmup phase.

  • final_lr_scale (float) – The final learning rate scale.

Example

>>> from speechbrain.nnet.linear import Linear
>>> inp_tensor = torch.rand([1,660,3])
>>> model = Linear(input_size=3, n_neurons=4)
>>> optim = torch.optim.Adam(model.parameters(), lr=1)
>>> output = model(inp_tensor)
>>> scheduler = TriStageLRSchedule(lr=1, warmup_steps=2, hold_steps=2, decay_steps=2, total_steps=6, init_lr_scale=0.01, final_lr_scale=0.05)
>>> optim.param_groups[0]["lr"]
1
>>> scheduler(optim, 1)
>>> optim.param_groups[0]["lr"]
0.505
>>> scheduler(optim, 2)
>>> optim.param_groups[0]["lr"]
1
>>> scheduler(optim, 3)
>>> optim.param_groups[0]["lr"]
1
>>> scheduler(optim, 4)
>>> optim.param_groups[0]["lr"]
1.0
>>> scheduler(optim, 5)
>>> optim.param_groups[0]["lr"]
0.223606797749979
>>> scheduler(optim, 6)
>>> optim.param_groups[0]["lr"]
0.05000000000000001
__call__(opt, num_updates)[source]

Calculate the learning rate corresponding to the current step (num_updates).

save(path)[source]

Saves the current metrics on the specified path.

load(path, end_of_epoch=False, device=None)[source]

Loads the needed information.