hybrid_tree_utils.py 10.4 KB
from collections import defaultdict

from . import tree_utils

from .constants import (
    EMPTY,
    UPPERCASE,
    LEMMAS,
    LEMMA_CASES,
    LEMMA_RULES,
    TAGS,
    HEADS,
    DEPRELS,
    SPINES,
    ANCHORS,
    ANCHOR_HS,
)

def make_lemma(token, case, rule):
    cut_prefix, cut, suffix = rule.split('_', maxsplit=2)
    cut_prefix, cut = int(cut_prefix), int(cut)
    lemma = token.lower()
    if cut_prefix:
        lemma = lemma[cut_prefix:]
    if cut:
        lemma = lemma[:-cut]
    lemma += suffix
    # failsafe
    if not lemma:
        return token
    if case == UPPERCASE:
        lemma = lemma[0].upper() + lemma[1:]
    return lemma

def correct_lemma(token, lemma, tag, morfeusz):
    match = [interp[2] for interp in morfeusz.analyse(token) if interp[2][0] == token and interp[2][2] == tag]
    match_lemmas = set(interp[1].split(':')[0] if ':' not in interp[0] else interp[1] for interp in match)
    if match_lemmas and lemma.lower() not in [l.lower() for l in match_lemmas]:
        if len(match_lemmas) != 1:
            print(token, lemma, tag, '->', sorted(match_lemmas))
            print('    ---> >1 matching lemma, will take alphabetically first!')
        return sorted(match_lemmas)[0]
    return lemma

def normalize(chart):
    N = len(chart)
    norm = [max(1e-12, sum(chart[i][j] for j in range(N))) for i in range(N)]
    normalized = [[chart[i][j] / norm[i] for j in range(N)] for i in range(N)]
    return normalized

def add_root(chart):
    N = len(chart)
    chartr = [[0.0 for j in range(N + 1)] for i in range(N + 1)]
    for i in range(N):
        for j in range(N):
            if i == j:
                chartr[i + 1][0] = chart[i][i]
            else:
                chartr[i + 1][j + 1] = chart[i][j]
    return chartr

def mst(chart):
    edges = []
    N = len(chart)
    paths = {(i, i) for i in range(N)}
    for i, row in enumerate(chart):
        edges += [(weight, j, i) for j, weight in enumerate(row)]
    edges = list(reversed(sorted(edges)))
    tree = set()
    while edges and len(tree) < N - 1:
        w, i, j = edges[0]
        if w == 0:
            break
        tree.add((i, j))
        # there can be only one root!
        if i == 0:
            edges = [e for e in edges if e[1] != 0]
        new_paths = {(i, j)}
        for p1 in paths:
            for p2 in paths:
                if p1[1] == i and p2[0] == j:
                    new_path = (p1[0], p2[1])
                    new_paths.add(new_path)
        paths.update(new_paths)
        edges = [e for e in edges if e[2] != j and (e[2], e[1]) not in paths]
    assert(len(tree) == N - 1)
    return tree

def get_heads(matrix):
    chart_n = normalize(matrix)
    chart_r = add_root(chart_n)
    heads = {dep - 1: hd - 1 for hd, dep in mst(chart_r) if dep > 0}
    return [hd for dep, hd in sorted(heads.items())]

def make_head_path(path, token, lemma, tag, index):
    node = tree_utils.TreeNode(0, (token, lemma, tag), True, {'index' : index}, [])
    if path == EMPTY:
        return node
    for cat in reversed(path.split('_')):
        node = tree_utils.TreeNode(0, cat, True, {}, [node])
    return node

def append_dependent(head_path, dep_path, anchor_cat, anchor_h, decompress=False):
    problem = None
    #print('APPENDING:', dep_path.to_brackets())
    #print('TO:', head_path.to_brackets())
    #print('AT:', anchor_cat, 'H:', repr(anchor_h))
    anchor_node = head_path
    head_chain = [head_path]
    while head_chain[-1].children:
        heads = [child for child in head_chain[-1].children if child.is_head]
        try:
            assert len(heads) == 1
            head_chain.append(heads[0])
        except:
            head_path.pretty_print()
            head_chain[-1].pretty_print()
            raise
    matching_heads = [hd for hd in head_chain if hd.category == anchor_cat]
    if not matching_heads:
        #hp = '->'.join(h.category for h in head_chain if type(h.category) == str)
        #print(f'COULDN’T FIND {anchor_cat} ANCHOR FOR {dep_path.category}, APPENDING TO TOP: {hp}')
        anchor_node = head_path
        problem = (dep_path, anchor_cat)
    else:
        # fallback: take the highest
        if anchor_h is None:
            anchor_h = len(matching_heads)
        elif anchor_h > len(matching_heads) and not decompress:
            anchor_h = len(matching_heads)
        elif anchor_h > len(matching_heads) and decompress:
            #print([hd.category for hd in head_chain])
            #print(anchor_cat, anchor_h)
            expand = anchor_h - len(matching_heads)
            #print('NEED TO ADD', expand, anchor_cat, 'NODE(S)')
            #head_path.pretty_print()
            for i in range(expand):
                child = matching_heads[0]
                parent = child.parent
                new_node = tree_utils.TreeNode(0, anchor_cat, child.is_head, {}, [child])
                child.is_head = True
                if parent:
                    parent.children.remove(child)
                    parent.add_child(new_node)
                if head_path == child:
                    head_path = new_node
                matching_heads.insert(0, new_node)
            #head_path.pretty_print()
        anchor_node = matching_heads[-anchor_h]
    anchor_node.add_child(dep_path)
    dep_path.is_head = False
    return head_path, problem

def rearrange(tree, compress=False):
    if tree.children:
        children = [rearrange(child, compress=compress) for child in tree.children]
        tree.children = sorted(children, key=lambda child: child.features['from'])
        tree.features['from'] = tree.children[0].features['from']
        tree.features['to'] = max(child.features['to'] for child in tree.children)
    else:
        tree.features['from'] = tree.features['index']
        tree.features['to'] = tree.features['index'] + 1
    return tree

POS2PRE = {
    'aglt' : 'aglt',
    'conj' : 'spójnik',
    'dig' : 'formalicz',
    'fin' : 'formaczas',
    'interp' : 'punct',
    'num' : 'formalicz',
    'praet' : 'formaczas',
    '_' : '???',
}

def add_dummy_pre(path):
    pos = path.category[2].split(':')[0]
    return tree_utils.TreeNode(0, f'{POS2PRE[pos]}', True, {}, [path])

def check_no_cycles(heads):
    children_dict = defaultdict(set)
    for i, hd in enumerate(heads):
        children_dict[hd].add(i)
    visited = {-1}
    queue = [-1]
    while queue:
        children = children_dict.pop(queue[0], set())
        if children.intersection(visited):
            return False
        visited.update(children)
        queue = queue[1:] + list(children)
    return not children

# find any <EMPTY>’s children and reattach them to <EMPTY>’s head
def try_reattach(heads, spines, tokens):
    
    to_reattach = [i for i, spine in enumerate(spines) if spine == EMPTY and i in heads]
    k = 0
    while to_reattach:
        k += 1
        if k == 50:
            1/0
        print('------- TO REATTACH:')
        for i in to_reattach:
            print('    ===>', i, tokens[i], '->', [t for t, h in zip(tokens, heads) if h == i])
        tr = to_reattach[0]
        hd_idx = heads[tr]
        assert(hd_idx != -1)
        new_heads = heads.copy()
        for i, hd in enumerate(heads):
            if hd == tr:
                new_heads[i] = hd_idx
        if check_no_cycles(new_heads):
            heads = new_heads
        else:
            print('can’t reattach (cycle!)')
            1/0
        to_reattach = [i for i, spine in enumerate(spines) if spine == EMPTY and i in heads]
    
    return heads

def reconstruct_tree(tokens, tags, decompress=False, root_label='ROOT'):
    #tokens, heads, deprels, spines, anchors, anchor_hs
    morph_tags = tags.get(TAGS, ['_' for _ in tokens])
    lemmas = tags.get(LEMMAS, ['_' for _ in tokens])
    heads = tags[HEADS]
    deprels = tags[DEPRELS]
    spines = tags[SPINES]
    anchors = tags[ANCHORS]
    anchor_hs = tags[ANCHOR_HS] 
    problems = []
    spines = list(spines)
    children = defaultdict(set)
    
    for i, (head, spine) in enumerate(zip(heads, spines)):
        if (head == -1 and root_label not in spine):
            spines[i] = (f'{root_label}_' + spines[i]).replace(f'_{EMPTY}', '')
        if (head != -1 and root_label in spine):
            assert(spines[i].startswith(root_label))
            spines[i] = EMPTY if spines[i] == root_label else spines[i].replace(f'{root_label}_', '')
    
    reattach = False
    
    #print(heads)
    #print(spines)
    new_heads = try_reattach(heads, spines, tokens)
    if new_heads != heads:
        heads = new_heads
        reattach = True
    
    head_paths = [
        make_head_path(p, tok, lemma, tag, i) for i, (p, tok, lemma, tag)
        in enumerate(zip(spines, tokens, lemmas, morph_tags))
    ]
    
    # if a spine is some other spine’s head and consists of a terminal only, add a dummy preterminal to append to
    head_paths = [add_dummy_pre(hp) if (type(hp.category) == tuple and i in heads) else hp for i, hp in enumerate(head_paths)]
    
    anchor_hs = [int(h) if h != '<ROOT>' else None for h in anchor_hs]
    root = None
    for i, head in enumerate(heads):
        if deprels[i] is not None:
            head_paths[i].features['deprel'] = deprels[i]
        if head == -1:
            root = i
        else:
            head_paths[head], problem = append_dependent(
                head_paths[head], head_paths[i], anchors[i], anchor_hs[i], decompress=decompress)
            if problem:
                problems.append(problem)
    if reattach:
        problems.append('reattach')
    return rearrange(head_paths[root], compress=(not decompress)), problems

def make_tree(tokens, tags, root_label, decompress=True):
    tree, problems = reconstruct_tree(tokens, tags, decompress=decompress, root_label=root_label)
    return tree

# copied and modified from parser_server.py
def _node2dict(tree):
    children = [_node2dict(child) for child in tree.children]
    leaves = tree.get_yield()
    #tok_indices = [leaf.features['index'] for leaf in leaves]
    node = {
        'is_head' : tree.is_head,
        'span' : {'from' : tree.features['from'], 'to' : tree.features['to']},
    }
    if 'deprel' in tree.features:
        node['deprel'] = tree.features['deprel']
    if tree.features:
        node['attributes'] = tree.features
    if children:
        node.update({
            'category' : tree.category,
            'children' : children,
        })
    else:
        node.update(dict(zip(('orth', 'base', 'tag'), tree.category)))
    return node

def tree2dict(tree, metadata={}):
    return {'tree' : _node2dict(tree), 'metadata' : metadata}