fsa.py
4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
'''
Created on Oct 8, 2013
@author: mlenart
'''
import state
import register
import logging
class FSA(object):
'''
A finite state automaton
'''
def __init__(self, encoder):
self.encodeWord = encoder.encodeWord
self.encodeData = encoder.encodeData
self.decodeData = encoder.decodeData
self.encodedPrevWord = None
self.initialState = state.State()
self.register = register.Register()
self.label2Freq = {0: float('inf')}
def tryToRecognize(self, word, addFreq=False):
return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq))
def feed(self, input):
# allWords = []
for n, (word, data) in enumerate(input, start=1):
assert data is not None
if type(data) in [str, unicode]:
data = [data]
encodedWord = self.encodeWord(word)
assert encodedWord >= self.encodedPrevWord
if encodedWord > self.encodedPrevWord:
self._addSorted(encodedWord, self.encodeData(data))
self.encodedPrevWord = encodedWord
assert self.tryToRecognize(word) == data
if n % 10000 == 0:
logging.info(word)
# allWords.append(word)
for label in encodedWord:
self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
self.initialState = self._replaceOrRegister(self.initialState, self.encodeWord(word))
self.encodedPrevWord = None
# for w in allWords:
# self.tryToRecognize(w, True)
def train(self, trainData):
self.label2Freq = {0: float('inf')}
for idx, word in enumerate(trainData):
self.tryToRecognize(word, addFreq=True)
for label in self.encodeWord(word):
self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
if idx % 100000 == 0:
logging.info(str(idx))
def dfs(self):
for state in self.initialState.dfs(set()):
yield state
def getStatesNum(self):
return self.register.getStatesNum()
def getTransitionsNum(self):
res = 0
for s in self.initialState.dfs(set()):
res += len(s.transitionsMap)
return res
def _addSorted(self, encodedWord, data):
assert self.encodedPrevWord < encodedWord
q = self.initialState
i = 0
while i <= len(encodedWord) and q.hasNext(encodedWord[i]):
q = q.getNext(encodedWord[i])
i += 1
if self.encodedPrevWord and i < len(self.encodedPrevWord):
nextState = q.getNext(self.encodedPrevWord[i])
q.setTransition(
self.encodedPrevWord[i],
self._replaceOrRegister(nextState, self.encodedPrevWord[i+1:]))
while i < len(encodedWord):
q.setTransition(encodedWord[i], state.State())
q = q.getNext(encodedWord[i])
i += 1
assert q.encodedData is None
# print q, encodedData
q.encodedData = data
def _replaceOrRegister(self, q, encodedWord):
if encodedWord:
nextState = q.getNext(encodedWord[0])
q.setTransition(
encodedWord[0],
self._replaceOrRegister(nextState, encodedWord[1:]))
if self.register.containsEquivalentState(q):
return self.register.getEquivalentState(q)
else:
self.register.addState(q)
return q
def calculateOffsets(self, sizeCounter):
currReverseOffset = 0
for state in self.initialState.dfs(set()):
currReverseOffset += sizeCounter(state)
state.reverseOffset = currReverseOffset
for state in self.initialState.dfs(set()):
state.offset = currReverseOffset - state.reverseOffset