|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
'''
Created on Oct 8, 2013
@author: mlenart
'''
import state
import register
import logging
class FSA(object):
'''
A finite state automaton
'''
|
|
17
18
|
def __init__(self, encoder, tagset=None, encodeData=True, encodeWords=True):
self.encodeWord = encoder.encodeWord if encodeWords else lambda x: x
|
|
19
20
|
self.encodeData = encoder.encodeData if encodeData else lambda x: x
self.decodeData = encoder.decodeData if encodeData else lambda x: x
|
|
21
|
self.encodedPrevWord = None
|
|
22
|
|
|
23
|
self.tagset = tagset
|
|
24
25
|
self.initialState = state.State()
self.register = register.Register()
|
|
26
|
self.label2Freq = {}
|
|
27
28
|
self.n = 0
self.closed = False
|
|
29
|
|
|
30
31
|
def tryToRecognize(self, word, addFreq=False):
return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq))
|
|
32
|
|
|
33
34
35
36
37
38
39
|
def addEntry(self, word, data):
assert not self.closed
assert data is not None
encodedWord = self.encodeWord(word)
assert encodedWord > self.encodedPrevWord
self._addSorted(encodedWord, self.encodeData(data))
self.encodedPrevWord = encodedWord
|
|
40
|
|
|
41
|
self.n += 1
|
|
42
|
|
|
43
|
# debug
|
|
44
|
if self.n % 10000 == 0:
|
|
45
46
47
48
49
50
51
52
53
54
|
logging.info(word)
logging.info(str(self.register.getStatesNum()))
# allWords.append(word)
for label in encodedWord:
self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
def close(self):
assert self.n > 0
assert not self.closed
self.initialState = self._replaceOrRegister(self.initialState, self.encodedPrevWord)
|
|
55
|
self.encodedPrevWord = None
|
|
56
57
|
self.closed = True
|
|
58
|
def train(self, trainData):
|
|
59
|
self.label2Freq = {}
|
|
60
|
for idx, word in enumerate(trainData):
|
|
61
62
63
|
self.tryToRecognize(word, addFreq=True)
for label in self.encodeWord(word):
self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
|
|
64
65
66
67
68
69
|
if idx % 100000 == 0:
logging.info(str(idx))
def dfs(self):
for state in self.initialState.dfs(set()):
yield state
|
|
70
71
72
|
def getStatesNum(self):
return self.register.getStatesNum()
|
|
73
74
75
76
77
78
|
def getTransitionsNum(self):
res = 0
for s in self.initialState.dfs(set()):
res += len(s.transitionsMap)
return res
|
|
79
80
81
|
def _addSorted(self, encodedWord, data):
assert self.encodedPrevWord < encodedWord
|
|
82
|
assert type(data) == bytearray
|
|
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
|
q = self.initialState
i = 0
while i <= len(encodedWord) and q.hasNext(encodedWord[i]):
q = q.getNext(encodedWord[i])
i += 1
if self.encodedPrevWord and i < len(self.encodedPrevWord):
nextState = q.getNext(self.encodedPrevWord[i])
q.setTransition(
self.encodedPrevWord[i],
self._replaceOrRegister(nextState, self.encodedPrevWord[i+1:]))
while i < len(encodedWord):
q.setTransition(encodedWord[i], state.State())
q = q.getNext(encodedWord[i])
i += 1
assert q.encodedData is None
q.encodedData = data
def _replaceOrRegister(self, q, encodedWord):
if encodedWord:
nextState = q.getNext(encodedWord[0])
q.setTransition(
encodedWord[0],
self._replaceOrRegister(nextState, encodedWord[1:]))
if self.register.containsEquivalentState(q):
return self.register.getEquivalentState(q)
else:
self.register.addState(q)
return q
def calculateOffsets(self, sizeCounter):
|
|
116
117
118
119
120
121
122
|
self.initialState.calculateOffsets(sizeCounter)
# currReverseOffset = 0
# for state in self.initialState.dfs(set()):
# currReverseOffset += sizeCounter(state)
# state.reverseOffset = currReverseOffset
# for state in self.initialState.dfs(set()):
# state.offset = currReverseOffset - state.reverseOffset
|
|
123
124
125
126
|
def debug(self):
for state in self.initialState.dfs(set()):
state.debug()
|