Skip to content

trainer

allennlp.training.trainer

[SOURCE]


TrainerCheckpoint

class TrainerCheckpoint(NamedTuple)

model_state

class TrainerCheckpoint(NamedTuple):
 | ...
 | model_state: Dict[str, Any] = None

trainer_state

class TrainerCheckpoint(NamedTuple):
 | ...
 | trainer_state: Dict[str, Any] = None

Trainer

class Trainer(Registrable):
 | def __init__(
 |     self,
 |     serialization_dir: str = None,
 |     cuda_device: Optional[Union[int, torch.device]] = None,
 |     distributed: bool = False,
 |     local_rank: int = 0,
 |     world_size: int = 1
 | ) -> None

The base class for an AllenNLP trainer. It can do pretty much anything you want. Your subclass should implement train and also probably from_params.

default_implementation

class Trainer(Registrable):
 | ...
 | default_implementation = "gradient_descent"

train

class Trainer(Registrable):
 | ...
 | def train(self) -> Dict[str, Any]

Train a model and return the results.

get_checkpoint_state

class Trainer(Registrable):
 | ...
 | def get_checkpoint_state(self) -> Optional[TrainerCheckpoint]

Returns a tuple of (model state, training state), where training state could have several internal components (e.g., for an, optimizer, learning rate scheduler, etc.).

get_best_weights_path

class Trainer(Registrable):
 | ...
 | def get_best_weights_path(self) -> Optional[str]

Returns the path to file containing the current best weights.