Blame view

fsabuilder/fsa/fsa.py 3.99 KB
Michał Lenart authored
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
'''
Created on Oct 8, 2013

@author: mlenart
'''

import state
import register
import logging

class FSA(object):
    '''
    A finite state automaton
    '''
Michał Lenart authored
17
    def __init__(self, encoder, tagset=None):
Michał Lenart authored
18
19
20
21
        self.encodeWord = encoder.encodeWord
        self.encodeData = encoder.encodeData
        self.decodeData = encoder.decodeData
        self.encodedPrevWord = None
Michał Lenart authored
22
        self.tagset = tagset
Michał Lenart authored
23
24
        self.initialState = state.State()
        self.register = register.Register()
Michał Lenart authored
25
        self.label2Freq = {}
Michał Lenart authored
26
Michał Lenart authored
27
28
    def tryToRecognize(self, word, addFreq=False):
        return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq))
Michał Lenart authored
29
30
31

    def feed(self, input):
Michał Lenart authored
32
#         allWords = []
Michał Lenart authored
33
34
35
        for n, (word, data) in enumerate(input, start=1):
            assert data is not None
            encodedWord = self.encodeWord(word)
Michał Lenart authored
36
            assert encodedWord > self.encodedPrevWord
Michał Lenart authored
37
38
39
            if encodedWord > self.encodedPrevWord:
                self._addSorted(encodedWord, self.encodeData(data))
                self.encodedPrevWord = encodedWord
Michał Lenart authored
40
#                 assert self.tryToRecognize(word) == data
Michał Lenart authored
41
42
                if n % 10000 == 0:
                    logging.info(word)
Michał Lenart authored
43
                    logging.info(str(self.register.getStatesNum()))
Michał Lenart authored
44
45
46
    #             allWords.append(word)
                for label in encodedWord:
                    self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
Michał Lenart authored
47
48
49

        self.initialState = self._replaceOrRegister(self.initialState, self.encodeWord(word))
        self.encodedPrevWord = None
Michał Lenart authored
50
Michał Lenart authored
51
52
53
54
#         for w in allWords:
#             self.tryToRecognize(w, True)

    def train(self, trainData):
Michał Lenart authored
55
        self.label2Freq = {}
Michał Lenart authored
56
        for idx, word in enumerate(trainData):
Michał Lenart authored
57
58
59
            self.tryToRecognize(word, addFreq=True)
            for label in self.encodeWord(word):
                self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
Michał Lenart authored
60
61
62
63
64
65
            if idx % 100000 == 0:
                logging.info(str(idx))

    def dfs(self):
        for state in self.initialState.dfs(set()):
            yield state
Michał Lenart authored
66
67
68

    def getStatesNum(self):
        return self.register.getStatesNum()
Michał Lenart authored
69
70
71
72
73
74

    def getTransitionsNum(self):
        res = 0
        for s in self.initialState.dfs(set()):
            res += len(s.transitionsMap)
        return res
Michał Lenart authored
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

    def _addSorted(self, encodedWord, data):
        assert self.encodedPrevWord < encodedWord
        q = self.initialState
        i = 0
        while i <= len(encodedWord) and q.hasNext(encodedWord[i]):
            q = q.getNext(encodedWord[i])
            i += 1
        if self.encodedPrevWord and i < len(self.encodedPrevWord):
            nextState = q.getNext(self.encodedPrevWord[i])
            q.setTransition(
                                self.encodedPrevWord[i], 
                                self._replaceOrRegister(nextState, self.encodedPrevWord[i+1:]))

        while i < len(encodedWord):
            q.setTransition(encodedWord[i], state.State())
            q = q.getNext(encodedWord[i])
            i += 1

        assert q.encodedData is None
#         print q, encodedData
        q.encodedData = data

    def _replaceOrRegister(self, q, encodedWord):
        if encodedWord:
            nextState = q.getNext(encodedWord[0])
            q.setTransition(
                                encodedWord[0], 
                                self._replaceOrRegister(nextState, encodedWord[1:]))

        if self.register.containsEquivalentState(q):
            return self.register.getEquivalentState(q)
        else:
            self.register.addState(q)
            return q

    def calculateOffsets(self, sizeCounter):
        currReverseOffset = 0
        for state in self.initialState.dfs(set()):
            currReverseOffset += sizeCounter(state)
            state.reverseOffset = currReverseOffset
        for state in self.initialState.dfs(set()):
            state.offset = currReverseOffset - state.reverseOffset