Skip to content





def get_forward_arguments(module: torch.nn.Module) -> Set[str]


class MultiTaskModel(Model):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     backbone: Backbone,
 |     heads: Dict[str, Head],
 |     *,
 |     loss_weights: Dict[str, float] = None,
 |     arg_name_mapping: Dict[str, Dict[str, str]] = None,
 |     allowed_arguments: Dict[str, Set[str]] = None,
 |     initializer: InitializerApplicator = InitializerApplicator(),
 |     **kwargs
 | )

A MultiTaskModel consists of a Backbone that encodes its inputs in some way, then a collection of Heads that make predictions from the backbone-encoded inputs. The predictions of each Head are combined to compute a joint loss, which is then used for training.

This model works by taking **kwargs in forward, and passing the right arguments from that to the backbone and to each head. By default, we use inspect to try to figure out getting the right arguments to the right modules, but we allow you to specify these arguments yourself in case our inference code gets it wrong.

It is the caller's responsibility to make sure that the backbone and all heads are compatible with each other, and with the input data that comes from a MultiTaskDatasetReader. We give some arguments in this class and in MultiTaskDatasetReader to help with plumbing the arguments in complex cases (e.g., you can change argument names so that they match what the backbone and heads expect).


  • vocab : Vocab
  • backbone : Backbone
  • heads : Dict[str, Head]
  • loss_weights : Dict[str, float], optional (default = equal weighting)
    If you want, you can specify a weight for each head, which we will multiply the loss by when aggregating across heads. This is equivalent in many cases to specifying a separate learning rate per head, and just putting a weighting on the loss is much easier than figuring out the right way to specify that in the optimizer.
  • arg_name_mapping : Dict[str, Dict[str, str]], optional (default = identity mapping)
    The mapping changes the names in the **kwargs dictionary passed to forward before passing on the arguments to the backbone and heads. This is keyed by component, and the top-level keys must match the keys passed in the heads parameter, plus a "backbone" key for the backbone. If you are using dataset readers that use dataset-specific names for their keys, this lets you change them to be consistent. For example, this dictionary might end up looking like this: {"backbone": {"question": "text", "review": "text"}, "classifier1": {"sentiment": "label"}, "classifier2": {"topic": "label"}}. Though in this particular example, we have two different inputs mapping to the same key in the backbone; this will work, as long are you are careful that you don't give both of those inputs in the same batch. If we see overlapping keys, we will crash. If you want to be able to do this kind of mixed training in the same batch, you need to handle that in your data code, not here; we won't handle complex batching inside this model.
  • allowed_arguments : Dict[str, Set[str]], optional (default = inferred)
    The list of arguments that should be passed from **kwargs to the forward method for the backbone and each head. If you provide this, the keys in here should match the keys given in the heads parameter, plus a "backbone" key for the backbone arguments. If not given, we will use the inspect module to figure this out. The only time that this inference might fail is if you have optional arguments that you want to be ignored, or something. You very likely don't need to worry about this argument.
  • initializer : InitializerApplicator, optional (default = InitializerApplicator())
    If provided, will be used to initialize the model parameters.


class MultiTaskModel(Model):
 | ...
 | def forward(self, **kwargs) -> Dict[str, torch.Tensor]


class MultiTaskModel(Model):
 | ...
 | @overrides
 | def get_metrics(self, reset: bool = False) -> Dict[str, float]


class MultiTaskModel(Model):
 | ...
 | @overrides
 | def make_output_human_readable(
 |     self,
 |     output_dict: Dict[str, torch.Tensor]
 | ) -> Dict[str, torch.Tensor]