Skip to content

adversarial_bias_mitigator

allennlp.fairness.adversarial_bias_mitigator

[SOURCE]


A Model wrapper to adversarially mitigate biases in predictions produced by a pretrained model for a downstream task.

The documentation and explanations are heavily based on: Zhang, B.H., Lemoine, B., & Mitchell, M. (2018). Mitigating Unwanted Biases with Adversarial Learning. Proceedings of the 2018 AAAI/ACM Conference on AI, Ethics, and Society. and Mitigating Unwanted Biases in Word Embeddings with Adversarial Learning colab notebook.

Adversarial networks mitigate some biases based on the idea that predicting an outcome Y given an input X should ideally be independent of some protected variable Z. Informally, "knowing Y would not help you predict Z any better than chance" (Zaldivar et al., 2018). This can be achieved using two networks in a series, where the first attempts to predict Y using X as input, and the second attempts to use the predicted value of Y to recover Z. Please refer to Figure 1 of Mitigating Unwanted Biases with Adversarial Learning. Ideally, we would like the first network to predict Y without permitting the second network to predict Z any better than chance.

For common NLP tasks, it's usually clear what X and Y are, but Z is not always available. We can construct our own Z by:

  1. computing a bias direction (e.g. for binary gender)

  2. computing the inner product of static sentence embeddings and the bias direction

Training adversarial networks is extremely difficult. It is important to:

  1. lower the step size of both the predictor and adversary to train both models slowly to avoid parameters diverging,

  2. initialize the parameters of the adversary to be small to avoid the predictor overfitting against a sub-optimal adversary,

  3. increase the adversary’s learning rate to prevent divergence if the predictor is too good at hiding the protected variable from the adversary.

AdversarialBiasMitigator

@Model.register("adversarial_bias_mitigator")
class AdversarialBiasMitigator(Model):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     predictor: Model,
 |     adversary: Model,
 |     bias_direction: BiasDirectionWrapper,
 |     predictor_output_key: str,
 |     **kwargs
 | )

Wrapper class to adversarially mitigate biases in any pretrained Model.

Parameters

  • vocab : Vocabulary
    Vocabulary of predictor.
  • predictor : Model
    Model for which to mitigate biases.
  • adversary : Model
    Model that attempts to recover protected variable values from predictor's predictions.
  • bias_direction : BiasDirectionWrapper
    Bias direction used by adversarial bias mitigator.
  • predictor_output_key : str
    Key corresponding to output in output_dict of predictor that should be passed as input to adversary.

Note

adversary must use same vocab as predictor, if it requires a vocab.

train

class AdversarialBiasMitigator(Model):
 | ...
 | def train(self, mode: bool = True)

forward

class AdversarialBiasMitigator(Model):
 | ...
 | def forward(self, *args, **kwargs)

forward_on_instance

class AdversarialBiasMitigator(Model):
 | ...
 | def forward_on_instance(self, *args, **kwargs)

forward_on_instances

class AdversarialBiasMitigator(Model):
 | ...
 | def forward_on_instances(self, *args, **kwargs)

get_regularization_penalty

class AdversarialBiasMitigator(Model):
 | ...
 | def get_regularization_penalty(self, *args, **kwargs)

get_parameters_for_histogram_logging

class AdversarialBiasMitigator(Model):
 | ...
 | def get_parameters_for_histogram_logging(self, *args, **kwargs)

get_parameters_for_histogram_tensorboard_logging

class AdversarialBiasMitigator(Model):
 | ...
 | def get_parameters_for_histogram_tensorboard_logging(
 |     self,
 |     *args,
 |     **kwargs
 | )

make_output_human_readable

class AdversarialBiasMitigator(Model):
 | ...
 | def make_output_human_readable(self, *args, **kwargs)

get_metrics

class AdversarialBiasMitigator(Model):
 | ...
 | def get_metrics(self, *args, **kwargs)

extend_embedder_vocab

class AdversarialBiasMitigator(Model):
 | ...
 | def extend_embedder_vocab(self, *args, **kwargs)

FeedForwardRegressionAdversary

@Model.register("feedforward_regression_adversary")
class FeedForwardRegressionAdversary(Model):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     feedforward: FeedForward,
 |     initializer: Optional[InitializerApplicator] = InitializerApplicator(),
 |     **kwargs
 | ) -> None

This Model implements a simple feedforward regression adversary.

Registered as a Model with name "feedforward_regression_adversary".

Parameters

  • vocab : Vocabulary
  • feedforward : FeedForward
    A feedforward layer.
  • initializer : Optional[InitializerApplicator], optional (default = InitializerApplicator())
    If provided, will be used to initialize the model parameters.

forward

class FeedForwardRegressionAdversary(Model):
 | ...
 | def forward(
 |     self,
 |     input: torch.FloatTensor,
 |     label: torch.FloatTensor
 | ) -> Dict[str, torch.Tensor]

Parameters

  • input : torch.FloatTensor
    A tensor of size (batch_size, ...).
  • label : torch.FloatTensor
    A tensor of the same size as input.

Returns

  • An output dictionary consisting of:
    • loss : torch.FloatTensor A scalar loss to be optimised.

AdversarialBiasMitigatorBackwardCallback

@TrainerCallback.register("adversarial_bias_mitigator_backward")
class AdversarialBiasMitigatorBackwardCallback(TrainerCallback):
 | def __init__(
 |     self,
 |     serialization_dir: str,
 |     adversary_loss_weight: float = 1.0
 | ) -> None

Performs backpropagation for adversarial bias mitigation. While the adversary's gradients are computed normally, the predictor's gradients are computed such that updates to the predictor's parameters will not aid the adversary and will make it more difficult for the adversary to recover protected variables.

Note

Intended to be used with AdversarialBiasMitigator. trainer.model is expected to have predictor and adversary data members.

Parameters

  • adversary_loss_weight : float, optional (default = 1.0)
    Quantifies how difficult predictor makes it for adversary to recover protected variables.

on_backward

class AdversarialBiasMitigatorBackwardCallback(TrainerCallback):
 | ...
 | def on_backward(
 |     self,
 |     trainer: GradientDescentTrainer,
 |     batch_outputs: Dict[str, torch.Tensor],
 |     backward_called: bool,
 |     **kwargs
 | ) -> bool