tree_utils.py 2.29 KB
from itertools import chain

def get_node_text(node):
    return ''.join(child.data for child in node.childNodes)

EMPTY_VAL = 'x'

class TreeNode(object):
    
    def __init__(self, nid, category, is_head, features, children):
        self.nid = nid
        self.parent = None
        self.category = category
        self.is_head = is_head
        self.features = features
        self.children = children
        for child in self.children:
            child.parent = self
    
    def add_child(self, child):
        self.children.append(child)
        child.parent = self
    
    def get_yield(self):
        if not self.children:
            return [self]
        return list(chain.from_iterable(child.get_yield() for child in self.children))
    
    def get_root(self):
        root = self
        while root.parent is not None:
            root = root.parent
        return root
    
    def pretty_print(self, tab='', features=[]):
        print(f'{tab}[{self.nid}] {"*" if self.is_head else ""}{self.category}{[self.features.get(f, EMPTY_VAL) for f in features]}')
        for child in self.children:
            child.pretty_print(tab=tab + '    ', features=features)
     
    def make_evalb_friendly(self, s):
        return s.replace(' ', '_').replace('(', 'LPAR').replace(')', 'RPAR')
    
    def to_brackets(self, features=[], mark_head=True, mark_head_terminals=False, morph_tags=False, dummy_pre=True):
        if mark_head_terminals:
            raise NotImplemented
        if not self.children:
            cat = self.make_evalb_friendly(self.category[0])
            if morph_tags:
                cat = f'({self.category[2]} {cat})'
            #if mark_head_terminals and self.is_head:
            #    cat = '*' + cat
            if len(self.parent.children) == 1 or not dummy_pre:
                return cat
            else:
                return f'(DUMMY_PRE {cat})'
        else:
            cat = '_'.join([self.category] + [self.features.get(f, EMPTY_VAL) for f in features])
            if mark_head and self.is_head:
                cat = '*' + cat
            cat = self.make_evalb_friendly(cat)
            return f'({cat} {" ".join(child.to_brackets(features=features, mark_head=mark_head, mark_head_terminals=mark_head_terminals, morph_tags=morph_tags, dummy_pre=dummy_pre) for child in self.children)})'