adversarial_bias_mitigator
allennlp.fairness.adversarial_bias_mitigator
A Model wrapper to adversarially mitigate biases in predictions produced by a pretrained model for a downstream task.
The documentation and explanations are heavily based on: Zhang, B.H., Lemoine, B., & Mitchell, M. (2018). Mitigating Unwanted Biases with Adversarial Learning. Proceedings of the 2018 AAAI/ACM Conference on AI, Ethics, and Society. and Mitigating Unwanted Biases in Word Embeddings with Adversarial Learning colab notebook.
Adversarial networks mitigate some biases based on the idea that predicting an outcome Y given an input X should ideally be independent of some protected variable Z. Informally, "knowing Y would not help you predict Z any better than chance" (Zaldivar et al., 2018). This can be achieved using two networks in a series, where the first attempts to predict Y using X as input, and the second attempts to use the predicted value of Y to recover Z. Please refer to Figure 1 of Mitigating Unwanted Biases with Adversarial Learning. Ideally, we would like the first network to predict Y without permitting the second network to predict Z any better than chance.
For common NLP tasks, it's usually clear what X and Y are, but Z is not always available. We can construct our own Z by:
-
computing a bias direction (e.g. for binary gender)
-
computing the inner product of static sentence embeddings and the bias direction
Training adversarial networks is extremely difficult. It is important to:
-
lower the step size of both the predictor and adversary to train both models slowly to avoid parameters diverging,
-
initialize the parameters of the adversary to be small to avoid the predictor overfitting against a sub-optimal adversary,
-
increase the adversary’s learning rate to prevent divergence if the predictor is too good at hiding the protected variable from the adversary.
AdversarialBiasMitigator¶
@Model.register("adversarial_bias_mitigator")
class AdversarialBiasMitigator(Model):
| def __init__(
| self,
| vocab: Vocabulary,
| predictor: Model,
| adversary: Model,
| bias_direction: BiasDirectionWrapper,
| predictor_output_key: str,
| **kwargs
| )
Wrapper class to adversarially mitigate biases in any pretrained Model.
Parameters¶
- vocab :
Vocabulary
Vocabulary of predictor. - predictor :
Model
Model for which to mitigate biases. - adversary :
Model
Model that attempts to recover protected variable values from predictor's predictions. - bias_direction :
BiasDirectionWrapper
Bias direction used by adversarial bias mitigator. - predictor_output_key :
str
Key corresponding to output inoutput_dict
of predictor that should be passed as input to adversary.
Note
adversary must use same vocab as predictor, if it requires a vocab.
train¶
class AdversarialBiasMitigator(Model):
| ...
| def train(self, mode: bool = True)
forward¶
class AdversarialBiasMitigator(Model):
| ...
| def forward(self, *args, **kwargs)
forward_on_instance¶
class AdversarialBiasMitigator(Model):
| ...
| def forward_on_instance(self, *args, **kwargs)
forward_on_instances¶
class AdversarialBiasMitigator(Model):
| ...
| def forward_on_instances(self, *args, **kwargs)
get_regularization_penalty¶
class AdversarialBiasMitigator(Model):
| ...
| def get_regularization_penalty(self, *args, **kwargs)
get_parameters_for_histogram_logging¶
class AdversarialBiasMitigator(Model):
| ...
| def get_parameters_for_histogram_logging(self, *args, **kwargs)
get_parameters_for_histogram_tensorboard_logging¶
class AdversarialBiasMitigator(Model):
| ...
| def get_parameters_for_histogram_tensorboard_logging(
| self,
| *args,
| **kwargs
| )
make_output_human_readable¶
class AdversarialBiasMitigator(Model):
| ...
| def make_output_human_readable(self, *args, **kwargs)
get_metrics¶
class AdversarialBiasMitigator(Model):
| ...
| def get_metrics(self, *args, **kwargs)
extend_embedder_vocab¶
class AdversarialBiasMitigator(Model):
| ...
| def extend_embedder_vocab(self, *args, **kwargs)
FeedForwardRegressionAdversary¶
@Model.register("feedforward_regression_adversary")
class FeedForwardRegressionAdversary(Model):
| def __init__(
| self,
| vocab: Vocabulary,
| feedforward: FeedForward,
| initializer: Optional[InitializerApplicator] = InitializerApplicator(),
| **kwargs
| ) -> None
This Model
implements a simple feedforward regression adversary.
Registered as a Model
with name "feedforward_regression_adversary".
Parameters¶
- vocab :
Vocabulary
- feedforward :
FeedForward
A feedforward layer. - initializer :
Optional[InitializerApplicator]
, optional (default =InitializerApplicator()
)
If provided, will be used to initialize the model parameters.
forward¶
class FeedForwardRegressionAdversary(Model):
| ...
| def forward(
| self,
| input: torch.FloatTensor,
| label: torch.FloatTensor
| ) -> Dict[str, torch.Tensor]
Parameters¶
- input :
torch.FloatTensor
A tensor of size (batch_size, ...). - label :
torch.FloatTensor
A tensor of the same size as input.
Returns¶
- An output dictionary consisting of:
loss
:torch.FloatTensor
A scalar loss to be optimised.
AdversarialBiasMitigatorBackwardCallback¶
@TrainerCallback.register("adversarial_bias_mitigator_backward")
class AdversarialBiasMitigatorBackwardCallback(TrainerCallback):
| def __init__(
| self,
| serialization_dir: str,
| adversary_loss_weight: float = 1.0
| ) -> None
Performs backpropagation for adversarial bias mitigation. While the adversary's gradients are computed normally, the predictor's gradients are computed such that updates to the predictor's parameters will not aid the adversary and will make it more difficult for the adversary to recover protected variables.
Note
Intended to be used with AdversarialBiasMitigator
.
trainer.model is expected to have predictor
and adversary
data members.
Parameters¶
- adversary_loss_weight :
float
, optional (default =1.0
)
Quantifies how difficult predictor makes it for adversary to recover protected variables.
on_backward¶
class AdversarialBiasMitigatorBackwardCallback(TrainerCallback):
| ...
| def on_backward(
| self,
| trainer: GradientDescentTrainer,
| batch_outputs: Dict[str, torch.Tensor],
| backward_called: bool,
| **kwargs
| ) -> bool