@MatrixAttention.register("linear") class LinearMatrixAttention(MatrixAttention): | def __init__( | self, | tensor_1_dim: int, | tensor_2_dim: int, | combination: str = "x,y", | activation: Activation = None | ) -> None
MatrixAttention takes two matrices as input and returns a matrix of attentions
by performing a dot product between a vector of weights and some
combination of the two input matrices, followed by an (optional) activation function. The
combination used is configurable.
If the two vectors are
y, we allow the following kinds of combinations :
x/y, where each of those binary operations is performed
elementwise. You can list as many combinations as you want, comma separated. For example, you
x,y,x*y as the
combination parameter to this class. The computed similarity
function would then be
w^T [x; y; x*y] + b, where
w is a vector of weights,
b is a
bias parameter, and
[;] is vector concatenation.
Note that if you want a bilinear similarity function with a diagonal weight matrix W, where the
similarity function is computed as
x * w * y + b (with
w the diagonal of
W), you can
accomplish that with this class by using "x*y" for
Registered as a
MatrixAttention with name "linear".
- tensor_1_dim :
The dimension of the first tensor,
x, described above. This is
x.size()[-1]- the length of the vector that will go into the similarity computation. We need this so we can build weight vectors correctly.
- tensor_2_dim :
The dimension of the second tensor,
y, described above. This is
y.size()[-1]- the length of the vector that will go into the similarity computation. We need this so we can build weight vectors correctly.
- combination :
str, optional (default =
- activation :
Activation, optional (default =
An activation function applied after the
w^T * [x;y] + bcalculation. Default is linear, i.e. no activation.
class LinearMatrixAttention(MatrixAttention): | ... | def reset_parameters(self)
class LinearMatrixAttention(MatrixAttention): | ... | def forward( | self, | matrix_1: torch.Tensor, | matrix_2: torch.Tensor | ) -> torch.Tensor