multitask_scheduler
allennlp.data.data_loaders.multitask_scheduler
MultiTaskScheduler¶
class MultiTaskScheduler(Registrable)
A class that determines how to order instances within an epoch.
This is used by the MultiTaskDataLoader
. The main operation performed by this class is to
take a dictionary of instance iterators, one for each dataset, and combine them into an
iterator of batches, based on some scheduling algorithm (such as round robin, randomly choosing
between available datasets, etc.). To control this behavior as training progresses, there is an
update_from_epoch_metrics
method available, which should be called from a Callback
during
training. Not all MultiTaskSchedulers
will implement this method.
batch_instances¶
class MultiTaskScheduler(Registrable):
| ...
| def batch_instances(
| self,
| epoch_instances: Dict[str, Iterable[Instance]]
| ) -> Iterable[List[Instance]]
Given a dictionary of Iterable[Instance]
for each dataset, combines them into an
Iterable
of batches of instances.
update_from_epoch_metrics¶
class MultiTaskScheduler(Registrable):
| ...
| def update_from_epoch_metrics(
| self,
| epoch_metrics: Dict[str, Any]
| ) -> None
In case you want to set the behavior of the scheduler based on current epoch metrics, you
can do that by calling this method from a Callback
. If your scheduling technique does not
depend on epoch metrics, you do not need to implement this method.
count_batches¶
class MultiTaskScheduler(Registrable):
| ...
| def count_batches(self, dataset_counts: Dict[str, int]) -> int
Given the number of instances per dataset, this returns the total number of batches the scheduler will return.
default_implementation¶
class MultiTaskScheduler(Registrable):
| ...
| default_implementation = "homogeneous_roundrobin"
RoundRobinScheduler¶
@MultiTaskScheduler.register("roundrobin")
class RoundRobinScheduler(MultiTaskScheduler):
| def __init__(self, batch_size: int, drop_last: bool = False)
Orders instances in a round-robin fashion, where we take one instance from every dataset in turn. When one dataset runs out, we continue iterating round-robin through the rest.
Registered as a MultiTaskScheduler
with name "roundrobin".
batch_instances¶
class RoundRobinScheduler(MultiTaskScheduler):
| ...
| def batch_instances(
| self,
| epoch_instances: Dict[str, Iterable[Instance]]
| ) -> Iterable[List[Instance]]
count_batches¶
class RoundRobinScheduler(MultiTaskScheduler):
| ...
| def count_batches(self, dataset_counts: Dict[str, int]) -> int
HomogeneousRoundRobinScheduler¶
@MultiTaskScheduler.register("homogeneous_roundrobin")
class HomogeneousRoundRobinScheduler(MultiTaskScheduler):
| def __init__(
| self,
| batch_size: Union[int, Dict[str, int]],
| drop_last: bool = False
| )
Orders instances in a round-robin fashion, but grouped into batches composed entirely of instances from one dataset. We'll return one batch from one dataset, then another batch from a different dataset, etc. This is currently necessary in AllenNLP if your instances have different fields for different datasets, as we can't currently combine instances with different fields.
When one dataset runs out, we continue iterating round-robin through the rest.
If you want more fine-grained control over which datasets can be combined, it should be relatively straightforward to write your own scheduler, following this logic, which allows some datasets to be combined and others not.
Registered as a MultiTaskScheduler
with name "homogeneous_roundrobin".
Parameters¶
- batch_size :
Union[int, Dict[str, int]]
Determines how many instances to group together in each dataset. If this is anint
, the same value is used for all datasets; otherwise, the keys must correspond to the dataset names used elsewhere in the multi-task code.
batch_instances¶
class HomogeneousRoundRobinScheduler(MultiTaskScheduler):
| ...
| def batch_instances(
| self,
| epoch_instances: Dict[str, Iterable[Instance]]
| ) -> Iterable[List[Instance]]
count_batches¶
class HomogeneousRoundRobinScheduler(MultiTaskScheduler):
| ...
| def count_batches(self, dataset_counts: Dict[str, int]) -> int