|
1
2
3
4
5
6
|
'''
Created on Oct 20, 2013
@author: mlenart
'''
|
|
7
|
import logging
|
|
8
|
from state import State
|
|
9
|
from morfeuszbuilder.utils import limits, exceptions
|
|
10
|
from morfeuszbuilder.utils.serializationUtils import *
|
|
11
|
|
|
12
13
14
15
16
|
class SerializationMethod(object):
SIMPLE = 'SIMPLE'
V1 = 'V1'
V2 = 'V2'
|
|
17
|
class Serializer(object):
|
|
18
19
|
MAGIC_NUMBER = 0x8fc2bc1b
|
|
20
|
|
|
21
|
def __init__(self, fsa):
|
|
22
|
self._fsa = fsa
|
|
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
self.tagset = None
self.qualifiersMap = None
self.segmentationRulesData = None
@staticmethod
def getSerializer(serializationMethod, fsa, tagset, qualifiersMap, segmentationRulesData):
res = {
SerializationMethod.SIMPLE: SimpleSerializer,
SerializationMethod.V1: VLengthSerializer1,
SerializationMethod.V2: VLengthSerializer2,
}[serializationMethod](fsa)
res.tagset = tagset
res.qualifiersMap = qualifiersMap
res.segmentationRulesData = segmentationRulesData
return res
|
|
38
39
40
41
|
@property
def fsa(self):
return self._fsa
|
|
42
|
|
|
43
|
# get the Morfeusz file format version that is being encoded
|
|
44
|
def getVersion(self):
|
|
45
|
return 18
|
|
46
|
|
|
47
|
def serialize2CppFile(self, fname, isGenerator, headerFilename="data/default_fsa.hpp"):
|
|
48
|
res = []
|
|
49
|
# self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state))
|
|
50
|
res.append('\n')
|
|
51
|
res.append('#include "%s"' % headerFilename)
|
|
52
|
res.append('\n')
|
|
53
|
res.append('namespace morfeusz {\n')
|
|
54
|
res.append('\n')
|
|
55
|
if isGenerator:
|
|
56
57
58
|
res.append('extern const unsigned char DEFAULT_SYNTH_FSA[] = {')
else:
res.append('extern const unsigned char DEFAULT_FSA[] = {')
|
|
59
|
res.append('\n')
|
|
60
|
for byte in self.fsa2bytearray(isGenerator):
|
|
61
62
|
res.append(hex(byte));
res.append(',');
|
|
63
64
65
|
res.append('\n')
res.append('};')
res.append('\n')
|
|
66
67
|
res.append('}')
res.append('\n')
|
|
68
|
with open(fname, 'w') as f:
|
|
69
|
f.write(''.join(res))
|
|
70
|
|
|
71
|
def serialize2BinaryFile(self, fname, isGenerator):
|
|
72
|
with open(fname, 'wb') as f:
|
|
73
|
f.write(self.fsa2bytearray(isGenerator))
|
|
74
75
76
77
|
def getStateSize(self, state):
raise NotImplementedError('Not implemented')
|
|
78
79
80
81
|
def fsa2bytearray(self, isGenerator):
tagsetData = self.serializeTagset(self.tagset)
qualifiersData = self.serializeQualifiersMap()
segmentationRulesData = self.segmentationRulesData
|
|
82
|
res = bytearray()
|
|
83
84
85
|
res.extend(self.serializePrologue())
fsaData = bytearray()
fsaData.extend(self.serializeFSAPrologue())
|
|
86
|
self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state))
|
|
87
|
for state in sorted(self.fsa.dfs(), key=lambda s: s.offset):
|
|
88
|
fsaData.extend(self.state2bytearray(state))
|
|
89
|
res.extend(htonl(len(fsaData)))
|
|
90
|
res.extend(fsaData)
|
|
91
|
res.extend(self.serializeEpilogue(tagsetData, qualifiersData, segmentationRulesData))
|
|
92
93
|
return res
|
|
94
|
def _serializeTags(self, tagsMap):
|
|
95
96
|
res = bytearray()
numOfTags = len(tagsMap)
|
|
97
|
res.extend(htons(numOfTags))
|
|
98
|
for tag, tagnum in sorted(tagsMap.iteritems(), key=lambda (tag, tagnum): tagnum):
|
|
99
|
res.extend(htons(tagnum))
|
|
100
101
102
103
|
res.extend(self.fsa.encodeWord(tag))
res.append(0)
return res
|
|
104
|
# serialize tagset data
|
|
105
106
107
|
def serializeTagset(self, tagset):
res = bytearray()
if tagset:
|
|
108
109
|
res.extend(self._serializeTags(tagset._tag2tagnum))
res.extend(self._serializeTags(tagset._name2namenum))
|
|
110
111
|
return res
|
|
112
113
114
115
116
117
118
119
120
121
|
def serializeQualifiersMap(self):
res = bytearray()
res.extend(htons(len(self.qualifiersMap)))
for qualifiers, n in sorted(self.qualifiersMap.iteritems(), key=lambda (qs, n): n):
res.append(len(qualifiers))
for q in qualifiers:
res.extend(q.encode('utf8'))
res.append(0)
return res
|
|
122
|
def serializePrologue(self):
|
|
123
124
125
|
res = bytearray()
# serialize magic number in big-endian order
|
|
126
127
128
129
|
res.append((Serializer.MAGIC_NUMBER & 0xFF000000) >> 24)
res.append((Serializer.MAGIC_NUMBER & 0x00FF0000) >> 16)
res.append((Serializer.MAGIC_NUMBER & 0x0000FF00) >> 8)
res.append(Serializer.MAGIC_NUMBER & 0x000000FF)
|
|
130
131
|
# serialize version number
|
|
132
|
res.append(self.getVersion())
|
|
133
|
|
|
134
135
|
# serialize implementation code
res.append(self.getImplementationCode())
|
|
136
|
|
|
137
138
|
return res
|
|
139
|
def serializeEpilogue(self, tagsetData, qualifiersData, segmentationRulesData):
|
|
140
|
res = bytearray()
|
|
141
|
tagsetDataSize = len(tagsetData) if tagsetData else 0
|
|
142
143
144
|
qualifiersDataSize = len(qualifiersData) if qualifiersData else 0
# segmentationDataSize = len(segmentationRulesData) if segmentationRulesData else 0
res.extend(htonl(tagsetDataSize + qualifiersDataSize))
|
|
145
146
|
# add additional data itself
|
|
147
|
if tagsetData:
|
|
148
149
|
assert type(tagsetData) == bytearray
res.extend(tagsetData)
|
|
150
|
|
|
151
152
153
154
155
|
if qualifiersData:
assert type(qualifiersData) == bytearray
res.extend(qualifiersData)
if segmentationRulesData:
|
|
156
157
|
assert type(segmentationRulesData) == bytearray
res.extend(segmentationRulesData)
|
|
158
159
|
return res
|
|
160
161
|
def state2bytearray(self, state):
res = bytearray()
|
|
162
163
|
res.extend(self.stateData2bytearray(state))
res.extend(self.transitionsData2bytearray(state))
|
|
164
|
return res
|
|
165
|
|
|
166
167
168
|
def getSortedTransitions(self, state):
defaultKey = lambda (label, nextState): (-state.label2Freq.get(label, 0), -self.fsa.label2Freq.get(label, 0))
return list(sorted(
|
|
169
|
state.transitionsMap.iteritems(),
|
|
170
|
key=defaultKey))
|
|
171
|
|
|
172
173
174
175
176
177
178
179
180
181
|
def stateData2bytearray(self, state):
raise NotImplementedError('Not implemented')
def transitionsData2bytearray(self, state):
raise NotImplementedError('Not implemented')
def getImplementationCode(self):
raise NotImplementedError('Not implemented')
class SimpleSerializer(Serializer):
|
|
182
|
|
|
183
|
def __init__(self, fsa, serializeTransitionsData=False):
|
|
184
185
|
super(SimpleSerializer, self).__init__(fsa)
self.ACCEPTING_FLAG = 128
|
|
186
|
self.serializeTransitionsData = serializeTransitionsData
|
|
187
188
|
def getImplementationCode(self):
|
|
189
190
191
192
|
return 0 if not self.serializeTransitionsData else 128
def serializeFSAPrologue(self):
return bytearray()
|
|
193
194
|
def getStateSize(self, state):
|
|
195
196
197
198
|
if self.serializeTransitionsData:
return 1 + 5 * len(state.transitionsMap.keys()) + self.getDataSize(state)
else:
return 1 + 4 * len(state.transitionsMap.keys()) + self.getDataSize(state)
|
|
199
|
|
|
200
201
202
|
def getDataSize(self, state):
return len(state.encodedData) if state.isAccepting() else 0
|
|
203
|
def stateData2bytearray(self, state):
|
|
204
|
res = bytearray()
|
|
205
206
|
firstByte = 0
if state.isAccepting():
|
|
207
208
|
firstByte |= self.ACCEPTING_FLAG
firstByte |= state.transitionsNum
|
|
209
210
|
assert firstByte < 256 and firstByte > 0
res.append(firstByte)
|
|
211
212
213
|
if state.isAccepting():
res.extend(state.encodedData)
return res
|
|
214
|
|
|
215
|
def transitionsData2bytearray(self, state):
|
|
216
|
res = bytearray()
|
|
217
218
219
220
221
222
223
|
# logging.debug('next')
for label, nextState in self.getSortedTransitions(state):
res.append(label)
offset = nextState.offset
res.append((offset & 0xFF0000) >> 16)
res.append((offset & 0x00FF00) >> 8)
res.append(offset & 0x0000FF)
|
|
224
225
226
227
228
|
if self.serializeTransitionsData:
transitionData = state.transitionsDataMap[label]
assert transitionData >= 0
assert transitionData < 256
res.append(transitionData)
|
|
229
230
|
return res
|
|
231
|
class VLengthSerializer1(Serializer):
|
|
232
|
|
|
233
234
|
def __init__(self, fsa, useArrays=False):
super(VLengthSerializer1, self).__init__(fsa)
|
|
235
236
237
238
|
self.statesTable = list(reversed(list(fsa.dfs())))
self.state2Index = dict([(state, idx) for (idx, state) in enumerate(self.statesTable)])
self._chooseArrayStates()
self.useArrays = useArrays
|
|
239
|
self.label2ShortLabel = None
|
|
240
|
|
|
241
|
self.ACCEPTING_FLAG = 0b10000000
|
|
242
|
# self.ARRAY_FLAG = 0b01000000
|
|
243
244
245
|
def getImplementationCode(self):
return 1
|
|
246
|
|
|
247
248
|
def serializeFSAPrologue(self):
res = bytearray()
|
|
249
250
|
# labels sorted by popularity
|
|
251
|
sortedLabels = [label for (label, freq) in sorted(self.fsa.label2Freq.iteritems(), key=lambda (label, freq): (-freq, label))]
|
|
252
253
|
# popular labels table
|
|
254
|
self.label2ShortLabel = dict([(label, sortedLabels.index(label) + 1) for label in sortedLabels[:63]])
|
|
255
256
|
logging.debug(dict([(chr(label), shortLabel) for label, shortLabel in self.label2ShortLabel.items()]))
|
|
257
258
|
# write remaining short labels (zeros)
|
|
259
260
261
|
for label in range(256):
res.append(self.label2ShortLabel.get(label, 0))
|
|
262
|
# write a magic char before initial state
|
|
263
264
265
266
267
268
269
270
271
272
273
274
275
276
|
res.append(ord('^'))
return res
def getStateSize(self, state):
return len(self.state2bytearray(state))
def getDataSize(self, state):
assert type(state.encodedData) == bytearray or not state.isAccepting()
return len(state.encodedData) if state.isAccepting() else 0
def stateShouldBeAnArray(self, state):
return self.useArrays and state.serializeAsArray
|
|
277
|
def stateData2bytearray(self, state):
|
|
278
|
# assert state.transitionsNum < 64
|
|
279
280
281
|
res = bytearray()
firstByte = 0
if state.isAccepting():
|
|
282
|
firstByte |= self.ACCEPTING_FLAG
|
|
283
284
285
|
# if self.stateShouldBeAnArray(state):
# firstByte |= self.ARRAY_FLAG
if state.transitionsNum < 127:
|
|
286
287
288
|
firstByte |= state.transitionsNum
res.append(firstByte)
else:
|
|
289
|
firstByte |= 127
|
|
290
291
292
|
res.append(firstByte)
res.append(state.transitionsNum)
|
|
293
294
295
296
297
298
|
if state.isAccepting():
res.extend(state.encodedData)
return res
def _transitions2ListBytes(self, state, originalState=None):
res = bytearray()
|
|
299
|
transitions = self.getSortedTransitions(state)
|
|
300
|
thisIdx = self.state2Index[originalState if originalState is not None else state]
|
|
301
|
logging.debug('state ' + str(state.offset))
|
|
302
303
304
305
306
307
308
309
310
|
if len(transitions) == 0:
assert state.isAccepting()
return bytearray()
else:
stateAfterThis = self.statesTable[thisIdx + 1]
for reversedN, (label, nextState) in enumerate(reversed(transitions)):
transitionBytes = bytearray()
assert nextState.reverseOffset is not None
assert stateAfterThis.reverseOffset is not None
|
|
311
312
|
logging.debug('next state reverse: ' + str(nextState.reverseOffset))
logging.debug('after state reverse: ' + str(stateAfterThis.reverseOffset))
|
|
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
|
# firstByte = label
n = len(transitions) - reversedN
hasShortLabel = label in self.label2ShortLabel
firstByte = self.label2ShortLabel[label] if hasShortLabel else 0
firstByte <<= 2
last = len(transitions) == n
isNext = last and stateAfterThis == nextState
offsetSize = 0
# offset = 0
offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + len(res)
assert offset > 0 or isNext
if offset > 0:
offsetSize += 1
if offset >= 256:
offsetSize += 1
if offset >= 256 * 256:
offsetSize += 1
|
|
333
334
335
336
|
exceptions.validate(
offset < 256 * 256 * 256,
u'Cannot build the automaton - it would exceed its max size which is %d' % (256 * 256 * 256))
# assert offset < 256 * 256 * 256 # TODO - przerobic na jakis porzadny wyjatek
|
|
337
338
|
assert offsetSize <= 3
firstByte |= offsetSize
|
|
339
|
|
|
340
|
transitionBytes.append(firstByte)
|
|
341
|
if not hasShortLabel:
|
|
342
|
transitionBytes.append(label)
|
|
343
|
# serialize offset in big-endian order
|
|
344
|
if offsetSize == 3:
|
|
345
|
transitionBytes.append((offset & 0xFF0000) >> 16)
|
|
346
|
if offsetSize >= 2:
|
|
347
|
transitionBytes.append((offset & 0x00FF00) >> 8)
|
|
348
|
if offsetSize >= 1:
|
|
349
350
351
|
transitionBytes.append(offset & 0x0000FF)
for b in reversed(transitionBytes):
res.insert(0, b)
|
|
352
|
logging.debug('inserted transition at beginning ' + chr(label) + ' -> ' + str(offset))
|
|
353
|
|
|
354
|
return res
|
|
355
356
357
358
359
360
361
362
363
364
365
|
def _trimState(self, state):
newState = State()
newState.encodedData = state.encodedData
newState.reverseOffset = state.reverseOffset
newState.offset = state.offset
newState.transitionsMap = dict([(label, nextState) for (label, nextState) in state.transitionsMap.iteritems()])
# newState.transitionsMap = dict([(label, nextState) for (label, nextState) in state.transitionsMap.iteritems() if not label in self.label2ShortLabel or not self.label2ShortLabel[label] in range(1,64)])
newState.serializeAsArray = False
return newState
|
|
366
|
def _transitions2ArrayBytes(self, state):
|
|
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
|
res = bytearray()
array = [0] * 64
for label, nextState in state.transitionsMap.iteritems():
if label in self.label2ShortLabel:
shortLabel = self.label2ShortLabel[label]
array[shortLabel] = nextState.offset
logging.debug(array)
for offset in map(lambda x: x if x else 0, array):
res.append(0)
res.append((offset & 0xFF0000) >> 16)
res.append((offset & 0x00FF00) >> 8)
res.append(offset & 0x0000FF)
res.extend(self._stateData2bytearray(self._trimState(state)))
res.extend(self._transitions2ListBytes(self._trimState(state), originalState=state))
return res
|
|
382
383
|
def transitionsData2bytearray(self, state):
|
|
384
|
if self.stateShouldBeAnArray(state):
|
|
385
|
return self._transitions2ArrayBytes(state)
|
|
386
387
388
389
390
391
392
393
394
395
396
|
else:
return self._transitions2ListBytes(state)
def _chooseArrayStates(self):
for state1 in self.fsa.initialState.transitionsMap.values():
for state2 in state1.transitionsMap.values():
# for state3 in state2.transitionsMap.values():
# state3.serializeAsArray = True
state2.serializeAsArray = True
state1.serializeAsArray = True
self.fsa.initialState.serializeAsArray = True
|
|
397
398
399
400
401
402
403
404
405
|
class VLengthSerializer2(Serializer):
def __init__(self, fsa, useArrays=False):
super(VLengthSerializer2, self).__init__(fsa)
self.statesTable = list(reversed(list(fsa.dfs())))
self.state2Index = dict([(state, idx) for (idx, state) in enumerate(self.statesTable)])
self.HAS_REMAINING_FLAG = 128
|
|
406
|
self.ACCEPTING_FLAG = 64
|
|
407
408
|
self.LAST_FLAG = 32
|
|
409
410
411
|
def serializeFSAPrologue(self):
return bytearray()
|
|
412
413
414
415
416
417
|
def getImplementationCode(self):
return 2
def getStateSize(self, state):
return len(self.state2bytearray(state))
|
|
418
419
420
|
# def getDataSize(self, state):
# assert type(state.encodedData) == bytearray or not state.isAccepting()
# return len(state.encodedData) if state.isAccepting() else 0
|
|
421
|
|
|
422
|
def stateData2bytearray(self, state):
|
|
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
|
res = bytearray()
if state.isAccepting():
res.extend(state.encodedData)
return res
def _first5Bits(self, number):
n = number
while n >= 32:
n >>= 7
return n
def _getOffsetBytes(self, offset):
res = bytearray()
remaining = offset
lastByte = True
while remaining >= 32:
nextByte = remaining & 0b01111111
remaining >>= 7
if not lastByte:
nextByte |= self.HAS_REMAINING_FLAG
else:
lastByte = False
res.insert(0, nextByte)
logging.debug(remaining)
return res
def _transitions2ListBytes(self, state):
res = bytearray()
thisIdx = self.state2Index[state]
transitions = self.getSortedTransitions(state)
if len(transitions) == 0:
assert state.isAccepting()
res.append(self.LAST_FLAG)
return res
else:
stateAfterThis = self.statesTable[thisIdx + 1]
for reversedN, (label, nextState) in enumerate(reversed(transitions)):
transitionBytes = bytearray()
assert nextState.reverseOffset is not None
assert stateAfterThis.reverseOffset is not None
|
|
463
464
|
logging.debug('next state reverse: ' + str(nextState.reverseOffset))
logging.debug('after state reverse: ' + str(stateAfterThis.reverseOffset))
|
|
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
|
n = len(transitions) - reversedN
firstByte = label
last = len(transitions) == n
secondByte = 0
if last: secondByte |= self.LAST_FLAG
if nextState.isAccepting(): secondByte |= self.ACCEPTING_FLAG
offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + len(res)
if offset >= 32:
secondByte |= self.HAS_REMAINING_FLAG
secondByte |= self._first5Bits(offset)
else:
secondByte |= offset
transitionBytes.append(firstByte)
transitionBytes.append(secondByte)
transitionBytes.extend(self._getOffsetBytes(offset))
for b in reversed(transitionBytes):
res.insert(0, b)
|
|
484
|
logging.debug('inserted transition at beginning ' + chr(label) + ' -> ' + str(offset))
|
|
485
486
487
488
489
|
return res
#
def transitionsData2bytearray(self, state):
return self._transitions2ListBytes(state)
|