Blame view

fsabuilder/morfeuszbuilder/fsa/fsa.py 4.19 KB
Michał Lenart authored
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
'''
Created on Oct 8, 2013

@author: mlenart
'''

import state
import register
import logging

class FSA(object):
    '''
    A finite state automaton
    '''
Michał Lenart authored
17
18
    def __init__(self, encoder, tagset=None, encodeData=True, encodeWords=True):
        self.encodeWord = encoder.encodeWord if encodeWords else lambda x: x
Michał Lenart authored
19
20
        self.encodeData = encoder.encodeData if encodeData else lambda x: x
        self.decodeData = encoder.decodeData if encodeData else lambda x: x
Michał Lenart authored
21
        self.encodedPrevWord = None
Michał Lenart authored
22
Michał Lenart authored
23
        self.tagset = tagset
Michał Lenart authored
24
25
        self.initialState = state.State()
        self.register = register.Register()
Michał Lenart authored
26
        self.label2Freq = {}
Michał Lenart authored
27
28
        self.n = 0
        self.closed = False
Michał Lenart authored
29
Michał Lenart authored
30
31
    def tryToRecognize(self, word, addFreq=False):
        return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq))
Michał Lenart authored
32
Michał Lenart authored
33
34
35
36
37
38
39
    def addEntry(self, word, data):
        assert not self.closed
        assert data is not None
        encodedWord = self.encodeWord(word)
        assert encodedWord > self.encodedPrevWord
        self._addSorted(encodedWord, self.encodeData(data))
        self.encodedPrevWord = encodedWord
Michał Lenart authored
40
Michał Lenart authored
41
        self.n += 1
Michał Lenart authored
42
Michał Lenart authored
43
        # debug
Michał Lenart authored
44
        if self.n % 10000 == 0:
Michał Lenart authored
45
46
47
48
49
50
51
52
53
54
            logging.info(word)
            logging.info(str(self.register.getStatesNum()))
    #             allWords.append(word)
        for label in encodedWord:
            self.label2Freq[label] = self.label2Freq.get(label, 0) + 1

    def close(self):
        assert self.n > 0
        assert not self.closed
        self.initialState = self._replaceOrRegister(self.initialState, self.encodedPrevWord)
Michał Lenart authored
55
        self.encodedPrevWord = None
Michał Lenart authored
56
57
        self.closed = True
Michał Lenart authored
58
    def train(self, trainData):
Michał Lenart authored
59
        self.label2Freq = {}
Michał Lenart authored
60
        for idx, word in enumerate(trainData):
Michał Lenart authored
61
62
63
            self.tryToRecognize(word, addFreq=True)
            for label in self.encodeWord(word):
                self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
Michał Lenart authored
64
65
66
67
68
69
            if idx % 100000 == 0:
                logging.info(str(idx))

    def dfs(self):
        for state in self.initialState.dfs(set()):
            yield state
Michał Lenart authored
70
71
72

    def getStatesNum(self):
        return self.register.getStatesNum()
Michał Lenart authored
73
74
75
76
77
78

    def getTransitionsNum(self):
        res = 0
        for s in self.initialState.dfs(set()):
            res += len(s.transitionsMap)
        return res
Michał Lenart authored
79
80
81

    def _addSorted(self, encodedWord, data):
        assert self.encodedPrevWord < encodedWord
Michał Lenart authored
82
        assert type(data) == bytearray
Michał Lenart authored
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        q = self.initialState
        i = 0
        while i <= len(encodedWord) and q.hasNext(encodedWord[i]):
            q = q.getNext(encodedWord[i])
            i += 1
        if self.encodedPrevWord and i < len(self.encodedPrevWord):
            nextState = q.getNext(self.encodedPrevWord[i])
            q.setTransition(
                                self.encodedPrevWord[i], 
                                self._replaceOrRegister(nextState, self.encodedPrevWord[i+1:]))

        while i < len(encodedWord):
            q.setTransition(encodedWord[i], state.State())
            q = q.getNext(encodedWord[i])
            i += 1

        assert q.encodedData is None
        q.encodedData = data

    def _replaceOrRegister(self, q, encodedWord):
        if encodedWord:
            nextState = q.getNext(encodedWord[0])
            q.setTransition(
                                encodedWord[0], 
                                self._replaceOrRegister(nextState, encodedWord[1:]))

        if self.register.containsEquivalentState(q):
            return self.register.getEquivalentState(q)
        else:
            self.register.addState(q)
            return q

    def calculateOffsets(self, sizeCounter):
Michał Lenart authored
116
117
118
119
120
121
122
        self.initialState.calculateOffsets(sizeCounter)
#         currReverseOffset = 0
#         for state in self.initialState.dfs(set()):
#             currReverseOffset += sizeCounter(state)
#             state.reverseOffset = currReverseOffset
#         for state in self.initialState.dfs(set()):
#             state.offset = currReverseOffset - state.reverseOffset
Michał Lenart authored
123
124
125
126

    def debug(self):
        for state in self.initialState.dfs(set()):
            state.debug()