Skip to content

vilbert_image_retrieval

allennlp_models.vision.models.vilbert_image_retrieval

[SOURCE]


ImageRetrievalVilbert#

@Model.register("vilbert_ir")
@Model.register("vilbert_ir_from_huggingface", constructor="from_huggingface_model_name")
class ImageRetrievalVilbert(VisionTextModel):
 | def __init__(
 |     self,
 |     vocab: Vocabulary,
 |     text_embeddings: TransformerEmbeddings,
 |     image_embeddings: ImageFeatureEmbeddings,
 |     encoder: BiModalEncoder,
 |     pooled_output_dim: int,
 |     fusion_method: str = "mul",
 |     dropout: float = 0.1,
 |     k: int = 1,
 |     *, ignore_text: bool = False,
 |     *, ignore_image: bool = False
 | ) -> None

Model for image retrieval task based on the VilBERT paper.

Parameters

  • vocab : Vocabulary
  • text_embeddings : TransformerEmbeddings
  • image_embeddings : ImageFeatureEmbeddings
  • encoder : BiModalEncoder
  • pooled_output_dim : int
  • fusion_method : str, optional (default = "mul")
  • dropout : float, optional (default = 0.1)
  • label_namespace : str, optional (default = answers)
  • k : int, optional (default = 1)

forward#

class ImageRetrievalVilbert(VisionTextModel):
 | ...
 | def forward(
 |     self,
 |     box_features: torch.Tensor,
 |     box_coordinates: torch.Tensor,
 |     box_mask: torch.Tensor,
 |     caption: TextFieldTensors,
 |     label: torch.Tensor
 | ) -> Dict[str, torch.Tensor]

get_metrics#

class ImageRetrievalVilbert(VisionTextModel):
 | ...
 | def get_metrics(self, reset: bool = False) -> Dict[str, float]

make_output_human_readable#

class ImageRetrievalVilbert(VisionTextModel):
 | ...
 | def make_output_human_readable(
 |     self,
 |     output_dict: Dict[str, torch.Tensor]
 | ) -> Dict[str, torch.Tensor]

default_predictor#

class ImageRetrievalVilbert(VisionTextModel):
 | ...
 | default_predictor = "vilbert_ir"