Commit 03c175746e2e52f5f3ae9df7c917a8023f592c87
1 parent
6361f38f
praca nad budową automatu ze zmiennymi długościami offsetów
git-svn-id: svn://svn.nlp.ipipan.waw.pl/morfeusz/morfeusz@7 ff4e3ee1-f430-4e82-ade0-24591c43f1fd
Showing
5 changed files
with
148 additions
and
23 deletions
fsabuilder/fsa/buildfsa.py
... | ... | @@ -57,6 +57,9 @@ def parseOptions(): |
57 | 57 | action='store_true', |
58 | 58 | default=False, |
59 | 59 | help='visualize result') |
60 | + parser.add_option('--train-file', | |
61 | + dest='trainFile', | |
62 | + help='A text file used for training. Should contain words from some large corpus - one word in each line') | |
60 | 63 | |
61 | 64 | opts, args = parser.parse_args() |
62 | 65 | |
... | ... | @@ -89,26 +92,30 @@ def parseOptions(): |
89 | 92 | |
90 | 93 | def readEncodedInput(inputFile): |
91 | 94 | with codecs.open(inputFile, 'r', 'utf8') as f: |
92 | - for line in f.readlines(): | |
95 | + for line in f: | |
93 | 96 | word, interps = line.strip().split() |
94 | 97 | yield word, interps.split(u'|') |
95 | 98 | |
96 | 99 | def readPolimorfInput(inputFile, encoder): |
97 | 100 | with codecs.open(inputFile, 'r', 'utf8') as f: |
98 | - for entry in convertinput.convertPolimorf(f.readlines(), lambda (word, interp): encoder.word2SortKey(word)): | |
101 | + for entry in convertinput.convertPolimorf(f, lambda (word, interp): encoder.word2SortKey(word)): | |
99 | 102 | yield entry |
100 | 103 | |
101 | 104 | def readPlainInput(inputFile, encoder): |
102 | 105 | with codecs.open(inputFile, 'r', 'utf8') as f: |
103 | - for line in sorted(f.readlines(), key=encoder.word2SortKey): | |
106 | + for line in sorted(f, key=encoder.word2SortKey): | |
104 | 107 | word = line.strip() |
105 | 108 | yield word, '' |
106 | 109 | |
110 | +def readTrainData(trainFile): | |
111 | + with codecs.open(trainFile, 'r', 'utf8') as f: | |
112 | + for line in f: | |
113 | + yield line.strip() | |
114 | + | |
107 | 115 | if __name__ == '__main__': |
108 | 116 | opts = parseOptions() |
109 | 117 | encoder = encode.Encoder() |
110 | 118 | fsa = FSA(encoder) |
111 | - serializer = SimpleSerializer() | |
112 | 119 | |
113 | 120 | inputData = { |
114 | 121 | InputFormat.ENCODED: readEncodedInput(opts.inputFile), |
... | ... | @@ -117,13 +124,19 @@ if __name__ == '__main__': |
117 | 124 | }[opts.inputFormat] |
118 | 125 | |
119 | 126 | logging.info('feeding FSA with data ...') |
120 | - fsa.feed(inputData) | |
127 | + fsa.feed(inputData, appendZero=True) | |
128 | + if opts.trainFile: | |
129 | + logging.info('training with '+opts.trainFile+' ...') | |
130 | + fsa.train(readTrainData(opts.trainFile)) | |
131 | + logging.info('done training') | |
132 | + serializer = SimpleSerializer(fsa) | |
121 | 133 | logging.info('states num: '+str(fsa.getStatesNum())) |
122 | - | |
134 | + logging.info('accepting states num: '+str(len([s for s in fsa.initialState.dfs(set()) if s.isAccepting()]))) | |
135 | + logging.info('sink states num: '+str(len([s for s in fsa.initialState.dfs(set()) if len(s.transitionsMap.items()) == 0]))) | |
123 | 136 | { |
124 | 137 | OutputFormat.CPP: serializer.serialize2CppFile, |
125 | 138 | OutputFormat.BINARY: serializer.serialize2BinaryFile |
126 | - }[opts.outputFormat](fsa, opts.outputFile) | |
139 | + }[opts.outputFormat](opts.outputFile) | |
127 | 140 | |
128 | 141 | if opts.visualize: |
129 | 142 | Visualizer().visualize(fsa) |
... | ... |
fsabuilder/fsa/encode.py
... | ... | @@ -10,14 +10,18 @@ class Encoder(object): |
10 | 10 | ''' |
11 | 11 | |
12 | 12 | |
13 | - def __init__(self, encoding='utf8'): | |
13 | + def __init__(self, encoding='utf8', appendZero): | |
14 | 14 | ''' |
15 | 15 | Constructor |
16 | 16 | ''' |
17 | 17 | self.encoding = encoding |
18 | + self.appendZero = appendZero | |
18 | 19 | |
19 | 20 | def encodeWord(self, word): |
20 | - return bytearray(word, self.encoding) | |
21 | + res = bytearray(word, self.encoding) | |
22 | + if self.appendZero: | |
23 | + res.append(0) | |
24 | + return res | |
21 | 25 | |
22 | 26 | def encodeData(self, data): |
23 | 27 | return bytearray(u'|'.join(data).encode(self.encoding)) + bytearray([0]) |
... | ... |
fsabuilder/fsa/fsa.py
... | ... | @@ -21,6 +21,7 @@ class FSA(object): |
21 | 21 | self.encodedPrevWord = None |
22 | 22 | self.initialState = state.State() |
23 | 23 | self.register = register.Register() |
24 | + self.label2Freq = {0: float('inf')} | |
24 | 25 | |
25 | 26 | def tryToRecognize(self, word, addFreq=False): |
26 | 27 | return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq)) |
... | ... | @@ -40,12 +41,21 @@ class FSA(object): |
40 | 41 | if n % 10000 == 0: |
41 | 42 | logging.info(word) |
42 | 43 | allWords.append(word) |
44 | + for label in encodedWord: | |
45 | + self.label2Freq[label] = self.label2Freq.get(label, 0) + 1 | |
43 | 46 | |
44 | 47 | self.initialState = self._replaceOrRegister(self.initialState, self.encodeWord(word)) |
45 | 48 | self.encodedPrevWord = None |
46 | 49 | |
47 | - for w in allWords: | |
48 | - self.tryToRecognize(w, True) | |
50 | +# for w in allWords: | |
51 | +# self.tryToRecognize(w, True) | |
52 | + | |
53 | + def train(self, trainData): | |
54 | + self.label2Freq = {0: float('inf')} | |
55 | + for word in trainData: | |
56 | + self.tryToRecognize(word, addFreq=True) | |
57 | + for label in self.encodeWord(word): | |
58 | + self.label2Freq[label] = self.label2Freq.get(label, 0) + 1 | |
49 | 59 | |
50 | 60 | def getStatesNum(self): |
51 | 61 | return self.register.getStatesNum() |
... | ... |
fsabuilder/fsa/serializer.py
... | ... | @@ -4,16 +4,22 @@ Created on Oct 20, 2013 |
4 | 4 | @author: mlenart |
5 | 5 | ''' |
6 | 6 | |
7 | +import logging | |
8 | + | |
7 | 9 | class Serializer(object): |
8 | 10 | |
9 | - def __init__(self): | |
10 | - pass | |
11 | + def __init__(self, fsa): | |
12 | + self.fsa = fsa | |
13 | + self.label2Count = {} | |
14 | + for state in self.fsa.initialState.dfs(): | |
15 | + for label, _ in state.transitionsMap.iteritems(): | |
16 | + self.label2Count[label] = self.label2Count.get(label, 0) + 1 | |
11 | 17 | |
12 | - def serialize2CppFile(self, fsa, fname): | |
18 | + def serialize2CppFile(self, fname): | |
13 | 19 | res = [] |
14 | - fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state)) | |
20 | + self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state)) | |
15 | 21 | res.append('const unsigned char DEFAULT_FSA[] = {') |
16 | - for idx, state in enumerate(sorted(fsa.initialState.dfs(set()), key=lambda state: state.offset)): | |
22 | + for idx, state in enumerate(sorted(self.fsa.initialState.dfs(set()), key=lambda state: state.offset)): | |
17 | 23 | res.append('// state '+str(idx)) |
18 | 24 | partRes = [] |
19 | 25 | for byte in self.state2bytearray(state): |
... | ... | @@ -24,10 +30,10 @@ class Serializer(object): |
24 | 30 | with open(fname, 'w') as f: |
25 | 31 | f.write('\n'.join(res)) |
26 | 32 | |
27 | - def serialize2BinaryFile(self, fsa, fname): | |
33 | + def serialize2BinaryFile(self, fname): | |
28 | 34 | res = bytearray() |
29 | - fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state)) | |
30 | - for state in sorted(fsa.initialState.dfs(set()), key=lambda state: state.offset): | |
35 | + self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state)) | |
36 | + for state in sorted(self.fsa.initialState.dfs(set()), key=lambda state: state.offset): | |
31 | 37 | # res.append('// state '+str(idx)) |
32 | 38 | res.extend(self.state2bytearray(state)) |
33 | 39 | with open(fname, 'wb') as f: |
... | ... | @@ -77,11 +83,103 @@ class SimpleSerializer(Serializer): |
77 | 83 | |
78 | 84 | def _transitionsData2bytearray(self, state): |
79 | 85 | res = bytearray() |
80 | - # must sort that strange way because it must be sorted according to char, not unsigned char | |
81 | - for byte, nextState in sorted(state.transitionsMap.iteritems(), key=lambda (_, state): -state.freq): | |
86 | +# logging.debug('next') | |
87 | + for byte, nextState in sorted(state.transitionsMap.iteritems(), key=lambda (label, next): (-next.freq, -self.label2Count[label])): | |
88 | +# logging.debug(byte) | |
82 | 89 | res.append(byte) |
83 | 90 | offset = nextState.offset |
84 | 91 | res.append(offset & 0x0000FF) |
85 | 92 | res.append((offset & 0x00FF00) >> 8) |
86 | 93 | res.append((offset & 0xFF0000) >> 16) |
87 | 94 | return res |
95 | + | |
96 | +class VLengthSerializer(Serializer): | |
97 | + | |
98 | + LAST_FLAG = 128 | |
99 | + | |
100 | + def __init__(self, fsa): | |
101 | + super(VLengthSerializer, self).__init__(fsa) | |
102 | + self.statesTable = list(reversed(fsa.dfs(set()))) | |
103 | + self.state2Index = dict([(state, idx) for (idx, state) in enumerate(self.statesTable)]) | |
104 | + | |
105 | + | |
106 | + def getStateSize(self, state): | |
107 | + return len(self.state2bytearray(state)) | |
108 | + | |
109 | + def getDataSize(self, state): | |
110 | + assert type(state.encodedData) == bytearray or not state.isAccepting() | |
111 | + return len(state.encodedData) if state.isAccepting() else 0 | |
112 | + | |
113 | + def state2bytearray(self, state): | |
114 | + res = bytearray() | |
115 | + res.extend(self._stateData2bytearray(state)) | |
116 | + res.extend(self._transitionsData2bytearray(state)) | |
117 | + return res | |
118 | + | |
119 | + def _stateData2bytearray(self, state): | |
120 | + res = bytearray() | |
121 | + if state.isAccepting(): | |
122 | + res.extend(state.encodedData) | |
123 | + return res | |
124 | + | |
125 | + def _transitionsData2bytearray(self, state): | |
126 | + res = bytearray() | |
127 | + sortedLabels = list(sorted(self.fsa.label2Freq.iteritems(), key=lambda label, freq: (-freq, label))) | |
128 | + label2Index = dict([(label, sortedLabels.index(label)) for label in sortedLabels][:30]) | |
129 | + transitions = sorted(state.transitionsMap.iteritems(), key=lambda (label, _): (-next.freq, -self.label2Count[label])) | |
130 | + thisIdx = self.state2Index[state] | |
131 | + | |
132 | + if len(transitions) == 0: | |
133 | + assert state.isAccepting() | |
134 | + return bytearray() | |
135 | + else: | |
136 | + offsetToStateAfterThis = 0 | |
137 | + stateAfterThis = self.statesTable[thisIdx + 1] | |
138 | + for reversedN, (label, nextState) in enumerate(reversed(transitions)): | |
139 | + assert nextState.reverseOffset is not None | |
140 | + n = len(transitions) - reversedN | |
141 | + | |
142 | + popularLabel = label2Index[label] < 31 | |
143 | + firstByte = (label2Index[label] + 1) if popularLabel else 0 | |
144 | + | |
145 | +# if state.isAccepting(): | |
146 | +# firstByte |= VLengthSerializer.ACCEPTING_FLAG | |
147 | + | |
148 | + last = len(transitions) == n | |
149 | + next = last and stateAfterThis == nextState | |
150 | + | |
151 | + if last: | |
152 | + firstByte |= VLengthSerializer.LAST_FLAG | |
153 | + | |
154 | + offsetSize = 0 | |
155 | + offset = 0 | |
156 | + if not next: | |
157 | + offsetSize = 1 | |
158 | + offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + offsetSize | |
159 | + if offset >= 256: | |
160 | + offset += 1 | |
161 | + offsetSize += 1 | |
162 | + if offset >= 256 * 256: | |
163 | + offset += 1 | |
164 | + offsetSize += 1 | |
165 | + assert offset < 256 * 256 * 256 #TODO - przerobić na jakiś porządny wyjątek | |
166 | + | |
167 | + firstByte |= (32 * offsetSize) | |
168 | + | |
169 | + res.append(firstByte) | |
170 | + if not popularLabel: | |
171 | + res.append(label) | |
172 | + if offsetSize >= 1: | |
173 | + res.append(offset & 0x0000FF) | |
174 | + if offsetSize >= 2: | |
175 | + res.append((offset & 0x00FF00) >> 8) | |
176 | + if offsetSize == 3: | |
177 | + res.append((offset & 0xFF0000) >> 16) | |
178 | + return res | |
179 | +# currReverseOffset = nextState.reverseOffset | |
180 | +# res.append(byte) | |
181 | +# offset = nextState.offset | |
182 | +# res.append(offset & 0x0000FF) | |
183 | +# res.append((offset & 0x00FF00) >> 8) | |
184 | +# res.append((offset & 0xFF0000) >> 16) | |
185 | +# return res | |
88 | 186 | \ No newline at end of file |
... | ... |
fsabuilder/fsa/state.py
... | ... | @@ -43,9 +43,9 @@ class State(object): |
43 | 43 | else: |
44 | 44 | return self.encodedData |
45 | 45 | |
46 | - def dfs(self, alreadyVisited): | |
46 | + def dfs(self, alreadyVisited=set(), sortKey=lambda (_, state): -state.freq): | |
47 | 47 | if not self in alreadyVisited: |
48 | - for _, state in sorted(self.transitionsMap.iteritems(), key=lambda (_, state): -state.freq): | |
48 | + for _, state in sorted(self.transitionsMap.iteritems(), key=sortKey): | |
49 | 49 | for state1 in state.dfs(alreadyVisited): |
50 | 50 | yield state1 |
51 | 51 | alreadyVisited.add(self) |
... | ... |