fsa.py
3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
'''
Created on Oct 8, 2013
@author: mlenart
'''
import state
import register
import logging
class FSA(object):
'''
A finite state automaton
'''
def __init__(self, encoder, encodeData=True, encodeWords=True):
self.encodeWord = encoder.encodeWord if encodeWords else lambda x: x
self.encodeData = encoder.encodeData if encodeData else lambda x: x
self.decodeData = encoder.decodeData if encodeData else lambda x: x
self.encodedPrevWord = None
# self.tagset = tagset
self.initialState = state.State()
self.register = register.Register()
self.label2Freq = {}
self.n = 0
self.closed = False
def tryToRecognize(self, word, addFreq=False):
return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq))
def addEntry(self, word, data):
assert not self.closed
assert data is not None
encodedWord = self.encodeWord(word)
assert encodedWord > self.encodedPrevWord
self._addSorted(encodedWord, self.encodeData(data))
self.encodedPrevWord = encodedWord
self.n += 1
# debug
if self.n < 10 or (self.n < 10000 and self.n % 1000 == 0) or self.n % 10000 == 0:
logging.info(u'%d %s' % (self.n, word))
for label in encodedWord:
self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
def close(self):
assert self.n > 0
assert not self.closed
self.initialState = self._replaceOrRegister(self.initialState, self.encodedPrevWord)
self.encodedPrevWord = None
self.closed = True
def train(self, trainData):
self.label2Freq = {}
for idx, word in enumerate(trainData):
self.tryToRecognize(word, addFreq=True)
for label in self.encodeWord(word):
self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
if idx % 100000 == 0:
logging.info(str(idx))
def dfs(self):
for state in self.initialState.dfs(set()):
yield state
def getStatesNum(self):
return self.register.getStatesNum()
def getTransitionsNum(self):
res = 0
for s in self.initialState.dfs(set()):
res += len(s.transitionsMap)
return res
def _addSorted(self, encodedWord, data):
assert self.encodedPrevWord < encodedWord
assert type(data) == bytearray
q = self.initialState
i = 0
while i <= len(encodedWord) and q.hasNext(encodedWord[i]):
q = q.getNext(encodedWord[i])
i += 1
if self.encodedPrevWord and i < len(self.encodedPrevWord):
nextState = q.getNext(self.encodedPrevWord[i])
q.setTransition(
self.encodedPrevWord[i],
self._replaceOrRegister(nextState, self.encodedPrevWord[i+1:]))
while i < len(encodedWord):
q.setTransition(encodedWord[i], state.State())
q = q.getNext(encodedWord[i])
i += 1
assert q.encodedData is None
q.encodedData = data
def _replaceOrRegister(self, q, encodedWord):
if encodedWord:
nextState = q.getNext(encodedWord[0])
q.setTransition(
encodedWord[0],
self._replaceOrRegister(nextState, encodedWord[1:]))
if self.register.containsEquivalentState(q):
return self.register.getEquivalentState(q)
else:
self.register.addState(q)
return q
def calculateOffsets(self, sizeCounter):
self.initialState.calculateOffsets(sizeCounter)
def debug(self):
for state in self.initialState.dfs(set()):
state.debug()