simple_influence
allennlp.interpret.influence_interpreters.simple_influence
SimpleInfluence¶
@InfluenceInterpreter.register("simple-influence")
class SimpleInfluence(InfluenceInterpreter):
| def __init__(
| self,
| model: Model,
| train_data_path: DatasetReaderInput,
| train_dataset_reader: DatasetReader,
| *, test_dataset_reader: Optional[DatasetReader] = None,
| *, train_data_loader: Lazy[DataLoader] = Lazy(SimpleDataLoader.from_dataset_reader),
| *, test_data_loader: Lazy[DataLoader] = Lazy(SimpleDataLoader.from_dataset_reader),
| *, params_to_freeze: List[str] = None,
| *, cuda_device: int = -1,
| *, lissa_batch_size: int = 8,
| *, damping: float = 3e-3,
| *, num_samples: int = 1,
| *, recursion_depth: Union[float, int] = 0.25,
| *, scale: float = 1e4
| ) -> None
Registered as an InfluenceInterpreter
with name "simple-influence".
This goes through every example in the train set to calculate the influence score. It uses LiSSA (Linear time Stochastic Second-Order Algorithm) to approximate the inverse of the Hessian used for the influence score calculation.
Parameters¶
-
lissa_batch_size :
int
, optional (default =8
)
The batch size to use for LiSSA. According to Koh, P.W., & Liang, P. (2017), it is better to use batched samples for approximation for better stability. -
damping :
float
, optional (default =3e-3
)
This is a hyperparameter for LiSSA. A damping termed added in case the approximated Hessian (during LiSSA) has negative eigenvalues. -
num_samples :
int
, optional (default =1
)
This is a hyperparameter for LiSSA that we determine how many rounds of the recursion process we would like to run for approxmation. -
recursion_depth :
Union[float, int]
, optional (default =0.25
)
This is a hyperparameter for LiSSA that determines the recursion depth we would like to go through. If afloat
, it means X% of the training examples. If anint
, it means recurse for X times. -
scale :
float
, optional (default =1e4
)
This is a hyperparameter for LiSSA to tune such that the Taylor expansion converges. It is applied to scale down the loss during LiSSA to ensure thatH <= I
, whereH
is the Hessian andI
is the identity matrix.See footnote 2 of Koh, P.W., & Liang, P. (2017).
Note
We choose the same default values for the LiSSA hyperparameters as Han, Xiaochuang et al. (2020).
get_inverse_hvp_lissa¶
def get_inverse_hvp_lissa(
vs: Sequence[torch.Tensor],
model: Model,
used_params: Sequence[torch.Tensor],
lissa_data_loader: DataLoader,
damping: float,
num_samples: int,
scale: float
) -> torch.Tensor
This function approximates the product of the inverse of the Hessian and
the vectors vs
using LiSSA.
Adapted from github.com/kohpangwei/influence-release, the repo for Koh, P.W., & Liang, P. (2017), and github.com/xhan77/influence-function-analysis, the repo for Han, Xiaochuang et al. (2020).
get_hvp¶
def get_hvp(
loss: torch.Tensor,
params: Sequence[torch.Tensor],
vectors: Sequence[torch.Tensor]
) -> Tuple[torch.Tensor, ...]
Get a Hessian-Vector Product (HVP) Hv
for each Hessian H
of the loss
with respect to the one of the parameter tensors in params
and the corresponding
vector v
in vectors
.
Parameters¶
- loss :
torch.Tensor
The loss calculated from the output of the model. - params :
Sequence[torch.Tensor]
Tunable and used parameters in the model that we will calculate the gradient and hessian with respect to. - vectors :
Sequence[torch.Tensor]
The list of vectors for calculating the HVP.