Commit 6679f8ba37e476aff1df115fd306b74602f61caa
1 parent
c2243119
- działająca podstawowa wersja z tablicą w każdych dwóch pierwszych stanach
git-svn-id: svn://svn.nlp.ipipan.waw.pl/morfeusz/morfeusz@11 ff4e3ee1-f430-4e82-ade0-24591c43f1fd
Showing
12 changed files
with
674 additions
and
138 deletions
CMakeLists.txt
| ... | ... | @@ -9,6 +9,6 @@ add_subdirectory (morfeusz) |
| 9 | 9 | |
| 10 | 10 | file(COPY fsabuilder testfiles DESTINATION .) |
| 11 | 11 | |
| 12 | -add_test (TestBuildFSA python fsabuilder/fsa/buildfsa.py -i testfiles/dict.txt -o testfiles/test.fsa -t SPELL --input-format=PLAIN --output-format=BINARY) | |
| 12 | +add_test (TestBuildFSA python fsabuilder/fsa/buildfsa.py -i testfiles/dict.txt -o testfiles/test.fsa -t SPELL --input-format=PLAIN --output-format=BINARY --use-arrays) | |
| 13 | 13 | add_test (TestRecognize fsa/test_recognize testfiles/test.fsa testfiles/dict.txt) |
| 14 | 14 | add_test (TestNOTRecognize fsa/test_not_recognize testfiles/test.fsa testfiles/out_of_dict.txt) |
| ... | ... |
fsa/_fsa_impl.hpp
| ... | ... | @@ -115,7 +115,7 @@ bool FSA<T>::tryToRecognize(const char* input, T& value) const { |
| 115 | 115 | i++; |
| 116 | 116 | } |
| 117 | 117 | // input[i] == '\0' |
| 118 | - currState.proceedToNext(0); | |
| 118 | +// currState.proceedToNext(0); | |
| 119 | 119 | |
| 120 | 120 | if (currState.isAccepting()) { |
| 121 | 121 | value = currState.getValue(); |
| ... | ... |
fsa/_fsaimpl.hpp
0 → 100644
| 1 | +/* | |
| 2 | + * File: _vfsa_impl.hpp | |
| 3 | + * Author: lennyn | |
| 4 | + * | |
| 5 | + * Created on October 29, 2013, 9:57 PM | |
| 6 | + */ | |
| 7 | + | |
| 8 | +#ifndef _VFSA_IMPL_HPP | |
| 9 | +#define _VFSA_IMPL_HPP | |
| 10 | + | |
| 11 | +#include <algorithm> | |
| 12 | +#include <utility> | |
| 13 | +#include <iostream> | |
| 14 | +#include <netinet/in.h> | |
| 15 | +#include "fsa.hpp" | |
| 16 | + | |
| 17 | +using namespace std; | |
| 18 | + | |
| 19 | +#pragma pack(push) /* push current alignment to stack */ | |
| 20 | +#pragma pack(1) /* set alignment to 1 byte boundary */ | |
| 21 | + | |
| 22 | +struct StateData2 { | |
| 23 | + unsigned transitionsNum: 6; | |
| 24 | + unsigned array : 1; | |
| 25 | + unsigned accepting : 1; | |
| 26 | +}; | |
| 27 | + | |
| 28 | +struct TransitionData2 { | |
| 29 | + unsigned offsetSize : 2; | |
| 30 | + unsigned shortLabel : 6; | |
| 31 | +}; | |
| 32 | + | |
| 33 | + | |
| 34 | +#pragma pack(pop) /* restore original alignment from stack */ | |
| 35 | + | |
| 36 | +template <class T> | |
| 37 | +int FSAImpl<T>::getMagicNumberOffset() { | |
| 38 | + return 0; | |
| 39 | +} | |
| 40 | + | |
| 41 | +template <class T> | |
| 42 | +int FSAImpl<T>::getVersionNumOffset() { | |
| 43 | + return getMagicNumberOffset() + sizeof (MAGIC_NUMBER); | |
| 44 | +} | |
| 45 | + | |
| 46 | +template <class T> | |
| 47 | +int FSAImpl<T>::getPopularCharsOffset() { | |
| 48 | + return getVersionNumOffset() + sizeof (VERSION_NUM); | |
| 49 | +} | |
| 50 | + | |
| 51 | +template <class T> | |
| 52 | +int FSAImpl<T>::getInitialStateOffset() { | |
| 53 | + return getPopularCharsOffset() + 256 + 1; | |
| 54 | +} | |
| 55 | + | |
| 56 | +template <class T> | |
| 57 | +vector<unsigned char> FSAImpl<T>::initializeChar2PopularCharIdx(const unsigned char* ptr) { | |
| 58 | + return vector<unsigned char>(ptr + getPopularCharsOffset(), ptr + getPopularCharsOffset() + 256); | |
| 59 | +} | |
| 60 | + | |
| 61 | +template <class T> | |
| 62 | +FSAImpl<T>::FSAImpl(const unsigned char* ptr, const Deserializer<T>& deserializer) | |
| 63 | +: FSA<T>(ptr + getInitialStateOffset(), deserializer), | |
| 64 | +label2ShortLabel(initializeChar2PopularCharIdx(ptr)) { | |
| 65 | + uint32_t magicNumber = ntohl(*((uint32_t*) ptr + getMagicNumberOffset())); | |
| 66 | + if (magicNumber != MAGIC_NUMBER) { | |
| 67 | + throw FSAException("Invalid magic number"); | |
| 68 | + } | |
| 69 | + unsigned char versionNum = *(ptr + getVersionNumOffset()); | |
| 70 | + if (versionNum != VERSION_NUM) { | |
| 71 | + throw FSAException("Invalid version number"); | |
| 72 | + } | |
| 73 | + // cerr << "initial state offset " << getInitialStateOffset() << endl; | |
| 74 | +} | |
| 75 | + | |
| 76 | +template <class T> | |
| 77 | +FSAImpl<T>::~FSAImpl() { | |
| 78 | + | |
| 79 | +} | |
| 80 | + | |
| 81 | +template <class T> | |
| 82 | +void FSAImpl<T>::reallyDoProceed( | |
| 83 | + const unsigned char* statePtr, | |
| 84 | + State<T>& state) const { | |
| 85 | +// const unsigned char stateByte = *statePtr; | |
| 86 | + StateData2* sd = (StateData2*) statePtr; | |
| 87 | + if (sd->accepting) { | |
| 88 | +// cerr << "ACCEPTING" << endl; | |
| 89 | + T object; | |
| 90 | + int size = this->deserializer.deserialize(statePtr + 1, object); | |
| 91 | + state.setNext(statePtr - this->startPtr, object, size); | |
| 92 | + } | |
| 93 | + else { | |
| 94 | + state.setNext(statePtr - this->startPtr); | |
| 95 | + } | |
| 96 | +} | |
| 97 | + | |
| 98 | +template <class T> | |
| 99 | +void FSAImpl<T>::doProceedToNextByList( | |
| 100 | + const char c, | |
| 101 | + const unsigned char shortLabel, | |
| 102 | + const unsigned char* ptr, | |
| 103 | + const unsigned int transitionsNum, | |
| 104 | + State<T>& state) const { | |
| 105 | + register const unsigned char* currPtr = ptr; | |
| 106 | + // TransitionData* foundTransition = (TransitionData*) (fromPointer + transitionsTableOffset); | |
| 107 | + bool found = false; | |
| 108 | + TransitionData2 td; | |
| 109 | + for (unsigned int i = 0; i < transitionsNum; i++) { | |
| 110 | + // const_cast<Counter*>(&counter)->increment(1); | |
| 111 | + td = *((TransitionData2*) currPtr); | |
| 112 | + if (td.shortLabel == shortLabel) { | |
| 113 | + if (shortLabel == 0) { | |
| 114 | + currPtr++; | |
| 115 | + char label = (char) *currPtr; | |
| 116 | + if (label == c) { | |
| 117 | + found = true; | |
| 118 | + break; | |
| 119 | + } | |
| 120 | + else { | |
| 121 | + currPtr += td.offsetSize + 1; | |
| 122 | + } | |
| 123 | + } else { | |
| 124 | + found = true; | |
| 125 | + break; | |
| 126 | + } | |
| 127 | + } | |
| 128 | + else { | |
| 129 | + if (td.shortLabel == 0) { | |
| 130 | + currPtr++; | |
| 131 | + } | |
| 132 | + currPtr += td.offsetSize + 1; | |
| 133 | + } | |
| 134 | + } | |
| 135 | + // const_cast<Counter*>(&counter)->increment(foundTransition - transitionsStart + 1); | |
| 136 | + if (!found) { | |
| 137 | +// cerr << "SINK for " << c << endl; | |
| 138 | + state.setNextAsSink(); | |
| 139 | + } else { | |
| 140 | + currPtr++; | |
| 141 | +// cerr << "offset size " << td.offsetSize << endl; | |
| 142 | +// cerr << "offset " << offset << endl; | |
| 143 | + switch (td.offsetSize) { | |
| 144 | + case 0: | |
| 145 | + break; | |
| 146 | + case 1: | |
| 147 | + currPtr += *currPtr + 1; | |
| 148 | + break; | |
| 149 | + case 2: | |
| 150 | + currPtr += ntohs(*((uint16_t*) currPtr)) + 2; | |
| 151 | + break; | |
| 152 | + case 3: | |
| 153 | + currPtr += (((unsigned int) ntohs(*((uint16_t*) currPtr))) << 8) + currPtr[2] + 3; | |
| 154 | + break; | |
| 155 | + } | |
| 156 | +// cerr << "FOUND " << c << " " << currPtr - this->startPtr << endl; | |
| 157 | + reallyDoProceed(currPtr, state); | |
| 158 | + } | |
| 159 | +} | |
| 160 | + | |
| 161 | +template <class T> | |
| 162 | +void FSAImpl<T>::doProceedToNextByArray( | |
| 163 | + const unsigned char shortLabel, | |
| 164 | + const uint32_t* ptr, | |
| 165 | + State<T>& state) const { | |
| 166 | + uint32_t offset = ntohl(ptr[shortLabel]); | |
| 167 | + if (offset != 0) { | |
| 168 | + const unsigned char* currPtr = this->startPtr + offset; | |
| 169 | + reallyDoProceed(currPtr, state); | |
| 170 | + } | |
| 171 | + else { | |
| 172 | + state.setNextAsSink(); | |
| 173 | + } | |
| 174 | +} | |
| 175 | + | |
| 176 | +template <class T> | |
| 177 | +void FSAImpl<T>::proceedToNext(const char c, State<T>& state) const { | |
| 178 | +// if (c <= 'z' && 'a' <= c) | |
| 179 | +// cerr << "NEXT " << c << " from " << state.getOffset() << endl; | |
| 180 | +// else | |
| 181 | +// cerr << "NEXT " << (short) c << " from " << state.getOffset() << endl; | |
| 182 | + const unsigned char* fromPointer = this->startPtr + state.getOffset(); | |
| 183 | + unsigned char shortLabel = this->label2ShortLabel[(const unsigned char) c]; | |
| 184 | + unsigned int transitionsTableOffset = 1; | |
| 185 | + if (state.isAccepting()) { | |
| 186 | + transitionsTableOffset += state.getValueSize(); | |
| 187 | +// cerr << "transitionsTableOffset " << transitionsTableOffset + state.getOffset() << " because value is " << state.getValue() << endl; | |
| 188 | + } | |
| 189 | + StateData2* sd = (StateData2*) (fromPointer); | |
| 190 | +// cerr << "transitions num=" << sd->transitionsNum << endl; | |
| 191 | + if (sd->array) { | |
| 192 | + if (shortLabel > 0) { | |
| 193 | + this->doProceedToNextByArray( | |
| 194 | + shortLabel, | |
| 195 | + (uint32_t*) (fromPointer + transitionsTableOffset), | |
| 196 | + state); | |
| 197 | + } | |
| 198 | + else { | |
| 199 | + reallyDoProceed((unsigned char*) fromPointer + transitionsTableOffset + 256, state); | |
| 200 | + proceedToNext(c, state); | |
| 201 | + } | |
| 202 | + } | |
| 203 | + else { | |
| 204 | + this->doProceedToNextByList( | |
| 205 | + c, | |
| 206 | + shortLabel, | |
| 207 | + (unsigned char*) (fromPointer + transitionsTableOffset), | |
| 208 | + sd->transitionsNum, | |
| 209 | + state); | |
| 210 | + } | |
| 211 | +} | |
| 212 | + | |
| 213 | +#endif /* _VFSA_IMPL_HPP */ | |
| 214 | + | |
| ... | ... |
fsa/_vfsa_impl.hpp
| ... | ... | @@ -19,11 +19,11 @@ using namespace std; |
| 19 | 19 | #pragma pack(push) /* push current alignment to stack */ |
| 20 | 20 | #pragma pack(1) /* set alignment to 1 byte boundary */ |
| 21 | 21 | |
| 22 | -//struct VTransitionData { | |
| 23 | -// unsigned label : 5; | |
| 24 | -// unsigned offsetSize : 2; | |
| 25 | -// unsigned last : 1; | |
| 26 | -//}; | |
| 22 | +struct StateData2 { | |
| 23 | + unsigned transitionsNum : 6; | |
| 24 | + unsigned next : 1; | |
| 25 | + unsigned accepting : 1; | |
| 26 | +}; | |
| 27 | 27 | |
| 28 | 28 | #pragma pack(pop) /* restore original alignment from stack */ |
| 29 | 29 | |
| ... | ... | @@ -49,12 +49,13 @@ int FSAImpl<T>::getInitialStateOffset() { |
| 49 | 49 | |
| 50 | 50 | template <class T> |
| 51 | 51 | vector<unsigned char> FSAImpl<T>::initializeChar2PopularCharIdx(const unsigned char* ptr) { |
| 52 | - vector<unsigned char> res(256, FSAImpl<bool>::POPULAR_CHARS_NUM); | |
| 53 | - const unsigned char* popularChars = ptr + getPopularCharsOffset(); | |
| 54 | - for (unsigned int i = 0; i < POPULAR_CHARS_NUM; i++) { | |
| 55 | - res[popularChars[i]] = i; | |
| 56 | - } | |
| 57 | - return res; | |
| 52 | + // vector<unsigned char> res(256, FSAImpl<bool>::POPULAR_CHARS_NUM); | |
| 53 | + // const unsigned char* popularChars = ptr + getPopularCharsOffset(); | |
| 54 | + // for (unsigned int i = 0; i < POPULAR_CHARS_NUM; i++) { | |
| 55 | + // res[popularChars[i]] = i; | |
| 56 | + // } | |
| 57 | + // return res; | |
| 58 | + return vector<unsigned char>(); | |
| 58 | 59 | } |
| 59 | 60 | |
| 60 | 61 | template <class T> |
| ... | ... | @@ -79,94 +80,165 @@ FSAImpl<T>::~FSAImpl() { |
| 79 | 80 | |
| 80 | 81 | template <class T> |
| 81 | 82 | void FSAImpl<T>::proceedToNext(const char c, State<T>& state) const { |
| 82 | - // if (c <= 'z' && 'a' <= c) | |
| 83 | - // cerr << "NEXT " << c << " from " << state.getOffset() << endl; | |
| 84 | - // else | |
| 85 | - // cerr << "NEXT " << (short) c << " from " << state.getOffset() << endl; | |
| 83 | +// if (c <= 'z' && 'a' <= c) | |
| 84 | +// cerr << "NEXT " << c << " from " << state.getOffset() << endl; | |
| 85 | +// else | |
| 86 | +// cerr << "NEXT " << (short) c << " from " << state.getOffset() << endl; | |
| 86 | 87 | const unsigned char* fromPointer = this->startPtr + state.getOffset(); |
| 87 | - unsigned int transitionsTableOffset = 0; | |
| 88 | + int transitionsTableOffset = sizeof (StateData2); | |
| 88 | 89 | if (state.isAccepting()) { |
| 89 | 90 | transitionsTableOffset += state.getValueSize(); |
| 90 | - // cerr << "transitionsTableOffset " << transitionsTableOffset + state.getOffset() << " because value is " << state.getValue() << endl; | |
| 91 | +// cerr << "transitionsTableOffset " << transitionsTableOffset + state.getOffset() << " because value is " << state.getValue() << endl; | |
| 91 | 92 | } |
| 92 | - | |
| 93 | + StateData2 stateData = *(StateData2*) (fromPointer); | |
| 94 | +// cerr << "transitions num=" << stateData.transitionsNum << endl; | |
| 95 | + register unsigned char* currPtr = (unsigned char*) (fromPointer + transitionsTableOffset); | |
| 96 | + // TransitionData* foundTransition = (TransitionData*) (fromPointer + transitionsTableOffset); | |
| 93 | 97 | bool found = false; |
| 94 | -// bool failed = false; | |
| 95 | - unsigned int requiredShortLabel = char2PopularCharIdx[(unsigned char) c]; | |
| 96 | - // cerr << "NEXT " << c << " " << (int) shortLabel << endl; | |
| 97 | -// VTransitionData* td; | |
| 98 | -// unsigned char transitionByte = *currPtr; | |
| 99 | - unsigned int offsetSize; | |
| 100 | - register const unsigned char* currPtr = fromPointer + transitionsTableOffset; | |
| 101 | - | |
| 102 | - while (!found) { | |
| 103 | - | |
| 104 | - register unsigned char firstByte = *currPtr; | |
| 105 | - | |
| 106 | - unsigned int shortLabel = firstByte & 0b00011111; | |
| 107 | - bool last = (firstByte & 0b10000000); | |
| 108 | - offsetSize = (firstByte & 0b01100000) >> 5; | |
| 109 | - | |
| 110 | - const_cast<FSAImpl<T>*>(this)->counter.increment(1); | |
| 111 | - | |
| 112 | - if (shortLabel != requiredShortLabel) { | |
| 113 | - if (last || shortLabel == POPULAR_CHARS_NUM) { | |
| 114 | - break; | |
| 115 | - } | |
| 116 | - currPtr += offsetSize + 1; | |
| 117 | - if (shortLabel == POPULAR_CHARS_NUM) { | |
| 118 | - currPtr++; | |
| 119 | - } | |
| 120 | - } | |
| 121 | - else if (shortLabel != POPULAR_CHARS_NUM) { | |
| 98 | + bool next = stateData.next; | |
| 99 | + for (unsigned int i = 0; i < stateData.transitionsNum; i++) { | |
| 100 | +// cerr << *currPtr << endl; | |
| 101 | + if ((char) *currPtr == c) { | |
| 122 | 102 | found = true; |
| 123 | - currPtr++; | |
| 103 | + next = next && i + 1 == stateData.transitionsNum; | |
| 104 | + break; | |
| 105 | + } else { | |
| 106 | + // unsigned int offsetSize = currPtr[1] & 0b00000011; | |
| 107 | + currPtr += (currPtr[1] & 0b00000011) + 2; | |
| 124 | 108 | } |
| 125 | - else { | |
| 109 | + } | |
| 110 | + // const_cast<Counter*>(&counter)->increment(foundTransition - transitionsStart + 1); | |
| 111 | + if (!found) { | |
| 112 | +// cerr << "SINK for " << c << endl; | |
| 113 | + state.setNextAsSink(); | |
| 114 | + } | |
| 115 | + else { | |
| 116 | + currPtr++; | |
| 117 | + if (!next) { | |
| 118 | + unsigned int offsetSize = *currPtr & 0b00000011; | |
| 119 | + unsigned int offset = *currPtr >> 2; | |
| 120 | +// cerr << "offset size " << offsetSize << endl; | |
| 121 | +// cerr << "offset " << offset << endl; | |
| 126 | 122 | currPtr++; |
| 127 | - char realLabel = (char) *currPtr; | |
| 128 | - if (realLabel != c) { | |
| 129 | - if (last) { | |
| 123 | + // currPtr += (*currPtr >> 2) + 1; | |
| 124 | + switch (offsetSize) { | |
| 125 | + case 0: | |
| 126 | + currPtr += offset; | |
| 127 | + break; | |
| 128 | + case 1: | |
| 129 | + currPtr += (offset << 8) + *currPtr + 1; | |
| 130 | + break; | |
| 131 | + case 2: | |
| 132 | + currPtr += (offset << 16) + ntohs(*((uint16_t*) currPtr)) + 2; | |
| 133 | + break; | |
| 134 | + case 3: | |
| 135 | + currPtr += (offset << 24) + (((unsigned int) ntohs(*((uint16_t*) currPtr))) << 8) + currPtr[2] + 3; | |
| 130 | 136 | break; |
| 131 | - } | |
| 132 | - currPtr += offsetSize + 1; | |
| 133 | - } | |
| 134 | - else { | |
| 135 | - found = true; | |
| 136 | - currPtr++; | |
| 137 | 137 | } |
| 138 | 138 | } |
| 139 | - } | |
| 140 | - | |
| 141 | - if (found) { | |
| 142 | - switch (offsetSize) { | |
| 143 | - case 0: | |
| 144 | - break; | |
| 145 | - case 1: | |
| 146 | - currPtr += *currPtr + 1; | |
| 147 | - break; | |
| 148 | - case 2: | |
| 149 | - currPtr += ntohs(*((uint16_t*) currPtr)) + 2; | |
| 150 | - break; | |
| 151 | - case 3: | |
| 152 | - currPtr += (((unsigned int) ntohs(*((uint16_t*) currPtr))) << 8) + currPtr[2] + 3; | |
| 153 | - break; | |
| 154 | - } | |
| 155 | - bool accepting = c == '\0'; | |
| 156 | - if (accepting) { | |
| 157 | - T value; | |
| 158 | - int valueSize = this->deserializer.deserialize(currPtr, value); | |
| 159 | - currPtr += valueSize; | |
| 160 | - state.setNext(currPtr - this->startPtr, value, valueSize); | |
| 161 | - } | |
| 162 | - else { | |
| 139 | +// cerr << "FOUND " << c << " " << currPtr - this->startPtr << endl; | |
| 140 | + StateData* nextStateData = (StateData*) (currPtr); | |
| 141 | + if (nextStateData->accepting) { | |
| 142 | +// cerr << "ACCEPTING" << endl; | |
| 143 | + T object; | |
| 144 | + int size = this->deserializer.deserialize(currPtr + sizeof (StateData), object); | |
| 145 | + state.setNext(currPtr - this->startPtr, object, size); | |
| 146 | + } else { | |
| 163 | 147 | state.setNext(currPtr - this->startPtr); |
| 164 | 148 | } |
| 165 | 149 | } |
| 166 | - else { | |
| 167 | - state.setNextAsSink(); | |
| 168 | - } | |
| 169 | 150 | } |
| 170 | 151 | |
| 152 | +//template <class T> | |
| 153 | +//void FSAImpl<T>::proceedToNext(const char c, State<T>& state) const { | |
| 154 | +// // if (c <= 'z' && 'a' <= c) | |
| 155 | +// // cerr << "NEXT " << c << " from " << state.getOffset() << endl; | |
| 156 | +// // else | |
| 157 | +// // cerr << "NEXT " << (short) c << " from " << state.getOffset() << endl; | |
| 158 | +// const unsigned char* fromPointer = this->startPtr + state.getOffset(); | |
| 159 | +// unsigned int transitionsTableOffset = 0; | |
| 160 | +// if (state.isAccepting()) { | |
| 161 | +// transitionsTableOffset += state.getValueSize(); | |
| 162 | +// // cerr << "transitionsTableOffset " << transitionsTableOffset + state.getOffset() << " because value is " << state.getValue() << endl; | |
| 163 | +// } | |
| 164 | +// | |
| 165 | +// bool found = false; | |
| 166 | +//// bool failed = false; | |
| 167 | +// unsigned int requiredShortLabel = char2PopularCharIdx[(unsigned char) c]; | |
| 168 | +// // cerr << "NEXT " << c << " " << (int) shortLabel << endl; | |
| 169 | +//// VTransitionData* td; | |
| 170 | +//// unsigned char transitionByte = *currPtr; | |
| 171 | +// unsigned int offsetSize; | |
| 172 | +// register const unsigned char* currPtr = fromPointer + transitionsTableOffset; | |
| 173 | +// | |
| 174 | +// while (!found) { | |
| 175 | +// | |
| 176 | +// register unsigned char firstByte = *currPtr; | |
| 177 | +// | |
| 178 | +// unsigned int shortLabel = firstByte & 0b00011111; | |
| 179 | +// bool last = (firstByte & 0b10000000); | |
| 180 | +// offsetSize = (firstByte & 0b01100000) >> 5; | |
| 181 | +// | |
| 182 | +// const_cast<FSAImpl<T>*>(this)->counter.increment(1); | |
| 183 | +// | |
| 184 | +// if (shortLabel != requiredShortLabel) { | |
| 185 | +// if (last) { | |
| 186 | +// break; | |
| 187 | +// } | |
| 188 | +// currPtr += offsetSize + 1; | |
| 189 | +// if (shortLabel == POPULAR_CHARS_NUM) { | |
| 190 | +// currPtr++; | |
| 191 | +// } | |
| 192 | +// } | |
| 193 | +// else if (shortLabel != POPULAR_CHARS_NUM) { | |
| 194 | +// found = true; | |
| 195 | +// currPtr++; | |
| 196 | +// } | |
| 197 | +// else { | |
| 198 | +// currPtr++; | |
| 199 | +// char realLabel = (char) *currPtr; | |
| 200 | +// if (realLabel != c) { | |
| 201 | +// if (last) { | |
| 202 | +// break; | |
| 203 | +// } | |
| 204 | +// currPtr += offsetSize + 1; | |
| 205 | +// } | |
| 206 | +// else { | |
| 207 | +// found = true; | |
| 208 | +// currPtr++; | |
| 209 | +// } | |
| 210 | +// } | |
| 211 | +// } | |
| 212 | +// | |
| 213 | +// if (found) { | |
| 214 | +// switch (offsetSize) { | |
| 215 | +// case 0: | |
| 216 | +// break; | |
| 217 | +// case 1: | |
| 218 | +// currPtr += *currPtr + 1; | |
| 219 | +// break; | |
| 220 | +// case 2: | |
| 221 | +// currPtr += ntohs(*((uint16_t*) currPtr)) + 2; | |
| 222 | +// break; | |
| 223 | +// case 3: | |
| 224 | +// currPtr += (((unsigned int) ntohs(*((uint16_t*) currPtr))) << 8) + currPtr[2] + 3; | |
| 225 | +// break; | |
| 226 | +// } | |
| 227 | +// bool accepting = c == '\0'; | |
| 228 | +// if (accepting) { | |
| 229 | +// T value; | |
| 230 | +// int valueSize = this->deserializer.deserialize(currPtr, value); | |
| 231 | +// currPtr += valueSize; | |
| 232 | +// state.setNext(currPtr - this->startPtr, value, valueSize); | |
| 233 | +// } | |
| 234 | +// else { | |
| 235 | +// state.setNext(currPtr - this->startPtr); | |
| 236 | +// } | |
| 237 | +// } | |
| 238 | +// else { | |
| 239 | +// state.setNextAsSink(); | |
| 240 | +// } | |
| 241 | +//} | |
| 242 | + | |
| 171 | 243 | #endif /* _VFSA_IMPL_HPP */ |
| 172 | 244 | |
| ... | ... |
fsa/fsa.hpp
| ... | ... | @@ -119,20 +119,36 @@ public: |
| 119 | 119 | } |
| 120 | 120 | |
| 121 | 121 | static const uint32_t MAGIC_NUMBER = 0x8fc2bc1b; |
| 122 | - static const unsigned char VERSION_NUM = 1; | |
| 123 | - static const unsigned int POPULAR_CHARS_NUM = 31; | |
| 122 | + static const unsigned char VERSION_NUM = 4; | |
| 123 | + | |
| 124 | + static const unsigned char ACCEPTING_FLAG = 0b10000000; | |
| 125 | + static const unsigned char ARRAY_FLAG = 0b01000000; | |
| 126 | + static const unsigned char TRANSITIONS_NUM_MASK = 0b00111111; | |
| 124 | 127 | |
| 125 | 128 | protected: |
| 126 | 129 | void proceedToNext(const char c, State<T>& state) const; |
| 127 | 130 | private: |
| 128 | 131 | Counter counter; |
| 129 | - const std::vector<unsigned char> char2PopularCharIdx; | |
| 132 | + const std::vector<unsigned char> label2ShortLabel; | |
| 130 | 133 | |
| 131 | 134 | static int getMagicNumberOffset(); |
| 132 | 135 | static int getVersionNumOffset(); |
| 133 | 136 | static int getPopularCharsOffset(); |
| 134 | 137 | static int getInitialStateOffset(); |
| 135 | 138 | static std::vector<unsigned char> initializeChar2PopularCharIdx(const unsigned char* ptr); |
| 139 | + void doProceedToNextByList( | |
| 140 | + const char c, | |
| 141 | + const unsigned char shortLabel, | |
| 142 | + const unsigned char* ptr, | |
| 143 | + const unsigned int transitionsNum, | |
| 144 | + State<T>& state) const; | |
| 145 | + void doProceedToNextByArray( | |
| 146 | + const unsigned char shortLabel, | |
| 147 | + const uint32_t* ptr, | |
| 148 | + State<T>& state) const; | |
| 149 | + void reallyDoProceed( | |
| 150 | + const unsigned char* statePtr, | |
| 151 | + State<T>& state) const; | |
| 136 | 152 | }; |
| 137 | 153 | |
| 138 | 154 | /** |
| ... | ... | @@ -201,7 +217,8 @@ private: |
| 201 | 217 | }; |
| 202 | 218 | |
| 203 | 219 | #include "_fsa_impl.hpp" |
| 204 | -#include "_vfsa_impl.hpp" | |
| 220 | +#include "_fsaimpl.hpp" | |
| 221 | +//#include "_vfsa_impl.hpp" | |
| 205 | 222 | #include "_state_impl.hpp" |
| 206 | 223 | |
| 207 | 224 | #endif /* FSA_HPP */ |
| ... | ... |
fsa/test_speed.cpp
| ... | ... | @@ -30,13 +30,14 @@ int main(int argc, char** argv) { |
| 30 | 30 | int unrecognized = 0; |
| 31 | 31 | while (ifs.getline(line, 65536, '\n')) { |
| 32 | 32 | char* val; |
| 33 | -// cout << line << endl; | |
| 33 | +// cerr << line << endl; | |
| 34 | 34 | if (fsa.tryToRecognize(line, val)) { |
| 35 | 35 | // printf("%s: *OK*\n", line); |
| 36 | 36 | recognized++; |
| 37 | 37 | } |
| 38 | 38 | else { |
| 39 | 39 | unrecognized++; |
| 40 | +// exit(1); | |
| 40 | 41 | // printf("%s: NOT FOUND\n", line); |
| 41 | 42 | } |
| 42 | 43 | } |
| ... | ... |
fsabuilder/fsa/buildfsa.py
| ... | ... | @@ -11,12 +11,10 @@ import codecs |
| 11 | 11 | import encode |
| 12 | 12 | import convertinput |
| 13 | 13 | from fsa import FSA |
| 14 | -from serializer import VLengthSerializer | |
| 14 | +from serializer import VLengthSerializer2, VLengthSerializer3 | |
| 15 | 15 | from visualizer import Visualizer |
| 16 | 16 | from optparse import OptionParser |
| 17 | 17 | |
| 18 | -logging.basicConfig(level=logging.INFO) | |
| 19 | - | |
| 20 | 18 | class OutputFormat(): |
| 21 | 19 | BINARY = 'BINARY' |
| 22 | 20 | CPP = 'CPP' |
| ... | ... | @@ -52,6 +50,11 @@ def parseOptions(): |
| 52 | 50 | parser.add_option('--output-format', |
| 53 | 51 | dest='outputFormat', |
| 54 | 52 | help='output format - BINARY or CPP') |
| 53 | + parser.add_option('--use-arrays', | |
| 54 | + dest='useArrays', | |
| 55 | + action='store_true', | |
| 56 | + default=False, | |
| 57 | + help='store states reachable by 2 transitions in arrays (should speed up recognition)') | |
| 55 | 58 | parser.add_option('--visualize', |
| 56 | 59 | dest='visualize', |
| 57 | 60 | action='store_true', |
| ... | ... | @@ -60,6 +63,11 @@ def parseOptions(): |
| 60 | 63 | parser.add_option('--train-file', |
| 61 | 64 | dest='trainFile', |
| 62 | 65 | help='A text file used for training. Should contain words from some large corpus - one word in each line') |
| 66 | + parser.add_option('--debug', | |
| 67 | + dest='debug', | |
| 68 | + action='store_true', | |
| 69 | + default=False, | |
| 70 | + help='output some debugging info') | |
| 63 | 71 | |
| 64 | 72 | opts, args = parser.parse_args() |
| 65 | 73 | |
| ... | ... | @@ -114,6 +122,10 @@ def readTrainData(trainFile): |
| 114 | 122 | |
| 115 | 123 | if __name__ == '__main__': |
| 116 | 124 | opts = parseOptions() |
| 125 | + if opts.debug: | |
| 126 | + logging.basicConfig(level=logging.DEBUG) | |
| 127 | + else: | |
| 128 | + logging.basicConfig(level=logging.INFO) | |
| 117 | 129 | encoder = encode.Encoder() |
| 118 | 130 | fsa = FSA(encoder) |
| 119 | 131 | |
| ... | ... | @@ -129,16 +141,19 @@ if __name__ == '__main__': |
| 129 | 141 | logging.info('training with '+opts.trainFile+' ...') |
| 130 | 142 | fsa.train(readTrainData(opts.trainFile)) |
| 131 | 143 | logging.info('done training') |
| 132 | - serializer = VLengthSerializer(fsa) | |
| 144 | + serializer = VLengthSerializer3(fsa, useArrays=opts.useArrays) | |
| 133 | 145 | logging.info('states num: '+str(fsa.getStatesNum())) |
| 134 | 146 | logging.info('transitions num: '+str(fsa.getTransitionsNum())) |
| 135 | 147 | logging.info('accepting states num: '+str(len([s for s in fsa.initialState.dfs(set()) if s.isAccepting()]))) |
| 136 | 148 | logging.info('sink states num: '+str(len([s for s in fsa.initialState.dfs(set()) if len(s.transitionsMap.items()) == 0]))) |
| 149 | + logging.info('array states num: '+str(len([s for s in fsa.dfs() if s.serializeAsArray]))) | |
| 137 | 150 | { |
| 138 | 151 | OutputFormat.CPP: serializer.serialize2CppFile, |
| 139 | 152 | OutputFormat.BINARY: serializer.serialize2BinaryFile |
| 140 | 153 | }[opts.outputFormat](opts.outputFile) |
| 141 | 154 | logging.info('size: '+str(fsa.initialState.reverseOffset)) |
| 155 | +# for s in fsa.dfs(): | |
| 156 | +# logging.debug('%d %s' % (s.offset, str(s.transitionsMap))) | |
| 142 | 157 | # for s in fsa.initialState.dfs(set()): |
| 143 | 158 | # logging.info(s.offset) |
| 144 | 159 | if opts.visualize: |
| ... | ... |
fsabuilder/fsa/encode.py
fsabuilder/fsa/fsa.py
| ... | ... | @@ -21,7 +21,7 @@ class FSA(object): |
| 21 | 21 | self.encodedPrevWord = None |
| 22 | 22 | self.initialState = state.State() |
| 23 | 23 | self.register = register.Register() |
| 24 | - self.label2Freq = {0: float('inf')} | |
| 24 | + self.label2Freq = {} | |
| 25 | 25 | |
| 26 | 26 | def tryToRecognize(self, word, addFreq=False): |
| 27 | 27 | return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq)) |
| ... | ... | @@ -52,7 +52,7 @@ class FSA(object): |
| 52 | 52 | # self.tryToRecognize(w, True) |
| 53 | 53 | |
| 54 | 54 | def train(self, trainData): |
| 55 | - self.label2Freq = {0: float('inf')} | |
| 55 | + self.label2Freq = {} | |
| 56 | 56 | for idx, word in enumerate(trainData): |
| 57 | 57 | self.tryToRecognize(word, addFreq=True) |
| 58 | 58 | for label in self.encodeWord(word): |
| ... | ... | @@ -115,6 +115,4 @@ class FSA(object): |
| 115 | 115 | state.reverseOffset = currReverseOffset |
| 116 | 116 | for state in self.initialState.dfs(set()): |
| 117 | 117 | state.offset = currReverseOffset - state.reverseOffset |
| 118 | - | |
| 119 | - | |
| 120 | 118 | |
| 121 | 119 | \ No newline at end of file |
| ... | ... |
fsabuilder/fsa/serializer.py
| ... | ... | @@ -5,6 +5,7 @@ Created on Oct 20, 2013 |
| 5 | 5 | ''' |
| 6 | 6 | |
| 7 | 7 | import logging |
| 8 | +from state import State | |
| 8 | 9 | |
| 9 | 10 | class Serializer(object): |
| 10 | 11 | |
| ... | ... | @@ -162,10 +163,18 @@ class VLengthSerializer(Serializer): |
| 162 | 163 | if state.isAccepting(): |
| 163 | 164 | res.extend(state.encodedData) |
| 164 | 165 | return res |
| 165 | - | |
| 166 | + | |
| 167 | + def getKey(self, state, label): | |
| 168 | + res = (-state.label2Freq.get(label, 0), -self.fsa.label2Freq.get(label, 0)) | |
| 169 | +# logging.info(chr(label)) | |
| 170 | +# logging.info(res) | |
| 171 | + return res | |
| 172 | + | |
| 166 | 173 | def _transitionsData2bytearray(self, state): |
| 167 | 174 | res = bytearray() |
| 168 | - transitions = sorted(state.transitionsMap.iteritems(), key=lambda (label, nextState): (self.label2Index.get(label, float('inf')), -nextState.freq, -self.label2Count[label])) | |
| 175 | +# logging.info(self.fsa.label2Freq) | |
| 176 | + transitions = list(sorted(state.transitionsMap.iteritems(), key=lambda (label, nextState): self.getKey(state, label))) | |
| 177 | +# logging.info(str([chr(label) for label, _ in transitions])) | |
| 169 | 178 | thisIdx = self.state2Index[state] |
| 170 | 179 | logging.debug('state '+str(state.offset)) |
| 171 | 180 | if len(transitions) == 0: |
| ... | ... | @@ -225,9 +234,9 @@ class VLengthSerializer(Serializer): |
| 225 | 234 | class VLengthSerializer2(Serializer): |
| 226 | 235 | |
| 227 | 236 | MAGIC_NUMBER = 0x8fc2bc1b |
| 228 | - VERSION = 2 | |
| 229 | - FINAL_FLAG = 0b10000000 | |
| 230 | - LAST_FLAG = 0b01000000 | |
| 237 | + VERSION = 3 | |
| 238 | + ACCEPTING_FLAG = 0b10000000 | |
| 239 | + NEXT_FLAG = 0b01000000 | |
| 231 | 240 | |
| 232 | 241 | def __init__(self, fsa): |
| 233 | 242 | super(VLengthSerializer2, self).__init__(fsa) |
| ... | ... | @@ -238,13 +247,13 @@ class VLengthSerializer2(Serializer): |
| 238 | 247 | res = bytearray() |
| 239 | 248 | |
| 240 | 249 | # serialize magic number in big-endian order |
| 241 | - res.append((VLengthSerializer.MAGIC_NUMBER & 0xFF000000) >> 24) | |
| 242 | - res.append((VLengthSerializer.MAGIC_NUMBER & 0x00FF0000) >> 16) | |
| 243 | - res.append((VLengthSerializer.MAGIC_NUMBER & 0x0000FF00) >> 8) | |
| 244 | - res.append(VLengthSerializer.MAGIC_NUMBER & 0x000000FF) | |
| 250 | + res.append((VLengthSerializer2.MAGIC_NUMBER & 0xFF000000) >> 24) | |
| 251 | + res.append((VLengthSerializer2.MAGIC_NUMBER & 0x00FF0000) >> 16) | |
| 252 | + res.append((VLengthSerializer2.MAGIC_NUMBER & 0x0000FF00) >> 8) | |
| 253 | + res.append(VLengthSerializer2.MAGIC_NUMBER & 0x000000FF) | |
| 245 | 254 | |
| 246 | 255 | # serialize version number |
| 247 | - res.append(VLengthSerializer.VERSION) | |
| 256 | + res.append(VLengthSerializer2.VERSION) | |
| 248 | 257 | |
| 249 | 258 | return res |
| 250 | 259 | |
| ... | ... | @@ -262,20 +271,37 @@ class VLengthSerializer2(Serializer): |
| 262 | 271 | return res |
| 263 | 272 | |
| 264 | 273 | def _stateData2bytearray(self, state): |
| 274 | + assert len(state.transitionsMap) < 64 | |
| 265 | 275 | res = bytearray() |
| 276 | + firstByte = 0 | |
| 277 | + if state.isAccepting(): | |
| 278 | + firstByte |= VLengthSerializer2.ACCEPTING_FLAG | |
| 279 | + transitions = list(sorted(state.transitionsMap.iteritems(), key=lambda (label, nextState): self.getKey(state, label))) | |
| 280 | + if transitions: | |
| 281 | + lastLabel, lastNextState = transitions[-1] | |
| 282 | + if self.state2Index[lastNextState] == self.state2Index[state] + 1: | |
| 283 | + firstByte |= VLengthSerializer2.NEXT_FLAG | |
| 284 | + firstByte |= len(state.transitionsMap) | |
| 285 | + assert firstByte < 256 and firstByte > 0 | |
| 286 | + res.append(firstByte) | |
| 266 | 287 | if state.isAccepting(): |
| 267 | 288 | res.extend(state.encodedData) |
| 268 | 289 | return res |
| 290 | + | |
| 291 | + def getKey(self, state, label): | |
| 292 | + res = (-state.label2Freq.get(label, 0), -self.fsa.label2Freq.get(label, 0)) | |
| 293 | +# logging.info(chr(label)) | |
| 294 | +# logging.info(res) | |
| 295 | + return res | |
| 269 | 296 | |
| 270 | 297 | def _transitionsData2bytearray(self, state): |
| 271 | 298 | res = bytearray() |
| 272 | - transitions = sorted(state.transitionsMap.iteritems(), key=lambda (label, nextState): (-nextState.freq, -self.label2Count[label])) | |
| 299 | + transitions = list(sorted(state.transitionsMap.iteritems(), key=lambda (label, nextState): self.getKey(state, label))) | |
| 273 | 300 | thisIdx = self.state2Index[state] |
| 274 | 301 | logging.debug('state '+str(state.offset)) |
| 275 | 302 | if len(transitions) == 0: |
| 276 | 303 | assert state.isAccepting() |
| 277 | -# flags | |
| 278 | - return bytearray(0, ) | |
| 304 | + return bytearray() | |
| 279 | 305 | else: |
| 280 | 306 | stateAfterThis = self.statesTable[thisIdx + 1] |
| 281 | 307 | for reversedN, (label, nextState) in enumerate(reversed(transitions)): |
| ... | ... | @@ -284,36 +310,178 @@ class VLengthSerializer2(Serializer): |
| 284 | 310 | assert stateAfterThis.reverseOffset is not None |
| 285 | 311 | logging.debug('next state reverse: '+str(nextState.reverseOffset)) |
| 286 | 312 | logging.debug('after state reverse: '+str(stateAfterThis.reverseOffset)) |
| 287 | - n = len(transitions) - reversedN | |
| 288 | - | |
| 289 | - popularLabel = label in self.label2Index | |
| 290 | - firstByte = self.label2Index[label] if popularLabel else 31 | |
| 291 | 313 | |
| 292 | - last = len(transitions) == n | |
| 293 | - next = last and stateAfterThis == nextState | |
| 314 | + firstByte = label | |
| 294 | 315 | |
| 295 | - if last: | |
| 296 | - firstByte |= VLengthSerializer.LAST_FLAG | |
| 316 | + n = len(transitions) - reversedN | |
| 297 | 317 | |
| 298 | - offsetSize = 0 | |
| 299 | - offset = 0 | |
| 300 | - if not next: | |
| 301 | - offsetSize = 1 | |
| 302 | -# nextState.offset - stateAfterThis.offset | |
| 303 | - offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + offsetSize + len(res) - 1 | |
| 318 | + last = len(transitions) == n | |
| 319 | + isNext = last and stateAfterThis == nextState | |
| 320 | + if not isNext: | |
| 321 | + offsetSize = 0 | |
| 322 | + # offset = 0 | |
| 323 | + offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + len(res) | |
| 304 | 324 | assert offset > 0 |
| 305 | - if offset >= 256: | |
| 306 | -# offset += 1 | |
| 325 | + if offset >= 64: | |
| 307 | 326 | offsetSize += 1 |
| 308 | - if offset >= 256 * 256: | |
| 309 | -# offset += 1 | |
| 327 | + if offset >= 256 * 64: | |
| 310 | 328 | offsetSize += 1 |
| 311 | - assert offset < 256 * 256 * 256 #TODO - przerobic na jakis porzadny wyjatek | |
| 329 | + if offset >= 256 * 256 * 64: | |
| 330 | + offsetSize += 1 | |
| 331 | + assert offset < 256 * 256 * 256 * 64 #TODO - przerobic na jakis porzadny wyjatek | |
| 312 | 332 | |
| 313 | - firstByte |= (32 * offsetSize) | |
| 333 | + secondByte = offsetSize | |
| 334 | + secondByte |= (offset >> (offsetSize * 8)) << 2 | |
| 335 | + | |
| 336 | + transitionBytes.append(firstByte) | |
| 337 | + transitionBytes.append(secondByte) | |
| 338 | + # serialize offset in big-endian order | |
| 339 | + if offsetSize == 3: | |
| 340 | + transitionBytes.append((offset & 0x00FF0000) >> 16) | |
| 341 | + if offsetSize >= 2: | |
| 342 | + transitionBytes.append((offset & 0x0000FF00) >> 8) | |
| 343 | + if offsetSize >= 1: | |
| 344 | + transitionBytes.append(offset & 0x000000FF) | |
| 345 | + for b in reversed(transitionBytes): | |
| 346 | + res.insert(0, b) | |
| 347 | + logging.debug('inserted transition at beginning '+chr(label)+' -> '+str(offset)) | |
| 348 | + else: | |
| 349 | + logging.debug('inserted transition at beginning '+chr(label)+' -> NEXT') | |
| 350 | + res.insert(0, firstByte) | |
| 351 | + return res | |
| 352 | + | |
| 353 | +class VLengthSerializer3(Serializer): | |
| 354 | + | |
| 355 | + MAGIC_NUMBER = 0x8fc2bc1b | |
| 356 | + VERSION = 4 | |
| 357 | + ACCEPTING_FLAG = 0b10000000 | |
| 358 | + ARRAY_FLAG = 0b01000000 | |
| 359 | + | |
| 360 | + def __init__(self, fsa, useArrays): | |
| 361 | + super(VLengthSerializer3, self).__init__(fsa) | |
| 362 | + self.statesTable = list(reversed(list(fsa.dfs()))) | |
| 363 | + self.state2Index = dict([(state, idx) for (idx, state) in enumerate(self.statesTable)]) | |
| 364 | + self._chooseArrayStates() | |
| 365 | + self.useArrays = useArrays | |
| 366 | + | |
| 367 | + def serializePrologue(self): | |
| 368 | + res = bytearray() | |
| 369 | + | |
| 370 | + # serialize magic number in big-endian order | |
| 371 | + res.append((VLengthSerializer3.MAGIC_NUMBER & 0xFF000000) >> 24) | |
| 372 | + res.append((VLengthSerializer3.MAGIC_NUMBER & 0x00FF0000) >> 16) | |
| 373 | + res.append((VLengthSerializer3.MAGIC_NUMBER & 0x0000FF00) >> 8) | |
| 374 | + res.append(VLengthSerializer3.MAGIC_NUMBER & 0x000000FF) | |
| 375 | + | |
| 376 | + # serialize version number | |
| 377 | + res.append(VLengthSerializer3.VERSION) | |
| 378 | + | |
| 379 | + # labels sorted by popularity | |
| 380 | + self.sortedLabels = [label for (label, freq) in sorted(self.fsa.label2Freq.iteritems(), key=lambda (label, freq): (-freq, label))] | |
| 381 | + remainingChars = [c for c in range(256) if not c in self.sortedLabels] | |
| 382 | +# while len(self.sortedLabels) < 256: | |
| 383 | +# self.sortedLabels.append(remainingChars.pop()) | |
| 384 | + | |
| 385 | + # popular labels table | |
| 386 | + self.label2ShortLabel = dict([(label, self.sortedLabels.index(label) + 1) for label in self.sortedLabels[:63]]) | |
| 387 | + | |
| 388 | + logging.debug(dict([(chr(label), shortLabel) for label, shortLabel in self.label2ShortLabel.items()])) | |
| 389 | + for label in range(256): | |
| 390 | + res.append(self.label2ShortLabel.get(label, 0)) | |
| 391 | + | |
| 392 | + res.append(ord('^')) | |
| 393 | + | |
| 394 | + return res | |
| 395 | + | |
| 396 | + def getStateSize(self, state): | |
| 397 | + return len(self.state2bytearray(state)) | |
| 398 | + | |
| 399 | + def getDataSize(self, state): | |
| 400 | + assert type(state.encodedData) == bytearray or not state.isAccepting() | |
| 401 | + return len(state.encodedData) if state.isAccepting() else 0 | |
| 402 | + | |
| 403 | + def state2bytearray(self, state): | |
| 404 | + res = bytearray() | |
| 405 | + res.extend(self._stateData2bytearray(state)) | |
| 406 | + res.extend(self._transitionsData2bytearray(state)) | |
| 407 | + return res | |
| 408 | + | |
| 409 | + def stateShouldBeAnArray(self, state): | |
| 410 | +# return False | |
| 411 | +# return len(state.transitionsMap) >= 13 | |
| 412 | + return self.useArrays and state.serializeAsArray | |
| 413 | + | |
| 414 | + def _stateData2bytearray(self, state): | |
| 415 | + assert len(state.transitionsMap) < 64 | |
| 416 | + res = bytearray() | |
| 417 | + firstByte = 0 | |
| 418 | + if state.isAccepting(): | |
| 419 | + firstByte |= VLengthSerializer3.ACCEPTING_FLAG | |
| 420 | +# transitions = list(sorted(state.transitionsMap.iteritems(), key=lambda (label, nextState): self.getKey(state, label))) | |
| 421 | +# if transitions: | |
| 422 | +# lastLabel, lastNextState = transitions[-1] | |
| 423 | +# if self.state2Index[lastNextState] == self.state2Index[state] + 1: | |
| 424 | +# firstByte |= VLengthSerializer3.NEXT_FLAG | |
| 425 | + if self.stateShouldBeAnArray(state): | |
| 426 | + firstByte |= VLengthSerializer3.ARRAY_FLAG | |
| 427 | + firstByte |= len(state.transitionsMap) | |
| 428 | + assert firstByte < 256 and firstByte > 0 | |
| 429 | + res.append(firstByte) | |
| 430 | + if state.isAccepting(): | |
| 431 | + res.extend(state.encodedData) | |
| 432 | + return res | |
| 433 | + | |
| 434 | + def getKey(self, state, label): | |
| 435 | + res = (-state.label2Freq.get(label, 0)) | |
| 436 | +# logging.info(chr(label)) | |
| 437 | +# logging.info(res) | |
| 438 | + return res | |
| 439 | + | |
| 440 | + def _transitions2ListBytes(self, state, originalState=None): | |
| 441 | + res = bytearray() | |
| 442 | + transitions = list(sorted(state.transitionsMap.iteritems(), key=lambda (label, nextState): self.getKey(state, label))) | |
| 443 | + thisIdx = self.state2Index[originalState if originalState is not None else state] | |
| 444 | + logging.debug('state '+str(state.offset)) | |
| 445 | + if len(transitions) == 0: | |
| 446 | + assert state.isAccepting() | |
| 447 | + return bytearray() | |
| 448 | + else: | |
| 449 | + stateAfterThis = self.statesTable[thisIdx + 1] | |
| 450 | + for reversedN, (label, nextState) in enumerate(reversed(transitions)): | |
| 451 | + transitionBytes = bytearray() | |
| 452 | + assert nextState.reverseOffset is not None | |
| 453 | + assert stateAfterThis.reverseOffset is not None | |
| 454 | + logging.debug('next state reverse: '+str(nextState.reverseOffset)) | |
| 455 | + logging.debug('after state reverse: '+str(stateAfterThis.reverseOffset)) | |
| 456 | + | |
| 457 | +# firstByte = label | |
| 458 | + | |
| 459 | + n = len(transitions) - reversedN | |
| 460 | + hasShortLabel = label in self.label2ShortLabel | |
| 461 | + firstByte = self.label2ShortLabel[label] if hasShortLabel else 0 | |
| 462 | + firstByte <<= 2 | |
| 463 | + | |
| 464 | + last = len(transitions) == n | |
| 465 | + isNext = last and stateAfterThis == nextState | |
| 466 | + offsetSize = 0 | |
| 467 | +# offset = 0 | |
| 468 | + offset = (stateAfterThis.reverseOffset - nextState.reverseOffset) + len(res) | |
| 469 | + assert offset > 0 or isNext | |
| 470 | + if offset > 0: | |
| 471 | + offsetSize += 1 | |
| 472 | + if offset >= 256: | |
| 473 | + offsetSize += 1 | |
| 474 | + if offset >= 256 * 256: | |
| 475 | + offsetSize += 1 | |
| 476 | + assert offset < 256 * 256 * 256 #TODO - przerobic na jakis porzadny wyjatek | |
| 477 | + assert offsetSize <= 3 | |
| 478 | + assert offsetSize > 0 or isNext | |
| 479 | + firstByte |= offsetSize | |
| 480 | +# secondByte = offsetSize | |
| 481 | +# secondByte |= (offset >> (offsetSize * 8)) << 2 | |
| 314 | 482 | |
| 315 | 483 | transitionBytes.append(firstByte) |
| 316 | - if not popularLabel: | |
| 484 | + if not hasShortLabel: | |
| 317 | 485 | transitionBytes.append(label) |
| 318 | 486 | # serialize offset in big-endian order |
| 319 | 487 | if offsetSize == 3: |
| ... | ... | @@ -325,4 +493,47 @@ class VLengthSerializer2(Serializer): |
| 325 | 493 | for b in reversed(transitionBytes): |
| 326 | 494 | res.insert(0, b) |
| 327 | 495 | logging.debug('inserted transition at beginning '+chr(label)+' -> '+str(offset)) |
| 496 | + | |
| 328 | 497 | return res |
| 498 | + | |
| 499 | + def _trimState(self, state): | |
| 500 | + newState = State() | |
| 501 | + newState.encodedData = state.encodedData | |
| 502 | + newState.reverseOffset = state.reverseOffset | |
| 503 | + newState.offset = state.offset | |
| 504 | + newState.transitionsMap = dict([(label, nextState) for (label, nextState) in state.transitionsMap.iteritems()]) | |
| 505 | +# newState.transitionsMap = dict([(label, nextState) for (label, nextState) in state.transitionsMap.iteritems() if not label in self.label2ShortLabel or not self.label2ShortLabel[label] in range(1,64)]) | |
| 506 | + newState.serializeAsArray = False | |
| 507 | + return newState | |
| 508 | + | |
| 509 | + def _transition2ArrayBytes(self, state): | |
| 510 | + res = bytearray() | |
| 511 | + array = [0] * 64 | |
| 512 | + for label, nextState in state.transitionsMap.iteritems(): | |
| 513 | + if label in self.label2ShortLabel: | |
| 514 | + shortLabel = self.label2ShortLabel[label] | |
| 515 | + array[shortLabel] = nextState.offset | |
| 516 | + logging.debug(array) | |
| 517 | + for offset in map(lambda x: x if x else 0, array): | |
| 518 | + res.append(0) | |
| 519 | + res.append((offset & 0xFF0000) >> 16) | |
| 520 | + res.append((offset & 0x00FF00) >> 8) | |
| 521 | + res.append(offset & 0x0000FF) | |
| 522 | + res.extend(self._stateData2bytearray(self._trimState(state))) | |
| 523 | + res.extend(self._transitions2ListBytes(self._trimState(state), originalState=state)) | |
| 524 | + return res | |
| 525 | + | |
| 526 | + def _transitionsData2bytearray(self, state): | |
| 527 | + if self.stateShouldBeAnArray(state): | |
| 528 | + return self._transition2ArrayBytes(state) | |
| 529 | + else: | |
| 530 | + return self._transitions2ListBytes(state) | |
| 531 | + | |
| 532 | + def _chooseArrayStates(self): | |
| 533 | + for state1 in self.fsa.initialState.transitionsMap.values(): | |
| 534 | + for state2 in state1.transitionsMap.values(): | |
| 535 | +# for state3 in state2.transitionsMap.values(): | |
| 536 | +# state3.serializeAsArray = True | |
| 537 | + state2.serializeAsArray = True | |
| 538 | + state1.serializeAsArray = True | |
| 539 | + self.fsa.initialState.serializeAsArray = True | |
| ... | ... |
fsabuilder/fsa/state.py
| ... | ... | @@ -15,6 +15,8 @@ class State(object): |
| 15 | 15 | self.encodedData = None |
| 16 | 16 | self.reverseOffset = None |
| 17 | 17 | self.offset = None |
| 18 | + self.label2Freq = {} | |
| 19 | + self.serializeAsArray = False | |
| 18 | 20 | |
| 19 | 21 | def setTransition(self, byte, nextState): |
| 20 | 22 | self.transitionsMap[byte] = nextState |
| ... | ... | @@ -25,6 +27,7 @@ class State(object): |
| 25 | 27 | def getNext(self, byte, addFreq=False): |
| 26 | 28 | if addFreq: |
| 27 | 29 | self.freq += 1 |
| 30 | + self.label2Freq[byte] = self.label2Freq.get(byte, 0) + 1 | |
| 28 | 31 | return self.transitionsMap.get(byte, None) |
| 29 | 32 | |
| 30 | 33 | def getRegisterKey(self): |
| ... | ... |
nbproject/configurations.xml
| ... | ... | @@ -2,6 +2,7 @@ |
| 2 | 2 | <configurationDescriptor version="90"> |
| 3 | 3 | <logicalFolder name="root" displayName="root" projectFiles="true" kind="ROOT"> |
| 4 | 4 | <df root="fsa" name="0"> |
| 5 | + <in>_fsaimpl.hpp</in> | |
| 5 | 6 | <in>test_not_recognize.cpp</in> |
| 6 | 7 | <in>test_recognize.cpp</in> |
| 7 | 8 | <in>test_speed.cpp</in> |
| ... | ... | @@ -38,7 +39,7 @@ |
| 38 | 39 | <buildCommandWorkingDir>build</buildCommandWorkingDir> |
| 39 | 40 | <buildCommand>${MAKE} -f Makefile</buildCommand> |
| 40 | 41 | <cleanCommand>${MAKE} -f Makefile clean</cleanCommand> |
| 41 | - <executablePath>build/fsa/test_speed</executablePath> | |
| 42 | + <executablePath>build/fsa/test_dict</executablePath> | |
| 42 | 43 | </makeTool> |
| 43 | 44 | </makefileType> |
| 44 | 45 | <folder path="0"> |
| ... | ... | @@ -56,13 +57,17 @@ |
| 56 | 57 | </incDir> |
| 57 | 58 | </ccTool> |
| 58 | 59 | </folder> |
| 59 | - <item path="fsa/test_not_recognize.cpp" ex="false" tool="1" flavor2="4"> | |
| 60 | + <item path="fsa/_fsaimpl.hpp" ex="false" tool="3" flavor2="0"> | |
| 61 | + </item> | |
| 62 | + <item path="fsa/test_not_recognize.cpp" ex="false" tool="1" flavor2="8"> | |
| 60 | 63 | <ccTool> |
| 61 | 64 | </ccTool> |
| 62 | 65 | </item> |
| 63 | - <item path="fsa/test_recognize.cpp" ex="false" tool="1" flavor2="0"> | |
| 66 | + <item path="fsa/test_recognize.cpp" ex="false" tool="1" flavor2="8"> | |
| 67 | + <ccTool> | |
| 68 | + </ccTool> | |
| 64 | 69 | </item> |
| 65 | - <item path="fsa/test_speed.cpp" ex="false" tool="1" flavor2="4"> | |
| 70 | + <item path="fsa/test_speed.cpp" ex="false" tool="1" flavor2="8"> | |
| 66 | 71 | <ccTool> |
| 67 | 72 | </ccTool> |
| 68 | 73 | </item> |
| ... | ... |