Blame view

fsabuilder/morfeuszbuilder/fsa/fsa.py 3.96 KB
Michał Lenart authored
1
2
3
4
5
6
7
8
9
'''
Created on Oct 8, 2013

@author: mlenart
'''

import state
import register
import logging
Michał Lenart authored
10
from morfeuszbuilder.utils import exceptions
Michał Lenart authored
11
12
13
14
15
16
17

class FSA(object):
    '''
    A finite state automaton
    '''
Michał Lenart authored
18
    def __init__(self, encoder, encodeData=True, encodeWords=True):
Michał Lenart authored
19
        self.encodeWord = encoder.encodeWord if encodeWords else lambda x: x
Michał Lenart authored
20
21
        self.encodeData = encoder.encodeData if encodeData else lambda x: x
        self.decodeData = encoder.decodeData if encodeData else lambda x: x
Michał Lenart authored
22
        self.encodedPrevWord = None
Michał Lenart authored
23
Michał Lenart authored
24
#         self.tagset = tagset
Michał Lenart authored
25
26
        self.initialState = state.State()
        self.register = register.Register()
Michał Lenart authored
27
        self.label2Freq = {}
Michał Lenart authored
28
29
        self.n = 0
        self.closed = False
Michał Lenart authored
30
Michał Lenart authored
31
32
    def tryToRecognize(self, word, addFreq=False):
        return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq))
Michał Lenart authored
33
Michał Lenart authored
34
35
36
37
38
39
40
    def addEntry(self, word, data):
        assert not self.closed
        assert data is not None
        encodedWord = self.encodeWord(word)
        assert encodedWord > self.encodedPrevWord
        self._addSorted(encodedWord, self.encodeData(data))
        self.encodedPrevWord = encodedWord
Michał Lenart authored
41
Michał Lenart authored
42
        self.n += 1
Michał Lenart authored
43
Michał Lenart authored
44
        # debug
Michał Lenart authored
45
        if self.n < 10 or (self.n < 10000 and self.n % 1000 == 0) or self.n % 10000 == 0:
Michał Lenart authored
46
            logging.info(u'%d %s' % (self.n, word))
Michał Lenart authored
47
48
49
50
        for label in encodedWord:
            self.label2Freq[label] = self.label2Freq.get(label, 0) + 1

    def close(self):
Michał Lenart authored
51
52
        if self.n == 0:
            raise exceptions.FSABuilderException('empty input')
Michał Lenart authored
53
54
        assert not self.closed
        self.initialState = self._replaceOrRegister(self.initialState, self.encodedPrevWord)
Michał Lenart authored
55
        self.encodedPrevWord = None
Michał Lenart authored
56
57
        self.closed = True
Michał Lenart authored
58
    def train(self, trainData):
Michał Lenart authored
59
        self.label2Freq = {}
Michał Lenart authored
60
        for idx, word in enumerate(trainData):
Michał Lenart authored
61
62
63
            self.tryToRecognize(word, addFreq=True)
            for label in self.encodeWord(word):
                self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
Michał Lenart authored
64
65
66
67
68
69
            if idx % 100000 == 0:
                logging.info(str(idx))

    def dfs(self):
        for state in self.initialState.dfs(set()):
            yield state
Michał Lenart authored
70
71
72

    def getStatesNum(self):
        return self.register.getStatesNum()
Michał Lenart authored
73
74
75
76
77
78

    def getTransitionsNum(self):
        res = 0
        for s in self.initialState.dfs(set()):
            res += len(s.transitionsMap)
        return res
Michał Lenart authored
79
80
81

    def _addSorted(self, encodedWord, data):
        assert self.encodedPrevWord < encodedWord
Michał Lenart authored
82
        assert type(data) == bytearray
Michał Lenart authored
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        q = self.initialState
        i = 0
        while i <= len(encodedWord) and q.hasNext(encodedWord[i]):
            q = q.getNext(encodedWord[i])
            i += 1
        if self.encodedPrevWord and i < len(self.encodedPrevWord):
            nextState = q.getNext(self.encodedPrevWord[i])
            q.setTransition(
                                self.encodedPrevWord[i], 
                                self._replaceOrRegister(nextState, self.encodedPrevWord[i+1:]))

        while i < len(encodedWord):
            q.setTransition(encodedWord[i], state.State())
            q = q.getNext(encodedWord[i])
            i += 1

        assert q.encodedData is None
        q.encodedData = data

    def _replaceOrRegister(self, q, encodedWord):
        if encodedWord:
            nextState = q.getNext(encodedWord[0])
            q.setTransition(
                                encodedWord[0], 
                                self._replaceOrRegister(nextState, encodedWord[1:]))

        if self.register.containsEquivalentState(q):
            return self.register.getEquivalentState(q)
        else:
            self.register.addState(q)
            return q

    def calculateOffsets(self, sizeCounter):
Michał Lenart authored
116
        self.initialState.calculateOffsets(sizeCounter)
Michał Lenart authored
117
118
119
120

    def debug(self):
        for state in self.initialState.dfs(set()):
            state.debug()