diff --git a/conf.py b/conf.py index 648e472..b4c42e8 100644 --- a/conf.py +++ b/conf.py @@ -30,7 +30,6 @@ W2V_MODEL_PATH = os.path.join(MAIN_PATH, 'models', W2V_MODEL_NAME) W2V_MODEL = Word2Vec.load(W2V_MODEL_PATH) NEURAL_MODEL_PATH = os.path.join(MAIN_PATH, 'models', NEURAL_MODEL_NAME) -NEURAL_MODEL = utils.initialize_neural_model(NEURAL_MODEL_ARCHITECTURE, NUMBER_OF_FEATURES, NEURAL_MODEL_PATH) FREQ_LIST_PATH = os.path.join(MAIN_PATH, 'freq', FREQ_LIST_NAME) FREQ_LIST = utils.load_freq_list(FREQ_LIST_PATH) diff --git a/corneferencer/entities.py b/corneferencer/entities.py index c1e1509..cc52b52 100644 --- a/corneferencer/entities.py +++ b/corneferencer/entities.py @@ -1,4 +1,4 @@ -from corneferencer.resolvers.vectors import get_mention_features +from corneferencer.resolvers import vectors class Text: @@ -19,6 +19,9 @@ class Text: return mnt return None + def get_mentions(self): + return self.mentions + def get_sets(self): sets = {} for mnt in self.mentions: @@ -62,4 +65,4 @@ class Mention: self.sentence_id = sentence_id self.first_in_sentence = first_in_sentence self.first_in_paragraph = first_in_paragraph - self.features = get_mention_features(self) + self.features = vectors.get_mention_features(self) diff --git a/corneferencer/inout/mmax.py b/corneferencer/inout/mmax.py index e30f551..62feb49 100644 --- a/corneferencer/inout/mmax.py +++ b/corneferencer/inout/mmax.py @@ -3,11 +3,11 @@ import shutil from lxml import etree -from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST +import conf from corneferencer.entities import Mention, Text -def read(inpath): +def read(inpath, clear_mentions=conf.CLEAR_INPUT): textname = os.path.splitext(os.path.basename(inpath))[0] textdir = os.path.dirname(inpath) @@ -15,11 +15,11 @@ def read(inpath): words_path = os.path.join(textdir, '%s_words.xml' % textname) text = Text(textname) - text.mentions = read_mentions(mentions_path, words_path) + text.mentions = read_mentions(mentions_path, words_path, clear_mentions) return text -def read_mentions(mentions_path, words_path): +def read_mentions(mentions_path, words_path, clear_mentions=conf.CLEAR_INPUT): mentions = [] mentions_tree = etree.parse(mentions_path) markables = mentions_tree.xpath("//ns:markable", @@ -43,7 +43,7 @@ def read_mentions(mentions_path, words_path): head = get_head(head_orth, mention_words) mention_group = '' - if markable.attrib['mention_group'] != 'empty' and not CLEAR_INPUT: + if markable.attrib['mention_group'] != 'empty' and not clear_mentions: mention_group = markable.attrib['mention_group'] mention = Mention(mnt_id=markable.attrib['id'], text=span_to_text(span, words, 'orth'), @@ -189,7 +189,7 @@ def get_prec_context(mention_start, words): while context_start >= 0: if not word_to_ignore(words[context_start]): context.append(words[context_start]) - if len(context) == CONTEXT: + if len(context) == conf.CONTEXT: break context_start -= 1 context.reverse() @@ -222,7 +222,7 @@ def get_follow_context(mention_end, words): while context_end < len(words): if not word_to_ignore(words[context_end]): context.append(words[context_end]) - if len(context) == CONTEXT: + if len(context) == conf.CONTEXT: break context_end += 1 return context @@ -349,9 +349,8 @@ def get_rarest_word(words): rarest_word = words[0] for i, word in enumerate(words): word_freq = 0 - if word['base'] in FREQ_LIST: - word_freq = FREQ_LIST[word['base']] - + if word['base'] in conf.FREQ_LIST: + word_freq = conf.FREQ_LIST[word['base']] if i == 0 or word_freq < min_freq: min_freq = word_freq rarest_word = word diff --git a/corneferencer/inout/tei.py b/corneferencer/inout/tei.py index 81d4afc..b166719 100644 --- a/corneferencer/inout/tei.py +++ b/corneferencer/inout/tei.py @@ -4,7 +4,7 @@ import shutil from lxml import etree -from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST +import conf from corneferencer.entities import Mention, Text from corneferencer.utils import eprint @@ -18,7 +18,7 @@ NSMAP = {None: TEI_NS, 'xi': XI_NS} -def read(inpath): +def read(inpath, clear_mentions=conf.CLEAR_INPUT): textname = os.path.basename(inpath) text = Text(textname) @@ -49,7 +49,7 @@ def read(inpath): eprint("Error: missing mentions layer for text %s!" % textname) return None - if os.path.exists(ann_coreference) and not CLEAR_INPUT: + if os.path.exists(ann_coreference) and not clear_mentions: add_coreference_layer(ann_coreference, text) return text @@ -215,6 +215,9 @@ def get_mention(mention, mnt_id, segments, segments_ids, paragraph_id, sentence_ semh_id = get_fval(f).split('#')[-1] semh = segments[semh_id] + if len(mnt_segments) == 0: + mnt_segments.append(semh) + (sent_segments, prec_context, follow_context, first_in_sentence, first_in_paragraph) = get_context(mnt_segments, segments, segments_ids) @@ -272,7 +275,7 @@ def get_prec_context(mention_start, segments, segments_ids): while context_start >= 0: if not word_to_ignore(segments[segments_ids[context_start]]): context.append(segments[segments_ids[context_start]]) - if len(context) == CONTEXT: + if len(context) == conf.CONTEXT: break context_start -= 1 context.reverse() @@ -285,7 +288,7 @@ def get_follow_context(mention_end, segments, segments_ids): while context_end < len(segments): if not word_to_ignore(segments[segments_ids[context_end]]): context.append(segments[segments_ids[context_end]]) - if len(context) == CONTEXT: + if len(context) == conf.CONTEXT: break context_end += 1 return context @@ -341,8 +344,8 @@ def get_rarest_word(words): rarest_word = words[0] for i, word in enumerate(words): word_freq = 0 - if word['base'] in FREQ_LIST: - word_freq = FREQ_LIST[word['base']] + if word['base'] in conf.FREQ_LIST: + word_freq = conf.FREQ_LIST[word['base']] if i == 0 or word_freq < min_freq: min_freq = word_freq diff --git a/corneferencer/main.py b/corneferencer/main.py index f9aebc7..9f85c5d 100644 --- a/corneferencer/main.py +++ b/corneferencer/main.py @@ -4,9 +4,11 @@ import sys from argparse import ArgumentParser from natsort import natsorted -sys.path.append(os.path.abspath(os.path.join('..'))) +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + import conf +import utils from inout import mmax, tei from inout.constants import INPUT_FORMATS from resolvers import resolve @@ -27,22 +29,25 @@ def main(): if conf.NEURAL_MODEL_ARCHITECTURE == 'siamese': resolver = conf.NEURAL_MODEL_ARCHITECTURE eprint("Warning: Using %s resolver because of selected neural model architecture!" % - conf.NEURAL_MODEL_ARCHITECTURE) - process_texts(args.input, args.output, args.format, resolver, args.threshold) + conf.NEURAL_MODEL_ARCHITECTURE) + process_texts(args.input, args.output, args.format, resolver, args.threshold, args.model) def parse_arguments(): parser = ArgumentParser(description='Corneferencer: coreference resolver using neural nets.') + parser.add_argument('-f', '--format', type=str, action='store', + dest='format', default=INPUT_FORMATS[0], + help='input format; default: %s; possibilities: %s' + % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS))) parser.add_argument('-i', '--input', type=str, action='store', dest='input', default='', help='input file or dir path') + parser.add_argument('-m', '--model', type=str, action='store', + dest='model', default='', + help='neural model path; default: %s' % conf.NEURAL_MODEL_PATH) parser.add_argument('-o', '--output', type=str, action='store', dest='output', default='', help='output path; if not specified writes output to standard output') - parser.add_argument('-f', '--format', type=str, action='store', - dest='format', default=INPUT_FORMATS[0], - help='input format; default: %s; possibilities: %s' - % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS))) parser.add_argument('-r', '--resolver', type=str, action='store', dest='resolver', default=RESOLVERS[0], help='resolve algorithm; default: %s; possibilities: %s' @@ -55,16 +60,17 @@ def parse_arguments(): return args -def process_texts(inpath, outpath, informat, resolver, threshold): +def process_texts(inpath, outpath, informat, resolver, threshold, model_path): + model = utils.initialize_neural_model(conf.NEURAL_MODEL_ARCHITECTURE, conf.NUMBER_OF_FEATURES, model_path) if os.path.isdir(inpath): - process_directory(inpath, outpath, informat, resolver, threshold) + process_directory(inpath, outpath, informat, resolver, threshold, model) elif os.path.isfile(inpath): - process_text(inpath, outpath, informat, resolver, threshold) + process_text(inpath, outpath, informat, resolver, threshold, model) else: eprint("Error: Specified input does not exist!") -def process_directory(inpath, outpath, informat, resolver, threshold): +def process_directory(inpath, outpath, informat, resolver, threshold, model): inpath = os.path.abspath(inpath) outpath = os.path.abspath(outpath) @@ -75,38 +81,38 @@ def process_directory(inpath, outpath, informat, resolver, threshold): textname = os.path.splitext(os.path.basename(filename))[0] textoutput = os.path.join(outpath, textname) textinput = os.path.join(inpath, filename) - process_text(textinput, textoutput, informat, resolver, threshold) + process_text(textinput, textoutput, informat, resolver, threshold, model) -def process_text(inpath, outpath, informat, resolver, threshold): +def process_text(inpath, outpath, informat, resolver, threshold, model): basename = os.path.basename(inpath) if informat == 'mmax' and basename.endswith('.mmax'): print (basename) text = mmax.read(inpath) if resolver == 'incremental': - resolve.incremental(text, threshold) + resolve.incremental(text, threshold, model) elif resolver == 'entity_based': - resolve.entity_based(text, threshold) + resolve.entity_based(text, threshold, model) elif resolver == 'closest': - resolve.closest(text, threshold) + resolve.closest(text, threshold, model) elif resolver == 'siamese': - resolve.siamese(text, threshold) + resolve.siamese(text, threshold, model) elif resolver == 'all2all': - resolve.all2all(text, threshold) + resolve.all2all(text, threshold, model) mmax.write(inpath, outpath, text) elif informat == 'tei': print (basename) text = tei.read(inpath) if resolver == 'incremental': - resolve.incremental(text, threshold) + resolve.incremental(text, threshold, model) elif resolver == 'entity_based': - resolve.entity_based(text, threshold) + resolve.entity_based(text, threshold, model) elif resolver == 'closest': - resolve.closest(text, threshold) + resolve.closest(text, threshold, model) elif resolver == 'siamese': - resolve.siamese(text, threshold) + resolve.siamese(text, threshold, model) elif resolver == 'all2all': - resolve.all2all(text, threshold) + resolve.all2all(text, threshold, model) tei.write(inpath, outpath, text) diff --git a/corneferencer/prepare_data.py b/corneferencer/prepare_data.py new file mode 100644 index 0000000..6591929 --- /dev/null +++ b/corneferencer/prepare_data.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- + +import codecs +import os +import random +import sys + +from itertools import combinations +from argparse import ArgumentParser +from natsort import natsorted + +sys.path.append(os.path.abspath(os.path.join('..'))) + +from inout import mmax, tei +from inout.constants import INPUT_FORMATS +from utils import eprint +from corneferencer.resolvers import vectors + + +POS_COUNT = 0 +NEG_COUNT = 0 + + +def main(): + args = parse_arguments() + if not args.input: + eprint("Error: Input file(s) not specified!") + elif args.format not in INPUT_FORMATS: + eprint("Error: Unknown input file format!") + else: + process_texts(args.input, args.output, args.format, args.proportion) + + +def parse_arguments(): + parser = ArgumentParser(description='Corneferencer: data preparator for neural nets training.') + parser.add_argument('-i', '--input', type=str, action='store', + dest='input', default='', + help='input dir path') + parser.add_argument('-o', '--output', type=str, action='store', + dest='output', default='', + help='output path; if not specified writes output to standard output') + parser.add_argument('-f', '--format', type=str, action='store', + dest='format', default=INPUT_FORMATS[0], + help='input format; default: %s; possibilities: %s' + % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS))) + parser.add_argument('-p', '--proportion', type=int, action='store', + dest='proportion', default=5, + help='negative examples proportion; default: 5') + args = parser.parse_args() + return args + + +def process_texts(inpath, outpath, informat, proportion): + if os.path.isdir(inpath): + process_directory(inpath, outpath, informat, proportion) + else: + eprint("Error: Specified input does not exist or is not a directory!") + + +def process_directory(inpath, outpath, informat, proportion): + inpath = os.path.abspath(inpath) + outpath = os.path.abspath(outpath) + + try: + create_data_vectors(inpath, outpath, informat, proportion) + finally: + print ('Positives: ', POS_COUNT) + print ('Negatives: ', NEG_COUNT) + + +def create_data_vectors(inpath, outpath, informat, proportion): + features_file = codecs.open(outpath, 'w', 'utf-8') + + files = os.listdir(inpath) + files = natsorted(files) + + for filename in files: + textname = os.path.splitext(os.path.basename(filename))[0] + textinput = os.path.join(inpath, filename) + + print ('Processing text: %s' % textname) + text = None + if informat == 'mmax' and filename.endswith('.mmax'): + text = mmax.read(textinput, False) + elif informat == 'tei': + text = tei.read(textinput, False) + + positives, negatives = diff_mentions(text, proportion) + write_features(features_file, positives, negatives) + + +def diff_mentions(text, proportion): + sets = text.get_sets() + all_mentions = text.get_mentions() + positives = get_positives(sets) + positives, negatives = get_negatives_and_update_positives(all_mentions, positives, proportion) + return positives, negatives + + +def get_positives(sets): + positives = [] + for set_id in sets: + coref_set = sets[set_id] + positives.extend(list(combinations(coref_set, 2))) + return positives + + +def get_negatives_and_update_positives(all_mentions, positives, proportion): + all_pairs = list(combinations(all_mentions, 2)) + + all_pairs = set(all_pairs) + negatives = [pair for pair in all_pairs if pair not in positives] + samples_count = proportion * len(positives) + if samples_count > len(negatives): + samples_count = len(negatives) + if proportion == 1: + positives = random.sample(set(positives), samples_count) + print (u'Więcej przypadków pozytywnych niż negatywnych!') + negatives = random.sample(set(negatives), samples_count) + return positives, negatives + + +def write_features(features_file, positives, negatives): + global POS_COUNT + POS_COUNT += len(positives) + for pair in positives: + vector = vectors.get_pair_vector(pair[0], pair[1]) + vector.append(1.0) + features_file.write(u'%s\n' % u'\t'.join([str(feature) for feature in vector])) + + global NEG_COUNT + NEG_COUNT += len(negatives) + for pair in negatives: + vector = vectors.get_pair_vector(pair[0], pair[1]) + vector.append(0.0) + features_file.write(u'%s\n' % u'\t'.join([str(feature) for feature in vector])) + + +if __name__ == '__main__': + main() diff --git a/corneferencer/resolvers/features.py b/corneferencer/resolvers/features.py index 987eae8..b4a83f0 100644 --- a/corneferencer/resolvers/features.py +++ b/corneferencer/resolvers/features.py @@ -72,97 +72,97 @@ def sentence_vec(mention): def mention_type(mention): - type_vec = [0] * 4 + type_vec = [0.0] * 4 if mention.head is None: - type_vec[3] = 1 + type_vec[3] = 1.0 elif mention.head['ctag'] in constants.NOUN_TAGS: - type_vec[0] = 1 + type_vec[0] = 1.0 elif mention.head['ctag'] in constants.PPRON_TAGS: - type_vec[1] = 1 + type_vec[1] = 1.0 elif mention.head['ctag'] in constants.ZERO_TAGS: - type_vec[2] = 1 + type_vec[2] = 1.0 else: - type_vec[3] = 1 + type_vec[3] = 1.0 return type_vec def is_first_second_person(mention): if mention.head is None: - return 0 + return 0.0 if mention.head['person'] in constants.FIRST_SECOND_PERSON: - return 1 - return 0 + return 1.0 + return 0.0 def is_demonstrative(mention): if mention.words[0]['base'].lower() in constants.INDICATIVE_PRONS_BASES: - return 1 - return 0 + return 1.0 + return 0.0 def is_demonstrative_nominal(mention): if mention.head is None: - return 0 + return 0.0 if is_demonstrative(mention) and mention.head['ctag'] in constants.NOUN_TAGS: - return 1 - return 0 + return 1.0 + return 0.0 def is_demonstrative_pronoun(mention): if mention.head is None: - return 0 + return 0.0 if (is_demonstrative(mention) and (mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS)): - return 1 - return 0 + return 1.0 + return 0.0 def is_refl_pronoun(mention): if mention.head is None: - return 0 + return 0.0 if mention.head['ctag'] in constants.SIEBIE_TAGS: - return 1 - return 0 + return 1.0 + return 0.0 def is_first_in_sentence(mention): if mention.first_in_sentence: - return 1 - return 0 + return 1.0 + return 0.0 def is_zero_or_pronoun(mention): if mention.head is None: - return 0 + return 0.0 if mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS: - return 1 - return 0 + return 1.0 + return 0.0 def head_contains_digit(mention): _digits = re.compile('\d') if _digits.search(mention.head_orth): - return 1 - return 0 + return 1.0 + return 0.0 def mention_contains_digit(mention): _digits = re.compile('\d') if _digits.search(mention.text): - return 1 - return 0 + return 1.0 + return 0.0 def contains_letter(mention): if any(c.isalpha() for c in mention.text): - return 1 - return 0 + return 1.0 + return 0.0 def post_modified(mention): if mention.head_orth != mention.words[-1]['orth']: - return 1 - return 0 + return 1.0 + return 0.0 # pair features @@ -171,20 +171,20 @@ def distances_vec(ante, ana): mnts_intersect = pair_intersect(ante, ana) - words_dist = [0] * 11 + words_dist = [0.0] * 11 words_bucket = 0 - if mnts_intersect != 1: + if mnts_intersect != 1.0: words_bucket = get_distance_bucket(ana.start_in_words - ante.end_in_words) - words_dist[words_bucket] = 1 + words_dist[words_bucket] = 1.0 vec.extend(words_dist) - mentions_dist = [0] * 11 + mentions_dist = [0.0] * 11 mentions_bucket = 0 - if mnts_intersect != 1: + if mnts_intersect != 1.0: mentions_bucket = get_distance_bucket(ana.position_in_mentions - ante.position_in_mentions) if words_bucket == 10: mentions_bucket = 10 - mentions_dist[mentions_bucket] = 1 + mentions_dist[mentions_bucket] = 1.0 vec.extend(mentions_dist) vec.append(mnts_intersect) @@ -196,45 +196,45 @@ def pair_intersect(ante, ana): for ante_word in ante.words: for ana_word in ana.words: if ana_word['id'] == ante_word['id']: - return 1 - return 0 + return 1.0 + return 0.0 def head_match(ante, ana): if ante.head_orth.lower() == ana.head_orth.lower(): - return 1 - return 0 + return 1.0 + return 0.0 def exact_match(ante, ana): if ante.text.lower() == ana.text.lower(): - return 1 - return 0 + return 1.0 + return 0.0 def base_match(ante, ana): if ante.lemmatized_text.lower() == ana.lemmatized_text.lower(): - return 1 - return 0 + return 1.0 + return 0.0 def ante_contains_rarest_from_ana(ante, ana): ana_rarest = ana.rarest for word in ante.words: if word['base'] == ana_rarest['base']: - return 1 - return 0 + return 1.0 + return 0.0 def agreement(ante, ana, tag_name): - agr_vec = [0] * 3 + agr_vec = [0.0] * 3 if (ante.head is None or ana.head is None or ante.head[tag_name] == 'unk' or ana.head[tag_name] == 'unk'): - agr_vec[2] = 1 + agr_vec[2] = 1.0 elif ante.head[tag_name] == ana.head[tag_name]: - agr_vec[0] = 1 + agr_vec[0] = 1.0 else: - agr_vec[1] = 1 + agr_vec[1] = 1.0 return agr_vec @@ -243,72 +243,72 @@ def is_acronym(ante, ana): return check_one_way_acronym(ana.text, ante.text) if ante.text.upper() == ante.text: return check_one_way_acronym(ante.text, ana.text) - return 0 + return 0.0 def same_sentence(ante, ana): if ante.sentence_id == ana.sentence_id: - return 1 - return 0 + return 1.0 + return 0.0 def neighbouring_sentence(ante, ana): if ana.sentence_id - ante.sentence_id == 1: - return 1 - return 0 + return 1.0 + return 0.0 def cousin_sentence(ante, ana): if ana.sentence_id - ante.sentence_id == 2: - return 1 - return 0 + return 1.0 + return 0.0 def distant_sentence(ante, ana): if ana.sentence_id - ante.sentence_id > 2: - return 1 - return 0 + return 1.0 + return 0.0 def same_paragraph(ante, ana): if ante.paragraph_id == ana.paragraph_id: - return 1 - return 0 + return 1.0 + return 0.0 def flat_gender_agreement(ante, ana): - agr_vec = [0] * 3 + agr_vec = [0.0] * 3 if (ante.head is None or ana.head is None or ante.head['gender'] == 'unk' or ana.head['gender'] == 'unk'): - agr_vec[2] = 1 + agr_vec[2] = 1.0 elif (ante.head['gender'] == ana.head['gender'] or (ante.head['gender'] in constants.MASCULINE_TAGS and ana.head['gender'] in constants.MASCULINE_TAGS)): - agr_vec[0] = 1 + agr_vec[0] = 1.0 else: - agr_vec[1] = 1 + agr_vec[1] = 1.0 return agr_vec def left_match(ante, ana): if (ante.text.lower().startswith(ana.text.lower()) or ana.text.lower().startswith(ante.text.lower())): - return 1 - return 0 + return 1.0 + return 0.0 def right_match(ante, ana): if (ante.text.lower().endswith(ana.text.lower()) or ana.text.lower().endswith(ante.text.lower())): - return 1 - return 0 + return 1.0 + return 0.0 def abbrev2(ante, ana): ante_abbrev = get_abbrev(ante) ana_abbrev = get_abbrev(ana) if ante.head_orth == ana_abbrev or ana.head_orth == ante_abbrev: - return 1 - return 0 + return 1.0 + return 0.0 def string_kernel(ante, ana): @@ -326,7 +326,7 @@ def head_string_kernel(ante, ana): def wordnet_synonyms(ante, ana): ante_synonyms = set() if ante.head is None or ana.head is None: - return 0 + return 0.0 if ante.head['base'] in conf.LEMMA2SYNONYMS: ante_synonyms = conf.LEMMA2SYNONYMS[ante.head['base']] @@ -336,13 +336,13 @@ def wordnet_synonyms(ante, ana): ana_synonyms = conf.LEMMA2SYNONYMS[ana.head['base']] if ana.head['base'] in ante_synonyms or ante.head['base'] in ana_synonyms: - return 1 - return 0 + return 1.0 + return 0.0 def wordnet_ana_is_hypernym(ante, ana): if ante.head is None or ana.head is None: - return 0 + return 0.0 ante_hypernyms = set() if ante.head['base'] in conf.LEMMA2HYPERNYMS: @@ -353,16 +353,16 @@ def wordnet_ana_is_hypernym(ante, ana): ana_hypernyms = conf.LEMMA2HYPERNYMS[ana.head['base']] if not ante_hypernyms or not ana_hypernyms: - return 0 + return 0.0 if ana.head['base'] in ante_hypernyms: - return 1 - return 0 + return 1.0 + return 0.0 def wordnet_ante_is_hypernym(ante, ana): if ante.head is None or ana.head is None: - return 0 + return 0.0 ana_hypernyms = set() if ana.head['base'] in conf.LEMMA2HYPERNYMS: @@ -373,18 +373,18 @@ def wordnet_ante_is_hypernym(ante, ana): ante_hypernyms = conf.LEMMA2HYPERNYMS[ante.head['base']] if not ante_hypernyms or not ana_hypernyms: - return 0 + return 0.0 if ante.head['base'] in ana_hypernyms: - return 1 - return 0 + return 1.0 + return 0.0 def wikipedia_link(ante, ana): ante_base = ante.lemmatized_text.lower() ana_base = ana.lemmatized_text.lower() if ante_base == ana_base: - return 1 + return 1.0 ante_links = set() if ante_base in conf.TITLE2LINKS: @@ -395,16 +395,16 @@ def wikipedia_link(ante, ana): ana_links = conf.TITLE2LINKS[ana_base] if ana_base in ante_links or ante_base in ana_links: - return 1 + return 1.0 - return 0 + return 0.0 def wikipedia_mutual_link(ante, ana): ante_base = ante.lemmatized_text.lower() ana_base = ana.lemmatized_text.lower() if ante_base == ana_base: - return 1 + return 1.0 ante_links = set() if ante_base in conf.TITLE2LINKS: @@ -415,52 +415,52 @@ def wikipedia_mutual_link(ante, ana): ana_links = conf.TITLE2LINKS[ana_base] if ana_base in ante_links and ante_base in ana_links: - return 1 + return 1.0 - return 0 + return 0.0 def wikipedia_redirect(ante, ana): ante_base = ante.lemmatized_text.lower() ana_base = ana.lemmatized_text.lower() if ante_base == ana_base: - return 1 + return 1.0 if ante_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ante_base] == ana_base: - return 1 + return 1.0 if ana_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ana_base] == ante_base: - return 1 + return 1.0 - return 0 + return 0.0 def samesent_anapron_antefirstinpar(ante, ana): if same_sentence(ante, ana) and is_zero_or_pronoun(ana) and ante.first_in_paragraph: - return 1 - return 0 + return 1.0 + return 0.0 def samesent_antefirstinpar_personnumbermatch(ante, ana): if (same_sentence(ante, ana) and ante.first_in_paragraph and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]): - return 1 - return 0 + return 1.0 + return 0.0 def adjsent_anapron_adjmen_personnumbermatch(ante, ana): if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana) and ana.position_in_mentions - ante.position_in_mentions == 1 and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]): - return 1 - return 0 + return 1.0 + return 0.0 def adjsent_anapron_adjmen(ante, ana): if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana) and ana.position_in_mentions - ante.position_in_mentions == 1): - return 1 - return 0 + return 1.0 + return 0.0 # supporting functions @@ -523,8 +523,8 @@ def check_one_way_acronym(acronym, expression): if expr2: initials += expr2[0].upper() if acronym == initials: - return 1 - return 0 + return 1.0 + return 0.0 def get_abbrev(mention): diff --git a/corneferencer/resolvers/resolve.py b/corneferencer/resolvers/resolve.py index c3a1038..bcca721 100644 --- a/corneferencer/resolvers/resolve.py +++ b/corneferencer/resolvers/resolve.py @@ -1,17 +1,15 @@ import numpy -from conf import NEURAL_MODEL -from corneferencer.resolvers import features -from corneferencer.resolvers.vectors import get_pair_features, get_pair_vector +from corneferencer.resolvers import features, vectors -def siamese(text, threshold): +def siamese(text, threshold, neural_model): last_set_id = 0 for i, ana in enumerate(text.mentions): if i > 0: for ante in reversed(text.mentions[:i]): if not features.pair_intersect(ante, ana): - pair_features = get_pair_features(ante, ana) + pair_features = vectors.get_pair_features(ante, ana) ante_vec = [] ante_vec.extend(ante.features) @@ -23,7 +21,7 @@ def siamese(text, threshold): ana_vec.extend(pair_features) ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32) - prediction = NEURAL_MODEL.predict([ante_sample, ana_sample])[0] + prediction = neural_model.predict([ante_sample, ana_sample])[0] if prediction < threshold: if ante.set: @@ -37,7 +35,7 @@ def siamese(text, threshold): # incremental resolve algorithm -def incremental(text, threshold): +def incremental(text, threshold, neural_model): last_set_id = 0 for i, ana in enumerate(text.mentions): if i > 0: @@ -45,9 +43,9 @@ def incremental(text, threshold): best_ante = None for ante in text.mentions[:i]: if not features.pair_intersect(ante, ana): - pair_vec = get_pair_vector(ante, ana) + pair_vec = vectors.get_pair_vector(ante, ana) sample = numpy.asarray([pair_vec], dtype=numpy.float32) - prediction = NEURAL_MODEL.predict(sample)[0] + prediction = neural_model.predict(sample)[0] if prediction > threshold and prediction >= best_prediction: best_prediction = prediction best_ante = ante @@ -62,86 +60,21 @@ def incremental(text, threshold): # all2all resolve algorithm -def all2all_debug(text, threshold): - last_set_id = 0 - for pos1, mnt1 in enumerate(text.mentions): - best_prediction = 0.0 - best_link = None - for pos2, mnt2 in enumerate(text.mentions): - if (mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2): - ante = mnt1 - ana = mnt2 - if pos2 < pos1: - ante = mnt2 - ana = mnt1 - pair_vec = get_pair_vector(ante, ana) - sample = numpy.asarray([pair_vec], dtype=numpy.float32) - prediction = NEURAL_MODEL.predict(sample)[0] - if prediction > threshold and prediction > best_prediction: - best_prediction = prediction - best_link = mnt2 - if best_link is not None: - if best_link.set and not mnt1.set: - mnt1.set = best_link.set - elif best_link.set and mnt1.set: - text.merge_sets(best_link.set, mnt1.set) - elif not best_link.set and not mnt1.set: - str_set_id = 'set_%d' % last_set_id - best_link.set = str_set_id - mnt1.set = str_set_id - last_set_id += 1 - - -def all2all_v1(text, threshold): - last_set_id = 0 - for pos1, mnt1 in enumerate(text.mentions): - best_prediction = 0.0 - best_link = None - for pos2, mnt2 in enumerate(text.mentions): - if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set) - and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): - ante = mnt1 - ana = mnt2 - if pos2 < pos1: - ante = mnt2 - ana = mnt1 - pair_vec = get_pair_vector(ante, ana) - sample = numpy.asarray([pair_vec], dtype=numpy.float32) - prediction = NEURAL_MODEL.predict(sample)[0] - if prediction > threshold and prediction > best_prediction: - best_prediction = prediction - best_link = mnt2 - if best_link is not None: - if best_link.set and not mnt1.set: - mnt1.set = best_link.set - elif not best_link.set and mnt1.set: - best_link.set = mnt1.set - elif best_link.set and mnt1.set: - text.merge_sets(best_link.set, mnt1.set) - elif not best_link.set and not mnt1.set: - str_set_id = 'set_%d' % last_set_id - best_link.set = str_set_id - mnt1.set = str_set_id - last_set_id += 1 - - -def all2all(text, threshold): +def all2all(text, threshold, neural_model): last_set_id = 0 sets = text.get_sets() for pos1, mnt1 in enumerate(text.mentions): best_prediction = 0.0 best_link = None for pos2, mnt2 in enumerate(text.mentions): - if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set) - and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): + if (pos2 > pos1 and + (mnt1.set != mnt2.set or not mnt1.set or not mnt2.set) + and not features.pair_intersect(mnt1, mnt2)): ante = mnt1 ana = mnt2 - if pos2 < pos1: - ante = mnt2 - ana = mnt1 - pair_vec = get_pair_vector(ante, ana) + pair_vec = vectors.get_pair_vector(ante, ana) sample = numpy.asarray([pair_vec], dtype=numpy.float32) - prediction = NEURAL_MODEL.predict(sample)[0] + prediction = neural_model.predict(sample)[0] if prediction > threshold and prediction > best_prediction: best_prediction = prediction best_link = mnt2 @@ -163,12 +96,12 @@ def all2all(text, threshold): # entity based resolve algorithm -def entity_based(text, threshold): +def entity_based(text, threshold, neural_model): sets = [] last_set_id = 0 for i, ana in enumerate(text.mentions): if i > 0: - best_fit = get_best_set(sets, ana, threshold) + best_fit = get_best_set(sets, ana, threshold, neural_model) if best_fit is not None: ana.set = best_fit['set_id'] best_fit['mentions'].append(ana) @@ -188,25 +121,25 @@ def entity_based(text, threshold): remove_singletons(sets) -def get_best_set(sets, ana, threshold): +def get_best_set(sets, ana, threshold, neural_model): best_prediction = 0.0 best_set = None for s in sets: - accuracy = predict_set(s['mentions'], ana) + accuracy = predict_set(s['mentions'], ana, neural_model) if accuracy > threshold and accuracy >= best_prediction: best_prediction = accuracy best_set = s return best_set -def predict_set(mentions, ana): +def predict_set(mentions, ana, neural_model): prediction_sum = 0.0 for mnt in mentions: prediction = 0.0 if not features.pair_intersect(mnt, ana): - pair_vec = get_pair_vector(mnt, ana) + pair_vec = vectors.get_pair_vector(mnt, ana) sample = numpy.asarray([pair_vec], dtype=numpy.float32) - prediction = NEURAL_MODEL.predict(sample)[0] + prediction = neural_model.predict(sample)[0] prediction_sum += prediction return prediction_sum / float(len(mentions)) @@ -218,15 +151,15 @@ def remove_singletons(sets): # closest resolve algorithm -def closest(text, threshold): +def closest(text, threshold, neural_model): last_set_id = 0 for i, ana in enumerate(text.mentions): if i > 0: for ante in reversed(text.mentions[:i]): if not features.pair_intersect(ante, ana): - pair_vec = get_pair_vector(ante, ana) + pair_vec = vectors.get_pair_vector(ante, ana) sample = numpy.asarray([pair_vec], dtype=numpy.float32) - prediction = NEURAL_MODEL.predict(sample)[0] + prediction = neural_model.predict(sample)[0] if prediction > threshold: if ante.set: ana.set = ante.set diff --git a/corneferencer/utils.py b/corneferencer/utils.py index ea6b125..15c4ca1 100644 --- a/corneferencer/utils.py +++ b/corneferencer/utils.py @@ -7,7 +7,6 @@ import javaobj from keras.models import Sequential, Model from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization, Lambda -from keras.optimizers import RMSprop, Adam from keras import backend as K