Skip to content

multitask_scheduler

allennlp.data.data_loaders.multitask_scheduler

[SOURCE]


MultiTaskScheduler

class MultiTaskScheduler(Registrable)

A class that determines how to order instances within an epoch. This is used by the MultiTaskDataLoader. The main operation performed by this class is to take a dictionary of instance iterators, one for each dataset, and combine them into an iterator of batches, based on some scheduling algorithm (such as round robin, randomly choosing between available datasets, etc.). To control this behavior as training progresses, there is an update_from_epoch_metrics method available, which should be called from a Callback during training. Not all MultiTaskSchedulers will implement this method.

batch_instances

class MultiTaskScheduler(Registrable):
 | ...
 | def batch_instances(
 |     self,
 |     epoch_instances: Dict[str, Iterable[Instance]]
 | ) -> Iterable[List[Instance]]

Given a dictionary of Iterable[Instance] for each dataset, combines them into an Iterable of batches of instances.

update_from_epoch_metrics

class MultiTaskScheduler(Registrable):
 | ...
 | def update_from_epoch_metrics(
 |     self,
 |     epoch_metrics: Dict[str, Any]
 | ) -> None

In case you want to set the behavior of the scheduler based on current epoch metrics, you can do that by calling this method from a Callback. If your scheduling technique does not depend on epoch metrics, you do not need to implement this method.

count_batches

class MultiTaskScheduler(Registrable):
 | ...
 | def count_batches(self, dataset_counts: Dict[str, int]) -> int

Given the number of instances per dataset, this returns the total number of batches the scheduler will return.

default_implementation

class MultiTaskScheduler(Registrable):
 | ...
 | default_implementation = "homogeneous_roundrobin"

RoundRobinScheduler

@MultiTaskScheduler.register("roundrobin")
class RoundRobinScheduler(MultiTaskScheduler):
 | def __init__(self, batch_size: int, drop_last: bool = False)

Orders instances in a round-robin fashion, where we take one instance from every dataset in turn. When one dataset runs out, we continue iterating round-robin through the rest.

Registered as a MultiTaskScheduler with name "roundrobin".

batch_instances

class RoundRobinScheduler(MultiTaskScheduler):
 | ...
 | def batch_instances(
 |     self,
 |     epoch_instances: Dict[str, Iterable[Instance]]
 | ) -> Iterable[List[Instance]]

count_batches

class RoundRobinScheduler(MultiTaskScheduler):
 | ...
 | def count_batches(self, dataset_counts: Dict[str, int]) -> int

HomogeneousRoundRobinScheduler

@MultiTaskScheduler.register("homogeneous_roundrobin")
class HomogeneousRoundRobinScheduler(MultiTaskScheduler):
 | def __init__(
 |     self,
 |     batch_size: Union[int, Dict[str, int]],
 |     drop_last: bool = False
 | )

Orders instances in a round-robin fashion, but grouped into batches composed entirely of instances from one dataset. We'll return one batch from one dataset, then another batch from a different dataset, etc. This is currently necessary in AllenNLP if your instances have different fields for different datasets, as we can't currently combine instances with different fields.

When one dataset runs out, we continue iterating round-robin through the rest.

If you want more fine-grained control over which datasets can be combined, it should be relatively straightforward to write your own scheduler, following this logic, which allows some datasets to be combined and others not.

Registered as a MultiTaskScheduler with name "homogeneous_roundrobin".

Parameters

  • batch_size : Union[int, Dict[str, int]]
    Determines how many instances to group together in each dataset. If this is an int, the same value is used for all datasets; otherwise, the keys must correspond to the dataset names used elsewhere in the multi-task code.

batch_instances

class HomogeneousRoundRobinScheduler(MultiTaskScheduler):
 | ...
 | def batch_instances(
 |     self,
 |     epoch_instances: Dict[str, Iterable[Instance]]
 | ) -> Iterable[List[Instance]]

count_batches

class HomogeneousRoundRobinScheduler(MultiTaskScheduler):
 | ...
 | def count_batches(self, dataset_counts: Dict[str, int]) -> int