masked_language_model
MaskedLanguageModelingReader#
class MaskedLanguageModelingReader(DatasetReader):
| def __init__(
| self,
| tokenizer: Tokenizer = None,
| token_indexers: Dict[str, TokenIndexer] = None,
| **kwargs
| ) -> None
Reads a text file and converts it into a Dataset
suitable for training a masked language
model.
The Field
s that we create are the following: an input TextField
, a mask position
ListField[IndexField]
, and a target token TextField
(the target tokens aren't a single
string of text, but we use a TextField
so we can index the target tokens the same way as
our input, typically with a single PretrainedTransformerIndexer
). The mask position and
target token lists are the same length.
NOTE: This is not fully functional! It was written to put together a demo for interpreting and
attacking masked language modeling, not for actually training anything. text_to_instance
is functional, but _read
is not. To make this fully functional, you would want some
sampling strategies for picking the locations for [MASK] tokens, and probably a bunch of
efficiency / multi-processing stuff.
Parameters
- tokenizer :
Tokenizer
, optional (default =WhitespaceTokenizer()
)
We use thisTokenizer
for the text. SeeTokenizer
. - token_indexers :
Dict[str, TokenIndexer]
, optional (default ={"tokens": SingleIdTokenIndexer()}
)
We use this to define the input representation for the text, and to get ids for the mask targets. SeeTokenIndexer
.
text_to_instance#
class MaskedLanguageModelingReader(DatasetReader):
| ...
| @overrides
| def text_to_instance(
| self,
| sentence: str = None,
| tokens: List[Token] = None,
| targets: List[str] = None
| ) -> Instance
Parameters
- sentence :
str
, optional
A sentence containing [MASK] tokens that should be filled in by the model. This input is superceded and ignored iftokens
is given. - tokens :
List[Token]
, optional
An already-tokenized sentence containing some number of [MASK] tokens to be predicted. - targets :
List[str]
, optional
Contains the target tokens to be predicted. The length of this list should be the same as the number of [MASK] tokens in the input.