Blame view

fsabuilder/morfeuszbuilder/fsa/serializer.py 18.4 KB
Michał Lenart authored
1
2
3
4
5
6
'''
Created on Oct 20, 2013

@author: mlenart
'''
Michał Lenart authored
7
import logging
Michał Lenart authored
8
from state import State
Michał Lenart authored
9
from morfeuszbuilder.utils import limits, exceptions
Michał Lenart authored
10
from morfeuszbuilder.utils.serializationUtils import *
Michał Lenart authored
11
Michał Lenart authored
12
13
14
15
16
class SerializationMethod(object):
    SIMPLE = 'SIMPLE'
    V1 = 'V1'
    V2 = 'V2'
Michał Lenart authored
17
class Serializer(object):
Michał Lenart authored
18
19

    MAGIC_NUMBER = 0x8fc2bc1b
Michał Lenart authored
20
Michał Lenart authored
21
    def __init__(self, fsa):
Michał Lenart authored
22
        self._fsa = fsa
Michał Lenart authored
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
        self.tagset = None
        self.qualifiersMap = None
        self.segmentationRulesData = None

    @staticmethod
    def getSerializer(serializationMethod, fsa, tagset, qualifiersMap, segmentationRulesData):
        res = {
            SerializationMethod.SIMPLE: SimpleSerializer,
            SerializationMethod.V1: VLengthSerializer1,
            SerializationMethod.V2: VLengthSerializer2,
        }[serializationMethod](fsa)
        res.tagset = tagset
        res.qualifiersMap = qualifiersMap
        res.segmentationRulesData = segmentationRulesData
        return res
Michał Lenart authored
38
39
40
41

    @property
    def fsa(self):
        return self._fsa
Michał Lenart authored
42
Michał Lenart authored
43
    # get the Morfeusz file format version that is being encoded
Michał Lenart authored
44
    def getVersion(self):
Michał Lenart authored
45
        return 18
Michał Lenart authored
46
Michał Lenart authored
47
    def serialize2CppFile(self, fname, isGenerator, headerFilename="data/default_fsa.hpp"):
Michał Lenart authored
48
        res = []
Michał Lenart authored
49
#         self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state))
Michał Lenart authored
50
        res.append('\n')
Michał Lenart authored
51
        res.append('#include "%s"' % headerFilename)
Michał Lenart authored
52
        res.append('\n')
Michał Lenart authored
53
        res.append('namespace morfeusz {\n')
Michał Lenart authored
54
        res.append('\n')
Michał Lenart authored
55
        if isGenerator:
Michał Lenart authored
56
57
58
            res.append('extern const unsigned char DEFAULT_SYNTH_FSA[] = {')
        else:
            res.append('extern const unsigned char DEFAULT_FSA[] = {')
Michał Lenart authored
59
        res.append('\n')
Michał Lenart authored
60
        for byte in self.fsa2bytearray(isGenerator):
Michał Lenart authored
61
62
            res.append(hex(byte));
            res.append(',');
Michał Lenart authored
63
64
65
        res.append('\n')
        res.append('};')
        res.append('\n')
Michał Lenart authored
66
67
        res.append('}')
        res.append('\n')
Michał Lenart authored
68
        with open(fname, 'w') as f:
Michał Lenart authored
69
            f.write(''.join(res))
Michał Lenart authored
70
Michał Lenart authored
71
    def serialize2BinaryFile(self, fname, isGenerator):
Michał Lenart authored
72
        with open(fname, 'wb') as f:
Michał Lenart authored
73
            f.write(self.fsa2bytearray(isGenerator))
Michał Lenart authored
74
75
76
77

    def getStateSize(self, state):
        raise NotImplementedError('Not implemented')
Michał Lenart authored
78
79
80
81
    def fsa2bytearray(self, isGenerator):
        tagsetData = self.serializeTagset(self.tagset)
        qualifiersData = self.serializeQualifiersMap()
        segmentationRulesData = self.segmentationRulesData
Michał Lenart authored
82
        res = bytearray()
Michał Lenart authored
83
84
85
        res.extend(self.serializePrologue())
        fsaData = bytearray()
        fsaData.extend(self.serializeFSAPrologue())
Michał Lenart authored
86
        self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state))
Michał Lenart authored
87
        for state in sorted(self.fsa.dfs(), key=lambda s: s.offset):
Michał Lenart authored
88
            fsaData.extend(self.state2bytearray(state))
Michał Lenart authored
89
        res.extend(htonl(len(fsaData)))
Michał Lenart authored
90
        res.extend(fsaData)
Michał Lenart authored
91
        res.extend(self.serializeEpilogue(tagsetData, qualifiersData, segmentationRulesData))
Michał Lenart authored
92
93
        return res
Michał Lenart authored
94
    def _serializeTags(self, tagsMap):
Michał Lenart authored
95
96
        res = bytearray()
        numOfTags = len(tagsMap)
Michał Lenart authored
97
        res.extend(htons(numOfTags))
Michał Lenart authored
98
        for tag, tagnum in sorted(tagsMap.iteritems(), key=lambda (tag, tagnum): tagnum):
Michał Lenart authored
99
            res.extend(htons(tagnum))
Michał Lenart authored
100
101
102
103
            res.extend(self.fsa.encodeWord(tag))
            res.append(0)
        return res
Michał Lenart authored
104
    # serialize tagset data
Michał Lenart authored
105
106
107
    def serializeTagset(self, tagset):
        res = bytearray()
        if tagset:
Michał Lenart authored
108
109
            res.extend(self._serializeTags(tagset._tag2tagnum))
            res.extend(self._serializeTags(tagset._name2namenum))
Michał Lenart authored
110
111
        return res
Michał Lenart authored
112
113
114
115
116
117
118
119
120
121
    def serializeQualifiersMap(self):
        res = bytearray()
        res.extend(htons(len(self.qualifiersMap)))
        for qualifiers, n in sorted(self.qualifiersMap.iteritems(), key=lambda (qs, n): n):
            res.append(len(qualifiers))
            for q in qualifiers:
                res.extend(q.encode('utf8'))
                res.append(0)
        return res
Michał Lenart authored
122
    def serializePrologue(self):
Michał Lenart authored
123
124
125
        res = bytearray()

        # serialize magic number in big-endian order
Michał Lenart authored
126
127
128
129
        res.append((Serializer.MAGIC_NUMBER & 0xFF000000) >> 24)
        res.append((Serializer.MAGIC_NUMBER & 0x00FF0000) >> 16)
        res.append((Serializer.MAGIC_NUMBER & 0x0000FF00) >> 8)
        res.append(Serializer.MAGIC_NUMBER & 0x000000FF)
Michał Lenart authored
130
131

        # serialize version number
Michał Lenart authored
132
        res.append(self.getVersion())
Michał Lenart authored
133
Michał Lenart authored
134
135
        # serialize implementation code 
        res.append(self.getImplementationCode())
Michał Lenart authored
136
Michał Lenart authored
137
138
        return res
Michał Lenart authored
139
    def serializeEpilogue(self, tagsetData, qualifiersData, segmentationRulesData):
Michał Lenart authored
140
        res = bytearray()
Michał Lenart authored
141
        tagsetDataSize = len(tagsetData) if tagsetData else 0
Michał Lenart authored
142
143
144
        qualifiersDataSize = len(qualifiersData) if qualifiersData else 0
#         segmentationDataSize = len(segmentationRulesData) if segmentationRulesData else 0
        res.extend(htonl(tagsetDataSize + qualifiersDataSize))
Michał Lenart authored
145
146

        # add additional data itself
Michał Lenart authored
147
        if tagsetData:
Michał Lenart authored
148
149
            assert type(tagsetData) == bytearray
            res.extend(tagsetData)
Michał Lenart authored
150
Michał Lenart authored
151
152
153
154
155
        if qualifiersData:
            assert type(qualifiersData) == bytearray
            res.extend(qualifiersData)

        if segmentationRulesData:
Michał Lenart authored
156
157
            assert type(segmentationRulesData) == bytearray
            res.extend(segmentationRulesData)
Michał Lenart authored
158
159
        return res
Michał Lenart authored
160
161
    def state2bytearray(self, state):
        res = bytearray()
Michał Lenart authored
162
163
        res.extend(self.stateData2bytearray(state))
        res.extend(self.transitionsData2bytearray(state))
Michał Lenart authored
164
        return res
Michał Lenart authored
165
Michał Lenart authored
166
167
168
    def getSortedTransitions(self, state):
        defaultKey = lambda (label, nextState): (-state.label2Freq.get(label, 0), -self.fsa.label2Freq.get(label, 0))
        return list(sorted(
Michał Lenart authored
169
                           state.transitionsMap.iteritems(),
Michał Lenart authored
170
                           key=defaultKey))
Michał Lenart authored
171
Michał Lenart authored
172
173
174
175
176
177
178
179
180
181
    def stateData2bytearray(self, state):
        raise NotImplementedError('Not implemented')

    def transitionsData2bytearray(self, state):
        raise NotImplementedError('Not implemented')

    def getImplementationCode(self):
        raise NotImplementedError('Not implemented')

class SimpleSerializer(Serializer):
Michał Lenart authored
182
Michał Lenart authored
183
    def __init__(self, fsa, serializeTransitionsData=False):
Michał Lenart authored
184
185
        super(SimpleSerializer, self).__init__(fsa)
        self.ACCEPTING_FLAG = 128
Michał Lenart authored
186
        self.serializeTransitionsData = serializeTransitionsData
Michał Lenart authored
187
188

    def getImplementationCode(self):
Michał Lenart authored
189
190
191
192
        return 0 if not self.serializeTransitionsData else 128

    def serializeFSAPrologue(self):
        return bytearray()
Michał Lenart authored
193
194

    def getStateSize(self, state):
Michał Lenart authored
195
196
197
198
        if self.serializeTransitionsData:
            return 1 + 5 * len(state.transitionsMap.keys()) + self.getDataSize(state)
        else:
            return 1 + 4 * len(state.transitionsMap.keys()) + self.getDataSize(state)
Michał Lenart authored
199
Michał Lenart authored
200
201
202
    def getDataSize(self, state):
        return len(state.encodedData) if state.isAccepting() else 0
Michał Lenart authored
203
    def stateData2bytearray(self, state):
Michał Lenart authored
204
        res = bytearray()
Michał Lenart authored
205
206
        firstByte = 0
        if state.isAccepting():
Michał Lenart authored
207
208
            firstByte |= self.ACCEPTING_FLAG
        firstByte |= state.transitionsNum
Michał Lenart authored
209
210
        assert firstByte < 256 and firstByte > 0
        res.append(firstByte)
Michał Lenart authored
211
212
213
        if state.isAccepting():
            res.extend(state.encodedData)
        return res
Michał Lenart authored
214
Michał Lenart authored
215
    def transitionsData2bytearray(self, state):
Michał Lenart authored
216
        res = bytearray()
Michał Lenart authored
217
218
219
220
221
222
223
#         logging.debug('next')
        for label, nextState in self.getSortedTransitions(state):
            res.append(label)
            offset = nextState.offset
            res.append((offset & 0xFF0000) >> 16)
            res.append((offset & 0x00FF00) >> 8)
            res.append(offset & 0x0000FF)
Michał Lenart authored
224
225
226
227
228
            if self.serializeTransitionsData:
                transitionData = state.transitionsDataMap[label]
                assert transitionData >= 0
                assert transitionData < 256
                res.append(transitionData)
Michał Lenart authored
229
230
        return res
Michał Lenart authored
231
class VLengthSerializer1(Serializer):
Michał Lenart authored
232
Michał Lenart authored
233
234
    def __init__(self, fsa, useArrays=False):
        super(VLengthSerializer1, self).__init__(fsa)
Michał Lenart authored
235
236
237
238
        self.statesTable = list(reversed(list(fsa.dfs())))
        self.state2Index = dict([(state, idx) for (idx, state) in enumerate(self.statesTable)])
        self._chooseArrayStates()
        self.useArrays = useArrays
Michał Lenart authored
239
        self.label2ShortLabel = None
Michał Lenart authored
240
Michał Lenart authored
241
        self.ACCEPTING_FLAG = 0b10000000
Michał Lenart authored
242
#         self.ARRAY_FLAG = 0b01000000
Michał Lenart authored
243
244
245

    def getImplementationCode(self):
        return 1
Michał Lenart authored
246
Michał Lenart authored
247
248
    def serializeFSAPrologue(self):
        res = bytearray()
Michał Lenart authored
249
250

        # labels sorted by popularity
Michał Lenart authored
251
        sortedLabels = [label for (label, freq) in sorted(self.fsa.label2Freq.iteritems(), key=lambda (label, freq): (-freq, label))]
Michał Lenart authored
252
253

        # popular labels table
Michał Lenart authored
254
        self.label2ShortLabel = dict([(label, sortedLabels.index(label) + 1) for label in sortedLabels[:63]])
Michał Lenart authored
255
256

        logging.debug(dict([(chr(label), shortLabel) for label, shortLabel in self.label2ShortLabel.items()]))
Michał Lenart authored
257
258

        # write remaining short labels (zeros)
Michał Lenart authored
259
260
261
        for label in range(256):
            res.append(self.label2ShortLabel.get(label, 0))
Michał Lenart authored
262
        # write a magic char before initial state
Michał Lenart authored
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        res.append(ord('^'))

        return res

    def getStateSize(self, state):
        return len(self.state2bytearray(state))

    def getDataSize(self, state):
        assert type(state.encodedData) == bytearray or not state.isAccepting()
        return len(state.encodedData) if state.isAccepting() else 0

    def stateShouldBeAnArray(self, state):
        return self.useArrays and state.serializeAsArray
Michał Lenart authored
277
    def stateData2bytearray(self, state):
Michał Lenart authored
278
#         assert state.transitionsNum < 64
Michał Lenart authored
279
280
281
        res = bytearray()
        firstByte = 0
        if state.isAccepting():
Michał Lenart authored
282
            firstByte |= self.ACCEPTING_FLAG
Michał Lenart authored
283
284
285
#         if self.stateShouldBeAnArray(state):
#             firstByte |= self.ARRAY_FLAG
        if state.transitionsNum < 127:
Michał Lenart authored
286
287
288
            firstByte |= state.transitionsNum
            res.append(firstByte)
        else:
Michał Lenart authored
289
            firstByte |= 127
Michał Lenart authored
290
291
292
            res.append(firstByte)
            res.append(state.transitionsNum)
Michał Lenart authored
293
294
295
296
297
298
        if state.isAccepting():
            res.extend(state.encodedData)
        return res

    def _transitions2ListBytes(self, state, originalState=None):
        res = bytearray()
Michał Lenart authored
299
        transitions = self.getSortedTransitions(state)
Michał Lenart authored
300
        thisIdx = self.state2Index[originalState if originalState is not None else state]
Michał Lenart authored
301
        logging.debug('state ' + str(state.offset))
Michał Lenart authored
302
303
304
305
306
307
308
309
310
        if len(transitions) == 0:
            assert state.isAccepting()
            return bytearray()
        else:
            stateAfterThis = self.statesTable[thisIdx + 1]
            for reversedN, (label, nextState) in enumerate(reversed(transitions)):
                transitionBytes = bytearray()
                assert nextState.reverseOffset is not None
                assert stateAfterThis.reverseOffset is not None
Michał Lenart authored
311
312
                logging.debug('next state reverse: ' + str(nextState.reverseOffset))
                logging.debug('after state reverse: ' + str(stateAfterThis.reverseOffset))
Michał Lenart authored
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332

#                 firstByte = label

                n = len(transitions) - reversedN
                hasShortLabel = label in self.label2ShortLabel
                firstByte = self.label2ShortLabel[label] if hasShortLabel else 0
                firstByte <<= 2

                last = len(transitions) == n
                isNext = last and stateAfterThis == nextState
                offsetSize = 0
#                 offset = 0
                offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + len(res)
                assert offset > 0 or isNext
                if offset > 0:
                    offsetSize += 1
                if offset >= 256:
                    offsetSize += 1
                if offset >= 256 * 256:
                    offsetSize += 1
Michał Lenart authored
333
334
335
336
                exceptions.validate(
                                    offset < 256 * 256 * 256,
                                    u'Cannot build the automaton - it would exceed its max size which is %d' % (256 * 256 * 256))
#                 assert offset < 256 * 256 * 256  # TODO - przerobic na jakis porzadny wyjatek
Michał Lenart authored
337
338
                assert offsetSize <= 3
                firstByte |= offsetSize
Michał Lenart authored
339
Michał Lenart authored
340
                transitionBytes.append(firstByte)
Michał Lenart authored
341
                if not hasShortLabel:
Michał Lenart authored
342
                    transitionBytes.append(label)
Michał Lenart authored
343
                # serialize offset in big-endian order
Michał Lenart authored
344
                if offsetSize == 3:
Michał Lenart authored
345
                    transitionBytes.append((offset & 0xFF0000) >> 16)
Michał Lenart authored
346
                if offsetSize >= 2:
Michał Lenart authored
347
                    transitionBytes.append((offset & 0x00FF00) >> 8)
Michał Lenart authored
348
                if offsetSize >= 1:
Michał Lenart authored
349
350
351
                    transitionBytes.append(offset & 0x0000FF)
                for b in reversed(transitionBytes):
                    res.insert(0, b)
Michał Lenart authored
352
                logging.debug('inserted transition at beginning ' + chr(label) + ' -> ' + str(offset))
Michał Lenart authored
353
Michał Lenart authored
354
        return res
Michał Lenart authored
355
356
357
358
359
360
361
362
363
364
365

    def _trimState(self, state):
        newState = State()
        newState.encodedData = state.encodedData
        newState.reverseOffset = state.reverseOffset
        newState.offset = state.offset
        newState.transitionsMap = dict([(label, nextState) for (label, nextState) in state.transitionsMap.iteritems()])
#         newState.transitionsMap = dict([(label, nextState) for (label, nextState) in state.transitionsMap.iteritems() if not label in self.label2ShortLabel or not self.label2ShortLabel[label] in range(1,64)])
        newState.serializeAsArray = False
        return newState
Michał Lenart authored
366
    def _transitions2ArrayBytes(self, state):
Michał Lenart authored
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
        res = bytearray()
        array = [0] * 64
        for label, nextState in state.transitionsMap.iteritems():
            if label in self.label2ShortLabel:
                shortLabel = self.label2ShortLabel[label]
                array[shortLabel] = nextState.offset
        logging.debug(array)
        for offset in map(lambda x: x if x else 0, array):
            res.append(0)
            res.append((offset & 0xFF0000) >> 16)
            res.append((offset & 0x00FF00) >> 8)
            res.append(offset & 0x0000FF)
        res.extend(self._stateData2bytearray(self._trimState(state)))
        res.extend(self._transitions2ListBytes(self._trimState(state), originalState=state))
        return res
Michał Lenart authored
382
383

    def transitionsData2bytearray(self, state):
Michał Lenart authored
384
        if self.stateShouldBeAnArray(state):
Michał Lenart authored
385
            return self._transitions2ArrayBytes(state)
Michał Lenart authored
386
387
388
389
390
391
392
393
394
395
396
        else:
            return self._transitions2ListBytes(state)

    def _chooseArrayStates(self):
        for state1 in self.fsa.initialState.transitionsMap.values():
            for state2 in state1.transitionsMap.values():
#                 for state3 in state2.transitionsMap.values():
#                     state3.serializeAsArray = True
                state2.serializeAsArray = True
            state1.serializeAsArray = True
        self.fsa.initialState.serializeAsArray = True
Michał Lenart authored
397
398
399
400
401
402
403
404
405

class VLengthSerializer2(Serializer):

    def __init__(self, fsa, useArrays=False):
        super(VLengthSerializer2, self).__init__(fsa)
        self.statesTable = list(reversed(list(fsa.dfs())))
        self.state2Index = dict([(state, idx) for (idx, state) in enumerate(self.statesTable)])

        self.HAS_REMAINING_FLAG = 128
Michał Lenart authored
406
        self.ACCEPTING_FLAG = 64
Michał Lenart authored
407
408
        self.LAST_FLAG = 32
Michał Lenart authored
409
410
411
    def serializeFSAPrologue(self):
        return bytearray()
Michał Lenart authored
412
413
414
415
416
417
    def getImplementationCode(self):
        return 2

    def getStateSize(self, state):
        return len(self.state2bytearray(state))
Michał Lenart authored
418
419
420
#     def getDataSize(self, state):
#         assert type(state.encodedData) == bytearray or not state.isAccepting()
#         return len(state.encodedData) if state.isAccepting() else 0
Michał Lenart authored
421
Michał Lenart authored
422
    def stateData2bytearray(self, state):
Michał Lenart authored
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
        res = bytearray()
        if state.isAccepting():
            res.extend(state.encodedData)
        return res

    def _first5Bits(self, number):
        n = number
        while n >= 32:
            n >>= 7
        return n

    def _getOffsetBytes(self, offset):
        res = bytearray()
        remaining = offset
        lastByte = True
        while remaining >= 32:
            nextByte = remaining & 0b01111111
            remaining >>= 7
            if not lastByte:
                nextByte |= self.HAS_REMAINING_FLAG
            else:
                lastByte = False
            res.insert(0, nextByte)
            logging.debug(remaining)
        return res

    def _transitions2ListBytes(self, state):
        res = bytearray()
        thisIdx = self.state2Index[state]
        transitions = self.getSortedTransitions(state)
        if len(transitions) == 0:
            assert state.isAccepting()
            res.append(self.LAST_FLAG)
            return res
        else:
            stateAfterThis = self.statesTable[thisIdx + 1]
            for reversedN, (label, nextState) in enumerate(reversed(transitions)):
                transitionBytes = bytearray()
                assert nextState.reverseOffset is not None
                assert stateAfterThis.reverseOffset is not None
Michał Lenart authored
463
464
                logging.debug('next state reverse: ' + str(nextState.reverseOffset))
                logging.debug('after state reverse: ' + str(stateAfterThis.reverseOffset))
Michał Lenart authored
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483

                n = len(transitions) - reversedN
                firstByte = label

                last = len(transitions) == n
                secondByte = 0
                if last: secondByte |= self.LAST_FLAG
                if nextState.isAccepting(): secondByte |= self.ACCEPTING_FLAG
                offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + len(res)
                if offset >= 32:
                    secondByte |= self.HAS_REMAINING_FLAG
                    secondByte |= self._first5Bits(offset)
                else:
                    secondByte |= offset
                transitionBytes.append(firstByte)
                transitionBytes.append(secondByte)
                transitionBytes.extend(self._getOffsetBytes(offset))
                for b in reversed(transitionBytes):
                    res.insert(0, b)
Michał Lenart authored
484
                logging.debug('inserted transition at beginning ' + chr(label) + ' -> ' + str(offset))
Michał Lenart authored
485
486
487
488
489

        return res
#     
    def transitionsData2bytearray(self, state):
        return self._transitions2ListBytes(state)