[ allennlp.data.dataloader ]
TensorDict = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]
def allennlp_collate(instances: List[Instance]) -> TensorDict
DataLoader is responsible for generating batches of instances from a
or another source of data. This is essentially just an abstraction over
This class only has one required method,
__iter__(), that creates an iterable
TensorDicts. Additionally, this class comes with a
that just raises a
TypeError by default. When possible, this should be overriden
to return the number of batches that will be generated by the
default_implementation = "pytorch_dataloader"
class PyTorchDataLoader(data.DataLoader, DataLoader): | def __init__( | self, | dataset: data.Dataset, | batch_size: int = 1, | shuffle: bool = False, | sampler: Sampler = None, | batch_sampler: BatchSampler = None, | num_workers: int = 0, | collate_fn=allennlp_collate, | pin_memory: bool = False, | drop_last: bool = False, | timeout: int = 0, | worker_init_fn=None, | multiprocessing_context: str = None, | batches_per_epoch: int = None | )
A registrable version of the pytorch
Firstly, this class exists is so that we can construct a DataLoader
from a configuration file and have a different default
You can use this class directly in python code, but it is identical to using
pytorch dataloader with allennlp's custom collate function:
from torch.utils.data import DataLoader from allennlp.data.samplers import allennlp_collate # Construct a dataloader directly for a dataset which contains allennlp # Instances which have _already_ been indexed. my_loader = DataLoader(dataset, batch_size=32, collate_fn=allennlp_collate)
Secondly, this class adds a
batches_per_epoch parameter which, if given, determines the number
of batches after which an epoch ends. If this is
None, then an epoch is set to be one full pass
through your data. You might use this if you have a very large dataset and want more frequent
checkpoints and evaluations on validation data, for instance.
In a typical AllenNLP configuration file, the
dataset parameter does not get an entry under
the "data_loader", it gets constructed separately.
| @classmethod | def from_partial_objects( | cls, | dataset: data.Dataset, | batch_size: int = 1, | shuffle: bool = False, | sampler: Lazy[Sampler] = None, | batch_sampler: Lazy[BatchSampler] = None, | num_workers: int = 0, | pin_memory: bool = False, | drop_last: bool = False, | timeout: int = 0, | worker_init_fn=None, | multiprocessing_context: str = None, | batches_per_epoch: int = None | ) -> "PyTorchDataLoader"