|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
'''
Created on Oct 8, 2013
@author: mlenart
'''
import state
import register
import logging
class FSA(object):
'''
A finite state automaton
'''
|
|
17
|
def __init__(self, encoder, encodeData=True, encodeWords=True):
|
|
18
|
self.encodeWord = encoder.encodeWord if encodeWords else lambda x: x
|
|
19
20
|
self.encodeData = encoder.encodeData if encodeData else lambda x: x
self.decodeData = encoder.decodeData if encodeData else lambda x: x
|
|
21
|
self.encodedPrevWord = None
|
|
22
|
|
|
23
|
# self.tagset = tagset
|
|
24
25
|
self.initialState = state.State()
self.register = register.Register()
|
|
26
|
self.label2Freq = {}
|
|
27
28
|
self.n = 0
self.closed = False
|
|
29
|
|
|
30
31
|
def tryToRecognize(self, word, addFreq=False):
return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq))
|
|
32
|
|
|
33
34
35
36
37
38
39
|
def addEntry(self, word, data):
assert not self.closed
assert data is not None
encodedWord = self.encodeWord(word)
assert encodedWord > self.encodedPrevWord
self._addSorted(encodedWord, self.encodeData(data))
self.encodedPrevWord = encodedWord
|
|
40
|
|
|
41
|
self.n += 1
|
|
42
|
|
|
43
|
# debug
|
|
44
|
if self.n < 10 or (self.n < 10000 and self.n % 1000 == 0) or self.n % 10000 == 0:
|
|
45
|
logging.info(u'%d %s' % (self.n, word))
|
|
46
47
48
49
50
51
52
|
for label in encodedWord:
self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
def close(self):
assert self.n > 0
assert not self.closed
self.initialState = self._replaceOrRegister(self.initialState, self.encodedPrevWord)
|
|
53
|
self.encodedPrevWord = None
|
|
54
55
|
self.closed = True
|
|
56
|
def train(self, trainData):
|
|
57
|
self.label2Freq = {}
|
|
58
|
for idx, word in enumerate(trainData):
|
|
59
60
61
|
self.tryToRecognize(word, addFreq=True)
for label in self.encodeWord(word):
self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
|
|
62
63
64
65
66
67
|
if idx % 100000 == 0:
logging.info(str(idx))
def dfs(self):
for state in self.initialState.dfs(set()):
yield state
|
|
68
69
70
|
def getStatesNum(self):
return self.register.getStatesNum()
|
|
71
72
73
74
75
76
|
def getTransitionsNum(self):
res = 0
for s in self.initialState.dfs(set()):
res += len(s.transitionsMap)
return res
|
|
77
78
79
|
def _addSorted(self, encodedWord, data):
assert self.encodedPrevWord < encodedWord
|
|
80
|
assert type(data) == bytearray
|
|
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
|
q = self.initialState
i = 0
while i <= len(encodedWord) and q.hasNext(encodedWord[i]):
q = q.getNext(encodedWord[i])
i += 1
if self.encodedPrevWord and i < len(self.encodedPrevWord):
nextState = q.getNext(self.encodedPrevWord[i])
q.setTransition(
self.encodedPrevWord[i],
self._replaceOrRegister(nextState, self.encodedPrevWord[i+1:]))
while i < len(encodedWord):
q.setTransition(encodedWord[i], state.State())
q = q.getNext(encodedWord[i])
i += 1
assert q.encodedData is None
q.encodedData = data
def _replaceOrRegister(self, q, encodedWord):
if encodedWord:
nextState = q.getNext(encodedWord[0])
q.setTransition(
encodedWord[0],
self._replaceOrRegister(nextState, encodedWord[1:]))
if self.register.containsEquivalentState(q):
return self.register.getEquivalentState(q)
else:
self.register.addState(q)
return q
def calculateOffsets(self, sizeCounter):
|
|
114
|
self.initialState.calculateOffsets(sizeCounter)
|
|
115
116
117
118
|
def debug(self):
for state in self.initialState.dfs(set()):
state.debug()
|