Commit 03c175746e2e52f5f3ae9df7c917a8023f592c87

Authored by Michał Lenart
1 parent 6361f38f

praca nad budową automatu ze zmiennymi długościami offsetów

git-svn-id: svn://svn.nlp.ipipan.waw.pl/morfeusz/morfeusz@7 ff4e3ee1-f430-4e82-ade0-24591c43f1fd
fsabuilder/fsa/buildfsa.py
... ... @@ -57,6 +57,9 @@ def parseOptions():
57 57 action='store_true',
58 58 default=False,
59 59 help='visualize result')
  60 + parser.add_option('--train-file',
  61 + dest='trainFile',
  62 + help='A text file used for training. Should contain words from some large corpus - one word in each line')
60 63  
61 64 opts, args = parser.parse_args()
62 65  
... ... @@ -89,26 +92,30 @@ def parseOptions():
89 92  
90 93 def readEncodedInput(inputFile):
91 94 with codecs.open(inputFile, 'r', 'utf8') as f:
92   - for line in f.readlines():
  95 + for line in f:
93 96 word, interps = line.strip().split()
94 97 yield word, interps.split(u'|')
95 98  
96 99 def readPolimorfInput(inputFile, encoder):
97 100 with codecs.open(inputFile, 'r', 'utf8') as f:
98   - for entry in convertinput.convertPolimorf(f.readlines(), lambda (word, interp): encoder.word2SortKey(word)):
  101 + for entry in convertinput.convertPolimorf(f, lambda (word, interp): encoder.word2SortKey(word)):
99 102 yield entry
100 103  
101 104 def readPlainInput(inputFile, encoder):
102 105 with codecs.open(inputFile, 'r', 'utf8') as f:
103   - for line in sorted(f.readlines(), key=encoder.word2SortKey):
  106 + for line in sorted(f, key=encoder.word2SortKey):
104 107 word = line.strip()
105 108 yield word, ''
106 109  
  110 +def readTrainData(trainFile):
  111 + with codecs.open(trainFile, 'r', 'utf8') as f:
  112 + for line in f:
  113 + yield line.strip()
  114 +
107 115 if __name__ == '__main__':
108 116 opts = parseOptions()
109 117 encoder = encode.Encoder()
110 118 fsa = FSA(encoder)
111   - serializer = SimpleSerializer()
112 119  
113 120 inputData = {
114 121 InputFormat.ENCODED: readEncodedInput(opts.inputFile),
... ... @@ -117,13 +124,19 @@ if __name__ == '__main__':
117 124 }[opts.inputFormat]
118 125  
119 126 logging.info('feeding FSA with data ...')
120   - fsa.feed(inputData)
  127 + fsa.feed(inputData, appendZero=True)
  128 + if opts.trainFile:
  129 + logging.info('training with '+opts.trainFile+' ...')
  130 + fsa.train(readTrainData(opts.trainFile))
  131 + logging.info('done training')
  132 + serializer = SimpleSerializer(fsa)
121 133 logging.info('states num: '+str(fsa.getStatesNum()))
122   -
  134 + logging.info('accepting states num: '+str(len([s for s in fsa.initialState.dfs(set()) if s.isAccepting()])))
  135 + logging.info('sink states num: '+str(len([s for s in fsa.initialState.dfs(set()) if len(s.transitionsMap.items()) == 0])))
123 136 {
124 137 OutputFormat.CPP: serializer.serialize2CppFile,
125 138 OutputFormat.BINARY: serializer.serialize2BinaryFile
126   - }[opts.outputFormat](fsa, opts.outputFile)
  139 + }[opts.outputFormat](opts.outputFile)
127 140  
128 141 if opts.visualize:
129 142 Visualizer().visualize(fsa)
... ...
fsabuilder/fsa/encode.py
... ... @@ -10,14 +10,18 @@ class Encoder(object):
10 10 '''
11 11  
12 12  
13   - def __init__(self, encoding='utf8'):
  13 + def __init__(self, encoding='utf8', appendZero):
14 14 '''
15 15 Constructor
16 16 '''
17 17 self.encoding = encoding
  18 + self.appendZero = appendZero
18 19  
19 20 def encodeWord(self, word):
20   - return bytearray(word, self.encoding)
  21 + res = bytearray(word, self.encoding)
  22 + if self.appendZero:
  23 + res.append(0)
  24 + return res
21 25  
22 26 def encodeData(self, data):
23 27 return bytearray(u'|'.join(data).encode(self.encoding)) + bytearray([0])
... ...
fsabuilder/fsa/fsa.py
... ... @@ -21,6 +21,7 @@ class FSA(object):
21 21 self.encodedPrevWord = None
22 22 self.initialState = state.State()
23 23 self.register = register.Register()
  24 + self.label2Freq = {0: float('inf')}
24 25  
25 26 def tryToRecognize(self, word, addFreq=False):
26 27 return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq))
... ... @@ -40,12 +41,21 @@ class FSA(object):
40 41 if n % 10000 == 0:
41 42 logging.info(word)
42 43 allWords.append(word)
  44 + for label in encodedWord:
  45 + self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
43 46  
44 47 self.initialState = self._replaceOrRegister(self.initialState, self.encodeWord(word))
45 48 self.encodedPrevWord = None
46 49  
47   - for w in allWords:
48   - self.tryToRecognize(w, True)
  50 +# for w in allWords:
  51 +# self.tryToRecognize(w, True)
  52 +
  53 + def train(self, trainData):
  54 + self.label2Freq = {0: float('inf')}
  55 + for word in trainData:
  56 + self.tryToRecognize(word, addFreq=True)
  57 + for label in self.encodeWord(word):
  58 + self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
49 59  
50 60 def getStatesNum(self):
51 61 return self.register.getStatesNum()
... ...
fsabuilder/fsa/serializer.py
... ... @@ -4,16 +4,22 @@ Created on Oct 20, 2013
4 4 @author: mlenart
5 5 '''
6 6  
  7 +import logging
  8 +
7 9 class Serializer(object):
8 10  
9   - def __init__(self):
10   - pass
  11 + def __init__(self, fsa):
  12 + self.fsa = fsa
  13 + self.label2Count = {}
  14 + for state in self.fsa.initialState.dfs():
  15 + for label, _ in state.transitionsMap.iteritems():
  16 + self.label2Count[label] = self.label2Count.get(label, 0) + 1
11 17  
12   - def serialize2CppFile(self, fsa, fname):
  18 + def serialize2CppFile(self, fname):
13 19 res = []
14   - fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state))
  20 + self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state))
15 21 res.append('const unsigned char DEFAULT_FSA[] = {')
16   - for idx, state in enumerate(sorted(fsa.initialState.dfs(set()), key=lambda state: state.offset)):
  22 + for idx, state in enumerate(sorted(self.fsa.initialState.dfs(set()), key=lambda state: state.offset)):
17 23 res.append('// state '+str(idx))
18 24 partRes = []
19 25 for byte in self.state2bytearray(state):
... ... @@ -24,10 +30,10 @@ class Serializer(object):
24 30 with open(fname, 'w') as f:
25 31 f.write('\n'.join(res))
26 32  
27   - def serialize2BinaryFile(self, fsa, fname):
  33 + def serialize2BinaryFile(self, fname):
28 34 res = bytearray()
29   - fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state))
30   - for state in sorted(fsa.initialState.dfs(set()), key=lambda state: state.offset):
  35 + self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state))
  36 + for state in sorted(self.fsa.initialState.dfs(set()), key=lambda state: state.offset):
31 37 # res.append('// state '+str(idx))
32 38 res.extend(self.state2bytearray(state))
33 39 with open(fname, 'wb') as f:
... ... @@ -77,11 +83,103 @@ class SimpleSerializer(Serializer):
77 83  
78 84 def _transitionsData2bytearray(self, state):
79 85 res = bytearray()
80   - # must sort that strange way because it must be sorted according to char, not unsigned char
81   - for byte, nextState in sorted(state.transitionsMap.iteritems(), key=lambda (_, state): -state.freq):
  86 +# logging.debug('next')
  87 + for byte, nextState in sorted(state.transitionsMap.iteritems(), key=lambda (label, next): (-next.freq, -self.label2Count[label])):
  88 +# logging.debug(byte)
82 89 res.append(byte)
83 90 offset = nextState.offset
84 91 res.append(offset & 0x0000FF)
85 92 res.append((offset & 0x00FF00) >> 8)
86 93 res.append((offset & 0xFF0000) >> 16)
87 94 return res
  95 +
  96 +class VLengthSerializer(Serializer):
  97 +
  98 + LAST_FLAG = 128
  99 +
  100 + def __init__(self, fsa):
  101 + super(VLengthSerializer, self).__init__(fsa)
  102 + self.statesTable = list(reversed(fsa.dfs(set())))
  103 + self.state2Index = dict([(state, idx) for (idx, state) in enumerate(self.statesTable)])
  104 +
  105 +
  106 + def getStateSize(self, state):
  107 + return len(self.state2bytearray(state))
  108 +
  109 + def getDataSize(self, state):
  110 + assert type(state.encodedData) == bytearray or not state.isAccepting()
  111 + return len(state.encodedData) if state.isAccepting() else 0
  112 +
  113 + def state2bytearray(self, state):
  114 + res = bytearray()
  115 + res.extend(self._stateData2bytearray(state))
  116 + res.extend(self._transitionsData2bytearray(state))
  117 + return res
  118 +
  119 + def _stateData2bytearray(self, state):
  120 + res = bytearray()
  121 + if state.isAccepting():
  122 + res.extend(state.encodedData)
  123 + return res
  124 +
  125 + def _transitionsData2bytearray(self, state):
  126 + res = bytearray()
  127 + sortedLabels = list(sorted(self.fsa.label2Freq.iteritems(), key=lambda label, freq: (-freq, label)))
  128 + label2Index = dict([(label, sortedLabels.index(label)) for label in sortedLabels][:30])
  129 + transitions = sorted(state.transitionsMap.iteritems(), key=lambda (label, _): (-next.freq, -self.label2Count[label]))
  130 + thisIdx = self.state2Index[state]
  131 +
  132 + if len(transitions) == 0:
  133 + assert state.isAccepting()
  134 + return bytearray()
  135 + else:
  136 + offsetToStateAfterThis = 0
  137 + stateAfterThis = self.statesTable[thisIdx + 1]
  138 + for reversedN, (label, nextState) in enumerate(reversed(transitions)):
  139 + assert nextState.reverseOffset is not None
  140 + n = len(transitions) - reversedN
  141 +
  142 + popularLabel = label2Index[label] < 31
  143 + firstByte = (label2Index[label] + 1) if popularLabel else 0
  144 +
  145 +# if state.isAccepting():
  146 +# firstByte |= VLengthSerializer.ACCEPTING_FLAG
  147 +
  148 + last = len(transitions) == n
  149 + next = last and stateAfterThis == nextState
  150 +
  151 + if last:
  152 + firstByte |= VLengthSerializer.LAST_FLAG
  153 +
  154 + offsetSize = 0
  155 + offset = 0
  156 + if not next:
  157 + offsetSize = 1
  158 + offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + offsetSize
  159 + if offset >= 256:
  160 + offset += 1
  161 + offsetSize += 1
  162 + if offset >= 256 * 256:
  163 + offset += 1
  164 + offsetSize += 1
  165 + assert offset < 256 * 256 * 256 #TODO - przerobić na jakiś porządny wyjątek
  166 +
  167 + firstByte |= (32 * offsetSize)
  168 +
  169 + res.append(firstByte)
  170 + if not popularLabel:
  171 + res.append(label)
  172 + if offsetSize >= 1:
  173 + res.append(offset & 0x0000FF)
  174 + if offsetSize >= 2:
  175 + res.append((offset & 0x00FF00) >> 8)
  176 + if offsetSize == 3:
  177 + res.append((offset & 0xFF0000) >> 16)
  178 + return res
  179 +# currReverseOffset = nextState.reverseOffset
  180 +# res.append(byte)
  181 +# offset = nextState.offset
  182 +# res.append(offset & 0x0000FF)
  183 +# res.append((offset & 0x00FF00) >> 8)
  184 +# res.append((offset & 0xFF0000) >> 16)
  185 +# return res
88 186 \ No newline at end of file
... ...
fsabuilder/fsa/state.py
... ... @@ -43,9 +43,9 @@ class State(object):
43 43 else:
44 44 return self.encodedData
45 45  
46   - def dfs(self, alreadyVisited):
  46 + def dfs(self, alreadyVisited=set(), sortKey=lambda (_, state): -state.freq):
47 47 if not self in alreadyVisited:
48   - for _, state in sorted(self.transitionsMap.iteritems(), key=lambda (_, state): -state.freq):
  48 + for _, state in sorted(self.transitionsMap.iteritems(), key=sortKey):
49 49 for state1 in state.dfs(alreadyVisited):
50 50 yield state1
51 51 alreadyVisited.add(self)
... ...