Skip to content





def decode_mst(
    energy: numpy.ndarray,
    length: int,
    has_labels: bool = True
) -> Tuple[numpy.ndarray, numpy.ndarray]

Note: Counter to typical intuition, this function decodes the maximum spanning tree.

Decode the optimal MST tree with the Chu-Liu-Edmonds algorithm for maximum spanning arborescences on graphs.


  • energy : numpy.ndarray
    A tensor with shape (num_labels, timesteps, timesteps) containing the energy of each edge. If has_labels is False, the tensor should have shape (timesteps, timesteps) instead.
  • length : int
    The length of this sequence, as the energy may have come from a padded batch.
  • has_labels : bool, optional (default = True)
    Whether the graph has labels or not.


def chu_liu_edmonds(
    length: int,
    score_matrix: numpy.ndarray,
    current_nodes: List[bool],
    final_edges: Dict[int, int],
    old_input: numpy.ndarray,
    old_output: numpy.ndarray,
    representatives: List[Set[int]]

Applies the chu-liu-edmonds algorithm recursively to a graph with edge weights defined by score_matrix.

Note that this function operates in place, so variables will be modified.


  • length : int
    The number of nodes.
  • score_matrix : numpy.ndarray
    The score matrix representing the scores for pairs of nodes.
  • current_nodes : List[bool]
    The nodes which are representatives in the graph. A representative at it's most basic represents a node, but as the algorithm progresses, individual nodes will represent collapsed cycles in the graph.
  • final_edges : Dict[int, int]
    An empty dictionary which will be populated with the nodes which are connected in the maximum spanning tree.
  • old_input : numpy.ndarray
  • old_output : numpy.ndarray
  • representatives : List[Set[int]]
    A list containing the nodes that a particular node is representing at this iteration in the graph.


  • Nothing - all variables are modified in place.