speechbrain.nnet.schedulers module
Schedulers for updating hyperparameters (such as learning rate).
- Authors
Mirco Ravanelli 2020
Peter Plantinga 2020
Loren Lugosch 2020
Summary
Classes:
The is an implementation of the Cyclic-Cosine learning rate scheduler with warmup. |
|
This implements a cyclical learning rate policy (CLR). |
|
A simple scheduler implementation that sets the learning rate to specific values after a specific number of steps has been reached. :param intervals: a list of dictionaries: {"steps": <number of steps>, "lr": the learning rate} 'steps' indicates the global step count at which a given rate will apply :type intervals: list. |
|
The Inverse Square Root Scheduler, as defined in the T5 paper https://arxiv.org/pdf/1910.10683.pdf :param warmup_steps: The number of steps over which the learning rate will be constant :type warmup_steps: int |
|
Scheduler with linear annealing technique. |
|
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. |
|
Scheduler with new-bob technique, used for LR annealing. |
|
A combination of Noam Scheduler and Interval Scheduler. |
|
The is an implementation of the transformer's learning rate scheduler with warmup. |
|
Learning rate scheduler which decreases the learning rate if the loss function of interest gets stuck on a plateau, or starts to increase. |
|
A convenience class for switching to a different loss function on a schedule |
|
Learning rate scheduler with step annealing technique. |
|
Warms up linearly, very slowly decays and cools down linearly again at the end of training. |
Functions:
Change the learning rate value within an optimizer. |
Reference
- speechbrain.nnet.schedulers.update_learning_rate(optimizer, new_lr, param_group=None)[source]
Change the learning rate value within an optimizer.
- Parameters:
Example
>>> from torch.optim import SGD >>> from speechbrain.nnet.linear import Linear >>> model = Linear(n_neurons=10, input_size=10) >>> optimizer = SGD(model.parameters(), lr=0.1) >>> update_learning_rate(optimizer, 0.2) >>> optimizer.param_groups[0]["lr"] 0.2
- class speechbrain.nnet.schedulers.NewBobScheduler(initial_value, annealing_factor=0.5, improvement_threshold=0.0025, patient=0)[source]
Bases:
object
Scheduler with new-bob technique, used for LR annealing.
The learning rate is annealed based on the validation performance. In particular: if (past_loss-current_loss)/past_loss< impr_threshold: lr=lr * annealing_factor.
- Parameters:
initial_value (float) – The initial hyperparameter value.
annealing_factor (float) – It is annealing factor used in new_bob strategy.
improvement_threshold (float) – It is the improvement rate between losses used to perform learning annealing in new_bob strategy.
patient (int) – When the annealing condition is violated patient times, the learning rate is finally reduced.
Example
>>> scheduler = NewBobScheduler(initial_value=1.0) >>> scheduler(metric_value=10.0) (1.0, 1.0) >>> scheduler(metric_value=2.0) (1.0, 1.0) >>> scheduler(metric_value=2.5) (1.0, 0.5)
- class speechbrain.nnet.schedulers.LinearScheduler(initial_value, final_value, epoch_count)[source]
Bases:
object
Scheduler with linear annealing technique.
The learning rate linearly decays over the specified number of epochs.
- Parameters:
Example
>>> scheduler = LinearScheduler(1.0, 0.0, 4) >>> scheduler(current_epoch=1) (1.0, 0.666...) >>> scheduler(current_epoch=2) (0.666..., 0.333...) >>> scheduler(current_epoch=3) (0.333..., 0.0) >>> scheduler(current_epoch=4) (0.0, 0.0)
- class speechbrain.nnet.schedulers.LinearWarmupScheduler(initial_value, num_warmup_steps, num_training_steps)[source]
Bases:
object
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. * Ge Li 2022
- Parameters:
Example
>>> scheduler = LinearWarmupScheduler(1.0, 2, 4) >>> scheduler.get_next_value() 0.0 >>> scheduler.get_next_value() 0.5 >>> scheduler.get_next_value() 1.0 >>> scheduler.get_next_value() 0.5 >>> scheduler.get_next_value() 0.0
- class speechbrain.nnet.schedulers.StepScheduler(initial_value, decay_factor=None, decay_drop=None, half_life=None)[source]
Bases:
object
Learning rate scheduler with step annealing technique.
The hyperparameter’s value decays over the epochs with the selected
epoch_decay
factor.value = init_value * decay_factor ^ floor((1 + epoch) / decay_drop)
- Parameters:
initial_value (float) – Initial value for the hyperparameter being updated.
decay_factor (float) – Factor multiplied with the initial_value
decay_drop (float) – Annealing factor (the decay of the hyperparameter value is faster with higher
decay_drop
values).half_life (int) – A convenience parameter to set decay_factor such that the parameter will drop to half its value at the specified epoch. May not be used together with decay_factor or decay_drop
Example
>>> scheduler = StepScheduler(initial_value=1.0) >>> scheduler(current_epoch=1) (1.0, 0.5) >>> scheduler(current_epoch=2) (0.5, 0.5) >>> scheduler(current_epoch=3) (0.5, 0.25)
- DEFAULT_DECAY_FACTOR = 0.5
- DEFAULT_DECAY_DROP = 2
- class speechbrain.nnet.schedulers.NoamScheduler(lr_initial, n_warmup_steps, model_size=None)[source]
Bases:
object
The is an implementation of the transformer’s learning rate scheduler with warmup. Reference: https://arxiv.org/abs/1706.03762
Note: this scheduler anneals the lr at each update of the model’s weight, and n_steps must be saved for restarting.
- Parameters:
lr_initial (float) – Initial learning rate (i.e. the lr used at epoch 0).
n_warmup_steps (int) – numer of warm-up steps
model_size (int) – size of transformer embed_dim. It is used to scale the maximum learning rate value reached by the scheduler. It is divided by model_size ** (0.5). If not specified the maximum learning rate value is instead multiplied by warmup_steps ** (0.5).
Example
>>> from speechbrain.nnet.linear import Linear >>> inp_tensor = torch.rand([1,660,3]) >>> model = Linear(input_size=3, n_neurons=4) >>> optim = torch.optim.Adam(model.parameters(), lr=1) >>> output = model(inp_tensor) >>> scheduler =NoamScheduler(optim.param_groups[0]["lr"], 3) >>> curr_lr,next_lr=scheduler(optim) >>> optim.param_groups[0]["lr"] 0.3333333333333333 >>> curr_lr,next_lr=scheduler(optim) >>> optim.param_groups[0]["lr"] 0.6666666666666666 >>> curr_lr,next_lr=scheduler(optim) >>> optim.param_groups[0]["lr"] 0.9999999999999999
- class speechbrain.nnet.schedulers.NoamIntervalScheduler(lr_initial, n_warmup_steps, anneal_steps, anneal_rates, model_size=None)[source]
Bases:
object
A combination of Noam Scheduler and Interval Scheduler. The scheduler behaves as a Noam Scheduler, and anneals the learning rate at disigned steps with designed decays.
Note: this scheduler anneals the lr at each update of the model’s weight, and n_steps must be saved for restarting.
- Parameters:
lr_initial (float) – Initial learning rate (i.e. the lr used at epoch 0).
n_warmup_steps (int) – numer of warm-up steps.
anneal_steps (list) – Pre-designed steps where the learning rate is to be annealed.
anneal_rates (list) – Pre-designed decay rate for each anneal step.
model_size (int) – size of transformer embed_dim. It is used to scale the maximum learning rate value reached by the scheduler. It is divided by model_size ** (0.5). If not specified the maximum learning rate value is instead multiplied by warmup_steps ** (0.5).
Example
>>> from speechbrain.nnet.linear import Linear >>> inp_tensor = torch.rand([1,660,3]) >>> model = Linear(input_size=3, n_neurons=4) >>> optim = torch.optim.Adam(model.parameters(), lr=1) >>> output = model(inp_tensor) >>> scheduler = NoamIntervalScheduler( ... lr_initial=optim.param_groups[0]["lr"], ... n_warmup_steps=3, ... anneal_steps=[6, 9], ... anneal_rates=[0.5, 0.1], ... ) >>> for _ in range(10): ... curr_lr,next_lr=scheduler(optim) ... print(optim.param_groups[0]["lr"]) 0.3333333333333333 0.6666666666666666 0.9999999999999999 0.8660254037844386 0.7745966692414833 0.7071067811865475 0.3273268353539886 0.3061862178478973 0.28867513459481287 0.027386127875258306
- class speechbrain.nnet.schedulers.CyclicCosineScheduler(n_warmup_steps, lr_initial=None, total_steps=100000)[source]
Bases:
object
The is an implementation of the Cyclic-Cosine learning rate scheduler with warmup.
Reference: https://openreview.net/pdf?id=BJYwwY9ll
Note: this scheduler anneals the lr at each update of the model’s weight, and n_steps must be saved for restarting.
- Parameters:
Example
>>> from speechbrain.nnet.linear import Linear >>> inp_tensor = torch.rand([1,660,3]) >>> model = Linear(input_size=3, n_neurons=4) >>> optim = torch.optim.Adam(model.parameters(), lr=1) >>> output = model(inp_tensor) >>> scheduler =CyclicCosineScheduler(3, optim.param_groups[0]["lr"]) >>> curr_lr,next_lr=scheduler(optim) >>> optim.param_groups[0]["lr"] 0.9999999990130395 >>> curr_lr,next_lr=scheduler(optim) >>> optim.param_groups[0]["lr"] 0.9999999997532598 >>> curr_lr,next_lr=scheduler(optim) >>> optim.param_groups[0]["lr"] 1.0
- class speechbrain.nnet.schedulers.ReduceLROnPlateau(lr_min=1e-08, factor=0.5, patience=2, dont_halve_until_epoch=65)[source]
Bases:
object
Learning rate scheduler which decreases the learning rate if the loss function of interest gets stuck on a plateau, or starts to increase. The difference from NewBobLRScheduler is that, this one keeps a memory of the last step where do not observe improvement, and compares against that particular loss value as opposed to the most recent loss.
- Parameters:
Example
>>> from torch.optim import Adam >>> from speechbrain.nnet.linear import Linear >>> inp_tensor = torch.rand([1,660,3]) >>> model = Linear(n_neurons=10, input_size=3) >>> optim = Adam(lr=1.0, params=model.parameters()) >>> output = model(inp_tensor) >>> scheduler = ReduceLROnPlateau(0.25, 0.5, 2, 1) >>> curr_lr,next_lr=scheduler([optim],current_epoch=1, current_loss=10.0) >>> curr_lr,next_lr=scheduler([optim],current_epoch=2, current_loss=11.0) >>> curr_lr,next_lr=scheduler([optim],current_epoch=3, current_loss=13.0) >>> curr_lr,next_lr=scheduler([optim],current_epoch=4, current_loss=14.0) >>> next_lr 0.5
- class speechbrain.nnet.schedulers.CyclicLRScheduler(base_lr=0.001, max_lr=0.006, step_size=2000.0, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle')[source]
Bases:
object
This implements a cyclical learning rate policy (CLR). The method cycles the learning rate between two boundaries with some constant frequency, as detailed in this paper (https://arxiv.org/abs/1506.01186). The amplitude of the cycle can be scaled on a per-iteration or per-cycle basis.
This class has three built-in policies, as put forth in the paper. “triangular”:
A basic triangular cycle w/ no amplitude scaling.
- “triangular2”:
A basic triangular cycle that scales initial amplitude by half each cycle.
- “exp_range”:
A cycle that scales initial amplitude by gamma**(cycle iterations) at each cycle iteration.
For more detail, please see the reference paper.
- Parameters:
base_lr (float) – initial learning rate which is the lower boundary in the cycle.
max_lr (float) – upper boundary in the cycle. Functionally, it defines the cycle amplitude (max_lr - base_lr). The lr at any cycle is the sum of base_lr and some scaling of the amplitude; therefore max_lr may not actually be reached depending on scaling function.
step_size (int) – number of training iterations per half cycle. The authors suggest setting step_size 2-8 x training iterations in epoch.
mode (str) – one of {triangular, triangular2, exp_range}. Default ‘triangular’. Values correspond to policies detailed above. If scale_fn is not None, this argument is ignored.
gamma (float) – constant in ‘exp_range’ scaling function: gamma**(cycle iterations)
scale_fn (lambda function) – Custom scaling policy defined by a single argument lambda function, where 0 <= scale_fn(x) <= 1 for all x >= 0. mode parameter is ignored
scale_mode (str) – {‘cycle’, ‘iterations’}. Defines whether scale_fn is evaluated on cycle number or cycle iterations (training iterations since start of cycle). Default is ‘cycle’.
Example
>>> from speechbrain.nnet.linear import Linear >>> inp_tensor = torch.rand([1,660,3]) >>> model = Linear(input_size=3, n_neurons=4) >>> optim = torch.optim.Adam(model.parameters(), lr=1) >>> output = model(inp_tensor) >>> scheduler = CyclicLRScheduler(base_lr=0.1, max_lr=0.3, step_size=2) >>> scheduler.on_batch_end(optim) >>> optim.param_groups[0]["lr"] 0.2 >>> scheduler.on_batch_end(optim) >>> optim.param_groups[0]["lr"] 0.3 >>> scheduler.on_batch_end(optim) >>> optim.param_groups[0]["lr"] 0.2
- class speechbrain.nnet.schedulers.IntervalScheduler(intervals)[source]
Bases:
object
A simple scheduler implementation that sets the learning rate to specific values after a specific number of steps has been reached. :param intervals: a list of dictionaries: {“steps”: <number of steps>, “lr”: the learning rate}
‘steps’ indicates the global step count at which a given rate will apply
Example
>>> import torch >>> from speechbrain.nnet.schedulers import IntervalScheduler >>> from speechbrain.nnet.linear import Linear >>> model = Linear(input_size=3, n_neurons=4) >>> optim = torch.optim.Adam(model.parameters(), lr=1) >>> scheduler = IntervalScheduler( ... intervals=[ ... {"steps": 2, "lr": 0.01}, ... {"steps": 5, "lr": 0.005}, ... {"steps": 9, "lr": 0.001} ... ] ... ) >>> optim.param_groups[0]["lr"] 1 >>> for _ in range(10): ... pre, post = scheduler(optim) ... print(f"{pre} -> {post}") 1 -> 1 1 -> 0.01 0.01 -> 0.01 0.01 -> 0.01 0.01 -> 0.005 0.005 -> 0.005 0.005 -> 0.005 0.005 -> 0.005 0.005 -> 0.001 0.001 -> 0.001
- class speechbrain.nnet.schedulers.InverseSquareRootScheduler(warmup_steps)[source]
Bases:
object
The Inverse Square Root Scheduler, as defined in the T5 paper https://arxiv.org/pdf/1910.10683.pdf :param warmup_steps: The number of steps over which the learning rate will be constant :type warmup_steps: int
- class speechbrain.nnet.schedulers.WarmCoolDecayLRSchedule(lr, warmup, cooldown, total_steps, decay_factor=0.75, decay_every=100000)[source]
Bases:
object
Warms up linearly, very slowly decays and cools down linearly again at the end of training. This is a three steps scheduler.
Reference
Scaling Vision Transformers arxiv.org/abs/2106.04560
- param lr:
The max learning rate to reach after warmup.
- type lr:
float
- param warmup:
Number of warmup steps (following a linear increase).
- type warmup:
int
- param cooldown:
Number of cooldown steps (following a linear decrease).
- type cooldown:
int
- param total_steps:
Total number of steps (used to decay).
- type total_steps:
int
- param decay_factor:
Decay factor applied every decay_every steps.
- type decay_factor:
float
- param decay_every:
Apply the decay factor to the learning rate every decay_every steps.
- type decay_every:
int
Example
>>> from speechbrain.nnet.linear import Linear >>> inp_tensor = torch.rand([1,660,3]) >>> model = Linear(input_size=3, n_neurons=4) >>> optim = torch.optim.Adam(model.parameters(), lr=1) >>> output = model(inp_tensor) >>> scheduler = WarmCoolDecayLRSchedule(lr=1, warmup=2, total_steps=6, decay_factor=0.5, decay_every=1, cooldown=1) >>> optim.param_groups[0]["lr"] 1 >>> scheduler(optim, 1) >>> optim.param_groups[0]["lr"] 0.5 >>> scheduler(optim, 2) >>> optim.param_groups[0]["lr"] 1.0 >>> scheduler(optim, 3) >>> optim.param_groups[0]["lr"] 0.5 >>> scheduler(optim, 4) >>> optim.param_groups[0]["lr"] 0.25 >>> scheduler(optim, 5) >>> optim.param_groups[0]["lr"] 0.12500000000000003 >>> scheduler(optim, 6) >>> optim.param_groups[0]["lr"] 0.0
- class speechbrain.nnet.schedulers.ScheduledLoss(schedule)[source]
Bases:
Module
A convenience class for switching to a different loss function on a schedule
- Parameters:
schedule (list) –
- a list of dictionaries with the following keys
loss_fn: the loss function to use steps: the number of steps to apply before switching
to the next one
Example
>>> loss_fn = ScheduledLoss( ... schedule=[ ... {"steps": 3, "loss_fn": nn.MSELoss()}, ... {"steps": 2, "loss_fn": nn.L1Loss()}, ... {"loss_fn": nn.SmoothL1Loss()} ... ] ... ) >>> x = torch.tensor([1., 2.]) >>> y = torch.tensor([1.5, 2.5]) >>> for idx in range(10): ... loss = loss_fn(x, y) ... print(loss.item()) 0.25 0.25 0.25 0.5 0.5 0.125 0.125 0.125 0.125 0.125