''' Created on Oct 20, 2013 @author: mlenart ''' import logging from state import State from morfeuszbuilder.utils.serializationUtils import * class Serializer(object): MAGIC_NUMBER = 0x8fc2bc1b def __init__(self, fsa, headerFilename="data/default_fsa.hpp"): self._fsa = fsa self.headerFilename = headerFilename @property def fsa(self): return self._fsa # get the Morfeusz file format version that is being encoded def getVersion(self): return 10 def serialize2CppFile(self, fname, generator, segmentationRulesData): res = [] # self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state)) res.append('\n') res.append('#include "%s"' % self.headerFilename) res.append('\n') res.append('\n') if generator: res.append('extern const unsigned char DEFAULT_SYNTH_FSA[] = {') else: res.append('extern const unsigned char DEFAULT_FSA[] = {') res.append('\n') for byte in self.fsa2bytearray( tagsetData=self.serializeTagset(self.fsa.tagset), segmentationRulesData=segmentationRulesData): res.append(hex(byte)); res.append(','); res.append('\n') res.append('};') res.append('\n') with open(fname, 'w') as f: f.write(''.join(res)) def serialize2BinaryFile(self, fname, segmentationRulesData): with open(fname, 'wb') as f: f.write(self.fsa2bytearray( tagsetData=self.serializeTagset(self.fsa.tagset), segmentationRulesData=segmentationRulesData)) def getStateSize(self, state): raise NotImplementedError('Not implemented') def fsa2bytearray(self, tagsetData, segmentationRulesData): res = bytearray() res.extend(self.serializePrologue()) fsaData = bytearray() fsaData.extend(self.serializeFSAPrologue()) self.fsa.calculateOffsets(sizeCounter=lambda state: self.getStateSize(state)) for state in sorted(self.fsa.dfs(), key=lambda s: s.offset): fsaData.extend(self.state2bytearray(state)) res.extend(htonl(len(fsaData))) res.extend(fsaData) res.extend(self.serializeEpilogue(tagsetData, segmentationRulesData)) return res def _serializeTags(self, tagsMap): res = bytearray() numOfTags = len(tagsMap) res.extend(htons(numOfTags)) for tag, tagnum in sorted(tagsMap.iteritems(), key=lambda (tag, tagnum): tagnum): res.extend(htons(tagnum)) res.extend(self.fsa.encodeWord(tag)) res.append(0) return res # serialize tagset data def serializeTagset(self, tagset): res = bytearray() if tagset: res.extend(self._serializeTags(tagset._tag2tagnum)) res.extend(self._serializeTags(tagset._name2namenum)) return res def serializePrologue(self): res = bytearray() # serialize magic number in big-endian order res.append((Serializer.MAGIC_NUMBER & 0xFF000000) >> 24) res.append((Serializer.MAGIC_NUMBER & 0x00FF0000) >> 16) res.append((Serializer.MAGIC_NUMBER & 0x0000FF00) >> 8) res.append(Serializer.MAGIC_NUMBER & 0x000000FF) # serialize version number res.append(self.getVersion()) # serialize implementation code res.append(self.getImplementationCode()) return res def serializeEpilogue(self, tagsetData, segmentationRulesData): res = bytearray() tagsetDataSize = len(tagsetData) if tagsetData else 0 segmentationDataSize = len(segmentationRulesData) if segmentationRulesData else 0 res.extend(htonl(tagsetDataSize)) # add additional data itself if tagsetDataSize: assert type(tagsetData) == bytearray res.extend(tagsetData) if segmentationDataSize: assert type(segmentationRulesData) == bytearray res.extend(segmentationRulesData) return res def state2bytearray(self, state): res = bytearray() res.extend(self.stateData2bytearray(state)) res.extend(self.transitionsData2bytearray(state)) return res def getSortedTransitions(self, state): defaultKey = lambda (label, nextState): (-state.label2Freq.get(label, 0), -self.fsa.label2Freq.get(label, 0)) return list(sorted( state.transitionsMap.iteritems(), key=defaultKey)) def stateData2bytearray(self, state): raise NotImplementedError('Not implemented') def transitionsData2bytearray(self, state): raise NotImplementedError('Not implemented') def getImplementationCode(self): raise NotImplementedError('Not implemented') class SimpleSerializer(Serializer): def __init__(self, fsa, serializeTransitionsData=False): super(SimpleSerializer, self).__init__(fsa) self.ACCEPTING_FLAG = 128 self.serializeTransitionsData = serializeTransitionsData def getImplementationCode(self): return 0 if not self.serializeTransitionsData else 128 def serializeFSAPrologue(self): return bytearray() def getStateSize(self, state): if self.serializeTransitionsData: return 1 + 5 * len(state.transitionsMap.keys()) + self.getDataSize(state) else: return 1 + 4 * len(state.transitionsMap.keys()) + self.getDataSize(state) def getDataSize(self, state): return len(state.encodedData) if state.isAccepting() else 0 def stateData2bytearray(self, state): res = bytearray() firstByte = 0 if state.isAccepting(): firstByte |= self.ACCEPTING_FLAG firstByte |= state.transitionsNum assert firstByte < 256 and firstByte > 0 res.append(firstByte) if state.isAccepting(): res.extend(state.encodedData) return res def transitionsData2bytearray(self, state): res = bytearray() # logging.debug('next') for label, nextState in self.getSortedTransitions(state): res.append(label) offset = nextState.offset res.append((offset & 0xFF0000) >> 16) res.append((offset & 0x00FF00) >> 8) res.append(offset & 0x0000FF) if self.serializeTransitionsData: transitionData = state.transitionsDataMap[label] assert transitionData >= 0 assert transitionData < 256 res.append(transitionData) return res class VLengthSerializer1(Serializer): def __init__(self, fsa, useArrays=False): super(VLengthSerializer1, self).__init__(fsa) self.statesTable = list(reversed(list(fsa.dfs()))) self.state2Index = dict([(state, idx) for (idx, state) in enumerate(self.statesTable)]) self._chooseArrayStates() self.useArrays = useArrays self.label2ShortLabel = None self.ACCEPTING_FLAG = 0b10000000 self.ARRAY_FLAG = 0b01000000 def getImplementationCode(self): return 1 def serializeFSAPrologue(self): res = bytearray() # labels sorted by popularity sortedLabels = [label for (label, freq) in sorted(self.fsa.label2Freq.iteritems(), key=lambda (label, freq): (-freq, label))] # popular labels table self.label2ShortLabel = dict([(label, sortedLabels.index(label) + 1) for label in sortedLabels[:63]]) logging.debug(dict([(chr(label), shortLabel) for label, shortLabel in self.label2ShortLabel.items()])) # write remaining short labels (zeros) for label in range(256): res.append(self.label2ShortLabel.get(label, 0)) # write a magic char before initial state res.append(ord('^')) return res def getStateSize(self, state): return len(self.state2bytearray(state)) def getDataSize(self, state): assert type(state.encodedData) == bytearray or not state.isAccepting() return len(state.encodedData) if state.isAccepting() else 0 def stateShouldBeAnArray(self, state): return self.useArrays and state.serializeAsArray def stateData2bytearray(self, state): assert state.transitionsNum < 64 res = bytearray() firstByte = 0 if state.isAccepting(): firstByte |= self.ACCEPTING_FLAG if self.stateShouldBeAnArray(state): firstByte |= self.ARRAY_FLAG firstByte |= state.transitionsNum assert firstByte < 256 and firstByte > 0 res.append(firstByte) if state.isAccepting(): res.extend(state.encodedData) return res def _transitions2ListBytes(self, state, originalState=None): res = bytearray() transitions = self.getSortedTransitions(state) thisIdx = self.state2Index[originalState if originalState is not None else state] logging.debug('state '+str(state.offset)) if len(transitions) == 0: assert state.isAccepting() return bytearray() else: stateAfterThis = self.statesTable[thisIdx + 1] for reversedN, (label, nextState) in enumerate(reversed(transitions)): transitionBytes = bytearray() assert nextState.reverseOffset is not None assert stateAfterThis.reverseOffset is not None logging.debug('next state reverse: '+str(nextState.reverseOffset)) logging.debug('after state reverse: '+str(stateAfterThis.reverseOffset)) # firstByte = label n = len(transitions) - reversedN hasShortLabel = label in self.label2ShortLabel firstByte = self.label2ShortLabel[label] if hasShortLabel else 0 firstByte <<= 2 last = len(transitions) == n isNext = last and stateAfterThis == nextState offsetSize = 0 # offset = 0 offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + len(res) assert offset > 0 or isNext if offset > 0: offsetSize += 1 if offset >= 256: offsetSize += 1 if offset >= 256 * 256: offsetSize += 1 assert offset < 256 * 256 * 256 #TODO - przerobic na jakis porzadny wyjatek assert offsetSize <= 3 assert offsetSize > 0 or isNext firstByte |= offsetSize transitionBytes.append(firstByte) if not hasShortLabel: transitionBytes.append(label) # serialize offset in big-endian order if offsetSize == 3: transitionBytes.append((offset & 0xFF0000) >> 16) if offsetSize >= 2: transitionBytes.append((offset & 0x00FF00) >> 8) if offsetSize >= 1: transitionBytes.append(offset & 0x0000FF) for b in reversed(transitionBytes): res.insert(0, b) logging.debug('inserted transition at beginning '+chr(label)+' -> '+str(offset)) return res def _trimState(self, state): newState = State() newState.encodedData = state.encodedData newState.reverseOffset = state.reverseOffset newState.offset = state.offset newState.transitionsMap = dict([(label, nextState) for (label, nextState) in state.transitionsMap.iteritems()]) # newState.transitionsMap = dict([(label, nextState) for (label, nextState) in state.transitionsMap.iteritems() if not label in self.label2ShortLabel or not self.label2ShortLabel[label] in range(1,64)]) newState.serializeAsArray = False return newState def _transitions2ArrayBytes(self, state): res = bytearray() array = [0] * 64 for label, nextState in state.transitionsMap.iteritems(): if label in self.label2ShortLabel: shortLabel = self.label2ShortLabel[label] array[shortLabel] = nextState.offset logging.debug(array) for offset in map(lambda x: x if x else 0, array): res.append(0) res.append((offset & 0xFF0000) >> 16) res.append((offset & 0x00FF00) >> 8) res.append(offset & 0x0000FF) res.extend(self._stateData2bytearray(self._trimState(state))) res.extend(self._transitions2ListBytes(self._trimState(state), originalState=state)) return res def transitionsData2bytearray(self, state): if self.stateShouldBeAnArray(state): return self._transitions2ArrayBytes(state) else: return self._transitions2ListBytes(state) def _chooseArrayStates(self): for state1 in self.fsa.initialState.transitionsMap.values(): for state2 in state1.transitionsMap.values(): # for state3 in state2.transitionsMap.values(): # state3.serializeAsArray = True state2.serializeAsArray = True state1.serializeAsArray = True self.fsa.initialState.serializeAsArray = True class VLengthSerializer2(Serializer): def __init__(self, fsa, useArrays=False): super(VLengthSerializer2, self).__init__(fsa) self.statesTable = list(reversed(list(fsa.dfs()))) self.state2Index = dict([(state, idx) for (idx, state) in enumerate(self.statesTable)]) self.HAS_REMAINING_FLAG = 128 self.ACCEPTING_FLAG = 64 self.LAST_FLAG = 32 def serializeFSAPrologue(self): return bytearray() def getImplementationCode(self): return 2 def getStateSize(self, state): return len(self.state2bytearray(state)) # def getDataSize(self, state): # assert type(state.encodedData) == bytearray or not state.isAccepting() # return len(state.encodedData) if state.isAccepting() else 0 def stateData2bytearray(self, state): res = bytearray() if state.isAccepting(): res.extend(state.encodedData) return res def _first5Bits(self, number): n = number while n >= 32: n >>= 7 return n def _getOffsetBytes(self, offset): res = bytearray() remaining = offset lastByte = True while remaining >= 32: nextByte = remaining & 0b01111111 remaining >>= 7 if not lastByte: nextByte |= self.HAS_REMAINING_FLAG else: lastByte = False res.insert(0, nextByte) logging.debug(remaining) return res def _transitions2ListBytes(self, state): res = bytearray() thisIdx = self.state2Index[state] transitions = self.getSortedTransitions(state) if len(transitions) == 0: assert state.isAccepting() res.append(self.LAST_FLAG) return res else: stateAfterThis = self.statesTable[thisIdx + 1] for reversedN, (label, nextState) in enumerate(reversed(transitions)): transitionBytes = bytearray() assert nextState.reverseOffset is not None assert stateAfterThis.reverseOffset is not None logging.debug('next state reverse: '+str(nextState.reverseOffset)) logging.debug('after state reverse: '+str(stateAfterThis.reverseOffset)) n = len(transitions) - reversedN firstByte = label last = len(transitions) == n secondByte = 0 if last: secondByte |= self.LAST_FLAG if nextState.isAccepting(): secondByte |= self.ACCEPTING_FLAG offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + len(res) if offset >= 32: secondByte |= self.HAS_REMAINING_FLAG secondByte |= self._first5Bits(offset) else: secondByte |= offset transitionBytes.append(firstByte) transitionBytes.append(secondByte) transitionBytes.extend(self._getOffsetBytes(offset)) for b in reversed(transitionBytes): res.insert(0, b) logging.debug('inserted transition at beginning '+chr(label)+' -> '+str(offset)) return res # def transitionsData2bytearray(self, state): return self._transitions2ListBytes(state)