Skip to content

bart

allennlp_models.generation.models.bart

[SOURCE]


DecoderCacheType#

DecoderCacheType = Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], ...]

BartEncoder#

@Seq2SeqEncoder.register("bart_encoder")
class BartEncoder(Seq2SeqEncoder):
 | def __init__(self, model_name)

The BART encoder, implemented as a Seq2SeqEncoder, which assumes it operates on already embedded inputs. This means that we remove the token and position embeddings from BART in this module. For the typical use case of using BART to encode inputs to your model (where we include the token and position embeddings from BART), you should use PretrainedTransformerEmbedder(bart_model_name, sub_module="encoder") instead of this.

Parameters

  • model_name : str
    Name of the pre-trained BART model to use. Available options can be found in transformers.models.bart.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP.

get_input_dim#

class BartEncoder(Seq2SeqEncoder):
 | ...
 | @overrides
 | def get_input_dim(self) -> int

get_output_dim#

class BartEncoder(Seq2SeqEncoder):
 | ...
 | @overrides
 | def get_output_dim(self) -> int

is_bidirectional#

class BartEncoder(Seq2SeqEncoder):
 | ...
 | @overrides
 | def is_bidirectional(self) -> bool

forward#

class BartEncoder(Seq2SeqEncoder):
 | ...
 | @overrides
 | def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor)

The first element is always the last encoder states for each input token. Depending on the config, the second output will contain a list of the encoder states after each transformer layer. Similarly, the third output can contain the attentions from each layer. We only care about the first element.

Bart#

@Model.register("bart")
class Bart(Model):
 | def __init__(
 |     self,
 |     model_name: str,
 |     vocab: Vocabulary,
 |     beam_search: Lazy[BeamSearch] = Lazy(BeamSearch),
 |     indexer: PretrainedTransformerIndexer = None,
 |     encoder: Seq2SeqEncoder = None,
 |     **kwargs
 | )

BART model from the paper "BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension" (https://arxiv.org/abs/1910.13461). The Bart model here uses a language modeling head and thus can be used for text generation.

Parameters

  • model_name : str
    Name of the pre-trained BART model to use. Available options can be found in transformers.models.bart.modeling_bart.BART_PRETRAINED_MODEL_ARCHIVE_MAP.
  • vocab : Vocabulary
    Vocabulary containing source and target vocabularies.
  • beam_search : Lazy[BeamSearch], optional (default = Lazy(BeamSearch))
    This is used to during inference to select the tokens of the decoded output sequence.
  • indexer : PretrainedTransformerIndexer, optional (default = None)
    Indexer to be used for converting decoded sequences of ids to to sequences of tokens.
  • encoder : Seq2SeqEncoder, optional (default = None)
    Encoder to used in BART. By default, the original BART encoder is used.

forward#

class Bart(Model):
 | ...
 | @overrides
 | def forward(
 |     self,
 |     source_tokens: TextFieldTensors,
 |     target_tokens: TextFieldTensors = None
 | ) -> Dict[str, torch.Tensor]

Performs the forward step of Bart.

Parameters

  • source_tokens : TextFieldTensors
    The source tokens for the encoder. We assume they are stored under the tokens key.
  • target_tokens : TextFieldTensors, optional (default = None)
    The target tokens for the decoder. We assume they are stored under the tokens key. If no target tokens are given, the source tokens are shifted to the right by 1.

Returns

  • Dict[str, torch.Tensor]
    During training, this dictionary contains the decoder_logits of shape (batch_size, max_target_length, target_vocab_size) and the loss. During inference, it contains predictions of shape (batch_size, max_decoding_steps) and log_probabilities of shape (batch_size,).

take_step#

class Bart(Model):
 | ...
 | def take_step(
 |     self,
 |     last_predictions: torch.Tensor,
 |     state: Dict[str, torch.Tensor],
 |     step: int
 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]

Take step during beam search.

Parameters

  • last_predictions : torch.Tensor
    The predicted token ids from the previous step. Shape: (group_size,)
  • state : Dict[str, torch.Tensor]
    State required to generate next set of predictions
  • step : int
    The time step in beam search decoding.

Returns

  • Tuple[torch.Tensor, Dict[str, torch.Tensor]]
    A tuple containing logits for the next tokens of shape (group_size, target_vocab_size) and an updated state dictionary.

make_output_human_readable#

class Bart(Model):
 | ...
 | @overrides
 | def make_output_human_readable(
 |     self,
 |     output_dict: Dict[str, torch.Tensor]
 | ) -> Dict[str, Any]

Parameters

  • output_dict : Dict[str, torch.Tensor]
    A dictionary containing a batch of predictions with key predictions. The tensor should have shape (batch_size, max_sequence_length)

Returns

  • Dict[str, Any]
    Original output_dict with an additional predicted_tokens key that maps to a list of lists of tokens.

get_metrics#

class Bart(Model):
 | ...
 | @overrides
 | def get_metrics(self, reset: bool = False) -> Dict[str, float]

default_predictor#

class Bart(Model):
 | ...
 | default_predictor = "seq2seq"