bias_mitigator_applicator
allennlp.fairness.bias_mitigator_applicator
A Model wrapper to mitigate biases in contextual embeddings during finetuning on a downstream task and test time.
Based on: Dev, S., Li, T., Phillips, J.M., & Srikumar, V. (2020). On Measuring and Mitigating Biased Inferences of Word Embeddings. ArXiv, abs/1908.09369.
BiasMitigatorApplicator¶
@Model.register("bias_mitigator_applicator")
class BiasMitigatorApplicator(Model):
| def __init__(
| self,
| vocab: Vocabulary,
| base_model: Model,
| bias_mitigator: Lazy[BiasMitigatorWrapper],
| **kwargs
| )
Wrapper class to apply bias mitigation to any pretrained Model.
Parameters¶
- vocab :
Vocabulary
Vocabulary of base model. - base_model :
Model
Base model for which to mitigate biases. - bias_mitigator :
Lazy[BiasMitigatorWrapper]
Bias mitigator to apply to base model.
train¶
class BiasMitigatorApplicator(Model):
| ...
| def train(self, mode: bool = True)
forward¶
class BiasMitigatorApplicator(Model):
| ...
| def forward(self, *args, **kwargs)
forward_on_instance¶
class BiasMitigatorApplicator(Model):
| ...
| def forward_on_instance(self, *args, **kwargs)
forward_on_instances¶
class BiasMitigatorApplicator(Model):
| ...
| def forward_on_instances(self, *args, **kwargs)
get_regularization_penalty¶
class BiasMitigatorApplicator(Model):
| ...
| def get_regularization_penalty(self, *args, **kwargs)
get_parameters_for_histogram_logging¶
class BiasMitigatorApplicator(Model):
| ...
| def get_parameters_for_histogram_logging(self, *args, **kwargs)
get_parameters_for_histogram_tensorboard_logging¶
class BiasMitigatorApplicator(Model):
| ...
| def get_parameters_for_histogram_tensorboard_logging(
| self,
| *args,
| **kwargs
| )
make_output_human_readable¶
class BiasMitigatorApplicator(Model):
| ...
| def make_output_human_readable(self, *args, **kwargs)
get_metrics¶
class BiasMitigatorApplicator(Model):
| ...
| def get_metrics(self, *args, **kwargs)
extend_embedder_vocab¶
class BiasMitigatorApplicator(Model):
| ...
| def extend_embedder_vocab(self, *args, **kwargs)