A very general abstract class representing a metric which can be accumulated.
class Metric(Registrable): | ... | supports_distributed = False
class Metric(Registrable): | ... | def __call__( | self, | predictions: torch.Tensor, | gold_labels: torch.Tensor, | mask: Optional[torch.BoolTensor] | )
- predictions :
A tensor of predictions.
- gold_labels :
A tensor corresponding to some gold label to evaluate against.
- mask :
torch.BoolTensor, optional (default =
A mask can be passed, in order to deal with metrics which are computed over potentially padded elements, such as sequence labels.
class Metric(Registrable): | ... | def get_metric(self, reset: bool)
Compute and return the metric. Optionally also call
class Metric(Registrable): | ... | def reset(self) -> None
Reset any accumulators or internal state.
class Metric(Registrable): | ... | @staticmethod | def detach_tensors(*tensors: torch.Tensor) -> Iterable[torch.Tensor]
If you actually passed gradient-tracking Tensors to a Metric, there will be a huge memory leak, because it will prevent garbage collection for the computation graph. This method ensures the tensors are detached.