Blame view

fsabuilder/morfeuszbuilder/fsa/serializer.py 16.9 KB
Michał Lenart authored
1
2
3
4
5
6
'''
Created on Oct 20, 2013

@author: mlenart
'''
Michał Lenart authored
7
import logging
Michał Lenart authored
8
from state import State
Michał Lenart authored
9
from morfeuszbuilder.utils.serializationUtils import *
Michał Lenart authored
10
Michał Lenart authored
11
class Serializer(object):
Michał Lenart authored
12
13

    MAGIC_NUMBER = 0x8fc2bc1b
Michał Lenart authored
14
Michał Lenart authored
15
    def __init__(self, fsa, headerFilename="data/default_fsa.hpp"):
Michał Lenart authored
16
        self._fsa = fsa
Michał Lenart authored
17
        self.headerFilename = headerFilename
Michał Lenart authored
18
19
20
21

    @property
    def fsa(self):
        return self._fsa
Michał Lenart authored
22
Michał Lenart authored
23
    # get the Morfeusz file format version that is being encoded
Michał Lenart authored
24
    def getVersion(self):
Michał Lenart authored
25
        return 10
Michał Lenart authored
26
Michał Lenart authored
27
    def serialize2CppFile(self, fname, generator, segmentationRulesData):
Michał Lenart authored
28
        res = []
Michał Lenart authored
29
#         self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state))
Michał Lenart authored
30
31
32
33
        res.append('\n')
        res.append('#include "%s"' % self.headerFilename)
        res.append('\n')
        res.append('\n')
Michał Lenart authored
34
35
36
37
        if generator:
            res.append('extern const unsigned char DEFAULT_SYNTH_FSA[] = {')
        else:
            res.append('extern const unsigned char DEFAULT_FSA[] = {')
Michał Lenart authored
38
        res.append('\n')
Michał Lenart authored
39
        for byte in self.fsa2bytearray(
Michał Lenart authored
40
41
                                       tagsetData=self.serializeTagset(self.fsa.tagset),
                                       segmentationRulesData=segmentationRulesData):
Michał Lenart authored
42
43
            res.append(hex(byte));
            res.append(',');
Michał Lenart authored
44
45
46
        res.append('\n')
        res.append('};')
        res.append('\n')
Michał Lenart authored
47
        with open(fname, 'w') as f:
Michał Lenart authored
48
            f.write(''.join(res))
Michał Lenart authored
49
Michał Lenart authored
50
    def serialize2BinaryFile(self, fname, segmentationRulesData):
Michał Lenart authored
51
        with open(fname, 'wb') as f:
Michał Lenart authored
52
            f.write(self.fsa2bytearray(
Michał Lenart authored
53
54
                                       tagsetData=self.serializeTagset(self.fsa.tagset),
                                       segmentationRulesData=segmentationRulesData))
Michał Lenart authored
55
56
57
58

    def getStateSize(self, state):
        raise NotImplementedError('Not implemented')
Michał Lenart authored
59
    def fsa2bytearray(self, tagsetData, segmentationRulesData):
Michał Lenart authored
60
        res = bytearray()
Michał Lenart authored
61
62
63
        res.extend(self.serializePrologue())
        fsaData = bytearray()
        fsaData.extend(self.serializeFSAPrologue())
Michał Lenart authored
64
        self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state))
Michał Lenart authored
65
        for state in sorted(self.fsa.dfs(), key=lambda s: s.offset):
Michał Lenart authored
66
            fsaData.extend(self.state2bytearray(state))
Michał Lenart authored
67
        res.extend(htonl(len(fsaData)))
Michał Lenart authored
68
        res.extend(fsaData)
Michał Lenart authored
69
        res.extend(self.serializeEpilogue(tagsetData, segmentationRulesData))
Michał Lenart authored
70
71
        return res
Michał Lenart authored
72
    def _serializeTags(self, tagsMap):
Michał Lenart authored
73
74
        res = bytearray()
        numOfTags = len(tagsMap)
Michał Lenart authored
75
        res.extend(htons(numOfTags))
Michał Lenart authored
76
        for tag, tagnum in sorted(tagsMap.iteritems(), key=lambda (tag, tagnum): tagnum):
Michał Lenart authored
77
            res.extend(htons(tagnum))
Michał Lenart authored
78
79
80
81
            res.extend(self.fsa.encodeWord(tag))
            res.append(0)
        return res
Michał Lenart authored
82
    # serialize tagset data
Michał Lenart authored
83
84
85
    def serializeTagset(self, tagset):
        res = bytearray()
        if tagset:
Michał Lenart authored
86
87
            res.extend(self._serializeTags(tagset._tag2tagnum))
            res.extend(self._serializeTags(tagset._name2namenum))
Michał Lenart authored
88
89
        return res
Michał Lenart authored
90
    def serializePrologue(self):
Michał Lenart authored
91
92
93
        res = bytearray()

        # serialize magic number in big-endian order
Michał Lenart authored
94
95
96
97
        res.append((Serializer.MAGIC_NUMBER & 0xFF000000) >> 24)
        res.append((Serializer.MAGIC_NUMBER & 0x00FF0000) >> 16)
        res.append((Serializer.MAGIC_NUMBER & 0x0000FF00) >> 8)
        res.append(Serializer.MAGIC_NUMBER & 0x000000FF)
Michał Lenart authored
98
99

        # serialize version number
Michał Lenart authored
100
        res.append(self.getVersion())
Michał Lenart authored
101
Michał Lenart authored
102
103
        # serialize implementation code 
        res.append(self.getImplementationCode())
Michał Lenart authored
104
Michał Lenart authored
105
106
        return res
Michał Lenart authored
107
    def serializeEpilogue(self, tagsetData, segmentationRulesData):
Michał Lenart authored
108
        res = bytearray()
Michał Lenart authored
109
110
111
        tagsetDataSize = len(tagsetData) if tagsetData else 0
        segmentationDataSize = len(segmentationRulesData) if segmentationRulesData else 0
        res.extend(htonl(tagsetDataSize))
Michał Lenart authored
112
113

        # add additional data itself
Michał Lenart authored
114
115
116
        if tagsetDataSize:
            assert type(tagsetData) == bytearray
            res.extend(tagsetData)
Michał Lenart authored
117
Michał Lenart authored
118
119
120
        if segmentationDataSize:
            assert type(segmentationRulesData) == bytearray
            res.extend(segmentationRulesData)
Michał Lenart authored
121
122
        return res
Michał Lenart authored
123
124
    def state2bytearray(self, state):
        res = bytearray()
Michał Lenart authored
125
126
        res.extend(self.stateData2bytearray(state))
        res.extend(self.transitionsData2bytearray(state))
Michał Lenart authored
127
        return res
Michał Lenart authored
128
Michał Lenart authored
129
130
131
132
133
    def getSortedTransitions(self, state):
        defaultKey = lambda (label, nextState): (-state.label2Freq.get(label, 0), -self.fsa.label2Freq.get(label, 0))
        return list(sorted(
                           state.transitionsMap.iteritems(), 
                           key=defaultKey))
Michał Lenart authored
134
Michał Lenart authored
135
136
137
138
139
140
141
142
143
144
    def stateData2bytearray(self, state):
        raise NotImplementedError('Not implemented')

    def transitionsData2bytearray(self, state):
        raise NotImplementedError('Not implemented')

    def getImplementationCode(self):
        raise NotImplementedError('Not implemented')

class SimpleSerializer(Serializer):
Michał Lenart authored
145
Michał Lenart authored
146
    def __init__(self, fsa, serializeTransitionsData=False):
Michał Lenart authored
147
148
        super(SimpleSerializer, self).__init__(fsa)
        self.ACCEPTING_FLAG = 128
Michał Lenart authored
149
        self.serializeTransitionsData = serializeTransitionsData
Michał Lenart authored
150
151

    def getImplementationCode(self):
Michał Lenart authored
152
153
154
155
        return 0 if not self.serializeTransitionsData else 128

    def serializeFSAPrologue(self):
        return bytearray()
Michał Lenart authored
156
157

    def getStateSize(self, state):
Michał Lenart authored
158
159
160
161
        if self.serializeTransitionsData:
            return 1 + 5 * len(state.transitionsMap.keys()) + self.getDataSize(state)
        else:
            return 1 + 4 * len(state.transitionsMap.keys()) + self.getDataSize(state)
Michał Lenart authored
162
Michał Lenart authored
163
164
165
    def getDataSize(self, state):
        return len(state.encodedData) if state.isAccepting() else 0
Michał Lenart authored
166
    def stateData2bytearray(self, state):
Michał Lenart authored
167
        res = bytearray()
Michał Lenart authored
168
169
        firstByte = 0
        if state.isAccepting():
Michał Lenart authored
170
171
            firstByte |= self.ACCEPTING_FLAG
        firstByte |= state.transitionsNum
Michał Lenart authored
172
173
        assert firstByte < 256 and firstByte > 0
        res.append(firstByte)
Michał Lenart authored
174
175
176
        if state.isAccepting():
            res.extend(state.encodedData)
        return res
Michał Lenart authored
177
Michał Lenart authored
178
    def transitionsData2bytearray(self, state):
Michał Lenart authored
179
        res = bytearray()
Michał Lenart authored
180
181
182
183
184
185
186
#         logging.debug('next')
        for label, nextState in self.getSortedTransitions(state):
            res.append(label)
            offset = nextState.offset
            res.append((offset & 0xFF0000) >> 16)
            res.append((offset & 0x00FF00) >> 8)
            res.append(offset & 0x0000FF)
Michał Lenart authored
187
188
189
190
191
            if self.serializeTransitionsData:
                transitionData = state.transitionsDataMap[label]
                assert transitionData >= 0
                assert transitionData < 256
                res.append(transitionData)
Michał Lenart authored
192
193
        return res
Michał Lenart authored
194
class VLengthSerializer1(Serializer):
Michał Lenart authored
195
Michał Lenart authored
196
197
    def __init__(self, fsa, useArrays=False):
        super(VLengthSerializer1, self).__init__(fsa)
Michał Lenart authored
198
199
200
201
        self.statesTable = list(reversed(list(fsa.dfs())))
        self.state2Index = dict([(state, idx) for (idx, state) in enumerate(self.statesTable)])
        self._chooseArrayStates()
        self.useArrays = useArrays
Michał Lenart authored
202
        self.label2ShortLabel = None
Michał Lenart authored
203
204
205
206
207
208

        self.ACCEPTING_FLAG =   0b10000000
        self.ARRAY_FLAG =       0b01000000

    def getImplementationCode(self):
        return 1
Michał Lenart authored
209
Michał Lenart authored
210
211
    def serializeFSAPrologue(self):
        res = bytearray()
Michał Lenart authored
212
213

        # labels sorted by popularity
Michał Lenart authored
214
        sortedLabels = [label for (label, freq) in sorted(self.fsa.label2Freq.iteritems(), key=lambda (label, freq): (-freq, label))]
Michał Lenart authored
215
216

        # popular labels table
Michał Lenart authored
217
        self.label2ShortLabel = dict([(label, sortedLabels.index(label) + 1) for label in sortedLabels[:63]])
Michał Lenart authored
218
219

        logging.debug(dict([(chr(label), shortLabel) for label, shortLabel in self.label2ShortLabel.items()]))
Michał Lenart authored
220
221

        # write remaining short labels (zeros)
Michał Lenart authored
222
223
224
        for label in range(256):
            res.append(self.label2ShortLabel.get(label, 0))
Michał Lenart authored
225
        # write a magic char before initial state
Michał Lenart authored
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        res.append(ord('^'))

        return res

    def getStateSize(self, state):
        return len(self.state2bytearray(state))

    def getDataSize(self, state):
        assert type(state.encodedData) == bytearray or not state.isAccepting()
        return len(state.encodedData) if state.isAccepting() else 0

    def stateShouldBeAnArray(self, state):
        return self.useArrays and state.serializeAsArray
Michał Lenart authored
240
241
    def stateData2bytearray(self, state):
        assert state.transitionsNum < 64
Michał Lenart authored
242
243
244
        res = bytearray()
        firstByte = 0
        if state.isAccepting():
Michał Lenart authored
245
            firstByte |= self.ACCEPTING_FLAG
Michał Lenart authored
246
        if self.stateShouldBeAnArray(state):
Michał Lenart authored
247
248
            firstByte |= self.ARRAY_FLAG
        firstByte |= state.transitionsNum
Michał Lenart authored
249
250
251
252
253
254
255
256
        assert firstByte < 256 and firstByte > 0
        res.append(firstByte)
        if state.isAccepting():
            res.extend(state.encodedData)
        return res

    def _transitions2ListBytes(self, state, originalState=None):
        res = bytearray()
Michał Lenart authored
257
        transitions = self.getSortedTransitions(state)
Michał Lenart authored
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        thisIdx = self.state2Index[originalState if originalState is not None else state]
        logging.debug('state '+str(state.offset))
        if len(transitions) == 0:
            assert state.isAccepting()
            return bytearray()
        else:
            stateAfterThis = self.statesTable[thisIdx + 1]
            for reversedN, (label, nextState) in enumerate(reversed(transitions)):
                transitionBytes = bytearray()
                assert nextState.reverseOffset is not None
                assert stateAfterThis.reverseOffset is not None
                logging.debug('next state reverse: '+str(nextState.reverseOffset))
                logging.debug('after state reverse: '+str(stateAfterThis.reverseOffset))

#                 firstByte = label

                n = len(transitions) - reversedN
                hasShortLabel = label in self.label2ShortLabel
                firstByte = self.label2ShortLabel[label] if hasShortLabel else 0
                firstByte <<= 2

                last = len(transitions) == n
                isNext = last and stateAfterThis == nextState
                offsetSize = 0
#                 offset = 0
                offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + len(res)
                assert offset > 0 or isNext
                if offset > 0:
                    offsetSize += 1
                if offset >= 256:
                    offsetSize += 1
                if offset >= 256 * 256:
                    offsetSize += 1
                assert offset < 256 * 256 * 256 #TODO - przerobic na jakis porzadny wyjatek
                assert offsetSize <= 3
                assert offsetSize > 0 or isNext
                firstByte |= offsetSize
Michał Lenart authored
295
Michał Lenart authored
296
                transitionBytes.append(firstByte)
Michał Lenart authored
297
                if not hasShortLabel:
Michał Lenart authored
298
                    transitionBytes.append(label)
Michał Lenart authored
299
                # serialize offset in big-endian order
Michał Lenart authored
300
                if offsetSize == 3:
Michał Lenart authored
301
                    transitionBytes.append((offset & 0xFF0000) >> 16)
Michał Lenart authored
302
                if offsetSize >= 2:
Michał Lenart authored
303
                    transitionBytes.append((offset & 0x00FF00) >> 8)
Michał Lenart authored
304
                if offsetSize >= 1:
Michał Lenart authored
305
306
307
308
                    transitionBytes.append(offset & 0x0000FF)
                for b in reversed(transitionBytes):
                    res.insert(0, b)
                logging.debug('inserted transition at beginning '+chr(label)+' -> '+str(offset))
Michał Lenart authored
309
Michał Lenart authored
310
        return res
Michał Lenart authored
311
312
313
314
315
316
317
318
319
320
321

    def _trimState(self, state):
        newState = State()
        newState.encodedData = state.encodedData
        newState.reverseOffset = state.reverseOffset
        newState.offset = state.offset
        newState.transitionsMap = dict([(label, nextState) for (label, nextState) in state.transitionsMap.iteritems()])
#         newState.transitionsMap = dict([(label, nextState) for (label, nextState) in state.transitionsMap.iteritems() if not label in self.label2ShortLabel or not self.label2ShortLabel[label] in range(1,64)])
        newState.serializeAsArray = False
        return newState
Michał Lenart authored
322
    def _transitions2ArrayBytes(self, state):
Michał Lenart authored
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
        res = bytearray()
        array = [0] * 64
        for label, nextState in state.transitionsMap.iteritems():
            if label in self.label2ShortLabel:
                shortLabel = self.label2ShortLabel[label]
                array[shortLabel] = nextState.offset
        logging.debug(array)
        for offset in map(lambda x: x if x else 0, array):
            res.append(0)
            res.append((offset & 0xFF0000) >> 16)
            res.append((offset & 0x00FF00) >> 8)
            res.append(offset & 0x0000FF)
        res.extend(self._stateData2bytearray(self._trimState(state)))
        res.extend(self._transitions2ListBytes(self._trimState(state), originalState=state))
        return res
Michał Lenart authored
338
339

    def transitionsData2bytearray(self, state):
Michał Lenart authored
340
        if self.stateShouldBeAnArray(state):
Michał Lenart authored
341
            return self._transitions2ArrayBytes(state)
Michał Lenart authored
342
343
344
345
346
347
348
349
350
351
352
        else:
            return self._transitions2ListBytes(state)

    def _chooseArrayStates(self):
        for state1 in self.fsa.initialState.transitionsMap.values():
            for state2 in state1.transitionsMap.values():
#                 for state3 in state2.transitionsMap.values():
#                     state3.serializeAsArray = True
                state2.serializeAsArray = True
            state1.serializeAsArray = True
        self.fsa.initialState.serializeAsArray = True
Michał Lenart authored
353
354
355
356
357
358
359
360
361
362
363
364

class VLengthSerializer2(Serializer):

    def __init__(self, fsa, useArrays=False):
        super(VLengthSerializer2, self).__init__(fsa)
        self.statesTable = list(reversed(list(fsa.dfs())))
        self.state2Index = dict([(state, idx) for (idx, state) in enumerate(self.statesTable)])

        self.HAS_REMAINING_FLAG = 128
        self.ACCEPTING_FLAG =    64
        self.LAST_FLAG = 32
Michał Lenart authored
365
366
367
    def serializeFSAPrologue(self):
        return bytearray()
Michał Lenart authored
368
369
370
371
372
373
    def getImplementationCode(self):
        return 2

    def getStateSize(self, state):
        return len(self.state2bytearray(state))
Michał Lenart authored
374
375
376
#     def getDataSize(self, state):
#         assert type(state.encodedData) == bytearray or not state.isAccepting()
#         return len(state.encodedData) if state.isAccepting() else 0
Michał Lenart authored
377
Michał Lenart authored
378
    def stateData2bytearray(self, state):
Michał Lenart authored
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        res = bytearray()
        if state.isAccepting():
            res.extend(state.encodedData)
        return res

    def _first5Bits(self, number):
        n = number
        while n >= 32:
            n >>= 7
        return n

    def _getOffsetBytes(self, offset):
        res = bytearray()
        remaining = offset
        lastByte = True
        while remaining >= 32:
            nextByte = remaining & 0b01111111
            remaining >>= 7
            if not lastByte:
                nextByte |= self.HAS_REMAINING_FLAG
            else:
                lastByte = False
            res.insert(0, nextByte)
            logging.debug(remaining)
        return res

    def _transitions2ListBytes(self, state):
        res = bytearray()
        thisIdx = self.state2Index[state]
        transitions = self.getSortedTransitions(state)
        if len(transitions) == 0:
            assert state.isAccepting()
            res.append(self.LAST_FLAG)
            return res
        else:
            stateAfterThis = self.statesTable[thisIdx + 1]
            for reversedN, (label, nextState) in enumerate(reversed(transitions)):
                transitionBytes = bytearray()
                assert nextState.reverseOffset is not None
                assert stateAfterThis.reverseOffset is not None
                logging.debug('next state reverse: '+str(nextState.reverseOffset))
                logging.debug('after state reverse: '+str(stateAfterThis.reverseOffset))

                n = len(transitions) - reversedN
                firstByte = label

                last = len(transitions) == n
                secondByte = 0
                if last: secondByte |= self.LAST_FLAG
                if nextState.isAccepting(): secondByte |= self.ACCEPTING_FLAG
                offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + len(res)
                if offset >= 32:
                    secondByte |= self.HAS_REMAINING_FLAG
                    secondByte |= self._first5Bits(offset)
                else:
                    secondByte |= offset
                transitionBytes.append(firstByte)
                transitionBytes.append(secondByte)
                transitionBytes.extend(self._getOffsetBytes(offset))
                for b in reversed(transitionBytes):
                    res.insert(0, b)
                logging.debug('inserted transition at beginning '+chr(label)+' -> '+str(offset))

        return res
#     
    def transitionsData2bytearray(self, state):
        return self._transitions2ListBytes(state)