allennlp.training.moving_average

class allennlp.training.moving_average.ExponentialMovingAverage(parameters: Iterable[Tuple[str, torch.Tensor]], decay: float = 0.9999, numerator: float = 1.0, denominator: float = 10.0)[source]

Bases: allennlp.training.moving_average.MovingAverage

Create shadow variables and maintain exponential moving average for model parameters.

Parameters
parametersIterable[Tuple[str, Parameter]], required

The parameters whose averages we’ll be tracking.

decayfloat, optional (default = 0.9999)

The decay rate that will be used if num_updates is not passed (and that will be used as an upper bound if num_updates is passed).

numeratorfloat, optional (default = 1.0)

The numerator used to compute the decay rate if num_updates is passed.

denominatorfloat, optional (default = 10.0)

The denominator used to compute the decay rate if num_updates is passed.

apply(self, num_updates: Union[int, NoneType] = None) → None[source]

Apply exponential moving average to named_parameters if specified, or we will apply this to all the trainable parameters of the model.

The optional num_updates parameter allows one to tweak the decay rate dynamically. If passed, the actual decay rate used is:

min(decay, (numerator + num_updates) / (denominator + num_updates))

(This logic is based on the Tensorflow exponential moving average

https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage)

class allennlp.training.moving_average.MovingAverage(parameters: Iterable[Tuple[str, torch.Tensor]])[source]

Bases: allennlp.common.registrable.Registrable

Tracks a moving average of model parameters.

apply(self, num_updates: Union[int, NoneType] = None)[source]

Update the moving averages based on the latest values of the parameters.

assign_average_value(self) → None[source]

Replace all the parameter values with the averages. Save the current parameter values to restore later.

default_implementation: str = 'exponential'
restore(self) → None[source]

Restore the backed-up (non-average) parameter values.