|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
'''
Created on Oct 8, 2013
@author: mlenart
'''
import state
import register
import logging
class FSA(object):
'''
A finite state automaton
'''
|
|
17
|
def __init__(self, encoder, tagset=None):
|
|
18
19
20
21
|
self.encodeWord = encoder.encodeWord
self.encodeData = encoder.encodeData
self.decodeData = encoder.decodeData
self.encodedPrevWord = None
|
|
22
|
self.tagset = tagset
|
|
23
24
|
self.initialState = state.State()
self.register = register.Register()
|
|
25
|
self.label2Freq = {}
|
|
26
|
|
|
27
28
|
def tryToRecognize(self, word, addFreq=False):
return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq))
|
|
29
30
31
|
def feed(self, input):
|
|
32
|
# allWords = []
|
|
33
34
35
|
for n, (word, data) in enumerate(input, start=1):
assert data is not None
encodedWord = self.encodeWord(word)
|
|
36
|
assert encodedWord > self.encodedPrevWord
|
|
37
38
39
|
if encodedWord > self.encodedPrevWord:
self._addSorted(encodedWord, self.encodeData(data))
self.encodedPrevWord = encodedWord
|
|
40
|
# assert self.tryToRecognize(word) == data
|
|
41
42
|
if n % 10000 == 0:
logging.info(word)
|
|
43
|
logging.info(str(self.register.getStatesNum()))
|
|
44
45
46
|
# allWords.append(word)
for label in encodedWord:
self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
|
|
47
48
49
|
self.initialState = self._replaceOrRegister(self.initialState, self.encodeWord(word))
self.encodedPrevWord = None
|
|
50
|
|
|
51
52
53
54
|
# for w in allWords:
# self.tryToRecognize(w, True)
def train(self, trainData):
|
|
55
|
self.label2Freq = {}
|
|
56
|
for idx, word in enumerate(trainData):
|
|
57
58
59
|
self.tryToRecognize(word, addFreq=True)
for label in self.encodeWord(word):
self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
|
|
60
61
62
63
64
65
|
if idx % 100000 == 0:
logging.info(str(idx))
def dfs(self):
for state in self.initialState.dfs(set()):
yield state
|
|
66
67
68
|
def getStatesNum(self):
return self.register.getStatesNum()
|
|
69
70
71
72
73
74
|
def getTransitionsNum(self):
res = 0
for s in self.initialState.dfs(set()):
res += len(s.transitionsMap)
return res
|
|
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
|
def _addSorted(self, encodedWord, data):
assert self.encodedPrevWord < encodedWord
q = self.initialState
i = 0
while i <= len(encodedWord) and q.hasNext(encodedWord[i]):
q = q.getNext(encodedWord[i])
i += 1
if self.encodedPrevWord and i < len(self.encodedPrevWord):
nextState = q.getNext(self.encodedPrevWord[i])
q.setTransition(
self.encodedPrevWord[i],
self._replaceOrRegister(nextState, self.encodedPrevWord[i+1:]))
while i < len(encodedWord):
q.setTransition(encodedWord[i], state.State())
q = q.getNext(encodedWord[i])
i += 1
assert q.encodedData is None
# print q, encodedData
q.encodedData = data
def _replaceOrRegister(self, q, encodedWord):
if encodedWord:
nextState = q.getNext(encodedWord[0])
q.setTransition(
encodedWord[0],
self._replaceOrRegister(nextState, encodedWord[1:]))
if self.register.containsEquivalentState(q):
return self.register.getEquivalentState(q)
else:
self.register.addState(q)
return q
def calculateOffsets(self, sizeCounter):
currReverseOffset = 0
for state in self.initialState.dfs(set()):
currReverseOffset += sizeCounter(state)
state.reverseOffset = currReverseOffset
for state in self.initialState.dfs(set()):
state.offset = currReverseOffset - state.reverseOffset
|