allennlp.modules.pruner

class allennlp.modules.pruner.Pruner(scorer: torch.nn.modules.module.Module)[source]

Bases: torch.nn.modules.module.Module

This module scores and prunes items in a list using a parameterised scoring function and a threshold.

Parameters
scorertorch.nn.Module, required.

A module which, given a tensor of shape (batch_size, num_items, embedding_size), produces a tensor of shape (batch_size, num_items, 1), representing a scalar score per item in the tensor.

forward(self, embeddings: torch.FloatTensor, mask: torch.LongTensor, num_items_to_keep: Union[int, torch.LongTensor]) → Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor, torch.FloatTensor][source]

Extracts the top-k scoring items with respect to the scorer. We additionally return the indices of the top-k in their original order, not ordered by score, so that downstream components can rely on the original ordering (e.g., for knowing what spans are valid antecedents in a coreference resolution model). May use the same k for all sentences in minibatch, or different k for each.

Parameters
embeddingstorch.FloatTensor, required.

A tensor of shape (batch_size, num_items, embedding_size), containing an embedding for each item in the list that we want to prune.

masktorch.LongTensor, required.

A tensor of shape (batch_size, num_items), denoting unpadded elements of embeddings.

num_items_to_keepUnion[int, torch.LongTensor], required.

If a tensor of shape (batch_size), specifies the number of items to keep for each individual sentence in minibatch. If an int, keep the same number of items for all sentences.

Returns
top_embeddingstorch.FloatTensor

The representations of the top-k scoring items. Has shape (batch_size, max_num_items_to_keep, embedding_size).

top_masktorch.LongTensor

The corresponding mask for top_embeddings. Has shape (batch_size, max_num_items_to_keep).

top_indicestorch.IntTensor

The indices of the top-k scoring items into the original embeddings tensor. This is returned because it can be useful to retain pointers to the original items, if each item is being scored by multiple distinct scorers, for instance. Has shape (batch_size, max_num_items_to_keep).

top_item_scorestorch.FloatTensor

The values of the top-k scoring items. Has shape (batch_size, max_num_items_to_keep, 1).