Skip to content

cosine

allennlp.training.learning_rate_schedulers.cosine

[SOURCE]


CosineWithRestarts

@LearningRateScheduler.register("cosine")
class CosineWithRestarts(LearningRateScheduler):
 | def __init__(
 |     self,
 |     optimizer: torch.optim.Optimizer,
 |     t_initial: int,
 |     t_mul: float = 1.0,
 |     eta_min: float = 0.0,
 |     eta_mul: float = 1.0,
 |     last_epoch: int = -1
 | ) -> None

Cosine annealing with restarts.

This is described in the paper https://arxiv.org/abs/1608.03983. Note that early stopping should typically be avoided when using this schedule.

Registered as a LearningRateScheduler with name "cosine".

Parameters

  • optimizer : torch.optim.Optimizer
    This argument does not get an entry in a configuration file for the object.
  • t_initial : int
    The number of iterations (epochs) within the first cycle.
  • t_mul : float, optional (default = 1)
    Determines the number of iterations (epochs) in the i-th decay cycle, which is the length of the last cycle multiplied by t_mul.
  • eta_min : float, optional (default = 0)
    The minimum learning rate.
  • eta_mul : float, optional (default = 1)
    Determines the initial learning rate for the i-th decay cycle, which is the last initial learning rate multiplied by m_mul.
  • last_epoch : int, optional (default = -1)
    The index of the last epoch. This is used when restarting.

Example

Config for using the CosineWithRestarts Learning Rate Scheduler with the following arguments:

  • t_initial set to 5
  • t_mul set to 0.9
  • eta_min set to 1e-12
  • eta_mul set to 0.8
  • last_epoch set to 10

{
    ...
   "trainer":{
        ...
        "learning_rate_scheduler": {
            "type": "cosine",
            "t_initial": 5,
            "t_mul": 0.9,
            "eta_min": 1e-12
            "eta_mul": 0.8
            "last_epoch": 10
        },
        ...
   }
}
Note that you do NOT pass a optimizer key to the Learning rate scheduler.

get_values

class CosineWithRestarts(LearningRateScheduler):
 | ...
 | def get_values(self)

Get updated learning rate.