resolve.py 2.46 KB
from conf import NEURAL_MODEL, THRESHOLD
from corneferencer.resolvers.vectors import create_pair_vector


# incremental resolve algorithm
def incremental(text):
    last_set_id = 1
    for i, ana in enumerate(text.mentions):
        if i > 0:
            best_prediction = 0.0
            best_ante = None
            for ante in text.mentions[:i:-1]:
                pair_vec = create_pair_vector(ante, ana)
                prediction = NEURAL_MODEL.predict(pair_vec)
                accuracy = prediction[0]
                if accuracy > THRESHOLD and accuracy > best_prediction:
                    best_prediction = accuracy
                    best_ante = ante
            if best_ante is not None:
                if best_ante.set:
                    ana.set = best_ante.set
                else:
                    str_set_id = 'set_%d' % last_set_id
                    best_ante.set = str_set_id
                    ana.set = str_set_id
                    last_set_id += 1


# entity based resolve algorithm
def entity_based(text):
    sets = []
    last_set_id = 1
    for i, ana in enumerate(text.mentions):
        if i > 0:
            best_fit = get_best_set(sets, ana)
            if best_fit is not None:
                ana.set = best_fit['set_id']
                best_fit['mentions'].append(ana)
            else:
                str_set_id = 'set_%d' % last_set_id
                sets.append({'set_id': str_set_id,
                             'mentions': [ana]})
                ana.set = str_set_id
                last_set_id += 1
        else:
            str_set_id = 'set_%d' % last_set_id
            sets.append({'set_id': str_set_id,
                         'mentions': [ana]})
            ana.set = str_set_id
            last_set_id += 1

    remove_singletons(sets)


def get_best_set(sets, ana):
    best_prediction = 0.0
    best_set = None
    for s in sets:
        accuracy = predict_set(s['mentions'], ana)
        if accuracy > THRESHOLD and accuracy >= best_prediction:
            best_prediction = accuracy
            best_set = s
    return best_set


def predict_set(mentions, ana):
    accuracy_sum = 0.0
    for mnt in mentions:
        pair_vec = create_pair_vector(mnt, ana)
        prediction = NEURAL_MODEL.predict(pair_vec)
        accuracy = prediction[0]
        accuracy_sum += accuracy
    return accuracy_sum / float(len(mentions))


def remove_singletons(sets):
    for s in sets:
        if len(s['mentions']) == 1:
            s['mentions'][0].set = ''