Commit 6361f38ffa48b79b633b34734351f70422e43a45
1 parent
a22a7344
sensowniejsze sortowanie wychodzących przejść.
git-svn-id: svn://svn.nlp.ipipan.waw.pl/morfeusz/morfeusz@6 ff4e3ee1-f430-4e82-ade0-24591c43f1fd
Showing
4 changed files
with
54 additions
and
26 deletions
fsabuilder/fsa/buildfsa.py
... | ... | @@ -11,7 +11,7 @@ import codecs |
11 | 11 | import encode |
12 | 12 | import convertinput |
13 | 13 | from fsa import FSA |
14 | -from serializer import SimpleSerializerWithStringValues | |
14 | +from serializer import SimpleSerializer | |
15 | 15 | from visualizer import Visualizer |
16 | 16 | from optparse import OptionParser |
17 | 17 | |
... | ... | @@ -26,6 +26,10 @@ class InputFormat(): |
26 | 26 | POLIMORF = 'POLIMORF' |
27 | 27 | PLAIN = 'PLAIN' |
28 | 28 | |
29 | +class FSAType(): | |
30 | + MORPH = 'MORPH' | |
31 | + SPELL = 'SPELL' | |
32 | + | |
29 | 33 | def parseOptions(): |
30 | 34 | """ |
31 | 35 | Parses commandline args |
... | ... | @@ -39,9 +43,12 @@ def parseOptions(): |
39 | 43 | dest='outputFile', |
40 | 44 | metavar='FILE', |
41 | 45 | help='path to output file') |
46 | + parser.add_option('-t', '--fsa-type', | |
47 | + dest='fsaType', | |
48 | + help='result FSA type - MORPH (for morphological analysis) or SPELL (for simple spell checker)') | |
42 | 49 | parser.add_option('--input-format', |
43 | 50 | dest='inputFormat', |
44 | - help='input format - ENCODED or POLIMORF') | |
51 | + help='input format - ENCODED, POLIMORF or PLAIN') | |
45 | 52 | parser.add_option('--output-format', |
46 | 53 | dest='outputFormat', |
47 | 54 | help='output format - BINARY or CPP') |
... | ... | @@ -53,14 +60,30 @@ def parseOptions(): |
53 | 60 | |
54 | 61 | opts, args = parser.parse_args() |
55 | 62 | |
56 | - if None in [opts.inputFile, opts.outputFile, opts.outputFormat, opts.inputFormat]: | |
63 | + if None in [opts.inputFile, opts.outputFile, opts.outputFormat, opts.inputFormat, opts.fsaType]: | |
57 | 64 | parser.print_help() |
58 | 65 | exit(1) |
59 | 66 | if not opts.outputFormat.upper() in [OutputFormat.BINARY, OutputFormat.CPP]: |
60 | - print >> sys.stderr, 'output format must be one of ('+str([OutputFormat.BINARY, OutputFormat.CPP])+')' | |
67 | + logging.error('output format must be one of ('+str([OutputFormat.BINARY, OutputFormat.CPP])+')') | |
68 | + parser.print_help() | |
61 | 69 | exit(1) |
62 | 70 | if not opts.inputFormat.upper() in [InputFormat.ENCODED, InputFormat.POLIMORF, InputFormat.PLAIN]: |
63 | - print >> sys.stderr, 'input format must be one of ('+str([InputFormat.ENCODED, InputFormat.POLIMORF])+')' | |
71 | + logging.error('input format must be one of ('+str([InputFormat.ENCODED, InputFormat.POLIMORF, InputFormat.PLAIN])+')') | |
72 | + parser.print_help() | |
73 | + exit(1) | |
74 | + if not opts.fsaType.upper() in [FSAType.MORPH, FSAType.SPELL]: | |
75 | + logging.error('input format must be one of ('+str([InputFormat.ENCODED, InputFormat.POLIMORF])+')') | |
76 | + parser.print_help() | |
77 | + exit(1) | |
78 | + if opts.inputFormat.upper() == FSAType.MORPH \ | |
79 | + and not opts.inputFormat.upper() in [InputFormat.ENCODED, InputFormat.POLIMORF]: | |
80 | + logging.error('input format for morph analysis FSA must be one of ('+str([InputFormat.ENCODED, InputFormat.POLIMORF])+')') | |
81 | + parser.print_help() | |
82 | + exit(1) | |
83 | + if opts.inputFormat.upper() == FSAType.SPELL \ | |
84 | + and not opts.inputFormat.upper() in [InputFormat.PLAIN]: | |
85 | + logging.error('input format for simple spelling FSA must be '+InputFormat.PLAIN) | |
86 | + parser.print_help() | |
64 | 87 | exit(1) |
65 | 88 | return opts |
66 | 89 | |
... | ... | @@ -85,7 +108,7 @@ if __name__ == '__main__': |
85 | 108 | opts = parseOptions() |
86 | 109 | encoder = encode.Encoder() |
87 | 110 | fsa = FSA(encoder) |
88 | - serializer = SimpleSerializerWithStringValues() | |
111 | + serializer = SimpleSerializer() | |
89 | 112 | |
90 | 113 | inputData = { |
91 | 114 | InputFormat.ENCODED: readEncodedInput(opts.inputFile), |
... | ... | @@ -96,9 +119,12 @@ if __name__ == '__main__': |
96 | 119 | logging.info('feeding FSA with data ...') |
97 | 120 | fsa.feed(inputData) |
98 | 121 | logging.info('states num: '+str(fsa.getStatesNum())) |
99 | - if opts.outputFormat == 'CPP': | |
100 | - serializer.serialize2CppFile(fsa, opts.outputFile) | |
101 | - else: | |
102 | - serializer.serialize2BinaryFile(fsa, opts.outputFile) | |
122 | + | |
123 | + { | |
124 | + OutputFormat.CPP: serializer.serialize2CppFile, | |
125 | + OutputFormat.BINARY: serializer.serialize2BinaryFile | |
126 | + }[opts.outputFormat](fsa, opts.outputFile) | |
127 | + | |
103 | 128 | if opts.visualize: |
104 | 129 | Visualizer().visualize(fsa) |
130 | + | |
... | ... |
fsabuilder/fsa/fsa.py
... | ... | @@ -22,11 +22,12 @@ class FSA(object): |
22 | 22 | self.initialState = state.State() |
23 | 23 | self.register = register.Register() |
24 | 24 | |
25 | - def tryToRecognize(self, word): | |
26 | - return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word))) | |
25 | + def tryToRecognize(self, word, addFreq=False): | |
26 | + return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq)) | |
27 | 27 | |
28 | 28 | def feed(self, input): |
29 | 29 | |
30 | + allWords = [] | |
30 | 31 | for n, (word, data) in enumerate(input, start=1): |
31 | 32 | assert data is not None |
32 | 33 | if type(data) in [str, unicode]: |
... | ... | @@ -38,9 +39,13 @@ class FSA(object): |
38 | 39 | assert self.tryToRecognize(word) == data |
39 | 40 | if n % 10000 == 0: |
40 | 41 | logging.info(word) |
42 | + allWords.append(word) | |
41 | 43 | |
42 | 44 | self.initialState = self._replaceOrRegister(self.initialState, self.encodeWord(word)) |
43 | 45 | self.encodedPrevWord = None |
46 | + | |
47 | + for w in allWords: | |
48 | + self.tryToRecognize(w, True) | |
44 | 49 | |
45 | 50 | def getStatesNum(self): |
46 | 51 | return self.register.getStatesNum() |
... | ... |
fsabuilder/fsa/serializer.py
... | ... | @@ -54,7 +54,8 @@ class SimpleSerializer(Serializer): |
54 | 54 | return 1 + 4 * len(state.transitionsMap.keys()) + self.getDataSize(state) |
55 | 55 | |
56 | 56 | def getDataSize(self, state): |
57 | - raise NotImplementedError('Not implemented') | |
57 | + assert type(state.encodedData) == bytearray or not state.isAccepting() | |
58 | + return len(state.encodedData) if state.isAccepting() else 0 | |
58 | 59 | |
59 | 60 | def state2bytearray(self, state): |
60 | 61 | res = bytearray() |
... | ... | @@ -77,17 +78,10 @@ class SimpleSerializer(Serializer): |
77 | 78 | def _transitionsData2bytearray(self, state): |
78 | 79 | res = bytearray() |
79 | 80 | # must sort that strange way because it must be sorted according to char, not unsigned char |
80 | - for byte, nextState in sorted(state.transitionsMap.iteritems(), key=lambda (c, _): c if (c >= 0 and c < 128) else c - 256): | |
81 | + for byte, nextState in sorted(state.transitionsMap.iteritems(), key=lambda (_, state): -state.freq): | |
81 | 82 | res.append(byte) |
82 | 83 | offset = nextState.offset |
83 | 84 | res.append(offset & 0x0000FF) |
84 | 85 | res.append((offset & 0x00FF00) >> 8) |
85 | 86 | res.append((offset & 0xFF0000) >> 16) |
86 | 87 | return res |
87 | - | |
88 | -class SimpleSerializerWithStringValues(SimpleSerializer): | |
89 | - | |
90 | - def getDataSize(self, state): | |
91 | - assert type(state.encodedData) == bytearray or not state.isAccepting() | |
92 | - return len(state.encodedData) if state.isAccepting() else 0 | |
93 | - | |
94 | 88 | \ No newline at end of file |
... | ... |
fsabuilder/fsa/state.py
... | ... | @@ -11,6 +11,7 @@ class State(object): |
11 | 11 | |
12 | 12 | def __init__(self): |
13 | 13 | self.transitionsMap = {} |
14 | + self.freq = 0 | |
14 | 15 | self.encodedData = None |
15 | 16 | self.reverseOffset = None |
16 | 17 | self.offset = None |
... | ... | @@ -21,7 +22,9 @@ class State(object): |
21 | 22 | def hasNext(self, byte): |
22 | 23 | return byte in self.transitionsMap |
23 | 24 | |
24 | - def getNext(self, byte): | |
25 | + def getNext(self, byte, addFreq=False): | |
26 | + if addFreq: | |
27 | + self.freq += 1 | |
25 | 28 | return self.transitionsMap.get(byte, None) |
26 | 29 | |
27 | 30 | def getRegisterKey(self): |
... | ... | @@ -30,11 +33,11 @@ class State(object): |
30 | 33 | def isAccepting(self): |
31 | 34 | return self.encodedData is not None |
32 | 35 | |
33 | - def tryToRecognize(self, word): | |
36 | + def tryToRecognize(self, word, addFreq=False): | |
34 | 37 | if word: |
35 | - nextState = self.getNext(word[0]) | |
38 | + nextState = self.getNext(word[0], addFreq) | |
36 | 39 | if nextState: |
37 | - return nextState.tryToRecognize(word[1:]) | |
40 | + return nextState.tryToRecognize(word[1:], addFreq) | |
38 | 41 | else: |
39 | 42 | return False |
40 | 43 | else: |
... | ... | @@ -42,7 +45,7 @@ class State(object): |
42 | 45 | |
43 | 46 | def dfs(self, alreadyVisited): |
44 | 47 | if not self in alreadyVisited: |
45 | - for _, state in sorted(self.transitionsMap.iteritems()): | |
48 | + for _, state in sorted(self.transitionsMap.iteritems(), key=lambda (_, state): -state.freq): | |
46 | 49 | for state1 in state.dfs(alreadyVisited): |
47 | 50 | yield state1 |
48 | 51 | alreadyVisited.add(self) |
... | ... |