Blame view

swigra/disambiguator-pcfg/node.py 5.65 KB
Jan Lupa authored
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# -*- encoding: utf-8 -*-
__author__ = 'nika'

import random, time
random.seed(time.time())

class Node(object):
    def __init__(self, _args):
        # nid="11" from="3" to="5" subtrees="1" chosen="false"
        self.nid = _args["nid"]
        self.fro = int(_args["from"])
        self.to = int(_args["to"])
        self.subtress = int(_args["subtrees"])
        if "chosen" in _args and _args["chosen"] == "false":
            self.chosen = False
        else:
            self.chosen = True
        self.terminal = False
        self.nonterminal = False
        self.arguments = {}
        self.category = u'terminal'
        self.children = []
        self.parents = {}

    def __unicode__(self):
        return "NID: "+unicode(self.nid)+" "+unicode(self.arguments)

    def getFromToCat(self):
        return unicode(self.getFrom()) + '@' + unicode(self.getTo()) + '@' +unicode(self.getCategory())

    def getArgDict(self, prefix):
        d = {}
        d[prefix+'@'+unicode(self.category)] = 1
        #print self.nid, self.category
        for key in self.arguments.keys():
            if key != "rekcja" and key != "poz":# and not(key == 'ink' and self.category != 'zdanie'):
                d[prefix+'@'+key+'@'+self.arguments[key]] = 1
        return d

    def make_exp_data(self, leaves,dom):
        d = {}
        for ii in range(len(leaves)):
            try:
                d['1gram@base@'+ leaves[ii]["base"]]  = 1
            except KeyError:
                pass
            for jj in leaves[ii]["tag"].split(':'):
                d['1gram@tag@'+jj] = 1
            #if ii < len(leaves)-1:
            #    #d['2gram@tag@' + leaves[ii]["tag"].split(':')[0] + '@' +leaves[ii+1]["tag"].split(':')[0]] = 1
            #    try:
            #        d['2gram@base@' + dom[leaves[ii]["base"]] + '@' + dom[leaves[ii+1]["base"]]] = 1
            #    except KeyError:
            #        pass

            #if ii < len(leaves)-2:
            #    d['3gram@tag@' + leaves[ii]["tag"].split(':')[0] + '@' +leaves[ii+1]["tag"].split(':')[0] + '@' + leaves[ii+2]["tag"].split(':')[0]] = 1
        return d

    def isTerminal(self):
        return self.terminal

    def isChosen(self):
        return self.chosen

    def addChildren(self, child):
        self.children.append(child)

    def addArgument(self, arg, value):
        self.arguments[arg] = value

    def setTerminal(self):
        self.terminal = True

    def setNonterminal(self):
        self.nonterminal = True

    def setCategory(self, _cat):
        self.category = _cat

    def getRandomChildren(self):
        return self.children[random.randint(0,len(self.children)-1)]

    def getAllChildren(self):
        #print self.children
        return self.children

    def getChildrenForActPcfg(self):
        ch = [(self.children['centre'], u'true')]
        ch += map(lambda x: (x,u'false'), self.children['productions'])
        return ch

    def getRawChildren(self):
        ch = [(self.children['centre'], u'true')]
        ch += map(lambda x: (x,u'false'), self.children['productions'])
        return ch

    def getCategory(self):
        cat = unicode(self.category)
        if self.category == "fw":
            cat += '@' + unicode(self.arguments['tfw'])
        if "przypadek" in self.arguments.keys():
            pass#cat += '@' + unicode(self.arguments["przypadek"])
        if "rodzaj" in self.arguments.keys():
            pass#cat += '@' + unicode(self.arguments["rodzaj"])
        if "liczba" in self.arguments.keys():
            pass#cat += '@' + unicode(self.arguments["liczba"])
        if "osoba" in self.arguments.keys():
            pass#cat += '@' + unicode(self.arguments["osoba"])
        return cat

    def getRodzaj(self):
        if 'rodzaj' in self.arguments.keys():
            return self.arguments['rodzaj']
        else:
            return None

    def getLiczba(self):
        if 'liczba' in self.arguments.keys():
            return self.arguments['liczba']
        else:
            return None

    def getOsoba(self):
        if 'osoba' in self.arguments.keys():
            return self.arguments['osoba']
        else:
            return None


    def getExtCategory(self):
        #if self.category == "fw":
        #    return unicode(self.category) + '@' + unicode(self.arguments['tfw']) + '@' + unicode(self.getRodzaj()) + '@' + unicode(self.getLiczba()) + '@' +unicode(self.getOsoba())
        return unicode(self.getCategory())# + '@' + unicode(self.getRodzaj()) + '@' + unicode(self.getLiczba()) + '@' +unicode(self.getOsoba())

    def getNode(self):
        return self.getCategory(), self.children

    def getArguments(self):
        return self.arguments

    def getID(self):
        return self.nid

    def getFrom(self):
        return self.fro

    def getTo(self):
        return self.to

    def equals(self, node):
        equal = True
        for arg in self.arguments.keys():
            if arg in node.arguments.keys():
                if not self.arguments[arg] == node.arguments[arg]:
                    equal = False
                    break
            else:
                equal = False
                break
        return equal and self.equals_from_to_cat(node)

    def equals_from_to_cat(self, node):
        return self.getFrom() == node.getFrom() and self.getTo() == node.getTo() and self.getCategory() == node.getCategory()

    def equals_from_to(self, node):
        return self.getFrom() == node.getFrom() and self.getTo() == node.getTo()

    def overlaps(self, node):
        #print self.getFrom(), self.getTo(), node.getFrom(), node.getTo(), self.getFrom() <= node.getTo() and self.getTo() >= node.getFrom()
        return self.getFrom() <= node.getTo() and self.getTo() >= node.getFrom()