''' Created on Oct 8, 2013 @author: mlenart ''' import state import register import logging class FSA(object): ''' A finite state automaton ''' def __init__(self, encoder, tagset=None, encodeData=True, encodeWords=True): self.encodeWord = encoder.encodeWord if encodeWords else lambda x: x self.encodeData = encoder.encodeData if encodeData else lambda x: x self.decodeData = encoder.decodeData if encodeData else lambda x: x self.encodedPrevWord = None self.tagset = tagset self.initialState = state.State() self.register = register.Register() self.label2Freq = {} self.n = 0 self.closed = False def tryToRecognize(self, word, addFreq=False): return self.decodeData(self.initialState.tryToRecognize(self.encodeWord(word), addFreq)) def addEntry(self, word, data): assert not self.closed assert data is not None encodedWord = self.encodeWord(word) assert encodedWord > self.encodedPrevWord self._addSorted(encodedWord, self.encodeData(data)) self.encodedPrevWord = encodedWord self.n += 1 # debug if self.n % 10000 == 0: logging.info(word) logging.info(str(self.register.getStatesNum())) # allWords.append(word) for label in encodedWord: self.label2Freq[label] = self.label2Freq.get(label, 0) + 1 def close(self): assert self.n > 0 assert not self.closed self.initialState = self._replaceOrRegister(self.initialState, self.encodedPrevWord) self.encodedPrevWord = None self.closed = True def train(self, trainData): self.label2Freq = {} for idx, word in enumerate(trainData): self.tryToRecognize(word, addFreq=True) for label in self.encodeWord(word): self.label2Freq[label] = self.label2Freq.get(label, 0) + 1 if idx % 100000 == 0: logging.info(str(idx)) def dfs(self): for state in self.initialState.dfs(set()): yield state def getStatesNum(self): return self.register.getStatesNum() def getTransitionsNum(self): res = 0 for s in self.initialState.dfs(set()): res += len(s.transitionsMap) return res def _addSorted(self, encodedWord, data): assert self.encodedPrevWord < encodedWord assert type(data) == bytearray q = self.initialState i = 0 while i <= len(encodedWord) and q.hasNext(encodedWord[i]): q = q.getNext(encodedWord[i]) i += 1 if self.encodedPrevWord and i < len(self.encodedPrevWord): nextState = q.getNext(self.encodedPrevWord[i]) q.setTransition( self.encodedPrevWord[i], self._replaceOrRegister(nextState, self.encodedPrevWord[i+1:])) while i < len(encodedWord): q.setTransition(encodedWord[i], state.State()) q = q.getNext(encodedWord[i]) i += 1 assert q.encodedData is None q.encodedData = data def _replaceOrRegister(self, q, encodedWord): if encodedWord: nextState = q.getNext(encodedWord[0]) q.setTransition( encodedWord[0], self._replaceOrRegister(nextState, encodedWord[1:])) if self.register.containsEquivalentState(q): return self.register.getEquivalentState(q) else: self.register.addState(q) return q def calculateOffsets(self, sizeCounter): self.initialState.calculateOffsets(sizeCounter) # currReverseOffset = 0 # for state in self.initialState.dfs(set()): # currReverseOffset += sizeCounter(state) # state.reverseOffset = currReverseOffset # for state in self.initialState.dfs(set()): # state.offset = currReverseOffset - state.reverseOffset def debug(self): for state in self.initialState.dfs(set()): state.debug()