|
1
2
3
4
5
6
|
'''
Created on Oct 20, 2013
@author: mlenart
'''
|
|
7
|
import logging
|
|
8
|
from state import State
|
|
9
|
from morfeuszbuilder.utils.serializationUtils import *
|
|
10
|
|
|
11
|
class Serializer(object):
|
|
12
13
|
MAGIC_NUMBER = 0x8fc2bc1b
|
|
14
|
|
|
15
|
def __init__(self, fsa, headerFilename="data/default_fsa.hpp"):
|
|
16
|
self._fsa = fsa
|
|
17
|
self.headerFilename = headerFilename
|
|
18
19
20
21
|
@property
def fsa(self):
return self._fsa
|
|
22
|
|
|
23
|
# get the Morfeusz file format version that is being encoded
|
|
24
|
def getVersion(self):
|
|
25
|
return 10
|
|
26
|
|
|
27
|
def serialize2CppFile(self, fname, generator, segmentationRulesData):
|
|
28
|
res = []
|
|
29
|
# self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state))
|
|
30
31
32
33
|
res.append('\n')
res.append('#include "%s"' % self.headerFilename)
res.append('\n')
res.append('\n')
|
|
34
35
36
37
|
if generator:
res.append('extern const unsigned char DEFAULT_SYNTH_FSA[] = {')
else:
res.append('extern const unsigned char DEFAULT_FSA[] = {')
|
|
38
|
res.append('\n')
|
|
39
|
for byte in self.fsa2bytearray(
|
|
40
41
|
tagsetData=self.serializeTagset(self.fsa.tagset),
segmentationRulesData=segmentationRulesData):
|
|
42
43
|
res.append(hex(byte));
res.append(',');
|
|
44
45
46
|
res.append('\n')
res.append('};')
res.append('\n')
|
|
47
|
with open(fname, 'w') as f:
|
|
48
|
f.write(''.join(res))
|
|
49
|
|
|
50
|
def serialize2BinaryFile(self, fname, segmentationRulesData):
|
|
51
|
with open(fname, 'wb') as f:
|
|
52
|
f.write(self.fsa2bytearray(
|
|
53
54
|
tagsetData=self.serializeTagset(self.fsa.tagset),
segmentationRulesData=segmentationRulesData))
|
|
55
56
57
58
|
def getStateSize(self, state):
raise NotImplementedError('Not implemented')
|
|
59
|
def fsa2bytearray(self, tagsetData, segmentationRulesData):
|
|
60
|
res = bytearray()
|
|
61
62
63
|
res.extend(self.serializePrologue())
fsaData = bytearray()
fsaData.extend(self.serializeFSAPrologue())
|
|
64
|
self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state))
|
|
65
|
for state in sorted(self.fsa.dfs(), key=lambda s: s.offset):
|
|
66
|
fsaData.extend(self.state2bytearray(state))
|
|
67
|
res.extend(htonl(len(fsaData)))
|
|
68
|
res.extend(fsaData)
|
|
69
|
res.extend(self.serializeEpilogue(tagsetData, segmentationRulesData))
|
|
70
71
|
return res
|
|
72
|
def _serializeTags(self, tagsMap):
|
|
73
74
|
res = bytearray()
numOfTags = len(tagsMap)
|
|
75
|
res.extend(htons(numOfTags))
|
|
76
|
for tag, tagnum in sorted(tagsMap.iteritems(), key=lambda (tag, tagnum): tagnum):
|
|
77
|
res.extend(htons(tagnum))
|
|
78
79
80
81
|
res.extend(self.fsa.encodeWord(tag))
res.append(0)
return res
|
|
82
|
# serialize tagset data
|
|
83
84
85
|
def serializeTagset(self, tagset):
res = bytearray()
if tagset:
|
|
86
87
|
res.extend(self._serializeTags(tagset._tag2tagnum))
res.extend(self._serializeTags(tagset._name2namenum))
|
|
88
89
|
return res
|
|
90
|
def serializePrologue(self):
|
|
91
92
93
|
res = bytearray()
# serialize magic number in big-endian order
|
|
94
95
96
97
|
res.append((Serializer.MAGIC_NUMBER & 0xFF000000) >> 24)
res.append((Serializer.MAGIC_NUMBER & 0x00FF0000) >> 16)
res.append((Serializer.MAGIC_NUMBER & 0x0000FF00) >> 8)
res.append(Serializer.MAGIC_NUMBER & 0x000000FF)
|
|
98
99
|
# serialize version number
|
|
100
|
res.append(self.getVersion())
|
|
101
|
|
|
102
103
|
# serialize implementation code
res.append(self.getImplementationCode())
|
|
104
|
|
|
105
106
|
return res
|
|
107
|
def serializeEpilogue(self, tagsetData, segmentationRulesData):
|
|
108
|
res = bytearray()
|
|
109
110
111
|
tagsetDataSize = len(tagsetData) if tagsetData else 0
segmentationDataSize = len(segmentationRulesData) if segmentationRulesData else 0
res.extend(htonl(tagsetDataSize))
|
|
112
113
|
# add additional data itself
|
|
114
115
116
|
if tagsetDataSize:
assert type(tagsetData) == bytearray
res.extend(tagsetData)
|
|
117
|
|
|
118
119
120
|
if segmentationDataSize:
assert type(segmentationRulesData) == bytearray
res.extend(segmentationRulesData)
|
|
121
122
|
return res
|
|
123
124
|
def state2bytearray(self, state):
res = bytearray()
|
|
125
126
|
res.extend(self.stateData2bytearray(state))
res.extend(self.transitionsData2bytearray(state))
|
|
127
|
return res
|
|
128
|
|
|
129
130
131
132
133
|
def getSortedTransitions(self, state):
defaultKey = lambda (label, nextState): (-state.label2Freq.get(label, 0), -self.fsa.label2Freq.get(label, 0))
return list(sorted(
state.transitionsMap.iteritems(),
key=defaultKey))
|
|
134
|
|
|
135
136
137
138
139
140
141
142
143
144
|
def stateData2bytearray(self, state):
raise NotImplementedError('Not implemented')
def transitionsData2bytearray(self, state):
raise NotImplementedError('Not implemented')
def getImplementationCode(self):
raise NotImplementedError('Not implemented')
class SimpleSerializer(Serializer):
|
|
145
|
|
|
146
|
def __init__(self, fsa, serializeTransitionsData=False):
|
|
147
148
|
super(SimpleSerializer, self).__init__(fsa)
self.ACCEPTING_FLAG = 128
|
|
149
|
self.serializeTransitionsData = serializeTransitionsData
|
|
150
151
|
def getImplementationCode(self):
|
|
152
153
154
155
|
return 0 if not self.serializeTransitionsData else 128
def serializeFSAPrologue(self):
return bytearray()
|
|
156
157
|
def getStateSize(self, state):
|
|
158
159
160
161
|
if self.serializeTransitionsData:
return 1 + 5 * len(state.transitionsMap.keys()) + self.getDataSize(state)
else:
return 1 + 4 * len(state.transitionsMap.keys()) + self.getDataSize(state)
|
|
162
|
|
|
163
164
165
|
def getDataSize(self, state):
return len(state.encodedData) if state.isAccepting() else 0
|
|
166
|
def stateData2bytearray(self, state):
|
|
167
|
res = bytearray()
|
|
168
169
|
firstByte = 0
if state.isAccepting():
|
|
170
171
|
firstByte |= self.ACCEPTING_FLAG
firstByte |= state.transitionsNum
|
|
172
173
|
assert firstByte < 256 and firstByte > 0
res.append(firstByte)
|
|
174
175
176
|
if state.isAccepting():
res.extend(state.encodedData)
return res
|
|
177
|
|
|
178
|
def transitionsData2bytearray(self, state):
|
|
179
|
res = bytearray()
|
|
180
181
182
183
184
185
186
|
# logging.debug('next')
for label, nextState in self.getSortedTransitions(state):
res.append(label)
offset = nextState.offset
res.append((offset & 0xFF0000) >> 16)
res.append((offset & 0x00FF00) >> 8)
res.append(offset & 0x0000FF)
|
|
187
188
189
190
191
|
if self.serializeTransitionsData:
transitionData = state.transitionsDataMap[label]
assert transitionData >= 0
assert transitionData < 256
res.append(transitionData)
|
|
192
193
|
return res
|
|
194
|
class VLengthSerializer1(Serializer):
|
|
195
|
|
|
196
197
|
def __init__(self, fsa, useArrays=False):
super(VLengthSerializer1, self).__init__(fsa)
|
|
198
199
200
201
|
self.statesTable = list(reversed(list(fsa.dfs())))
self.state2Index = dict([(state, idx) for (idx, state) in enumerate(self.statesTable)])
self._chooseArrayStates()
self.useArrays = useArrays
|
|
202
|
self.label2ShortLabel = None
|
|
203
204
205
206
207
208
|
self.ACCEPTING_FLAG = 0b10000000
self.ARRAY_FLAG = 0b01000000
def getImplementationCode(self):
return 1
|
|
209
|
|
|
210
211
|
def serializeFSAPrologue(self):
res = bytearray()
|
|
212
213
|
# labels sorted by popularity
|
|
214
|
sortedLabels = [label for (label, freq) in sorted(self.fsa.label2Freq.iteritems(), key=lambda (label, freq): (-freq, label))]
|
|
215
216
|
# popular labels table
|
|
217
|
self.label2ShortLabel = dict([(label, sortedLabels.index(label) + 1) for label in sortedLabels[:63]])
|
|
218
219
|
logging.debug(dict([(chr(label), shortLabel) for label, shortLabel in self.label2ShortLabel.items()]))
|
|
220
221
|
# write remaining short labels (zeros)
|
|
222
223
224
|
for label in range(256):
res.append(self.label2ShortLabel.get(label, 0))
|
|
225
|
# write a magic char before initial state
|
|
226
227
228
229
230
231
232
233
234
235
236
237
238
239
|
res.append(ord('^'))
return res
def getStateSize(self, state):
return len(self.state2bytearray(state))
def getDataSize(self, state):
assert type(state.encodedData) == bytearray or not state.isAccepting()
return len(state.encodedData) if state.isAccepting() else 0
def stateShouldBeAnArray(self, state):
return self.useArrays and state.serializeAsArray
|
|
240
241
|
def stateData2bytearray(self, state):
assert state.transitionsNum < 64
|
|
242
243
244
|
res = bytearray()
firstByte = 0
if state.isAccepting():
|
|
245
|
firstByte |= self.ACCEPTING_FLAG
|
|
246
|
if self.stateShouldBeAnArray(state):
|
|
247
248
|
firstByte |= self.ARRAY_FLAG
firstByte |= state.transitionsNum
|
|
249
250
251
252
253
254
255
256
|
assert firstByte < 256 and firstByte > 0
res.append(firstByte)
if state.isAccepting():
res.extend(state.encodedData)
return res
def _transitions2ListBytes(self, state, originalState=None):
res = bytearray()
|
|
257
|
transitions = self.getSortedTransitions(state)
|
|
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
|
thisIdx = self.state2Index[originalState if originalState is not None else state]
logging.debug('state '+str(state.offset))
if len(transitions) == 0:
assert state.isAccepting()
return bytearray()
else:
stateAfterThis = self.statesTable[thisIdx + 1]
for reversedN, (label, nextState) in enumerate(reversed(transitions)):
transitionBytes = bytearray()
assert nextState.reverseOffset is not None
assert stateAfterThis.reverseOffset is not None
logging.debug('next state reverse: '+str(nextState.reverseOffset))
logging.debug('after state reverse: '+str(stateAfterThis.reverseOffset))
# firstByte = label
n = len(transitions) - reversedN
hasShortLabel = label in self.label2ShortLabel
firstByte = self.label2ShortLabel[label] if hasShortLabel else 0
firstByte <<= 2
last = len(transitions) == n
isNext = last and stateAfterThis == nextState
offsetSize = 0
# offset = 0
offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + len(res)
assert offset > 0 or isNext
if offset > 0:
offsetSize += 1
if offset >= 256:
offsetSize += 1
if offset >= 256 * 256:
offsetSize += 1
assert offset < 256 * 256 * 256 #TODO - przerobic na jakis porzadny wyjatek
assert offsetSize <= 3
assert offsetSize > 0 or isNext
firstByte |= offsetSize
|
|
295
|
|
|
296
|
transitionBytes.append(firstByte)
|
|
297
|
if not hasShortLabel:
|
|
298
|
transitionBytes.append(label)
|
|
299
|
# serialize offset in big-endian order
|
|
300
|
if offsetSize == 3:
|
|
301
|
transitionBytes.append((offset & 0xFF0000) >> 16)
|
|
302
|
if offsetSize >= 2:
|
|
303
|
transitionBytes.append((offset & 0x00FF00) >> 8)
|
|
304
|
if offsetSize >= 1:
|
|
305
306
307
308
|
transitionBytes.append(offset & 0x0000FF)
for b in reversed(transitionBytes):
res.insert(0, b)
logging.debug('inserted transition at beginning '+chr(label)+' -> '+str(offset))
|
|
309
|
|
|
310
|
return res
|
|
311
312
313
314
315
316
317
318
319
320
321
|
def _trimState(self, state):
newState = State()
newState.encodedData = state.encodedData
newState.reverseOffset = state.reverseOffset
newState.offset = state.offset
newState.transitionsMap = dict([(label, nextState) for (label, nextState) in state.transitionsMap.iteritems()])
# newState.transitionsMap = dict([(label, nextState) for (label, nextState) in state.transitionsMap.iteritems() if not label in self.label2ShortLabel or not self.label2ShortLabel[label] in range(1,64)])
newState.serializeAsArray = False
return newState
|
|
322
|
def _transitions2ArrayBytes(self, state):
|
|
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
|
res = bytearray()
array = [0] * 64
for label, nextState in state.transitionsMap.iteritems():
if label in self.label2ShortLabel:
shortLabel = self.label2ShortLabel[label]
array[shortLabel] = nextState.offset
logging.debug(array)
for offset in map(lambda x: x if x else 0, array):
res.append(0)
res.append((offset & 0xFF0000) >> 16)
res.append((offset & 0x00FF00) >> 8)
res.append(offset & 0x0000FF)
res.extend(self._stateData2bytearray(self._trimState(state)))
res.extend(self._transitions2ListBytes(self._trimState(state), originalState=state))
return res
|
|
338
339
|
def transitionsData2bytearray(self, state):
|
|
340
|
if self.stateShouldBeAnArray(state):
|
|
341
|
return self._transitions2ArrayBytes(state)
|
|
342
343
344
345
346
347
348
349
350
351
352
|
else:
return self._transitions2ListBytes(state)
def _chooseArrayStates(self):
for state1 in self.fsa.initialState.transitionsMap.values():
for state2 in state1.transitionsMap.values():
# for state3 in state2.transitionsMap.values():
# state3.serializeAsArray = True
state2.serializeAsArray = True
state1.serializeAsArray = True
self.fsa.initialState.serializeAsArray = True
|
|
353
354
355
356
357
358
359
360
361
362
363
364
|
class VLengthSerializer2(Serializer):
def __init__(self, fsa, useArrays=False):
super(VLengthSerializer2, self).__init__(fsa)
self.statesTable = list(reversed(list(fsa.dfs())))
self.state2Index = dict([(state, idx) for (idx, state) in enumerate(self.statesTable)])
self.HAS_REMAINING_FLAG = 128
self.ACCEPTING_FLAG = 64
self.LAST_FLAG = 32
|
|
365
366
367
|
def serializeFSAPrologue(self):
return bytearray()
|
|
368
369
370
371
372
373
|
def getImplementationCode(self):
return 2
def getStateSize(self, state):
return len(self.state2bytearray(state))
|
|
374
375
376
|
# def getDataSize(self, state):
# assert type(state.encodedData) == bytearray or not state.isAccepting()
# return len(state.encodedData) if state.isAccepting() else 0
|
|
377
|
|
|
378
|
def stateData2bytearray(self, state):
|
|
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
|
res = bytearray()
if state.isAccepting():
res.extend(state.encodedData)
return res
def _first5Bits(self, number):
n = number
while n >= 32:
n >>= 7
return n
def _getOffsetBytes(self, offset):
res = bytearray()
remaining = offset
lastByte = True
while remaining >= 32:
nextByte = remaining & 0b01111111
remaining >>= 7
if not lastByte:
nextByte |= self.HAS_REMAINING_FLAG
else:
lastByte = False
res.insert(0, nextByte)
logging.debug(remaining)
return res
def _transitions2ListBytes(self, state):
res = bytearray()
thisIdx = self.state2Index[state]
transitions = self.getSortedTransitions(state)
if len(transitions) == 0:
assert state.isAccepting()
res.append(self.LAST_FLAG)
return res
else:
stateAfterThis = self.statesTable[thisIdx + 1]
for reversedN, (label, nextState) in enumerate(reversed(transitions)):
transitionBytes = bytearray()
assert nextState.reverseOffset is not None
assert stateAfterThis.reverseOffset is not None
logging.debug('next state reverse: '+str(nextState.reverseOffset))
logging.debug('after state reverse: '+str(stateAfterThis.reverseOffset))
n = len(transitions) - reversedN
firstByte = label
last = len(transitions) == n
secondByte = 0
if last: secondByte |= self.LAST_FLAG
if nextState.isAccepting(): secondByte |= self.ACCEPTING_FLAG
offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + len(res)
if offset >= 32:
secondByte |= self.HAS_REMAINING_FLAG
secondByte |= self._first5Bits(offset)
else:
secondByte |= offset
transitionBytes.append(firstByte)
transitionBytes.append(secondByte)
transitionBytes.extend(self._getOffsetBytes(offset))
for b in reversed(transitionBytes):
res.insert(0, b)
logging.debug('inserted transition at beginning '+chr(label)+' -> '+str(offset))
return res
#
def transitionsData2bytearray(self, state):
return self._transitions2ListBytes(state)
|