resolve.py 5.18 KB
import numpy

from conf import NEURAL_MODEL
from corneferencer.resolvers import features
from corneferencer.resolvers.vectors import get_pair_features, get_pair_vector


def siamese(text, threshold):
    last_set_id = 0
    for i, ana in enumerate(text.mentions):
        if i > 0:
            for ante in reversed(text.mentions[:i]):
                if not features.pair_intersect(ante, ana):
                    pair_features = get_pair_features(ante, ana)

                    ante_vec = []
                    ante_vec.extend(ante.features)
                    ante_vec.extend(pair_features)
                    ante_sample = numpy.asarray([ante_vec], dtype=numpy.float32)

                    ana_vec = []
                    ana_vec.extend(ana.features)
                    ana_vec.extend(pair_features)
                    ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32)

                    prediction = NEURAL_MODEL.predict([ante_sample, ana_sample])[0]

                    if prediction < threshold:
                        if ante.set:
                            ana.set = ante.set
                        else:
                            str_set_id = 'set_%d' % last_set_id
                            ante.set = str_set_id
                            ana.set = str_set_id
                            last_set_id += 1
                        break


# incremental resolve algorithm
def incremental(text, threshold):
    last_set_id = 0
    for i, ana in enumerate(text.mentions):
        if i > 0:
            best_prediction = 0.0
            best_ante = None
            for ante in text.mentions[:i]:
                if not features.pair_intersect(ante, ana):
                    pair_vec = get_pair_vector(ante, ana)
                    sample = numpy.asarray([pair_vec], dtype=numpy.float32)
                    prediction = NEURAL_MODEL.predict(sample)[0]
                    if prediction > threshold and prediction >= best_prediction:
                        best_prediction = prediction
                        best_ante = ante
            if best_ante is not None:
                if best_ante.set:
                    ana.set = best_ante.set
                else:
                    str_set_id = 'set_%d' % last_set_id
                    best_ante.set = str_set_id
                    ana.set = str_set_id
                    last_set_id += 1


# entity based resolve algorithm
def entity_based(text, threshold):
    sets = []
    last_set_id = 0
    for i, ana in enumerate(text.mentions):
        if i > 0:
            best_fit = get_best_set(sets, ana, threshold)
            if best_fit is not None:
                ana.set = best_fit['set_id']
                best_fit['mentions'].append(ana)
            else:
                str_set_id = 'set_%d' % last_set_id
                sets.append({'set_id': str_set_id,
                             'mentions': [ana]})
                ana.set = str_set_id
                last_set_id += 1
        else:
            str_set_id = 'set_%d' % last_set_id
            sets.append({'set_id': str_set_id,
                         'mentions': [ana]})
            ana.set = str_set_id
            last_set_id += 1

    remove_singletons(sets)


def get_best_set(sets, ana, threshold):
    best_prediction = 0.0
    best_set = None
    for s in sets:
        accuracy = predict_set(s['mentions'], ana)
        if accuracy > threshold and accuracy >= best_prediction:
            best_prediction = accuracy
            best_set = s
    return best_set


def predict_set(mentions, ana):
    prediction_sum = 0.0
    for mnt in mentions:
        prediction = 0.0
        if not features.pair_intersect(mnt, ana):
            pair_vec = get_pair_vector(mnt, ana)
            sample = numpy.asarray([pair_vec], dtype=numpy.float32)
            prediction = NEURAL_MODEL.predict(sample)[0]
        prediction_sum += prediction
    return prediction_sum / float(len(mentions))


def remove_singletons(sets):
    for s in sets:
        if len(s['mentions']) == 1:
            s['mentions'][0].set = ''


# closest resolve algorithm
def closest(text, threshold):
    last_set_id = 0
    for i, ana in enumerate(text.mentions):
        if i > 0:
            for ante in reversed(text.mentions[:i]):
                if not features.pair_intersect(ante, ana):
                    pair_features = get_pair_features(ante, ana)

                    ante_vec = []
                    ante_vec.extend(ante.features)
                    ante_vec.extend(pair_features)
                    ante_sample = numpy.asarray([ante_vec], dtype=numpy.float32)

                    ana_vec = []
                    ana_vec.extend(ana.features)
                    ana_vec.extend(pair_features)
                    ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32)

                    prediction = NEURAL_MODEL.predict([ante_sample, ana_sample])[0]

                    if prediction > threshold:
                        if ante.set:
                            ana.set = ante.set
                        else:
                            str_set_id = 'set_%d' % last_set_id
                            ante.set = str_set_id
                            ana.set = str_set_id
                            last_set_id += 1
                        break