tree_utils.py
2.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from itertools import chain
def get_node_text(node):
return ''.join(child.data for child in node.childNodes)
EMPTY_VAL = 'x'
class TreeNode(object):
def __init__(self, nid, category, is_head, features, children):
self.nid = nid
self.parent = None
self.category = category
self.is_head = is_head
self.features = features
self.children = children
for child in self.children:
child.parent = self
def add_child(self, child):
self.children.append(child)
child.parent = self
def get_yield(self):
if not self.children:
return [self]
return list(chain.from_iterable(child.get_yield() for child in self.children))
def get_root(self):
root = self
while root.parent is not None:
root = root.parent
return root
def pretty_print(self, tab='', features=[]):
print(f'{tab}[{self.nid}] {"*" if self.is_head else ""}{self.category}{[self.features.get(f, EMPTY_VAL) for f in features]}')
for child in self.children:
child.pretty_print(tab=tab + ' ', features=features)
def make_evalb_friendly(self, s):
return s.replace(' ', '_').replace('(', 'LPAR').replace(')', 'RPAR')
def to_brackets(self, features=[], mark_head=True, mark_head_terminals=False, morph_tags=False, dummy_pre=True):
if mark_head_terminals:
raise NotImplemented
if not self.children:
cat = self.make_evalb_friendly(self.category[0])
if morph_tags:
cat = f'({self.category[2]} {cat})'
#if mark_head_terminals and self.is_head:
# cat = '*' + cat
if len(self.parent.children) == 1 or not dummy_pre:
return cat
else:
return f'(DUMMY_PRE {cat})'
else:
cat = '_'.join([self.category] + [self.features.get(f, EMPTY_VAL) for f in features])
if mark_head and self.is_head:
cat = '*' + cat
cat = self.make_evalb_friendly(cat)
return f'({cat} {" ".join(child.to_brackets(features=features, mark_head=mark_head, mark_head_terminals=mark_head_terminals, morph_tags=morph_tags, dummy_pre=dummy_pre) for child in self.children)})'