hybrid_tree_utils.py 11.7 KB
from collections import defaultdict
from itertools import chain

from .constants import (
    EMPTY,
    TOKENS,
    LEMMAS,
    TAGS,
    SPINES,
    ANCHORS,
    ANCHOR_HS,
    HEADS,
    DEPRELS,
)

class TreeNode(object):
    
    def __init__(self, nid, category, is_head, from_index, to_index, head_index=None, deprel=None, attributes=None, children=None):
        self.nid = nid
        self.parent = None
        self.category = category
        self.is_head = is_head
        self.from_index = from_index
        self.to_index = to_index
        self.head_index = head_index
        self.deprel = deprel
        self.attributes = attributes if attributes is not None else {}
        self.children = children if children is not None else []
        for child in self.children:
            child.parent = self
    
    def add_child(self, child):
        self.children.append(child)
        child.parent = self
    
    def get_yield(self):
        if not self.children:
            return [self]
        return list(chain.from_iterable(child.get_yield() for child in self.children))
    
    def is_continuous(self):
        idx = [token.from_index for token in self.get_yield()]
        return (idx == sorted(idx))
    
    def get_root(self):
        root = self
        while root.parent is not None:
            root = root.parent
        return root
    
    def get_head_child(self):
        if tree.children:
            heads = [child for child in tree.children if child.is_head]
            assert(len(heads) == 1)
            return heads[0]
        return None
    
    def get_head_token(self):
        while tree.children:
            tree = tree.get_head_child
        return tree
    
    def make_evalb_friendly(self, s):
        return s.replace(' ', '_').replace('(', 'LPAR').replace(')', 'RPAR')
    
    def to_brackets(self, features=[], mark_head=True, mark_head_terminals=False, morph_tags=False, dummy_pre=False):
        if mark_head_terminals:
            raise NotImplementedError
        if not self.children:
            cat = self.make_evalb_friendly(self.category[0])
            if morph_tags:
                cat = f'({self.category[2]} {cat})'
            #if mark_head_terminals and self.is_head:
            #    cat = '*' + cat
            if len(self.parent.children) == 1 or not dummy_pre:
                return cat
            else:
                return f'(DUMMY_PRE {cat})'
        else:
            cat = '_'.join([self.category] + [self.attributes.get(f, EMPTY_VAL) for f in features])
            if mark_head and self.is_head:
                cat = '*' + cat
            cat = self.make_evalb_friendly(cat)
            return f'({cat} {" ".join(child.to_brackets(features=features, mark_head=mark_head, mark_head_terminals=mark_head_terminals, morph_tags=morph_tags, dummy_pre=dummy_pre) for child in self.children)})'
    
    def pretty_print(self, tab='', features=[]):
        print(f'{tab}[{self.nid}] {"*" if self.is_head else ""}{self.category}{[self.attributes.get(f, EMPTY_VAL) for f in features]}')
        for child in self.children:
            child.pretty_print(tab=tab + '    ', features=features)      

def _do_mark_heads(tree, dependency_heads):
    if not tree.children:
        return tree.from_index
    child_head_idx = [_do_mark_heads(child, dependency_heads) for child in tree.children]
    heads = []
    for child, child_head_id in zip(tree.children, child_head_idx):
        if dependency_heads[child_head_id] not in child_head_idx:
            child.is_head = True
            heads.append(child_head_id)
    assert(len(heads) == 1)
    return heads[0]

def _mark_heads(tree, dependency_heads):
    _do_mark_heads(tree, dependency_heads)

def _rearrange(tree):
    if tree.children:
        children = [_rearrange(child) for child in tree.children]
        tree.children = sorted(children, key=lambda child: child.from_index)
        tree.from_index = tree.children[0].from_index
        tree.to_index = max(child.to_index for child in tree.children)
    return tree

def tree_from_dataset_instance(instance, dataset_features):
    nonterminals = instance['nonterminals']
    nodes = []
    
    # create a node for each nonterminal
    for i, nonterminal in enumerate(nonterminals):
        category = nonterminal['cat']
        if category is None:
            assert(len(nonterminal['children']) == 1)
            token_idx = nonterminal['children'][0]
            orth = instance[TOKENS][token_idx]
            lemma = instance[LEMMAS][token_idx]
            tag = dataset_features[TAGS].feature.int2str(instance[TAGS][token_idx])
            category = (orth, lemma, tag)
            nodes.append(TreeNode(i, category, False, token_idx, token_idx + 1))
        else:
            nodes.append(TreeNode(i, category, False, 0, 0))
    
    # link the nodes with parent-head relations
    for nonterminal, dupa in list(zip(nonterminals, nodes)):
        if nonterminal['cat'] is not None:
            for child_idx in nonterminal['children']:
                dupa.add_child(nodes[child_idx])
    
    tree = nodes[0]
    # mark the heads according to dependency relations
    _mark_heads(tree, instance['heads'])
    
    return _rearrange(tree)

def get_heads(matrix):
    chart_n = normalize(matrix)
    chart_r = add_root(chart_n)
    heads = {dep - 1: hd - 1 for hd, dep in mst(chart_r) if dep > 0}
    return [hd for dep, hd in sorted(heads.items())]

def make_head_path(path, token, lemma, tag, index):
    node = TreeNode(0, (token, lemma, tag), False, index, index + 1)
    if path == EMPTY:
        return node
    for cat in reversed(path.split('_')):
        node.is_head = True
        node = TreeNode(0, cat, False, None, None, children=[node])
    return node

def append_dependent(head_path, dep_path, anchor_cat, anchor_h, decompress=False):
    problem = None
    anchor_node = head_path
    head_chain = [head_path]
    while head_chain[-1].children:
        heads = [child for child in head_chain[-1].children if child.is_head]
        try:
            assert len(heads) == 1
            head_chain.append(heads[0])
        except:
            head_path.pretty_print()
            head_chain[-1].pretty_print()
            raise
    matching_heads = [hd for hd in head_chain if hd.category == anchor_cat]
    if not matching_heads:
        anchor_node = head_path
        problem = (dep_path, anchor_cat)
    else:
        # fallback: take the highest
        if anchor_h is None:
            anchor_h = len(matching_heads)
        elif anchor_h > len(matching_heads) and not decompress:
            anchor_h = len(matching_heads)
        elif anchor_h > len(matching_heads) and decompress:
            expand = anchor_h - len(matching_heads)
            for i in range(expand):
                child = matching_heads[0]
                parent = child.parent
                new_node = TreeNode(0, anchor_cat, child.is_head, None, None, children=[child])
                child.is_head = True
                if parent:
                    parent.children.remove(child)
                    parent.add_child(new_node)
                if head_path == child:
                    head_path = new_node
                matching_heads.insert(0, new_node)
        anchor_node = matching_heads[-anchor_h]
    anchor_node.add_child(dep_path)
    dep_path.is_head = False
    return head_path, problem

POS2PRE = {
    #'aglt' : 'aglt',
    #'conj' : 'spójnik',
    #'dig' : 'formalicz',
    #'fin' : 'formaczas',
    #'interp' : 'punct',
    #'num' : 'formalicz',
    #'praet' : 'formaczas',
    #'_' : '???',
}

def add_dummy_pre(path):
    pos = path.category[2].split(':')[0]
    return TreeNode(0, f'{POS2PRE[pos]}', True, {}, [path])

def check_no_cycles(heads):
    children_dict = defaultdict(set)
    for i, hd in enumerate(heads):
        children_dict[hd].add(i)
    visited = {-1}
    queue = [-1]
    while queue:
        children = children_dict.pop(queue[0], set())
        if children.intersection(visited):
            return False
        visited.update(children)
        queue = queue[1:] + list(children)
    return not children

# find any <EMPTY>’s children and reattach them to <EMPTY>’s head
def try_reattach(heads, spines, tokens):
    
    to_reattach = [i for i, spine in enumerate(spines) if spine == EMPTY and i in heads]
    k = 0
    while to_reattach:
        k += 1
        if k == 50:
            1/0
        print('------- TO REATTACH:')
        for i in to_reattach:
            print('    ===>', i, tokens[i], '->', [t for t, h in zip(tokens, heads) if h == i])
        tr = to_reattach[0]
        hd_idx = heads[tr]
        assert(hd_idx is not None)
        new_heads = heads.copy()
        for i, hd in enumerate(heads):
            if hd == tr:
                new_heads[i] = hd_idx
        if check_no_cycles(new_heads):
            heads = new_heads
        else:
            print('can’t reattach (cycle!)')
            1/0
        to_reattach = [i for i, spine in enumerate(spines) if spine == EMPTY and i in heads]
    
    return heads

def reconstruct_tree(tokens, tags, decompress=False, root_label='ROOT'):
    morph_tags = tags.get(TAGS, ['_' for _ in tokens])
    lemmas = tags.get(LEMMAS, ['_' for _ in tokens])
    heads = tags[HEADS]
    deprels = tags[DEPRELS]
    spines = tags[SPINES]
    anchors = tags[ANCHORS]
    anchor_hs = tags[ANCHOR_HS] 
    problems = []
    spines = list(spines)
    children = defaultdict(set)
    
    for i, (head, spine) in enumerate(zip(heads, spines)):
        if (head is None and root_label not in spine):
            spines[i] = (f'{root_label}_' + spines[i]).replace(f'_{EMPTY}', '')
        if (head is not None and root_label in spine):
            assert(spines[i].startswith(root_label))
            spines[i] = EMPTY if spines[i] == root_label else spines[i].replace(f'{root_label}_', '')
    
    reattach = False
    
    new_heads = try_reattach(heads, spines, tokens)
    if new_heads != heads:
        heads = new_heads
        reattach = True
    
    head_paths = [
        make_head_path(p, tok, lemma, tag, i) for i, (p, tok, lemma, tag)
        in enumerate(zip(spines, tokens, lemmas, morph_tags))
    ]
    
    # if a spine is some other spine’s head and consists of a terminal only, add a dummy preterminal to append to
    head_paths = [add_dummy_pre(hp) if (type(hp.category) == tuple and i in heads) else hp for i, hp in enumerate(head_paths)]
    
    anchor_hs = [int(h) if h != '<ROOT>' else None for h in anchor_hs]
    root = None
    for i, head in enumerate(heads):
        if deprels[i] is not None:
            head_paths[i].attributes['deprel'] = deprels[i]
        if head is None:
            root = i
        else:
            head_paths[head], problem = append_dependent(
                head_paths[head], head_paths[i], anchors[i], anchor_hs[i], decompress=decompress)
            if problem:
                problems.append(problem)
    if reattach:
        problems.append('reattach')
    
    tree = _rearrange(head_paths[root])
    
    return tree, problems

def make_tree(tokens, tags, root_label, decompress=True):
    tree, problems = reconstruct_tree(tokens, tags, decompress=decompress, root_label=root_label)
    return tree

def _node2dict(tree):
    children = [_node2dict(child) for child in tree.children]
    leaves = tree.get_yield()
    node = {
        'is_head' : tree.is_head,
        'span' : {'from' : tree.from_index, 'to' : tree.to_index},
    }
    if tree.deprel is not None:
        node['deprel'] = tree.deprel
    if tree.attributes:
        node['attributes'] = tree.attributes
    if children:
        node.update({
            'category' : tree.category,
            'children' : children,
        })
    else:
        node.update(dict(zip(('orth', 'base', 'tag'), tree.category)))
    return node

def tree2dict(tree, metadata={}):
    return {'tree' : _node2dict(tree), 'metadata' : metadata}