simple_tagger
allennlp.models.simple_tagger
SimpleTagger¶
@Model.register("simple_tagger")
class SimpleTagger(Model):
| def __init__(
| self,
| vocab: Vocabulary,
| text_field_embedder: TextFieldEmbedder,
| encoder: Seq2SeqEncoder,
| calculate_span_f1: bool = None,
| label_encoding: Optional[str] = None,
| label_namespace: str = "labels",
| verbose_metrics: bool = False,
| initializer: InitializerApplicator = InitializerApplicator(),
| **kwargs
| ) -> None
This SimpleTagger
simply encodes a sequence of text with a stacked Seq2SeqEncoder
, then
predicts a tag for each token in the sequence.
Registered as a Model
with name "simple_tagger".
Parameters¶
- vocab :
Vocabulary
A Vocabulary, required in order to compute sizes for input/output projections. - text_field_embedder :
TextFieldEmbedder
Used to embed thetokens
TextField
we get as input to the model. - encoder :
Seq2SeqEncoder
The encoder (with its own internal stacking) that we will use in between embedding tokens and predicting output tags. - calculate_span_f1 :
bool
, optional (default =None
)
Calculate span-level F1 metrics during training. If this isTrue
, thenlabel_encoding
is required. IfNone
and label_encoding is specified, this is set toTrue
. IfNone
and label_encoding is not specified, it defaults toFalse
. - label_encoding :
str
, optional (default =None
)
Label encoding to use when calculating span f1. Valid options are "BIO", "BIOUL", "IOB1", "BMES". Required ifcalculate_span_f1
is true. - label_namespace :
str
, optional (default =labels
)
This is needed to compute the SpanBasedF1Measure metric, if desired. Unless you did something unusual, the default value should be what you want. - verbose_metrics :
bool
, optional (default =False
)
If true, metrics will be returned per label class in addition to the overall statistics. - initializer :
InitializerApplicator
, optional (default =InitializerApplicator()
)
Used to initialize the model parameters.
forward¶
class SimpleTagger(Model):
| ...
| def forward(
| self,
| tokens: TextFieldTensors,
| tags: torch.LongTensor = None,
| metadata: List[Dict[str, Any]] = None,
| ignore_loss_on_o_tags: bool = False
| ) -> Dict[str, torch.Tensor]
Parameters¶
- tokens :
TextFieldTensors
The output ofTextField.as_array()
, which should typically be passed directly to aTextFieldEmbedder
. This output is a dictionary mapping keys toTokenIndexer
tensors. At its most basic, using aSingleIdTokenIndexer
this is :{"tokens": Tensor(batch_size, num_tokens)}
. This dictionary will have the same keys as were used for theTokenIndexers
when you created theTextField
representing your sequence. The dictionary is designed to be passed directly to aTextFieldEmbedder
, which knows how to combine different word representations into a single vector per token in your input. - tags :
torch.LongTensor
, optional (default =None
)
A torch tensor representing the sequence of integer gold class labels of shape(batch_size, num_tokens)
. - metadata :
List[Dict[str, Any]]
, optional (default =None
)
metadata containing the original words in the sentence to be tagged under a 'words' key. - ignore_loss_on_o_tags :
bool
, optional (default =False
)
If True, we compute the loss only for actual spans intags
, and not onO
tokens. This is useful for computing gradients of the loss on a single span, for interpretation / attacking.
Returns¶
- An output dictionary consisting of:
logits
(torch.FloatTensor
) : A tensor of shape(batch_size, num_tokens, tag_vocab_size)
representing unnormalised log probabilities of the tag classes.class_probabilities
(torch.FloatTensor
) : A tensor of shape(batch_size, num_tokens, tag_vocab_size)
representing a distribution of the tag classes per word.loss
(torch.FloatTensor
, optional) : A scalar loss to be optimised.
make_output_human_readable¶
class SimpleTagger(Model):
| ...
| def make_output_human_readable(
| self,
| output_dict: Dict[str, torch.Tensor]
| ) -> Dict[str, torch.Tensor]
Does a simple position-wise argmax over each token, converts indices to string labels, and
adds a "tags"
key to the dictionary with the result.
get_metrics¶
class SimpleTagger(Model):
| ...
| def get_metrics(self, reset: bool = False) -> Dict[str, float]
default_predictor¶
class SimpleTagger(Model):
| ...
| default_predictor = "sentence_tagger"