vilbert_vqa
allennlp_models.vision.models.vilbert_vqa
VqaVilbert#
@Model.register("vqa_vilbert")
@Model.register("vqa_vilbert_from_huggingface", constructor="from_huggingface_model_name")
class VqaVilbert(VisionTextModel):
| def __init__(
| self,
| vocab: Vocabulary,
| text_embeddings: TransformerEmbeddings,
| image_embeddings: ImageFeatureEmbeddings,
| encoder: BiModalEncoder,
| pooled_output_dim: int,
| fusion_method: str = "sum",
| dropout: float = 0.1,
| label_namespace: str = "answers",
| *,
| ignore_text: bool = False,
| ignore_image: bool = False
| ) -> None
Model for VQA task based on the VilBERT paper.
Parameters¶
- vocab :
Vocabulary
- text_embeddings :
TransformerEmbeddings
- image_embeddings :
ImageFeatureEmbeddings
- encoder :
BiModalEncoder
- pooled_output_dim :
int
- fusion_method :
str
, optional (default ="sum"
) - dropout :
float
, optional (default =0.1
) - label_namespace :
str
, optional (default =answers
)
forward#
class VqaVilbert(VisionTextModel):
| ...
| @overrides
| def forward(
| self,
| box_features: torch.Tensor,
| box_coordinates: torch.Tensor,
| box_mask: torch.Tensor,
| question: TextFieldTensors,
| labels: Optional[torch.Tensor] = None,
| label_weights: Optional[torch.Tensor] = None
| ) -> Dict[str, torch.Tensor]
get_metrics#
class VqaVilbert(VisionTextModel):
| ...
| @overrides
| def get_metrics(self, reset: bool = False) -> Dict[str, float]
make_output_human_readable#
class VqaVilbert(VisionTextModel):
| ...
| @overrides
| def make_output_human_readable(
| self,
| output_dict: Dict[str, torch.Tensor]
| ) -> Dict[str, torch.Tensor]
default_predictor#
class VqaVilbert(VisionTextModel):
| ...
| default_predictor = "vilbert_vqa"