hybrid_tree_utils.py 14.1 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
from collections import defaultdict
from itertools import chain

from .constants import (
    EMPTY,
    TOKENS,
    LEMMAS,
    UPPERCASE,
    TAGS,
    SPINES,
    ANCHORS,
    ANCHOR_HS,
    HEADS,
    DEPRELS,
)

class TreeNode(object):
    
    def __init__(self, nid, category, is_head, from_index, to_index, head_index=None, deprel=None, attributes=None, children=None):
        self.nid = nid
        self.parent = None
        self.category = category
        self.is_head = is_head
        self.from_index = from_index
        self.to_index = to_index
        self.head_index = head_index
        self.deprel = deprel
        self.attributes = attributes if attributes is not None else {}
        self.children = children if children is not None else []
        for child in self.children:
            child.parent = self
    
    def add_child(self, child):
        self.children.append(child)
        child.parent = self
    
    def get_yield(self):
        if not self.children:
            return [self]
        return list(chain.from_iterable(child.get_yield() for child in self.children))
    
    def is_continuous(self):
        idx = [token.from_index for token in self.get_yield()]
        return (idx == sorted(idx))
    
    def get_root(self):
        root = self
        while root.parent is not None:
            root = root.parent
        return root
    
    def get_head_child(self):
        if tree.children:
            heads = [child for child in tree.children if child.is_head]
            assert(len(heads) == 1)
            return heads[0]
        return None
    
    def get_head_token(self):
        while tree.children:
            tree = tree.get_head_child
        return tree
    
    def make_evalb_friendly(self, s):
        return s.replace(' ', '_').replace('(', 'LPAR').replace(')', 'RPAR')
    
    def to_brackets(self, features=[], mark_head=True, mark_head_terminals=False, morph_tags=False, dummy_pre=False):
        if mark_head_terminals:
            raise NotImplementedError
        if not self.children:
            cat = self.make_evalb_friendly(self.category[0])
            if morph_tags:
                cat = f'({self.category[2]} {cat})'
            #if mark_head_terminals and self.is_head:
            #    cat = '*' + cat
            if len(self.parent.children) == 1 or not dummy_pre:
                return cat
            else:
                return f'(DUMMY_PRE {cat})'
        else:
            cat = '_'.join([self.category] + [self.attributes.get(f, EMPTY_VAL) for f in features])
            if mark_head and self.is_head:
                cat = '*' + cat
            cat = self.make_evalb_friendly(cat)
            return f'({cat} {" ".join(child.to_brackets(features=features, mark_head=mark_head, mark_head_terminals=mark_head_terminals, morph_tags=morph_tags, dummy_pre=dummy_pre) for child in self.children)})'
    
    def pretty_print(self, tab='', features=[]):
        print(f'{tab}[{self.nid}] {"*" if self.is_head else ""}{self.category}{[self.attributes.get(f, EMPTY_VAL) for f in features]}')
        for child in self.children:
            child.pretty_print(tab=tab + '    ', features=features)      

def make_lemma(token, case, rule):
    cut_prefix, cut, suffix = rule.split('_', maxsplit=2)
    cut_prefix, cut = int(cut_prefix), int(cut)
    lemma = token.lower()
    if cut_prefix:
        lemma = lemma[cut_prefix:]
    if cut:
        lemma = lemma[:-cut]
    lemma += suffix
    # failsafe
    if not lemma:
        return token
    if case == UPPERCASE:
        lemma = lemma[0].upper() + lemma[1:]
    return lemma

def correct_lemma(token, lemma, tag, morfeusz):
    match = [interp[2] for interp in morfeusz.analyse(token) if interp[2][0] == token and interp[2][2] == tag]
    match_lemmas = set(interp[1].split(':')[0] if ':' not in interp[0] else interp[1] for interp in match)
    if match_lemmas and lemma.lower() not in [l.lower() for l in match_lemmas]:
        if len(match_lemmas) != 1:
            print(token, lemma, tag, '->', sorted(match_lemmas))
            print('    ---> >1 matching lemma, will take alphabetically first!')
        return sorted(match_lemmas)[0]
    return lemma

def normalize(chart):
    N = len(chart)
    norm = [max(1e-12, sum(chart[i][j] for j in range(N))) for i in range(N)]
    normalized = [[chart[i][j] / norm[i] for j in range(N)] for i in range(N)]
    return normalized

def add_root(chart):
    N = len(chart)
    chartr = [[0.0 for j in range(N + 1)] for i in range(N + 1)]
    for i in range(N):
        for j in range(N):
            if i == j:
                chartr[i + 1][0] = chart[i][i]
            else:
                chartr[i + 1][j + 1] = chart[i][j]
    return chartr

def mst(chart):
    edges = []
    N = len(chart)
    paths = {(i, i) for i in range(N)}
    for i, row in enumerate(chart):
        edges += [(weight, j, i) for j, weight in enumerate(row)]
    edges = list(reversed(sorted(edges)))
    tree = set()
    while edges and len(tree) < N - 1:
        w, i, j = edges[0]
        if w == 0:
            break
        tree.add((i, j))
        # there can be only one root!
        if i == 0:
            edges = [e for e in edges if e[1] != 0]
        new_paths = {(i, j)}
        for p1 in paths:
            for p2 in paths:
                if p1[1] == i and p2[0] == j:
                    new_path = (p1[0], p2[1])
                    new_paths.add(new_path)
        paths.update(new_paths)
        edges = [e for e in edges if e[2] != j and (e[2], e[1]) not in paths]
    assert(len(tree) == N - 1)
    return tree

def _do_mark_heads(tree, dependency_heads):
    if not tree.children:
        return tree.from_index
    child_head_idx = [_do_mark_heads(child, dependency_heads) for child in tree.children]
    heads = []
    for child, child_head_id in zip(tree.children, child_head_idx):
        if dependency_heads[child_head_id] not in child_head_idx:
            child.is_head = True
            heads.append(child_head_id)
    assert(len(heads) == 1)
    return heads[0]

def _mark_heads(tree, dependency_heads):
    _do_mark_heads(tree, dependency_heads)

def _rearrange(tree):
    if tree.children:
        children = [_rearrange(child) for child in tree.children]
        tree.children = sorted(children, key=lambda child: child.from_index)
        tree.from_index = tree.children[0].from_index
        tree.to_index = max(child.to_index for child in tree.children)
    return tree

def tree_from_dataset_instance(instance, dataset_features):
    nonterminals = instance['nonterminals']
    nodes = []
    
    # create a node for each nonterminal
    for i, nonterminal in enumerate(nonterminals):
        category = nonterminal['cat']
        if category is None:
            assert(len(nonterminal['children']) == 1)
            token_idx = nonterminal['children'][0]
            orth = instance[TOKENS][token_idx]
            lemma = instance[LEMMAS][token_idx]
            tag = dataset_features[TAGS].feature.int2str(instance[TAGS][token_idx])
            category = (orth, lemma, tag)
            nodes.append(TreeNode(i, category, False, token_idx, token_idx + 1))
        else:
            nodes.append(TreeNode(i, category, False, 0, 0))
    
    # link the nodes with parent-head relations
    for nonterminal, dupa in list(zip(nonterminals, nodes)):
        if nonterminal['cat'] is not None:
            for child_idx in nonterminal['children']:
                dupa.add_child(nodes[child_idx])
    
    tree = nodes[0]
    # mark the heads according to dependency relations
    _mark_heads(tree, instance['heads'])
    
    return _rearrange(tree)

def get_heads(matrix):
    chart_n = normalize(matrix)
    chart_r = add_root(chart_n)
    heads = {dep - 1: hd - 1 if hd > 0 else None for hd, dep in mst(chart_r) if dep > 0}
    return [hd for dep, hd in sorted(heads.items())]

def make_head_path(path, token, lemma, tag, index):
    node = TreeNode(0, (token, lemma, tag), False, index, index + 1)
    if path == EMPTY:
        return node
    for cat in reversed(path.split('_')):
        node.is_head = True
        node = TreeNode(0, cat, False, None, None, children=[node])
    return node

def append_dependent(head_path, dep_path, anchor_cat, anchor_h, decompress=False):
    problem = None
    anchor_node = head_path
    head_chain = [head_path]
    while head_chain[-1].children:
        heads = [child for child in head_chain[-1].children if child.is_head]
        try:
            assert len(heads) == 1
            head_chain.append(heads[0])
        except:
            head_path.pretty_print()
            head_chain[-1].pretty_print()
            raise
    matching_heads = [hd for hd in head_chain if hd.category == anchor_cat]
    if not matching_heads:
        anchor_node = head_path
        problem = (dep_path, anchor_cat)
    else:
        # fallback: take the highest
        if anchor_h is None:
            anchor_h = len(matching_heads)
        elif anchor_h > len(matching_heads) and not decompress:
            anchor_h = len(matching_heads)
        elif anchor_h > len(matching_heads) and decompress:
            expand = anchor_h - len(matching_heads)
            for i in range(expand):
                child = matching_heads[0]
                parent = child.parent
                new_node = TreeNode(0, anchor_cat, child.is_head, None, None, children=[child])
                child.is_head = True
                if parent:
                    parent.children.remove(child)
                    parent.add_child(new_node)
                if head_path == child:
                    head_path = new_node
                matching_heads.insert(0, new_node)
        anchor_node = matching_heads[-anchor_h]
    anchor_node.add_child(dep_path)
    dep_path.is_head = False
    return head_path, problem

POS2PRE = {
    #'aglt' : 'aglt',
    #'conj' : 'spójnik',
    #'dig' : 'formalicz',
    #'fin' : 'formaczas',
    #'interp' : 'punct',
    #'num' : 'formalicz',
    #'praet' : 'formaczas',
    #'_' : '???',
}

def add_dummy_pre(path):
    pos = path.category[2].split(':')[0]
    return TreeNode(0, f'{POS2PRE[pos]}', True, {}, [path])

def check_no_cycles(heads):
    children_dict = defaultdict(set)
    for i, hd in enumerate(heads):
        children_dict[hd].add(i)
    visited = {-1}
    queue = [-1]
    while queue:
        children = children_dict.pop(queue[0], set())
        if children.intersection(visited):
            return False
        visited.update(children)
        queue = queue[1:] + list(children)
    return not children

# find any <EMPTY>’s children and reattach them to <EMPTY>’s head
def try_reattach(heads, spines, tokens):
    
    to_reattach = [i for i, spine in enumerate(spines) if spine == EMPTY and i in heads]
    k = 0
    while to_reattach:
        k += 1
        if k == 50:
            1/0
        print('------- TO REATTACH:')
        for i in to_reattach:
            print('    ===>', i, tokens[i], '->', [t for t, h in zip(tokens, heads) if h == i])
        tr = to_reattach[0]
        hd_idx = heads[tr]
        assert(hd_idx is not None)
        new_heads = heads.copy()
        for i, hd in enumerate(heads):
            if hd == tr:
                new_heads[i] = hd_idx
        if check_no_cycles(new_heads):
            heads = new_heads
        else:
            print('can’t reattach (cycle!)')
            1/0
        to_reattach = [i for i, spine in enumerate(spines) if spine == EMPTY and i in heads]
    
    return heads

def reconstruct_tree(tokens, tags, decompress=False, root_label='ROOT'):
    morph_tags = tags.get(TAGS, ['_' for _ in tokens])
    lemmas = tags.get(LEMMAS, ['_' for _ in tokens])
    heads = tags[HEADS]
    deprels = tags[DEPRELS]
    spines = tags[SPINES]
    anchors = tags[ANCHORS]
    anchor_hs = tags[ANCHOR_HS] 
    problems = []
    spines = list(spines)
    children = defaultdict(set)
    
    for i, (head, spine) in enumerate(zip(heads, spines)):
        if (head is None and root_label not in spine):
            spines[i] = (f'{root_label}_' + spines[i]).replace(f'_{EMPTY}', '')
        if (head is not None and root_label in spine):
            assert(spines[i].startswith(root_label))
            spines[i] = EMPTY if spines[i] == root_label else spines[i].replace(f'{root_label}_', '')
    
    reattach = False
    
    new_heads = try_reattach(heads, spines, tokens)
    if new_heads != heads:
        heads = new_heads
        reattach = True
    
    head_paths = [
        make_head_path(p, tok, lemma, tag, i) for i, (p, tok, lemma, tag)
        in enumerate(zip(spines, tokens, lemmas, morph_tags))
    ]
    
    # if a spine is some other spine’s head and consists of a terminal only, add a dummy preterminal to append to
    head_paths = [add_dummy_pre(hp) if (type(hp.category) == tuple and i in heads) else hp for i, hp in enumerate(head_paths)]
    
    anchor_hs = [int(h) if h != '<ROOT>' else None for h in anchor_hs]
    root = None
    for i, head in enumerate(heads):
        if deprels[i] is not None:
            head_paths[i].attributes['deprel'] = deprels[i]
        if head is None:
            root = i
        else:
            head_paths[head], problem = append_dependent(
                head_paths[head], head_paths[i], anchors[i], anchor_hs[i], decompress=decompress)
            if problem:
                problems.append(problem)
    if reattach:
        problems.append('reattach')
    
    tree = _rearrange(head_paths[root])
    
    return tree, problems

def make_tree(tokens, tags, root_label, decompress=True):
    tree, problems = reconstruct_tree(tokens, tags, decompress=decompress, root_label=root_label)
    return tree

def _node2dict(tree):
    children = [_node2dict(child) for child in tree.children]
    leaves = tree.get_yield()
    node = {
        'is_head' : tree.is_head,
        'span' : {'from' : tree.from_index, 'to' : tree.to_index},
    }
    if tree.attributes:
        node['attributes'] = tree.attributes
        if 'deprel' in tree.attributes:
            node['deprel'] = tree.attributes['deprel']
    if children:
        node.update({
            'category' : tree.category,
            'children' : children,
        })
    else:
        node.update(dict(zip(('orth', 'base', 'tag'), tree.category)))
    return node

def tree2dict(tree, metadata={}):
    return {'tree' : _node2dict(tree), 'metadata' : metadata}