allennlp.training.trainer

class allennlp.training.trainer.Trainer(model: allennlp.models.model.Model, optimizer: torch.optim.optimizer.Optimizer, iterator: allennlp.data.iterators.data_iterator.DataIterator, train_dataset: Iterable[allennlp.data.instance.Instance], validation_dataset: Optional[Iterable[allennlp.data.instance.Instance]] = None, patience: Optional[int] = None, validation_metric: str = '-loss', validation_iterator: allennlp.data.iterators.data_iterator.DataIterator = None, shuffle: bool = True, num_epochs: int = 20, serialization_dir: Optional[str] = None, num_serialized_models_to_keep: int = 20, keep_serialized_model_every_num_seconds: int = None, checkpointer: allennlp.training.checkpointer.Checkpointer = None, model_save_interval: float = None, cuda_device: Union[int, List] = -1, grad_norm: Optional[float] = None, grad_clipping: Optional[float] = None, learning_rate_scheduler: Optional[allennlp.training.learning_rate_schedulers.learning_rate_scheduler.LearningRateScheduler] = None, momentum_scheduler: Optional[allennlp.training.momentum_schedulers.momentum_scheduler.MomentumScheduler] = None, summary_interval: int = 100, histogram_interval: int = None, should_log_parameter_statistics: bool = True, should_log_learning_rate: bool = False, log_batch_size_period: Optional[int] = None, moving_average: Optional[allennlp.training.moving_average.MovingAverage] = None)[source]

Bases: allennlp.training.trainer_base.TrainerBase

batch_loss(self, batch_group: List[Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]], for_training: bool) → torch.Tensor[source]

Does a forward pass on the given batches and returns the loss value in the result. If for_training is True also applies regularization penalty.

classmethod from_params(model: allennlp.models.model.Model, serialization_dir: str, iterator: allennlp.data.iterators.data_iterator.DataIterator, train_data: Iterable[allennlp.data.instance.Instance], validation_data: Union[Iterable[allennlp.data.instance.Instance], NoneType], params: allennlp.common.params.Params, validation_iterator: allennlp.data.iterators.data_iterator.DataIterator = None) → 'Trainer'[source]

This is the automatic implementation of from_params. Any class that subclasses FromParams (or Registrable, which itself subclasses FromParams) gets this implementation for free. If you want your class to be instantiated from params in the “obvious” way – pop off parameters and hand them to your constructor with the same names – this provides that functionality.

If you need more complex logic in your from from_params method, you’ll have to implement your own method that overrides this one.

rescale_gradients(self) → Union[float, NoneType][source]
train(self) → Dict[str, Any][source]

Trains the supplied model with the supplied parameters.