'''
Created on Oct 8, 2013

@author: mlenart
'''

import state
import register
import logging

class FSA(object):
    '''
    A finite state automaton
    '''


    def __init__(self, encoder, tagset=None):
        self.encodeWord = encoder.encodeWord
        self.encodeData = encoder.encodeData
        self.decodeData = encoder.decodeData
        self.encodedPrevWord = None
        self.tagset = tagset
        self.initialState = state.State()
        self.register = register.Register()
        self.label2Freq = {}
    
    def tryToRecognize(self, word, addFreq=False):
        return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq))
    
    def feed(self, input):
        
#         allWords = []
        for n, (word, data) in enumerate(input, start=1):
            assert data is not None
            encodedWord = self.encodeWord(word)
            assert encodedWord > self.encodedPrevWord
            if encodedWord > self.encodedPrevWord:
                self._addSorted(encodedWord, self.encodeData(data))
                self.encodedPrevWord = encodedWord
#                 assert self.tryToRecognize(word) == data
                if n % 10000 == 0:
                    logging.info(word)
                    logging.info(str(self.register.getStatesNum()))
    #             allWords.append(word)
                for label in encodedWord:
                    self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
        
        self.initialState = self._replaceOrRegister(self.initialState, self.encodeWord(word))
        self.encodedPrevWord = None
        
#         for w in allWords:
#             self.tryToRecognize(w, True)
    
    def train(self, trainData):
        self.label2Freq = {}
        for idx, word in enumerate(trainData):
            self.tryToRecognize(word, addFreq=True)
            for label in self.encodeWord(word):
                self.label2Freq[label] = self.label2Freq.get(label, 0) + 1
            if idx % 100000 == 0:
                logging.info(str(idx))
    
    def dfs(self):
        for state in self.initialState.dfs(set()):
            yield state
    
    def getStatesNum(self):
        return self.register.getStatesNum()
    
    def getTransitionsNum(self):
        res = 0
        for s in self.initialState.dfs(set()):
            res += len(s.transitionsMap)
        return res
        
    def _addSorted(self, encodedWord, data):
        assert self.encodedPrevWord < encodedWord
        q = self.initialState
        i = 0
        while i <= len(encodedWord) and q.hasNext(encodedWord[i]):
            q = q.getNext(encodedWord[i])
            i += 1
        if self.encodedPrevWord and i < len(self.encodedPrevWord):
            nextState = q.getNext(self.encodedPrevWord[i])
            q.setTransition(
                                self.encodedPrevWord[i], 
                                self._replaceOrRegister(nextState, self.encodedPrevWord[i+1:]))
        
        while i < len(encodedWord):
            q.setTransition(encodedWord[i], state.State())
            q = q.getNext(encodedWord[i])
            i += 1
        
        assert q.encodedData is None
#         print q, encodedData
        q.encodedData = data
    
    def _replaceOrRegister(self, q, encodedWord):
        if encodedWord:
            nextState = q.getNext(encodedWord[0])
            q.setTransition(
                                encodedWord[0], 
                                self._replaceOrRegister(nextState, encodedWord[1:]))
        
        if self.register.containsEquivalentState(q):
            return self.register.getEquivalentState(q)
        else:
            self.register.addState(q)
            return q
    
    def calculateOffsets(self, sizeCounter):
        currReverseOffset = 0
        for state in self.initialState.dfs(set()):
            currReverseOffset += sizeCounter(state)
            state.reverseOffset = currReverseOffset
        for state in self.initialState.dfs(set()):
            state.offset = currReverseOffset - state.reverseOffset