Blame view

fsabuilder/morfeuszbuilder/fsa/fsa.py 3.85 KB
Michał Lenart authored
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
'''
Created on Oct 8, 2013

@author: mlenart
'''

import state
import register
import logging

class FSA(object):
    '''
    A finite state automaton
    '''
Michał Lenart authored
17
    def __init__(self, encoder, encodeData=True, encodeWords=True):
Michał Lenart authored
18
        self.encodeWord = encoder.encodeWord if encodeWords else lambda x: x
Michał Lenart authored
19
20
        self.encodeData = encoder.encodeData if encodeData else lambda x: x
        self.decodeData = encoder.decodeData if encodeData else lambda x: x
Michał Lenart authored
21
        self.encodedPrevWord = None
Michał Lenart authored
22
Michał Lenart authored
23
#         self.tagset = tagset
Michał Lenart authored
24
25
        self.initialState = state.State()
        self.register = register.Register()
Michał Lenart authored
26
        self.label2Freq = {}
Michał Lenart authored
27
28
        self.n = 0
        self.closed = False
Michał Lenart authored
29
Michał Lenart authored
30
31
    def tryToRecognize(self, word, addFreq=False):
        return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq))
Michał Lenart authored
32
Michał Lenart authored
33
34
35
36
37
38
39
    def addEntry(self, word, data):
        assert not self.closed
        assert data is not None
        encodedWord = self.encodeWord(word)
        assert encodedWord > self.encodedPrevWord
        self._addSorted(encodedWord, self.encodeData(data))
        self.encodedPrevWord = encodedWord
Michał Lenart authored
40
Michał Lenart authored
41
        self.n += 1
Michał Lenart authored
42
Michał Lenart authored
43
        # debug
Michał Lenart authored
44
        if self.n < 10 or (self.n < 10000 and self.n % 1000 == 0) or self.n % 10000 == 0:
Michał Lenart authored
45
            logging.info(u'%d %s' % (self.n, word))
Michał Lenart authored
46
47
48
49
50
51
52
        for label in encodedWord:
            self.label2Freq[label] = self.label2Freq.get(label, 0) + 1

    def close(self):
        assert self.n > 0
        assert not self.closed
        self.initialState = self._replaceOrRegister(self.initialState, self.encodedPrevWord)
Michał Lenart authored
53
        self.encodedPrevWord = None
Michał Lenart authored
54
55
        self.closed = True
Michał Lenart authored
56
    def train(self, trainData):
Michał Lenart authored
57
        self.label2Freq = {}
Michał Lenart authored
58
        for idx, word in enumerate(trainData):
Michał Lenart authored
59
60
61
            self.tryToRecognize(word, addFreq=True)
            for label in self.encodeWord(word):
                self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
Michał Lenart authored
62
63
64
65
66
67
            if idx % 100000 == 0:
                logging.info(str(idx))

    def dfs(self):
        for state in self.initialState.dfs(set()):
            yield state
Michał Lenart authored
68
69
70

    def getStatesNum(self):
        return self.register.getStatesNum()
Michał Lenart authored
71
72
73
74
75
76

    def getTransitionsNum(self):
        res = 0
        for s in self.initialState.dfs(set()):
            res += len(s.transitionsMap)
        return res
Michał Lenart authored
77
78
79

    def _addSorted(self, encodedWord, data):
        assert self.encodedPrevWord < encodedWord
Michał Lenart authored
80
        assert type(data) == bytearray
Michał Lenart authored
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        q = self.initialState
        i = 0
        while i <= len(encodedWord) and q.hasNext(encodedWord[i]):
            q = q.getNext(encodedWord[i])
            i += 1
        if self.encodedPrevWord and i < len(self.encodedPrevWord):
            nextState = q.getNext(self.encodedPrevWord[i])
            q.setTransition(
                                self.encodedPrevWord[i], 
                                self._replaceOrRegister(nextState, self.encodedPrevWord[i+1:]))

        while i < len(encodedWord):
            q.setTransition(encodedWord[i], state.State())
            q = q.getNext(encodedWord[i])
            i += 1

        assert q.encodedData is None
        q.encodedData = data

    def _replaceOrRegister(self, q, encodedWord):
        if encodedWord:
            nextState = q.getNext(encodedWord[0])
            q.setTransition(
                                encodedWord[0], 
                                self._replaceOrRegister(nextState, encodedWord[1:]))

        if self.register.containsEquivalentState(q):
            return self.register.getEquivalentState(q)
        else:
            self.register.addState(q)
            return q

    def calculateOffsets(self, sizeCounter):
Michał Lenart authored
114
        self.initialState.calculateOffsets(sizeCounter)
Michał Lenart authored
115
116
117
118

    def debug(self):
        for state in self.initialState.dfs(set()):
            state.debug()