Skip to content




class Metric(Registrable)

A very general abstract class representing a metric which can be accumulated.


class Metric(Registrable):
 | ...
 | supports_distributed = False


class Metric(Registrable):
 | ...
 | def __call__(
 |     self,
 |     predictions: torch.Tensor,
 |     gold_labels: torch.Tensor,
 |     mask: Optional[torch.BoolTensor]
 | )


  • predictions : torch.Tensor
    A tensor of predictions.
  • gold_labels : torch.Tensor
    A tensor corresponding to some gold label to evaluate against.
  • mask : torch.BoolTensor, optional (default = None)
    A mask can be passed, in order to deal with metrics which are computed over potentially padded elements, such as sequence labels.


class Metric(Registrable):
 | ...
 | def get_metric(self, reset: bool)

Compute and return the metric. Optionally also call self.reset.


class Metric(Registrable):
 | ...
 | def reset(self) -> None

Reset any accumulators or internal state.


class Metric(Registrable):
 | ...
 | @staticmethod
 | def detach_tensors(*tensors: torch.Tensor) -> Iterable[torch.Tensor]

If you actually passed gradient-tracking Tensors to a Metric, there will be a huge memory leak, because it will prevent garbage collection for the computation graph. This method ensures the tensors are detached.