# -*- coding: utf-8 -*-

from dictionary.models import sort_arguments, sort_positions, sortatributes
from settings import MORFEUSZ2
from copy import deepcopy

def lexicalisation(argument, subj, base, negativity, reference=None):
    b = argument.type
    if b == 'fixed':
        return (get_words(sortatributes(argument)[-1]), [])
    attributes = sortatributes(argument)
    lexicalisation_type = attributes[0].values.all()[0].argument.type
    lexicalisation_parameters = sortatributes(attributes[0].values.all()[0].argument)
    if lexicalisation_type == 'xp': # xp(...)[np/prepnp], ...
       lexicalisation_type = lexicalisation_parameters[0].values.all()[0].argument.type
       lexicalisation_parameters = sortatributes(lexicalisation_parameters[0].values.all()[0].argument)
    if lexicalisation_type == 'np': # np(case), number, nouns, atr
        nps = get_nps(get_case(lexicalisation_parameters[0], subj, negativity), get_number(attributes[1], subj), get_words(attributes[2]), attributes[3])
        return (nps, get_verb(base, get_number(attributes[1], subj), subj))
    elif lexicalisation_type == 'prepnp': #prepnp(prep, case), number, nouns, atr
        prepnps = get_prepnps(get_preposition(lexicalisation_parameters[0]), get_case(lexicalisation_parameters[1], subj, negativity), get_number(attributes[1], subj), get_words(attributes[2]), attributes[3])
        return (prepnps, [])
    elif lexicalisation_type == 'adjp': # adjp(case), number, gender, degree, adjectives, atr
        adjps = get_adjps(get_case(lexicalisation_parameters[0], subj, negativity, reference), get_number(attributes[1], subj, reference), get_gender(attributes[2], reference), get_degree(attributes[3]), get_words(attributes[4]), attributes[5])
        return (adjps, get_verb(base, get_number(attributes[1], subj), subj))
    elif lexicalisation_type == 'prepadjp': #prepadjp(prep, case), number, gender, degree, adjectives, atr
        prepadjps = get_prepadjps(get_preposition(lexicalisation_parameters[0]), get_case(lexicalisation_parameters[1], subj, False, reference), get_number(attributes[1], subj, reference), get_gender(attributes[2], reference), get_degree(attributes[3]), get_words(attributes[4]), attributes[5])
        return (prepadjps, [])
    elif lexicalisation_type == 'infp':
        infps = get_infps(get_aspect(lexicalisation_parameters[0]), get_words(attributes[2]), attributes[4])
        return (infps, [])
    elif lexicalisation_type == 'advp': #advp(type), degree, adverb, atr
        advps = get_advps(get_degree(attributes[1]), get_words(attributes[2]), attributes[3])
        return (advps, [base])
    elif lexicalisation_type == 'nump': # nump(case), num, noun, atr
        numps = get_numps(get_case(lexicalisation_parameters[0], subj, negativity, reference), get_words(attributes[1]), get_words(attributes[2]), attributes[3])
        return (numps, get_verb(base, 'pl', subj))
    elif lexicalisation_type == 'prepnump': # prepnump(prep,case), num, noun, atr
        numps = get_prepnumps(get_preposition(lexicalisation_parameters[0]), get_case(lexicalisation_parameters[1], subj, False, reference), get_words(attributes[1]), get_words(attributes[2]), attributes[3])
        return (numps, get_verb(base, 'pl', subj))
    else:
        return ([], [])
    return ([], [])

def is_subj(categories):
    for cat in categories:
        if cat.category == u'subj':
            return True
    return False

def get_preposition(attribute):
    return attribute.values.all()[0].parameter.type.name

def get_numerals(attribute):
    return get_words(attribute)

def get_words(attribute):
    words = [word.text[1:-1] for word in attribute.values.all()]
    return words

def get_aspect(attribute):
    return attribute.values.all()[0].parameter.type.name

def get_case(attribute, is_subj, negativity, reference=None):
    case = attribute.values.all()[0].parameter.type.name
    if case == u'str':
        if is_subj:
            case = [u'nom']
        elif negativity:
            case = [u'gen']
        else:
            case = [u'acc']
    elif case == u'part':
        case = [u'gen', u'acc']
    elif case == u'agr' and reference is not None:
        _, tag = reference
        base = tag.split(':')[0]
        if base == u'siebie':
            case = [tag.split(':')[1]]
        else:
            case = [tag.split(':')[2]]
    else:
        case = [case]
    return case

def get_number(attribute, is_subj, reference=None):
    number = attribute.values.all()[0].parameter.type.name
    if number == u'_':
        if is_subj:
            number = u'sg'
    elif number == u'agr' and reference is not None:
        _, tag = reference
        base = tag.split(':')[0]
        if base == u'siebie':
            number = u'_'
        else:
            number = tag.split(':')[1]
    elif number == u'agr' and reference is None:
        number = u'sg'
    return number

def get_gender(attribute, reference=None):
    gender = attribute.values.all()[0].parameter.type.name
    if gender == u'_':
        gender = u'n'
    elif gender == u'm':
        gender = u'm1'
    elif gender == u'agr' and reference is not None:
        _, tag = reference
        base = tag.split(':')[0]
        if base == u'siebie':
            gender = u'_'
        else:
            gender = tag.split(':')[3]
    elif gender == u'agr' and reference is None:
        gender = 'm1'
    return gender

def get_degree(attribute):
    degree = attribute.values.all()[0].parameter.type.name
    if degree == u'_':
        degree = u'pos'
    return degree

def get_nps(cases, number, nouns, atr):
    result = []
    for noun in nouns:
        options = [(interp.orth, interp.getTag(MORFEUSZ2)) for interp in MORFEUSZ2.generate(noun.encode('utf8'))]
        options_temp = []
        for case in cases:
            if case != u'_':
                filtered = []
                for option in options:
                    (orth, tag) = option
                    if u':' + case in tag or u'.' + case in tag:
                        filtered.append(option)
                options_temp += filtered
            else:
                options_temp += filtered
        options = options_temp
        if number != u'_':
            filtered = []
            for option in options:
                (orth, tag) = option
                if u':' + number + u':' in tag:
                    filtered.append(option)
            options = filtered
        result += options
    return dependents(atr, result)

def get_prepnps(prep, cases, number, nouns, _atr):
    nps = get_nps(cases, number, nouns, _atr)
    return [prep + ' ' + np for np in nps]

def get_infps(aspect, verbs, atr):
    result = []
    for verb in verbs:
        options = [(interp.orth, interp.getTag(MORFEUSZ2)) for interp in MORFEUSZ2.generate(verb.encode('utf8'))]
        filtered = []
        for option in options:
            (orth, tag) = option
            if u'inf:' in tag:
                filtered.append(option)
        options = filtered
        if aspect != u'_':
            for option in options:
                (orth, tag) = option
                if u':' + aspect + u':' in tag:
                    filtered.append(option)
            options = filtered        
        result += options
    return dependents(atr, result)

def get_adjps(cases, number, gender, degree, adjectives, atr):
    result = []
    for adjective in adjectives:
        options = [(interp.orth, interp.getTag(MORFEUSZ2)) for interp in MORFEUSZ2.generate(adjective.encode('utf8'))]
        filtered = []
        for option in options:
            (orth, tag) = option
            if u'adj:' in tag:
                filtered.append(option)
        options = filtered
        options_temp = []
        for case in cases:
            if case != u'_':
                filtered = []
                for option in options:
                    (orth, tag) = option
                    if u':' + case + u':' in tag:
                        filtered.append(option)
                options_temp += filtered
            else:
                options_temp += options
        options = options_temp
        if number != u'_':
            filtered = []
            for option in options:
                (orth, tag) = option
                if u':' + number + u':' in tag:
                    filtered.append(option)
            options = filtered
        if gender != u'_':
            filtered = []
            for option in options:
                (orth, tag) = option
                if u':' + gender + u':' in tag or  u'.' + gender + u':' in tag or  u':' + gender + u'.' in tag or  u'.' + gender + u'.' in tag:
                    filtered.append(option)
            options = filtered
        if degree != u'_':
            filtered = []
            for option in options:
                (orth, tag) = option
                if u':' + degree in tag:
                    filtered.append(option)
            options = filtered
        result += options
    return dependents(atr, result)

def get_prepadjps(prep, case, number, gender, degree, adjectives, _atr):
     adjps = get_adjps(case, number, gender, degree, adjectives, _atr)
     return [prep + ' ' + adjp for adjp in adjps]

def get_advps(degree, adverbs, atr):
    result = []
    for adverb in adverbs:
        options = [(interp.orth, interp.getTag(MORFEUSZ2)) for interp in MORFEUSZ2.generate(adverb.encode('utf8'))]
        filtered = []
        for option in options:
            (orth, tag) = option
            if u'adv' in tag:
                filtered.append(option)
        options = filtered
        if ':' in tag and degree != u'_':
            filtered = []
            for option in options:
                (orth, tag) = option
                if u':' + degree in tag:
                    filtered.append(option)
            options = filtered
        result += options
    return dependents(atr, result)

def get_numps(cases, numerals, nouns, atr):
    results = []
    nums = []
    for numeral in numerals:
        options = [(interp.orth, interp.getTag(MORFEUSZ2)) for interp in MORFEUSZ2.generate(numeral.encode('utf8'))]
        filtered = []
        for option in options:
            (orth, tag) = option
            if u'num:' in tag:
                filtered.append(option)
        options = filtered
        options_temp = []
        for case in cases:
            if case != u'_':
                filtered = []
                for option in options:
                    (orth, tag) = option
                    if u':' + case + u':' in tag or ':' + case + '.' in tag or '.' + case + '.' in tag:
                        filtered.append(option)
                options_temp += filtered
            else:
                options_temp += options
        nums = options_temp
        if len(nums) == 0:
            return []
        for (num_orth, num_tag) in nums:
            rec = num_tag.split(':')[4]
            for noun in nouns:
                options = [(interp.orth, interp.getTag(MORFEUSZ2)) for interp in MORFEUSZ2.generate(noun.encode('utf8')) if 'subst:' in interp.getTag(MORFEUSZ2)]
                filtered = []
                for option in options:
                    (orth, tag) = option
                    if u':pl:' in tag:
                        filtered.append(option)
                options = filtered
                if rec == 'rec':
                    c = ['gen']
                else:
                    c = cases
                options_temp = []
                for case in c:
                    if case != u'_':
                        filtered = []
                        for option in options:
                            (orth, tag) = option
                            if u':' + case + u':' in tag or ':' + case + '.' in tag or '.' + case + '.' in tag:
                                filtered.append(option)
                        options_temp += filtered
                    else:
                        options_temp += options
                options = options_temp
                for (orth, tag) in options:
                    gender = tag.split(':')[3]
                    if u':' + gender + u':' in num_tag or ':' + gender + '.' in num_tag or '.' + gender + '.' in num_tag:
                        results.append(num_orth + ' ' + orth)
            
    return results #ignoring ambiguos atr for numps

def get_prepnumps(prep, cases, numerals, nouns, atr):
    numps = get_numps(cases, numerals, nouns, atr)
    return [prep + ' ' + nump for nump in numps]

 
def get_verb(inf, number, is_subj):
    if not is_subj:
        return None
    else:
        options = [(interp.orth, interp.getTag(MORFEUSZ2)) for interp in MORFEUSZ2.generate(inf.encode('utf8'))]
        filtered = []
        for option in options:
            (orth, tag) = option
            if u'fin:' in tag and u':' + number + ':' in tag and u':ter:' in tag:
                filtered.append(option)
        options = filtered
        return [orth for orth, _ in options]

def dependents(atr, options):
    if atr.selection_mode.name == u'ratr' or atr.selection_mode.name == u'ratr1':
        result = []
        for option in options:
            result += phrase(option, atr.values.all())
        return result
    else:
       return [orth for orth, _ in options]

def phrase(head, dependents):
    modifiers = {'pre': [], 'post': []}
    for dependent in dependents:
        values = []
        type = None
        for argument in dependent.position.arguments.all():
            if argument.type == u'fixed':
                type = argument.type
            elif argument.type == u'lex':
                type = sortatributes(argument)[0].values.all()[0].argument.type
                value, _ = lexicalisation(argument, False, '', False, head)
                values += value
        if type == 'adjp':
            modifiers['pre'].append(values)
        else:
            modifiers['post'].append(values)
    pre = []
    for permutation in permutations(modifiers['pre']):
        pre += cartesian(permutation)
    pre = [' '.join(words) for words in pre]
    pre = list(set(pre))
    post = []
    for permutation in permutations(modifiers['post']):
        post += cartesian(permutation)
    post = [' '.join(words) for words in post]
    post = list(set(post))
    orth, _ = head
    result = []
    if len(pre) == 0 and len(post) == 0:
        result.append(orth)
    for prefix in pre:
        for suffix in post:
            if prefix == '' and suffix == '':
                if len(pre) == 1 and len(post) == 1:
                    result.append(orth)
            elif prefix == '':
                result.append(orth + ' ' + suffix)
            elif suffix == '':
                result.append(prefix + ' ' + orth)
            else:
                result.append(prefix + ' ' + orth + ' ' + suffix)
    return result
        

def cartesian(llist):
    if len(llist) == 0:
        result = [[]]
    else:
        result = []
        tail = cartesian(llist[1:])
        for element in llist[0]:
            tailcopy = deepcopy(tail)
            for cart in tailcopy:
                cart.insert(0, element)
                result.append(cart)
        result += tail
    return result

def permutations(llist):
    if len(llist) == 0:
        result = [[]]
    else:
        result = []
        perms = permutations(llist[1:])
        for perm in perms:
            for i in range(0, len(perm) + 1):
                permcopy = deepcopy(perm)
                permcopy.insert(i, llist[0])
                result.append(permcopy)
    return result