resolve.py 2.54 KB
import numpy

from conf import NEURAL_MODEL, THRESHOLD
from corneferencer.resolvers.vectors import get_pair_vector


# incremental resolve algorithm
def incremental(text):
    last_set_id = 0
    for i, ana in enumerate(text.mentions):
        if i > 0:
            best_prediction = 0.0
            best_ante = None
            for ante in text.mentions[:i]:
                pair_vec = get_pair_vector(ante, ana)
                sample = numpy.asarray([pair_vec], dtype=numpy.float32)
                prediction = NEURAL_MODEL.predict(sample)[0]
                if prediction > THRESHOLD and prediction >= best_prediction:
                    best_prediction = prediction
                    best_ante = ante
            if best_ante is not None:
                if best_ante.set:
                    ana.set = best_ante.set
                else:
                    str_set_id = 'set_%d' % last_set_id
                    best_ante.set = str_set_id
                    ana.set = str_set_id
                    last_set_id += 1


# entity based resolve algorithm
def entity_based(text):
    sets = []
    last_set_id = 0
    for i, ana in enumerate(text.mentions):
        if i > 0:
            best_fit = get_best_set(sets, ana)
            if best_fit is not None:
                ana.set = best_fit['set_id']
                best_fit['mentions'].append(ana)
            else:
                str_set_id = 'set_%d' % last_set_id
                sets.append({'set_id': str_set_id,
                             'mentions': [ana]})
                ana.set = str_set_id
                last_set_id += 1
        else:
            str_set_id = 'set_%d' % last_set_id
            sets.append({'set_id': str_set_id,
                         'mentions': [ana]})
            ana.set = str_set_id
            last_set_id += 1

    remove_singletons(sets)


def get_best_set(sets, ana):
    best_prediction = 0.0
    best_set = None
    for s in sets:
        accuracy = predict_set(s['mentions'], ana)
        if accuracy > THRESHOLD and accuracy >= best_prediction:
            best_prediction = accuracy
            best_set = s
    return best_set


def predict_set(mentions, ana):
    prediction_sum = 0.0
    for mnt in mentions:
        pair_vec = get_pair_vector(mnt, ana)
        sample = numpy.asarray([pair_vec], dtype=numpy.float32)
        prediction = NEURAL_MODEL.predict(sample)[0]
        prediction_sum += prediction
    return prediction_sum / float(len(mentions))


def remove_singletons(sets):
    for s in sets:
        if len(s['mentions']) == 1:
            s['mentions'][0].set = ''