diff --git a/conf.py b/conf.py new file mode 100644 index 0000000..c95d924 --- /dev/null +++ b/conf.py @@ -0,0 +1,23 @@ +import os + +from gensim.models.word2vec import Word2Vec + +from corneferencer.utils import initialize_neural_model + + +CONTEXT = 5 +THRESHOLD = 0.5 +RANDOM_WORD_VECTORS = True +W2V_SIZE = 50 +W2V_MODEL_NAME = 'w2v_allwiki_nkjpfull_50.model' + +NUMBER_OF_FEATURES = 1126 +NEURAL_MODEL_NAME = 'weights_2017_05_10.h5' + + +# do not change that +W2V_MODEL_PATH = os.path.join(os.path.dirname(__file__), 'models', W2V_MODEL_NAME) +W2V_MODEL = Word2Vec.load(W2V_MODEL_PATH) + +NEURAL_MODEL_PATH = os.path.join(os.path.dirname(__file__), 'models', NEURAL_MODEL_NAME) +NEURAL_MODEL = initialize_neural_model(NUMBER_OF_FEATURES) diff --git a/corneferencer/entities.py b/corneferencer/entities.py new file mode 100644 index 0000000..130133d --- /dev/null +++ b/corneferencer/entities.py @@ -0,0 +1,39 @@ +from corneferencer.resolvers.vectors import get_mention_features + + +class Text: + + def __init__(self, text_id): + self.__id = text_id + self.mentions = [] + + def get_mention_set(self, mnt_id): + for mnt in self.mentions: + if mnt.id == mnt_id: + return mnt.set + return None + + +class Mention: + + def __init__(self, mnt_id, text, lemmatized_text, words, span, + head_orth, head_base, dominant, node, prec_context, + follow_context, sentence, position_in_mentions, + start_in_words, end_in_words): + self.id = mnt_id + self.set = '' + self.text = text + self.lemmatized_text = lemmatized_text + self.words = words + self.span = span + self.head_orth = head_orth + self.head_base = head_base + self.dominant = dominant + self.node = node + self.prec_context = prec_context + self.follow_context = follow_context + self.sentence = sentence + self.position_in_mentions = position_in_mentions + self.start_in_words = start_in_words + self.end_in_words = end_in_words + self.features = get_mention_features(self) diff --git a/corneferencer/core.py b/corneferencer/inout/__init__.py index e69de29..e69de29 100644 --- a/corneferencer/core.py +++ b/corneferencer/inout/__init__.py diff --git a/corneferencer/inout/constants.py b/corneferencer/inout/constants.py new file mode 100644 index 0000000..73ff001 --- /dev/null +++ b/corneferencer/inout/constants.py @@ -0,0 +1 @@ +INPUT_FORMATS = ['mmax'] diff --git a/corneferencer/inout/mmax.py b/corneferencer/inout/mmax.py new file mode 100644 index 0000000..5607b54 --- /dev/null +++ b/corneferencer/inout/mmax.py @@ -0,0 +1,315 @@ +import os +import shutil + +from lxml import etree + +from conf import CONTEXT +from corneferencer.entities import Mention, Text + + +def read(inpath): + textname = os.path.splitext(os.path.basename(inpath))[0] + textdir = os.path.dirname(inpath) + + mentions_path = os.path.join(textdir, '%s_mentions.xml' % textname) + words_path = os.path.join(textdir, '%s_words.xml' % textname) + + text = Text(textname) + mentions = read_mentions(mentions_path, words_path) + text.mentions = mentions + return text + + +def read_mentions(mentions_path, words_path): + mentions = [] + mentions_tree = etree.parse(mentions_path) + markables = mentions_tree.xpath("//ns:markable", + namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) + words = get_words(words_path) + + for idx, markable in enumerate(markables): + span = markable.attrib['span'] + + dominant = '' + if 'dominant' in markable.attrib: + dominant = markable.attrib['dominant'] + + head_orth = markable.attrib['mention_head'] + mention_words = span_to_words(span, words) + + (prec_context, follow_context, sentence, + mnt_start_position, mnt_end_position) = get_context(mention_words, words) + + head_base = get_head_base(head_orth, mention_words) + mention = Mention(mnt_id=markable.attrib['id'], + text=span_to_text(span, words, 'orth'), + lemmatized_text=span_to_text(span, words, 'base'), + words=mention_words, + span=span, + head_orth=head_orth, + head_base=head_base, + dominant=dominant, + node=markable, + prec_context=prec_context, + follow_context=follow_context, + sentence=sentence, + position_in_mentions=idx, + start_in_words=mnt_start_position, + end_in_words=mnt_end_position) + mentions.append(mention) + + return mentions + + +def get_words(filepath): + tree = etree.parse(filepath) + words = [] + for word in tree.xpath("//word"): + hasnps = False + if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true': + hasnps = True + lastinsent = False + if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': + lastinsent = True + words.append({'id': word.attrib['id'], + 'orth': word.text, + 'base': word.attrib['base'], + 'hasnps': hasnps, + 'lastinsent': lastinsent, + 'ctag': word.attrib['ctag']}) + return words + + +def span_to_words(span, words): + fragments = span.split(',') + mention_parts = [] + for fragment in fragments: + mention_parts.extend(fragment_to_words(fragment, words)) + return mention_parts + + +def fragment_to_words(fragment, words): + mention_parts = [] + if '..' in fragment: + mention_parts.extend(get_multiword(fragment, words)) + else: + mention_parts.extend(get_word(fragment, words)) + return mention_parts + + +def get_multiword(fragment, words): + mention_parts = [] + boundaries = fragment.split('..') + start_id = boundaries[0] + end_id = boundaries[1] + in_string = False + for word in words: + if word['id'] == start_id: + in_string = True + if in_string and not word_to_ignore(word): + mention_parts.append(word) + if word['id'] == end_id: + break + return mention_parts + + +def get_word(word_id, words): + for word in words: + if word['id'] == word_id: + if not word_to_ignore(word): + return [word] + else: + return [] + return [] + + +def word_to_ignore(word): + if word['ctag'] == 'interp': + return True + return False + + +def get_context(mention_words, words): + prec_context = [] + follow_context = [] + sentence = [] + mnt_start_position = -1 + mnt_end_position = -1 + first_word = mention_words[0] + last_word = mention_words[-1] + for idx, word in enumerate(words): + if word['id'] == first_word['id']: + prec_context = get_prec_context(idx, words) + mnt_start_position = get_mention_start(first_word, words) + if word['id'] == last_word['id']: + follow_context = get_follow_context(idx, words) + sentence = get_sentence(idx, words) + mnt_end_position = get_mention_end(last_word, words) + break + return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position + + +def get_prec_context(mention_start, words): + context = [] + context_start = mention_start - 1 + while context_start >= 0: + if not word_to_ignore(words[context_start]): + context.append(words[context_start]) + if len(context) == CONTEXT: + break + context_start -= 1 + context.reverse() + return context + + +def get_mention_start(first_word, words): + start = 0 + for word in words: + if not word_to_ignore(word): + start += 1 + if word['id'] == first_word['id']: + break + return start + + +def get_mention_end(last_word, words): + end = 0 + for word in words: + if not word_to_ignore(word): + end += 1 + if word['id'] == last_word['id']: + break + return end + + +def get_follow_context(mention_end, words): + context = [] + context_end = mention_end + 1 + while context_end < len(words): + if not word_to_ignore(words[context_end]): + context.append(words[context_end]) + if len(context) == CONTEXT: + break + context_end += 1 + return context + + +def get_sentence(word_idx, words): + sentence_start = get_sentence_start(words, word_idx) + sentence_end = get_sentence_end(words, word_idx) + sentence = [word for word in words[sentence_start:sentence_end + 1] if not word_to_ignore(word)] + return sentence + + +def get_sentence_start(words, word_idx): + search_start = word_idx + while word_idx >= 0: + if words[word_idx]['lastinsent'] and search_start != word_idx: + return word_idx + 1 + word_idx -= 1 + return 0 + + +def get_sentence_end(words, word_idx): + while word_idx < len(words): + if words[word_idx]['lastinsent']: + return word_idx + word_idx += 1 + return len(words) - 1 + + +def get_head_base(head_orth, words): + for word in words: + if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth: + return word['base'] + return None + + +def span_to_text(span, words, form): + fragments = span.split(',') + mention_parts = [] + for fragment in fragments: + mention_parts.append(fragment_to_text(fragment, words, form)) + return u' [...] '.join(mention_parts) + + +def fragment_to_text(fragment, words, form): + if '..' in fragment: + text = get_multiword_text(fragment, words, form) + else: + text = get_one_word_text(fragment, words, form) + return text + + +def get_multiword_text(fragment, words, form): + mention_parts = [] + boundaries = fragment.split('..') + start_id = boundaries[0] + end_id = boundaries[1] + in_string = False + for word in words: + if word['id'] == start_id: + in_string = True + if in_string and not word_to_ignore(word): + mention_parts.append(word) + if word['id'] == end_id: + break + return to_text(mention_parts, form) + + +def to_text(words, form): + text = '' + for idx, word in enumerate(words): + if word['hasnps'] or idx == 0: + text += word[form] + else: + text += u' %s' % word[form] + return text + + +def get_one_word_text(word_id, words, form): + this_word = next(word for word in words if word['id'] == word_id) + return this_word[form] + + +def write(inpath, outpath, text): + textname = os.path.splitext(os.path.basename(inpath))[0] + intextdir = os.path.dirname(inpath) + outtextdir = os.path.dirname(outpath) + + in_mmax_path = os.path.join(intextdir, '%s.mmax' % textname) + out_mmax_path = os.path.join(outtextdir, '%s.mmax' % textname) + copy_mmax(in_mmax_path, out_mmax_path) + + in_words_path = os.path.join(intextdir, '%s_words.xml' % textname) + out_words_path = os.path.join(outtextdir, '%s_words.xml' % textname) + copy_words(in_words_path, out_words_path) + + in_mentions_path = os.path.join(intextdir, '%s_mentions.xml' % textname) + out_mentions_path = os.path.join(outtextdir, '%s_mentions.xml' % textname) + write_mentions(in_mentions_path, out_mentions_path, text) + + +def copy_mmax(src, dest): + shutil.copyfile(src, dest) + + +def copy_words(src, dest): + shutil.copyfile(src, dest) + + +def write_mentions(inpath, outpath, text): + tree = etree.parse(inpath) + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) + + for mnt in mentions: + mnt_set = text.get_mention_set(mnt.attrib['id']) + if mnt_set: + mnt.attrib['mention_group'] = mnt_set + else: + mnt.attrib['mention_group'] = 'empty' + + with open(outpath, 'wb') as output_file: + output_file.write(etree.tostring(tree, pretty_print=True, + xml_declaration=True, encoding='UTF-8', + doctype=u'<!DOCTYPE markables SYSTEM "markables.dtd">')) diff --git a/corneferencer/main.py b/corneferencer/main.py new file mode 100644 index 0000000..c3325d1 --- /dev/null +++ b/corneferencer/main.py @@ -0,0 +1,84 @@ +import os +import sys + +from argparse import ArgumentParser +from natsort import natsorted + +sys.path.append(os.path.abspath(os.path.join('..'))) + +from inout import mmax +from inout.constants import INPUT_FORMATS +from resolvers import resolve +from resolvers.constants import RESOLVERS +from utils import eprint + + +def main(): + args = parse_arguments() + if not args.input: + eprint("Error: Input file(s) not specified!") + elif args.resolver not in RESOLVERS: + eprint("Error: Unknown resolve algorithm!") + elif args.format not in INPUT_FORMATS: + eprint("Error: Unknown input file format!") + else: + process_texts(args.input, args.output, args.format, args.resolver) + + +def parse_arguments(): + parser = ArgumentParser(description='Corneferencer: coreference resolver using neural nets.') + parser.add_argument('-i', '--input', type=str, action='store', + dest='input', default='', + help='input file or dir path') + parser.add_argument('-o', '--output', type=str, action='store', + dest='output', default='', + help='output path; if not specified writes output to standard output') + parser.add_argument('-f', '--format', type=str, action='store', + dest='format', default='mmax', + help='input format; default: mmax') + parser.add_argument('-r', '--resolver', type=str, action='store', + dest='resolver', default='incremental', + help='resolve algorithm; default: incremental; possibilities: %s' + % ', '.join(RESOLVERS)) + + args = parser.parse_args() + return args + + +def process_texts(inpath, outpath, informat, resolver): + if os.path.isdir(inpath): + process_directory(inpath, outpath, informat, resolver) + elif os.path.isfile(inpath): + process_file(inpath, outpath, informat, resolver) + else: + eprint("Error: Specified input does not exist!") + + +def process_directory(inpath, outpath, informat, resolver): + inpath = os.path.abspath(inpath) + outpath = os.path.abspath(outpath) + + files = os.listdir(inpath) + files = natsorted(files) + + for filename in files: + textname = os.path.splitext(os.path.basename(filename))[0] + textoutput = os.path.join(outpath, textname) + textinput = os.path.join(inpath, filename) + process_file(textinput, textoutput, informat, resolver) + + +def process_file(inpath, outpath, informat, resolver): + basename = os.path.basename(inpath) + if informat == 'mmax' and basename.endswith('.mmax'): + print (basename) + text = mmax.read(inpath) + if resolver == 'incremental': + resolve.incremental(text) + elif resolver == 'entity_based': + resolve.entity_based(text) + mmax.write(inpath, outpath, text) + + +if __name__ == '__main__': + main() diff --git a/corneferencer/readers/__init__.py b/corneferencer/readers/__init__.py deleted file mode 100644 index e69de29..0000000 --- a/corneferencer/readers/__init__.py +++ /dev/null diff --git a/corneferencer/entities/__init__.py b/corneferencer/resolvers/__init__.py index e69de29..e69de29 100644 --- a/corneferencer/entities/__init__.py +++ b/corneferencer/resolvers/__init__.py diff --git a/corneferencer/resolvers/constants.py b/corneferencer/resolvers/constants.py new file mode 100644 index 0000000..92c7106 --- /dev/null +++ b/corneferencer/resolvers/constants.py @@ -0,0 +1 @@ +RESOLVERS = ['entity_based', 'incremental'] diff --git a/corneferencer/resolvers/features.py b/corneferencer/resolvers/features.py new file mode 100644 index 0000000..8d593d0 --- /dev/null +++ b/corneferencer/resolvers/features.py @@ -0,0 +1,170 @@ +import numpy +import random + +from conf import RANDOM_WORD_VECTORS, W2V_MODEL, W2V_SIZE + + +# mention features +def head_vec(mention): + return list(get_wv(W2V_MODEL, mention.head_base)) + + +def first_word_vec(mention): + return list(get_wv(W2V_MODEL, mention.words[0]['base'])) + + +def last_word_vec(mention): + return list(get_wv(W2V_MODEL, mention.words[-1]['base'])) + + +def first_after_vec(mention): + if len(mention.follow_context) > 0: + vec = list(get_wv(W2V_MODEL, mention.follow_context[0]['base'])) + else: + vec = [0.0] * W2V_SIZE + return vec + + +def second_after_vec(mention): + if len(mention.follow_context) > 1: + vec = list(get_wv(W2V_MODEL, mention.follow_context[1]['base'])) + else: + vec = [0.0] * W2V_SIZE + return vec + + +def first_before_vec(mention): + if len(mention.prec_context) > 0: + vec = list(get_wv(W2V_MODEL, mention.prec_context[-1]['base'])) + else: + vec = [0.0] * W2V_SIZE + return vec + + +def second_before_vec(mention): + if len(mention.prec_context) > 1: + vec = list(get_wv(W2V_MODEL, mention.prec_context[-2]['base'])) + else: + vec = [0.0] * W2V_SIZE + return vec + + +def preceding_context_vec(mention): + return list(get_context_vec(mention.prec_context, W2V_MODEL)) + + +def following_context_vec(mention): + return list(get_context_vec(mention.follow_context, W2V_MODEL)) + + +def mention_vec(mention): + return list(get_context_vec(mention.words, W2V_MODEL)) + + +def sentence_vec(mention): + return list(get_context_vec(mention.sentence, W2V_MODEL)) + + +# pair features +def distances_vec(ante, ana): + vec = [] + + mnts_intersect = pair_intersect(ante, ana) + + words_dist = [0] * 11 + words_bucket = 0 + if mnts_intersect != 1: + words_bucket = get_distance_bucket(ana.start_in_words - ante.end_in_words - 1) + words_dist[words_bucket] = 1 + vec.extend(words_dist) + + mentions_dist = [0] * 11 + mentions_bucket = 0 + if mnts_intersect != 1: + mentions_bucket = get_distance_bucket(ana.position_in_mentions - ante.position_in_mentions - 1) + if words_bucket == 10: + mentions_bucket = 10 + mentions_dist[mentions_bucket] = 1 + vec.extend(mentions_dist) + + vec.append(mnts_intersect) + + return vec + + +def pair_intersect(ante, ana): + for ante_word in ante.words: + for ana_word in ana.words: + if ana_word['id'] == ante_word['id']: + return 1 + return 0 + + +def head_match(ante, ana): + if ante.head_orth.lower() == ana.head_orth.lower(): + return 1 + return 0 + + +def exact_match(ante, ana): + if ante.text.lower() == ana.text.lower(): + return 1 + return 0 + + +def base_match(ante, ana): + if ante.lemmatized_text.lower() == ana.lemmatized_text.lower(): + return 1 + return 0 + + +# supporting functions +def get_wv(model, lemma, use_random_vec=True): + vec = None + if use_random_vec: + vec = random_vec() + try: + vec = model.wv[lemma] + except KeyError: + pass + except TypeError: + pass + return vec + + +def random_vec(): + return numpy.asarray([random.uniform(-0.25, 0.25) for i in range(0, W2V_SIZE)], dtype=numpy.float32) + + +def get_context_vec(words, model): + vec = numpy.zeros(W2V_SIZE, dtype=numpy.float32) + unknown_count = 0 + if len(words) != 0: + for word in words: + word_vec = get_wv(model, word['base'], RANDOM_WORD_VECTORS) + if word_vec is None: + unknown_count += 1 + else: + vec += word_vec + significant_words = len(words) - unknown_count + if significant_words != 0: + vec = vec / float(significant_words) + else: + vec = random_vec() + return vec + + +def get_distance_bucket(distance): + if 0 <= distance <= 4: + return distance + elif 5 <= distance <= 7: + return 5 + elif 8 <= distance <= 15: + return 6 + elif 16 <= distance <= 31: + return 7 + elif 32 <= distance <= 63: + return 8 + elif distance >= 64: + return 9 + return 10 diff --git a/corneferencer/resolvers/resolve.py b/corneferencer/resolvers/resolve.py new file mode 100644 index 0000000..18abce9 --- /dev/null +++ b/corneferencer/resolvers/resolve.py @@ -0,0 +1,79 @@ +from conf import NEURAL_MODEL, THRESHOLD +from corneferencer.resolvers.vectors import create_pair_vector + + +# incremental resolve algorithm +def incremental(text): + last_set_id = 1 + for i, ana in enumerate(text.mentions): + if i > 0: + best_prediction = 0.0 + best_ante = None + for ante in text.mentions[:i:-1]: + pair_vec = create_pair_vector(ante, ana) + prediction = NEURAL_MODEL.predict(pair_vec) + accuracy = prediction[0] + if accuracy > THRESHOLD and accuracy > best_prediction: + best_prediction = accuracy + best_ante = ante + if best_ante is not None: + if best_ante.set: + ana.set = best_ante.set + else: + str_set_id = 'set_%d' % last_set_id + best_ante.set = str_set_id + ana.set = str_set_id + last_set_id += 1 + + +# entity based resolve algorithm +def entity_based(text): + sets = [] + last_set_id = 1 + for i, ana in enumerate(text.mentions): + if i > 0: + best_fit = get_best_set(sets, ana) + if best_fit is not None: + ana.set = best_fit['set_id'] + best_fit['mentions'].append(ana) + else: + str_set_id = 'set_%d' % last_set_id + sets.append({'set_id': str_set_id, + 'mentions': [ana]}) + ana.set = str_set_id + last_set_id += 1 + else: + str_set_id = 'set_%d' % last_set_id + sets.append({'set_id': str_set_id, + 'mentions': [ana]}) + ana.set = str_set_id + last_set_id += 1 + + remove_singletons(sets) + + +def get_best_set(sets, ana): + best_prediction = 0.0 + best_set = None + for s in sets: + accuracy = predict_set(s['mentions'], ana) + if accuracy > THRESHOLD and accuracy >= best_prediction: + best_prediction = accuracy + best_set = s + return best_set + + +def predict_set(mentions, ana): + accuracy_sum = 0.0 + for mnt in mentions: + pair_vec = create_pair_vector(mnt, ana) + prediction = NEURAL_MODEL.predict(pair_vec) + accuracy = prediction[0] + accuracy_sum += accuracy + return accuracy_sum / float(len(mentions)) + + +def remove_singletons(sets): + for s in sets: + if len(s['mentions']) == 1: + s['mentions'][0].set = '' diff --git a/corneferencer/resolvers/vectors.py b/corneferencer/resolvers/vectors.py new file mode 100644 index 0000000..ad54932 --- /dev/null +++ b/corneferencer/resolvers/vectors.py @@ -0,0 +1,41 @@ +import numpy + +from corneferencer.resolvers import features + +# input_1 to have shape (None, 1126) but got array with shape (1126, 1) +def create_pair_vector(ante, ana): + vec = [] + # ante_features = get_mention_features(ante) + # vec.extend(ante_features) + # ana_features = get_mention_features(ana) + # vec.extend(ana_features) + vec.extend(ante.features) + vec.extend(ana.features) + pair_features = get_pair_features(ante, ana) + vec.extend(pair_features) + return numpy.asarray([vec], dtype=numpy.float32) + + +def get_mention_features(mention): + vec = [] + vec.extend(features.head_vec(mention)) + vec.extend(features.first_word_vec(mention)) + vec.extend(features.last_word_vec(mention)) + vec.extend(features.first_after_vec(mention)) + vec.extend(features.second_after_vec(mention)) + vec.extend(features.first_before_vec(mention)) + vec.extend(features.second_before_vec(mention)) + vec.extend(features.preceding_context_vec(mention)) + vec.extend(features.following_context_vec(mention)) + vec.extend(features.mention_vec(mention)) + vec.extend(features.sentence_vec(mention)) + return vec + + +def get_pair_features(ante, ana): + vec = [] + vec.extend(features.distances_vec(ante, ana)) + vec.append(features.head_match(ante, ana)) + vec.append(features.exact_match(ante, ana)) + vec.append(features.base_match(ante, ana)) + return vec diff --git a/corneferencer/utils.py b/corneferencer/utils.py new file mode 100644 index 0000000..b0bcda1 --- /dev/null +++ b/corneferencer/utils.py @@ -0,0 +1,25 @@ +from __future__ import print_function + +import sys + +from keras.models import Model +from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization + + +def eprint(*args, **kwargs): + print(*args, file=sys.stderr, **kwargs) + + +def initialize_neural_model(number_of_features): + inputs = Input(shape=(number_of_features,)) + output_from_1st_layer = Dense(1000, activation='relu')(inputs) + output_from_1st_layer = Dropout(0.5)(output_from_1st_layer) + output_from_1st_layer = BatchNormalization()(output_from_1st_layer) + output_from_2nd_layer = Dense(500, activation='relu')(output_from_1st_layer) + output_from_2nd_layer = Dropout(0.5)(output_from_2nd_layer) + output_from_2nd_layer = BatchNormalization()(output_from_2nd_layer) + output = Dense(1, activation='sigmoid')(output_from_2nd_layer) + + model = Model(inputs, output) + model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy']) + return model diff --git a/requirements.txt b/requirements.txt index e69de29..a77a83a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1,4 @@ +lxml +natsort +gensim +numpy diff --git a/setup.py b/setup.py deleted file mode 100644 index e69de29..0000000 --- a/setup.py +++ /dev/null