Commit c29c36cde4fa9ea17a625fa2b0e8ea5c1192ec25

Authored by Bartłomiej Nitoń
1 parent db88d6e4

Add prepare data script and other minor improvements.

@@ -30,7 +30,6 @@ W2V_MODEL_PATH = os.path.join(MAIN_PATH, 'models', W2V_MODEL_NAME) @@ -30,7 +30,6 @@ W2V_MODEL_PATH = os.path.join(MAIN_PATH, 'models', W2V_MODEL_NAME)
30 W2V_MODEL = Word2Vec.load(W2V_MODEL_PATH) 30 W2V_MODEL = Word2Vec.load(W2V_MODEL_PATH)
31 31
32 NEURAL_MODEL_PATH = os.path.join(MAIN_PATH, 'models', NEURAL_MODEL_NAME) 32 NEURAL_MODEL_PATH = os.path.join(MAIN_PATH, 'models', NEURAL_MODEL_NAME)
33 -NEURAL_MODEL = utils.initialize_neural_model(NEURAL_MODEL_ARCHITECTURE, NUMBER_OF_FEATURES, NEURAL_MODEL_PATH)  
34 33
35 FREQ_LIST_PATH = os.path.join(MAIN_PATH, 'freq', FREQ_LIST_NAME) 34 FREQ_LIST_PATH = os.path.join(MAIN_PATH, 'freq', FREQ_LIST_NAME)
36 FREQ_LIST = utils.load_freq_list(FREQ_LIST_PATH) 35 FREQ_LIST = utils.load_freq_list(FREQ_LIST_PATH)
corneferencer/entities.py
1 -from corneferencer.resolvers.vectors import get_mention_features 1 +from corneferencer.resolvers import vectors
2 2
3 3
4 class Text: 4 class Text:
@@ -19,6 +19,9 @@ class Text: @@ -19,6 +19,9 @@ class Text:
19 return mnt 19 return mnt
20 return None 20 return None
21 21
  22 + def get_mentions(self):
  23 + return self.mentions
  24 +
22 def get_sets(self): 25 def get_sets(self):
23 sets = {} 26 sets = {}
24 for mnt in self.mentions: 27 for mnt in self.mentions:
@@ -62,4 +65,4 @@ class Mention: @@ -62,4 +65,4 @@ class Mention:
62 self.sentence_id = sentence_id 65 self.sentence_id = sentence_id
63 self.first_in_sentence = first_in_sentence 66 self.first_in_sentence = first_in_sentence
64 self.first_in_paragraph = first_in_paragraph 67 self.first_in_paragraph = first_in_paragraph
65 - self.features = get_mention_features(self) 68 + self.features = vectors.get_mention_features(self)
corneferencer/inout/mmax.py
@@ -3,11 +3,11 @@ import shutil @@ -3,11 +3,11 @@ import shutil
3 3
4 from lxml import etree 4 from lxml import etree
5 5
6 -from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST 6 +import conf
7 from corneferencer.entities import Mention, Text 7 from corneferencer.entities import Mention, Text
8 8
9 9
10 -def read(inpath): 10 +def read(inpath, clear_mentions=conf.CLEAR_INPUT):
11 textname = os.path.splitext(os.path.basename(inpath))[0] 11 textname = os.path.splitext(os.path.basename(inpath))[0]
12 textdir = os.path.dirname(inpath) 12 textdir = os.path.dirname(inpath)
13 13
@@ -15,11 +15,11 @@ def read(inpath): @@ -15,11 +15,11 @@ def read(inpath):
15 words_path = os.path.join(textdir, '%s_words.xml' % textname) 15 words_path = os.path.join(textdir, '%s_words.xml' % textname)
16 16
17 text = Text(textname) 17 text = Text(textname)
18 - text.mentions = read_mentions(mentions_path, words_path) 18 + text.mentions = read_mentions(mentions_path, words_path, clear_mentions)
19 return text 19 return text
20 20
21 21
22 -def read_mentions(mentions_path, words_path): 22 +def read_mentions(mentions_path, words_path, clear_mentions=conf.CLEAR_INPUT):
23 mentions = [] 23 mentions = []
24 mentions_tree = etree.parse(mentions_path) 24 mentions_tree = etree.parse(mentions_path)
25 markables = mentions_tree.xpath("//ns:markable", 25 markables = mentions_tree.xpath("//ns:markable",
@@ -43,7 +43,7 @@ def read_mentions(mentions_path, words_path): @@ -43,7 +43,7 @@ def read_mentions(mentions_path, words_path):
43 43
44 head = get_head(head_orth, mention_words) 44 head = get_head(head_orth, mention_words)
45 mention_group = '' 45 mention_group = ''
46 - if markable.attrib['mention_group'] != 'empty' and not CLEAR_INPUT: 46 + if markable.attrib['mention_group'] != 'empty' and not clear_mentions:
47 mention_group = markable.attrib['mention_group'] 47 mention_group = markable.attrib['mention_group']
48 mention = Mention(mnt_id=markable.attrib['id'], 48 mention = Mention(mnt_id=markable.attrib['id'],
49 text=span_to_text(span, words, 'orth'), 49 text=span_to_text(span, words, 'orth'),
@@ -189,7 +189,7 @@ def get_prec_context(mention_start, words): @@ -189,7 +189,7 @@ def get_prec_context(mention_start, words):
189 while context_start >= 0: 189 while context_start >= 0:
190 if not word_to_ignore(words[context_start]): 190 if not word_to_ignore(words[context_start]):
191 context.append(words[context_start]) 191 context.append(words[context_start])
192 - if len(context) == CONTEXT: 192 + if len(context) == conf.CONTEXT:
193 break 193 break
194 context_start -= 1 194 context_start -= 1
195 context.reverse() 195 context.reverse()
@@ -222,7 +222,7 @@ def get_follow_context(mention_end, words): @@ -222,7 +222,7 @@ def get_follow_context(mention_end, words):
222 while context_end < len(words): 222 while context_end < len(words):
223 if not word_to_ignore(words[context_end]): 223 if not word_to_ignore(words[context_end]):
224 context.append(words[context_end]) 224 context.append(words[context_end])
225 - if len(context) == CONTEXT: 225 + if len(context) == conf.CONTEXT:
226 break 226 break
227 context_end += 1 227 context_end += 1
228 return context 228 return context
@@ -349,9 +349,8 @@ def get_rarest_word(words): @@ -349,9 +349,8 @@ def get_rarest_word(words):
349 rarest_word = words[0] 349 rarest_word = words[0]
350 for i, word in enumerate(words): 350 for i, word in enumerate(words):
351 word_freq = 0 351 word_freq = 0
352 - if word['base'] in FREQ_LIST:  
353 - word_freq = FREQ_LIST[word['base']]  
354 - 352 + if word['base'] in conf.FREQ_LIST:
  353 + word_freq = conf.FREQ_LIST[word['base']]
355 if i == 0 or word_freq < min_freq: 354 if i == 0 or word_freq < min_freq:
356 min_freq = word_freq 355 min_freq = word_freq
357 rarest_word = word 356 rarest_word = word
corneferencer/inout/tei.py
@@ -4,7 +4,7 @@ import shutil @@ -4,7 +4,7 @@ import shutil
4 4
5 from lxml import etree 5 from lxml import etree
6 6
7 -from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST 7 +import conf
8 from corneferencer.entities import Mention, Text 8 from corneferencer.entities import Mention, Text
9 from corneferencer.utils import eprint 9 from corneferencer.utils import eprint
10 10
@@ -18,7 +18,7 @@ NSMAP = {None: TEI_NS, @@ -18,7 +18,7 @@ NSMAP = {None: TEI_NS,
18 'xi': XI_NS} 18 'xi': XI_NS}
19 19
20 20
21 -def read(inpath): 21 +def read(inpath, clear_mentions=conf.CLEAR_INPUT):
22 textname = os.path.basename(inpath) 22 textname = os.path.basename(inpath)
23 23
24 text = Text(textname) 24 text = Text(textname)
@@ -49,7 +49,7 @@ def read(inpath): @@ -49,7 +49,7 @@ def read(inpath):
49 eprint("Error: missing mentions layer for text %s!" % textname) 49 eprint("Error: missing mentions layer for text %s!" % textname)
50 return None 50 return None
51 51
52 - if os.path.exists(ann_coreference) and not CLEAR_INPUT: 52 + if os.path.exists(ann_coreference) and not clear_mentions:
53 add_coreference_layer(ann_coreference, text) 53 add_coreference_layer(ann_coreference, text)
54 54
55 return text 55 return text
@@ -215,6 +215,9 @@ def get_mention(mention, mnt_id, segments, segments_ids, paragraph_id, sentence_ @@ -215,6 +215,9 @@ def get_mention(mention, mnt_id, segments, segments_ids, paragraph_id, sentence_
215 semh_id = get_fval(f).split('#')[-1] 215 semh_id = get_fval(f).split('#')[-1]
216 semh = segments[semh_id] 216 semh = segments[semh_id]
217 217
  218 + if len(mnt_segments) == 0:
  219 + mnt_segments.append(semh)
  220 +
218 (sent_segments, prec_context, follow_context, 221 (sent_segments, prec_context, follow_context,
219 first_in_sentence, first_in_paragraph) = get_context(mnt_segments, segments, segments_ids) 222 first_in_sentence, first_in_paragraph) = get_context(mnt_segments, segments, segments_ids)
220 223
@@ -272,7 +275,7 @@ def get_prec_context(mention_start, segments, segments_ids): @@ -272,7 +275,7 @@ def get_prec_context(mention_start, segments, segments_ids):
272 while context_start >= 0: 275 while context_start >= 0:
273 if not word_to_ignore(segments[segments_ids[context_start]]): 276 if not word_to_ignore(segments[segments_ids[context_start]]):
274 context.append(segments[segments_ids[context_start]]) 277 context.append(segments[segments_ids[context_start]])
275 - if len(context) == CONTEXT: 278 + if len(context) == conf.CONTEXT:
276 break 279 break
277 context_start -= 1 280 context_start -= 1
278 context.reverse() 281 context.reverse()
@@ -285,7 +288,7 @@ def get_follow_context(mention_end, segments, segments_ids): @@ -285,7 +288,7 @@ def get_follow_context(mention_end, segments, segments_ids):
285 while context_end < len(segments): 288 while context_end < len(segments):
286 if not word_to_ignore(segments[segments_ids[context_end]]): 289 if not word_to_ignore(segments[segments_ids[context_end]]):
287 context.append(segments[segments_ids[context_end]]) 290 context.append(segments[segments_ids[context_end]])
288 - if len(context) == CONTEXT: 291 + if len(context) == conf.CONTEXT:
289 break 292 break
290 context_end += 1 293 context_end += 1
291 return context 294 return context
@@ -341,8 +344,8 @@ def get_rarest_word(words): @@ -341,8 +344,8 @@ def get_rarest_word(words):
341 rarest_word = words[0] 344 rarest_word = words[0]
342 for i, word in enumerate(words): 345 for i, word in enumerate(words):
343 word_freq = 0 346 word_freq = 0
344 - if word['base'] in FREQ_LIST:  
345 - word_freq = FREQ_LIST[word['base']] 347 + if word['base'] in conf.FREQ_LIST:
  348 + word_freq = conf.FREQ_LIST[word['base']]
346 349
347 if i == 0 or word_freq < min_freq: 350 if i == 0 or word_freq < min_freq:
348 min_freq = word_freq 351 min_freq = word_freq
corneferencer/main.py
@@ -4,9 +4,11 @@ import sys @@ -4,9 +4,11 @@ import sys
4 from argparse import ArgumentParser 4 from argparse import ArgumentParser
5 from natsort import natsorted 5 from natsort import natsorted
6 6
7 -sys.path.append(os.path.abspath(os.path.join('..'))) 7 +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
  8 +
8 9
9 import conf 10 import conf
  11 +import utils
10 from inout import mmax, tei 12 from inout import mmax, tei
11 from inout.constants import INPUT_FORMATS 13 from inout.constants import INPUT_FORMATS
12 from resolvers import resolve 14 from resolvers import resolve
@@ -27,22 +29,25 @@ def main(): @@ -27,22 +29,25 @@ def main():
27 if conf.NEURAL_MODEL_ARCHITECTURE == 'siamese': 29 if conf.NEURAL_MODEL_ARCHITECTURE == 'siamese':
28 resolver = conf.NEURAL_MODEL_ARCHITECTURE 30 resolver = conf.NEURAL_MODEL_ARCHITECTURE
29 eprint("Warning: Using %s resolver because of selected neural model architecture!" % 31 eprint("Warning: Using %s resolver because of selected neural model architecture!" %
30 - conf.NEURAL_MODEL_ARCHITECTURE)  
31 - process_texts(args.input, args.output, args.format, resolver, args.threshold) 32 + conf.NEURAL_MODEL_ARCHITECTURE)
  33 + process_texts(args.input, args.output, args.format, resolver, args.threshold, args.model)
32 34
33 35
34 def parse_arguments(): 36 def parse_arguments():
35 parser = ArgumentParser(description='Corneferencer: coreference resolver using neural nets.') 37 parser = ArgumentParser(description='Corneferencer: coreference resolver using neural nets.')
  38 + parser.add_argument('-f', '--format', type=str, action='store',
  39 + dest='format', default=INPUT_FORMATS[0],
  40 + help='input format; default: %s; possibilities: %s'
  41 + % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS)))
36 parser.add_argument('-i', '--input', type=str, action='store', 42 parser.add_argument('-i', '--input', type=str, action='store',
37 dest='input', default='', 43 dest='input', default='',
38 help='input file or dir path') 44 help='input file or dir path')
  45 + parser.add_argument('-m', '--model', type=str, action='store',
  46 + dest='model', default='',
  47 + help='neural model path; default: %s' % conf.NEURAL_MODEL_PATH)
39 parser.add_argument('-o', '--output', type=str, action='store', 48 parser.add_argument('-o', '--output', type=str, action='store',
40 dest='output', default='', 49 dest='output', default='',
41 help='output path; if not specified writes output to standard output') 50 help='output path; if not specified writes output to standard output')
42 - parser.add_argument('-f', '--format', type=str, action='store',  
43 - dest='format', default=INPUT_FORMATS[0],  
44 - help='input format; default: %s; possibilities: %s'  
45 - % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS)))  
46 parser.add_argument('-r', '--resolver', type=str, action='store', 51 parser.add_argument('-r', '--resolver', type=str, action='store',
47 dest='resolver', default=RESOLVERS[0], 52 dest='resolver', default=RESOLVERS[0],
48 help='resolve algorithm; default: %s; possibilities: %s' 53 help='resolve algorithm; default: %s; possibilities: %s'
@@ -55,16 +60,17 @@ def parse_arguments(): @@ -55,16 +60,17 @@ def parse_arguments():
55 return args 60 return args
56 61
57 62
58 -def process_texts(inpath, outpath, informat, resolver, threshold): 63 +def process_texts(inpath, outpath, informat, resolver, threshold, model_path):
  64 + model = utils.initialize_neural_model(conf.NEURAL_MODEL_ARCHITECTURE, conf.NUMBER_OF_FEATURES, model_path)
59 if os.path.isdir(inpath): 65 if os.path.isdir(inpath):
60 - process_directory(inpath, outpath, informat, resolver, threshold) 66 + process_directory(inpath, outpath, informat, resolver, threshold, model)
61 elif os.path.isfile(inpath): 67 elif os.path.isfile(inpath):
62 - process_text(inpath, outpath, informat, resolver, threshold) 68 + process_text(inpath, outpath, informat, resolver, threshold, model)
63 else: 69 else:
64 eprint("Error: Specified input does not exist!") 70 eprint("Error: Specified input does not exist!")
65 71
66 72
67 -def process_directory(inpath, outpath, informat, resolver, threshold): 73 +def process_directory(inpath, outpath, informat, resolver, threshold, model):
68 inpath = os.path.abspath(inpath) 74 inpath = os.path.abspath(inpath)
69 outpath = os.path.abspath(outpath) 75 outpath = os.path.abspath(outpath)
70 76
@@ -75,38 +81,38 @@ def process_directory(inpath, outpath, informat, resolver, threshold): @@ -75,38 +81,38 @@ def process_directory(inpath, outpath, informat, resolver, threshold):
75 textname = os.path.splitext(os.path.basename(filename))[0] 81 textname = os.path.splitext(os.path.basename(filename))[0]
76 textoutput = os.path.join(outpath, textname) 82 textoutput = os.path.join(outpath, textname)
77 textinput = os.path.join(inpath, filename) 83 textinput = os.path.join(inpath, filename)
78 - process_text(textinput, textoutput, informat, resolver, threshold) 84 + process_text(textinput, textoutput, informat, resolver, threshold, model)
79 85
80 86
81 -def process_text(inpath, outpath, informat, resolver, threshold): 87 +def process_text(inpath, outpath, informat, resolver, threshold, model):
82 basename = os.path.basename(inpath) 88 basename = os.path.basename(inpath)
83 if informat == 'mmax' and basename.endswith('.mmax'): 89 if informat == 'mmax' and basename.endswith('.mmax'):
84 print (basename) 90 print (basename)
85 text = mmax.read(inpath) 91 text = mmax.read(inpath)
86 if resolver == 'incremental': 92 if resolver == 'incremental':
87 - resolve.incremental(text, threshold) 93 + resolve.incremental(text, threshold, model)
88 elif resolver == 'entity_based': 94 elif resolver == 'entity_based':
89 - resolve.entity_based(text, threshold) 95 + resolve.entity_based(text, threshold, model)
90 elif resolver == 'closest': 96 elif resolver == 'closest':
91 - resolve.closest(text, threshold) 97 + resolve.closest(text, threshold, model)
92 elif resolver == 'siamese': 98 elif resolver == 'siamese':
93 - resolve.siamese(text, threshold) 99 + resolve.siamese(text, threshold, model)
94 elif resolver == 'all2all': 100 elif resolver == 'all2all':
95 - resolve.all2all(text, threshold) 101 + resolve.all2all(text, threshold, model)
96 mmax.write(inpath, outpath, text) 102 mmax.write(inpath, outpath, text)
97 elif informat == 'tei': 103 elif informat == 'tei':
98 print (basename) 104 print (basename)
99 text = tei.read(inpath) 105 text = tei.read(inpath)
100 if resolver == 'incremental': 106 if resolver == 'incremental':
101 - resolve.incremental(text, threshold) 107 + resolve.incremental(text, threshold, model)
102 elif resolver == 'entity_based': 108 elif resolver == 'entity_based':
103 - resolve.entity_based(text, threshold) 109 + resolve.entity_based(text, threshold, model)
104 elif resolver == 'closest': 110 elif resolver == 'closest':
105 - resolve.closest(text, threshold) 111 + resolve.closest(text, threshold, model)
106 elif resolver == 'siamese': 112 elif resolver == 'siamese':
107 - resolve.siamese(text, threshold) 113 + resolve.siamese(text, threshold, model)
108 elif resolver == 'all2all': 114 elif resolver == 'all2all':
109 - resolve.all2all(text, threshold) 115 + resolve.all2all(text, threshold, model)
110 tei.write(inpath, outpath, text) 116 tei.write(inpath, outpath, text)
111 117
112 118
corneferencer/prepare_data.py 0 → 100644
  1 +# -*- coding: utf-8 -*-
  2 +
  3 +import codecs
  4 +import os
  5 +import random
  6 +import sys
  7 +
  8 +from itertools import combinations
  9 +from argparse import ArgumentParser
  10 +from natsort import natsorted
  11 +
  12 +sys.path.append(os.path.abspath(os.path.join('..')))
  13 +
  14 +from inout import mmax, tei
  15 +from inout.constants import INPUT_FORMATS
  16 +from utils import eprint
  17 +from corneferencer.resolvers import vectors
  18 +
  19 +
  20 +POS_COUNT = 0
  21 +NEG_COUNT = 0
  22 +
  23 +
  24 +def main():
  25 + args = parse_arguments()
  26 + if not args.input:
  27 + eprint("Error: Input file(s) not specified!")
  28 + elif args.format not in INPUT_FORMATS:
  29 + eprint("Error: Unknown input file format!")
  30 + else:
  31 + process_texts(args.input, args.output, args.format, args.proportion)
  32 +
  33 +
  34 +def parse_arguments():
  35 + parser = ArgumentParser(description='Corneferencer: data preparator for neural nets training.')
  36 + parser.add_argument('-i', '--input', type=str, action='store',
  37 + dest='input', default='',
  38 + help='input dir path')
  39 + parser.add_argument('-o', '--output', type=str, action='store',
  40 + dest='output', default='',
  41 + help='output path; if not specified writes output to standard output')
  42 + parser.add_argument('-f', '--format', type=str, action='store',
  43 + dest='format', default=INPUT_FORMATS[0],
  44 + help='input format; default: %s; possibilities: %s'
  45 + % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS)))
  46 + parser.add_argument('-p', '--proportion', type=int, action='store',
  47 + dest='proportion', default=5,
  48 + help='negative examples proportion; default: 5')
  49 + args = parser.parse_args()
  50 + return args
  51 +
  52 +
  53 +def process_texts(inpath, outpath, informat, proportion):
  54 + if os.path.isdir(inpath):
  55 + process_directory(inpath, outpath, informat, proportion)
  56 + else:
  57 + eprint("Error: Specified input does not exist or is not a directory!")
  58 +
  59 +
  60 +def process_directory(inpath, outpath, informat, proportion):
  61 + inpath = os.path.abspath(inpath)
  62 + outpath = os.path.abspath(outpath)
  63 +
  64 + try:
  65 + create_data_vectors(inpath, outpath, informat, proportion)
  66 + finally:
  67 + print ('Positives: ', POS_COUNT)
  68 + print ('Negatives: ', NEG_COUNT)
  69 +
  70 +
  71 +def create_data_vectors(inpath, outpath, informat, proportion):
  72 + features_file = codecs.open(outpath, 'w', 'utf-8')
  73 +
  74 + files = os.listdir(inpath)
  75 + files = natsorted(files)
  76 +
  77 + for filename in files:
  78 + textname = os.path.splitext(os.path.basename(filename))[0]
  79 + textinput = os.path.join(inpath, filename)
  80 +
  81 + print ('Processing text: %s' % textname)
  82 + text = None
  83 + if informat == 'mmax' and filename.endswith('.mmax'):
  84 + text = mmax.read(textinput, False)
  85 + elif informat == 'tei':
  86 + text = tei.read(textinput, False)
  87 +
  88 + positives, negatives = diff_mentions(text, proportion)
  89 + write_features(features_file, positives, negatives)
  90 +
  91 +
  92 +def diff_mentions(text, proportion):
  93 + sets = text.get_sets()
  94 + all_mentions = text.get_mentions()
  95 + positives = get_positives(sets)
  96 + positives, negatives = get_negatives_and_update_positives(all_mentions, positives, proportion)
  97 + return positives, negatives
  98 +
  99 +
  100 +def get_positives(sets):
  101 + positives = []
  102 + for set_id in sets:
  103 + coref_set = sets[set_id]
  104 + positives.extend(list(combinations(coref_set, 2)))
  105 + return positives
  106 +
  107 +
  108 +def get_negatives_and_update_positives(all_mentions, positives, proportion):
  109 + all_pairs = list(combinations(all_mentions, 2))
  110 +
  111 + all_pairs = set(all_pairs)
  112 + negatives = [pair for pair in all_pairs if pair not in positives]
  113 + samples_count = proportion * len(positives)
  114 + if samples_count > len(negatives):
  115 + samples_count = len(negatives)
  116 + if proportion == 1:
  117 + positives = random.sample(set(positives), samples_count)
  118 + print (u'Więcej przypadków pozytywnych niż negatywnych!')
  119 + negatives = random.sample(set(negatives), samples_count)
  120 + return positives, negatives
  121 +
  122 +
  123 +def write_features(features_file, positives, negatives):
  124 + global POS_COUNT
  125 + POS_COUNT += len(positives)
  126 + for pair in positives:
  127 + vector = vectors.get_pair_vector(pair[0], pair[1])
  128 + vector.append(1.0)
  129 + features_file.write(u'%s\n' % u'\t'.join([str(feature) for feature in vector]))
  130 +
  131 + global NEG_COUNT
  132 + NEG_COUNT += len(negatives)
  133 + for pair in negatives:
  134 + vector = vectors.get_pair_vector(pair[0], pair[1])
  135 + vector.append(0.0)
  136 + features_file.write(u'%s\n' % u'\t'.join([str(feature) for feature in vector]))
  137 +
  138 +
  139 +if __name__ == '__main__':
  140 + main()
corneferencer/resolvers/features.py
@@ -72,97 +72,97 @@ def sentence_vec(mention): @@ -72,97 +72,97 @@ def sentence_vec(mention):
72 72
73 73
74 def mention_type(mention): 74 def mention_type(mention):
75 - type_vec = [0] * 4 75 + type_vec = [0.0] * 4
76 if mention.head is None: 76 if mention.head is None:
77 - type_vec[3] = 1 77 + type_vec[3] = 1.0
78 elif mention.head['ctag'] in constants.NOUN_TAGS: 78 elif mention.head['ctag'] in constants.NOUN_TAGS:
79 - type_vec[0] = 1 79 + type_vec[0] = 1.0
80 elif mention.head['ctag'] in constants.PPRON_TAGS: 80 elif mention.head['ctag'] in constants.PPRON_TAGS:
81 - type_vec[1] = 1 81 + type_vec[1] = 1.0
82 elif mention.head['ctag'] in constants.ZERO_TAGS: 82 elif mention.head['ctag'] in constants.ZERO_TAGS:
83 - type_vec[2] = 1 83 + type_vec[2] = 1.0
84 else: 84 else:
85 - type_vec[3] = 1 85 + type_vec[3] = 1.0
86 return type_vec 86 return type_vec
87 87
88 88
89 def is_first_second_person(mention): 89 def is_first_second_person(mention):
90 if mention.head is None: 90 if mention.head is None:
91 - return 0 91 + return 0.0
92 if mention.head['person'] in constants.FIRST_SECOND_PERSON: 92 if mention.head['person'] in constants.FIRST_SECOND_PERSON:
93 - return 1  
94 - return 0 93 + return 1.0
  94 + return 0.0
95 95
96 96
97 def is_demonstrative(mention): 97 def is_demonstrative(mention):
98 if mention.words[0]['base'].lower() in constants.INDICATIVE_PRONS_BASES: 98 if mention.words[0]['base'].lower() in constants.INDICATIVE_PRONS_BASES:
99 - return 1  
100 - return 0 99 + return 1.0
  100 + return 0.0
101 101
102 102
103 def is_demonstrative_nominal(mention): 103 def is_demonstrative_nominal(mention):
104 if mention.head is None: 104 if mention.head is None:
105 - return 0 105 + return 0.0
106 if is_demonstrative(mention) and mention.head['ctag'] in constants.NOUN_TAGS: 106 if is_demonstrative(mention) and mention.head['ctag'] in constants.NOUN_TAGS:
107 - return 1  
108 - return 0 107 + return 1.0
  108 + return 0.0
109 109
110 110
111 def is_demonstrative_pronoun(mention): 111 def is_demonstrative_pronoun(mention):
112 if mention.head is None: 112 if mention.head is None:
113 - return 0 113 + return 0.0
114 if (is_demonstrative(mention) and 114 if (is_demonstrative(mention) and
115 (mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS)): 115 (mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS)):
116 - return 1  
117 - return 0 116 + return 1.0
  117 + return 0.0
118 118
119 119
120 def is_refl_pronoun(mention): 120 def is_refl_pronoun(mention):
121 if mention.head is None: 121 if mention.head is None:
122 - return 0 122 + return 0.0
123 if mention.head['ctag'] in constants.SIEBIE_TAGS: 123 if mention.head['ctag'] in constants.SIEBIE_TAGS:
124 - return 1  
125 - return 0 124 + return 1.0
  125 + return 0.0
126 126
127 127
128 def is_first_in_sentence(mention): 128 def is_first_in_sentence(mention):
129 if mention.first_in_sentence: 129 if mention.first_in_sentence:
130 - return 1  
131 - return 0 130 + return 1.0
  131 + return 0.0
132 132
133 133
134 def is_zero_or_pronoun(mention): 134 def is_zero_or_pronoun(mention):
135 if mention.head is None: 135 if mention.head is None:
136 - return 0 136 + return 0.0
137 if mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS: 137 if mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS:
138 - return 1  
139 - return 0 138 + return 1.0
  139 + return 0.0
140 140
141 141
142 def head_contains_digit(mention): 142 def head_contains_digit(mention):
143 _digits = re.compile('\d') 143 _digits = re.compile('\d')
144 if _digits.search(mention.head_orth): 144 if _digits.search(mention.head_orth):
145 - return 1  
146 - return 0 145 + return 1.0
  146 + return 0.0
147 147
148 148
149 def mention_contains_digit(mention): 149 def mention_contains_digit(mention):
150 _digits = re.compile('\d') 150 _digits = re.compile('\d')
151 if _digits.search(mention.text): 151 if _digits.search(mention.text):
152 - return 1  
153 - return 0 152 + return 1.0
  153 + return 0.0
154 154
155 155
156 def contains_letter(mention): 156 def contains_letter(mention):
157 if any(c.isalpha() for c in mention.text): 157 if any(c.isalpha() for c in mention.text):
158 - return 1  
159 - return 0 158 + return 1.0
  159 + return 0.0
160 160
161 161
162 def post_modified(mention): 162 def post_modified(mention):
163 if mention.head_orth != mention.words[-1]['orth']: 163 if mention.head_orth != mention.words[-1]['orth']:
164 - return 1  
165 - return 0 164 + return 1.0
  165 + return 0.0
166 166
167 167
168 # pair features 168 # pair features
@@ -171,20 +171,20 @@ def distances_vec(ante, ana): @@ -171,20 +171,20 @@ def distances_vec(ante, ana):
171 171
172 mnts_intersect = pair_intersect(ante, ana) 172 mnts_intersect = pair_intersect(ante, ana)
173 173
174 - words_dist = [0] * 11 174 + words_dist = [0.0] * 11
175 words_bucket = 0 175 words_bucket = 0
176 - if mnts_intersect != 1: 176 + if mnts_intersect != 1.0:
177 words_bucket = get_distance_bucket(ana.start_in_words - ante.end_in_words) 177 words_bucket = get_distance_bucket(ana.start_in_words - ante.end_in_words)
178 - words_dist[words_bucket] = 1 178 + words_dist[words_bucket] = 1.0
179 vec.extend(words_dist) 179 vec.extend(words_dist)
180 180
181 - mentions_dist = [0] * 11 181 + mentions_dist = [0.0] * 11
182 mentions_bucket = 0 182 mentions_bucket = 0
183 - if mnts_intersect != 1: 183 + if mnts_intersect != 1.0:
184 mentions_bucket = get_distance_bucket(ana.position_in_mentions - ante.position_in_mentions) 184 mentions_bucket = get_distance_bucket(ana.position_in_mentions - ante.position_in_mentions)
185 if words_bucket == 10: 185 if words_bucket == 10:
186 mentions_bucket = 10 186 mentions_bucket = 10
187 - mentions_dist[mentions_bucket] = 1 187 + mentions_dist[mentions_bucket] = 1.0
188 vec.extend(mentions_dist) 188 vec.extend(mentions_dist)
189 189
190 vec.append(mnts_intersect) 190 vec.append(mnts_intersect)
@@ -196,45 +196,45 @@ def pair_intersect(ante, ana): @@ -196,45 +196,45 @@ def pair_intersect(ante, ana):
196 for ante_word in ante.words: 196 for ante_word in ante.words:
197 for ana_word in ana.words: 197 for ana_word in ana.words:
198 if ana_word['id'] == ante_word['id']: 198 if ana_word['id'] == ante_word['id']:
199 - return 1  
200 - return 0 199 + return 1.0
  200 + return 0.0
201 201
202 202
203 def head_match(ante, ana): 203 def head_match(ante, ana):
204 if ante.head_orth.lower() == ana.head_orth.lower(): 204 if ante.head_orth.lower() == ana.head_orth.lower():
205 - return 1  
206 - return 0 205 + return 1.0
  206 + return 0.0
207 207
208 208
209 def exact_match(ante, ana): 209 def exact_match(ante, ana):
210 if ante.text.lower() == ana.text.lower(): 210 if ante.text.lower() == ana.text.lower():
211 - return 1  
212 - return 0 211 + return 1.0
  212 + return 0.0
213 213
214 214
215 def base_match(ante, ana): 215 def base_match(ante, ana):
216 if ante.lemmatized_text.lower() == ana.lemmatized_text.lower(): 216 if ante.lemmatized_text.lower() == ana.lemmatized_text.lower():
217 - return 1  
218 - return 0 217 + return 1.0
  218 + return 0.0
219 219
220 220
221 def ante_contains_rarest_from_ana(ante, ana): 221 def ante_contains_rarest_from_ana(ante, ana):
222 ana_rarest = ana.rarest 222 ana_rarest = ana.rarest
223 for word in ante.words: 223 for word in ante.words:
224 if word['base'] == ana_rarest['base']: 224 if word['base'] == ana_rarest['base']:
225 - return 1  
226 - return 0 225 + return 1.0
  226 + return 0.0
227 227
228 228
229 def agreement(ante, ana, tag_name): 229 def agreement(ante, ana, tag_name):
230 - agr_vec = [0] * 3 230 + agr_vec = [0.0] * 3
231 if (ante.head is None or ana.head is None or 231 if (ante.head is None or ana.head is None or
232 ante.head[tag_name] == 'unk' or ana.head[tag_name] == 'unk'): 232 ante.head[tag_name] == 'unk' or ana.head[tag_name] == 'unk'):
233 - agr_vec[2] = 1 233 + agr_vec[2] = 1.0
234 elif ante.head[tag_name] == ana.head[tag_name]: 234 elif ante.head[tag_name] == ana.head[tag_name]:
235 - agr_vec[0] = 1 235 + agr_vec[0] = 1.0
236 else: 236 else:
237 - agr_vec[1] = 1 237 + agr_vec[1] = 1.0
238 return agr_vec 238 return agr_vec
239 239
240 240
@@ -243,72 +243,72 @@ def is_acronym(ante, ana): @@ -243,72 +243,72 @@ def is_acronym(ante, ana):
243 return check_one_way_acronym(ana.text, ante.text) 243 return check_one_way_acronym(ana.text, ante.text)
244 if ante.text.upper() == ante.text: 244 if ante.text.upper() == ante.text:
245 return check_one_way_acronym(ante.text, ana.text) 245 return check_one_way_acronym(ante.text, ana.text)
246 - return 0 246 + return 0.0
247 247
248 248
249 def same_sentence(ante, ana): 249 def same_sentence(ante, ana):
250 if ante.sentence_id == ana.sentence_id: 250 if ante.sentence_id == ana.sentence_id:
251 - return 1  
252 - return 0 251 + return 1.0
  252 + return 0.0
253 253
254 254
255 def neighbouring_sentence(ante, ana): 255 def neighbouring_sentence(ante, ana):
256 if ana.sentence_id - ante.sentence_id == 1: 256 if ana.sentence_id - ante.sentence_id == 1:
257 - return 1  
258 - return 0 257 + return 1.0
  258 + return 0.0
259 259
260 260
261 def cousin_sentence(ante, ana): 261 def cousin_sentence(ante, ana):
262 if ana.sentence_id - ante.sentence_id == 2: 262 if ana.sentence_id - ante.sentence_id == 2:
263 - return 1  
264 - return 0 263 + return 1.0
  264 + return 0.0
265 265
266 266
267 def distant_sentence(ante, ana): 267 def distant_sentence(ante, ana):
268 if ana.sentence_id - ante.sentence_id > 2: 268 if ana.sentence_id - ante.sentence_id > 2:
269 - return 1  
270 - return 0 269 + return 1.0
  270 + return 0.0
271 271
272 272
273 def same_paragraph(ante, ana): 273 def same_paragraph(ante, ana):
274 if ante.paragraph_id == ana.paragraph_id: 274 if ante.paragraph_id == ana.paragraph_id:
275 - return 1  
276 - return 0 275 + return 1.0
  276 + return 0.0
277 277
278 278
279 def flat_gender_agreement(ante, ana): 279 def flat_gender_agreement(ante, ana):
280 - agr_vec = [0] * 3 280 + agr_vec = [0.0] * 3
281 if (ante.head is None or ana.head is None or 281 if (ante.head is None or ana.head is None or
282 ante.head['gender'] == 'unk' or ana.head['gender'] == 'unk'): 282 ante.head['gender'] == 'unk' or ana.head['gender'] == 'unk'):
283 - agr_vec[2] = 1 283 + agr_vec[2] = 1.0
284 elif (ante.head['gender'] == ana.head['gender'] or 284 elif (ante.head['gender'] == ana.head['gender'] or
285 (ante.head['gender'] in constants.MASCULINE_TAGS and ana.head['gender'] in constants.MASCULINE_TAGS)): 285 (ante.head['gender'] in constants.MASCULINE_TAGS and ana.head['gender'] in constants.MASCULINE_TAGS)):
286 - agr_vec[0] = 1 286 + agr_vec[0] = 1.0
287 else: 287 else:
288 - agr_vec[1] = 1 288 + agr_vec[1] = 1.0
289 return agr_vec 289 return agr_vec
290 290
291 291
292 def left_match(ante, ana): 292 def left_match(ante, ana):
293 if (ante.text.lower().startswith(ana.text.lower()) or 293 if (ante.text.lower().startswith(ana.text.lower()) or
294 ana.text.lower().startswith(ante.text.lower())): 294 ana.text.lower().startswith(ante.text.lower())):
295 - return 1  
296 - return 0 295 + return 1.0
  296 + return 0.0
297 297
298 298
299 def right_match(ante, ana): 299 def right_match(ante, ana):
300 if (ante.text.lower().endswith(ana.text.lower()) or 300 if (ante.text.lower().endswith(ana.text.lower()) or
301 ana.text.lower().endswith(ante.text.lower())): 301 ana.text.lower().endswith(ante.text.lower())):
302 - return 1  
303 - return 0 302 + return 1.0
  303 + return 0.0
304 304
305 305
306 def abbrev2(ante, ana): 306 def abbrev2(ante, ana):
307 ante_abbrev = get_abbrev(ante) 307 ante_abbrev = get_abbrev(ante)
308 ana_abbrev = get_abbrev(ana) 308 ana_abbrev = get_abbrev(ana)
309 if ante.head_orth == ana_abbrev or ana.head_orth == ante_abbrev: 309 if ante.head_orth == ana_abbrev or ana.head_orth == ante_abbrev:
310 - return 1  
311 - return 0 310 + return 1.0
  311 + return 0.0
312 312
313 313
314 def string_kernel(ante, ana): 314 def string_kernel(ante, ana):
@@ -326,7 +326,7 @@ def head_string_kernel(ante, ana): @@ -326,7 +326,7 @@ def head_string_kernel(ante, ana):
326 def wordnet_synonyms(ante, ana): 326 def wordnet_synonyms(ante, ana):
327 ante_synonyms = set() 327 ante_synonyms = set()
328 if ante.head is None or ana.head is None: 328 if ante.head is None or ana.head is None:
329 - return 0 329 + return 0.0
330 330
331 if ante.head['base'] in conf.LEMMA2SYNONYMS: 331 if ante.head['base'] in conf.LEMMA2SYNONYMS:
332 ante_synonyms = conf.LEMMA2SYNONYMS[ante.head['base']] 332 ante_synonyms = conf.LEMMA2SYNONYMS[ante.head['base']]
@@ -336,13 +336,13 @@ def wordnet_synonyms(ante, ana): @@ -336,13 +336,13 @@ def wordnet_synonyms(ante, ana):
336 ana_synonyms = conf.LEMMA2SYNONYMS[ana.head['base']] 336 ana_synonyms = conf.LEMMA2SYNONYMS[ana.head['base']]
337 337
338 if ana.head['base'] in ante_synonyms or ante.head['base'] in ana_synonyms: 338 if ana.head['base'] in ante_synonyms or ante.head['base'] in ana_synonyms:
339 - return 1  
340 - return 0 339 + return 1.0
  340 + return 0.0
341 341
342 342
343 def wordnet_ana_is_hypernym(ante, ana): 343 def wordnet_ana_is_hypernym(ante, ana):
344 if ante.head is None or ana.head is None: 344 if ante.head is None or ana.head is None:
345 - return 0 345 + return 0.0
346 346
347 ante_hypernyms = set() 347 ante_hypernyms = set()
348 if ante.head['base'] in conf.LEMMA2HYPERNYMS: 348 if ante.head['base'] in conf.LEMMA2HYPERNYMS:
@@ -353,16 +353,16 @@ def wordnet_ana_is_hypernym(ante, ana): @@ -353,16 +353,16 @@ def wordnet_ana_is_hypernym(ante, ana):
353 ana_hypernyms = conf.LEMMA2HYPERNYMS[ana.head['base']] 353 ana_hypernyms = conf.LEMMA2HYPERNYMS[ana.head['base']]
354 354
355 if not ante_hypernyms or not ana_hypernyms: 355 if not ante_hypernyms or not ana_hypernyms:
356 - return 0 356 + return 0.0
357 357
358 if ana.head['base'] in ante_hypernyms: 358 if ana.head['base'] in ante_hypernyms:
359 - return 1  
360 - return 0 359 + return 1.0
  360 + return 0.0
361 361
362 362
363 def wordnet_ante_is_hypernym(ante, ana): 363 def wordnet_ante_is_hypernym(ante, ana):
364 if ante.head is None or ana.head is None: 364 if ante.head is None or ana.head is None:
365 - return 0 365 + return 0.0
366 366
367 ana_hypernyms = set() 367 ana_hypernyms = set()
368 if ana.head['base'] in conf.LEMMA2HYPERNYMS: 368 if ana.head['base'] in conf.LEMMA2HYPERNYMS:
@@ -373,18 +373,18 @@ def wordnet_ante_is_hypernym(ante, ana): @@ -373,18 +373,18 @@ def wordnet_ante_is_hypernym(ante, ana):
373 ante_hypernyms = conf.LEMMA2HYPERNYMS[ante.head['base']] 373 ante_hypernyms = conf.LEMMA2HYPERNYMS[ante.head['base']]
374 374
375 if not ante_hypernyms or not ana_hypernyms: 375 if not ante_hypernyms or not ana_hypernyms:
376 - return 0 376 + return 0.0
377 377
378 if ante.head['base'] in ana_hypernyms: 378 if ante.head['base'] in ana_hypernyms:
379 - return 1  
380 - return 0 379 + return 1.0
  380 + return 0.0
381 381
382 382
383 def wikipedia_link(ante, ana): 383 def wikipedia_link(ante, ana):
384 ante_base = ante.lemmatized_text.lower() 384 ante_base = ante.lemmatized_text.lower()
385 ana_base = ana.lemmatized_text.lower() 385 ana_base = ana.lemmatized_text.lower()
386 if ante_base == ana_base: 386 if ante_base == ana_base:
387 - return 1 387 + return 1.0
388 388
389 ante_links = set() 389 ante_links = set()
390 if ante_base in conf.TITLE2LINKS: 390 if ante_base in conf.TITLE2LINKS:
@@ -395,16 +395,16 @@ def wikipedia_link(ante, ana): @@ -395,16 +395,16 @@ def wikipedia_link(ante, ana):
395 ana_links = conf.TITLE2LINKS[ana_base] 395 ana_links = conf.TITLE2LINKS[ana_base]
396 396
397 if ana_base in ante_links or ante_base in ana_links: 397 if ana_base in ante_links or ante_base in ana_links:
398 - return 1 398 + return 1.0
399 399
400 - return 0 400 + return 0.0
401 401
402 402
403 def wikipedia_mutual_link(ante, ana): 403 def wikipedia_mutual_link(ante, ana):
404 ante_base = ante.lemmatized_text.lower() 404 ante_base = ante.lemmatized_text.lower()
405 ana_base = ana.lemmatized_text.lower() 405 ana_base = ana.lemmatized_text.lower()
406 if ante_base == ana_base: 406 if ante_base == ana_base:
407 - return 1 407 + return 1.0
408 408
409 ante_links = set() 409 ante_links = set()
410 if ante_base in conf.TITLE2LINKS: 410 if ante_base in conf.TITLE2LINKS:
@@ -415,52 +415,52 @@ def wikipedia_mutual_link(ante, ana): @@ -415,52 +415,52 @@ def wikipedia_mutual_link(ante, ana):
415 ana_links = conf.TITLE2LINKS[ana_base] 415 ana_links = conf.TITLE2LINKS[ana_base]
416 416
417 if ana_base in ante_links and ante_base in ana_links: 417 if ana_base in ante_links and ante_base in ana_links:
418 - return 1 418 + return 1.0
419 419
420 - return 0 420 + return 0.0
421 421
422 422
423 def wikipedia_redirect(ante, ana): 423 def wikipedia_redirect(ante, ana):
424 ante_base = ante.lemmatized_text.lower() 424 ante_base = ante.lemmatized_text.lower()
425 ana_base = ana.lemmatized_text.lower() 425 ana_base = ana.lemmatized_text.lower()
426 if ante_base == ana_base: 426 if ante_base == ana_base:
427 - return 1 427 + return 1.0
428 428
429 if ante_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ante_base] == ana_base: 429 if ante_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ante_base] == ana_base:
430 - return 1 430 + return 1.0
431 431
432 if ana_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ana_base] == ante_base: 432 if ana_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ana_base] == ante_base:
433 - return 1 433 + return 1.0
434 434
435 - return 0 435 + return 0.0
436 436
437 437
438 def samesent_anapron_antefirstinpar(ante, ana): 438 def samesent_anapron_antefirstinpar(ante, ana):
439 if same_sentence(ante, ana) and is_zero_or_pronoun(ana) and ante.first_in_paragraph: 439 if same_sentence(ante, ana) and is_zero_or_pronoun(ana) and ante.first_in_paragraph:
440 - return 1  
441 - return 0 440 + return 1.0
  441 + return 0.0
442 442
443 443
444 def samesent_antefirstinpar_personnumbermatch(ante, ana): 444 def samesent_antefirstinpar_personnumbermatch(ante, ana):
445 if (same_sentence(ante, ana) and ante.first_in_paragraph 445 if (same_sentence(ante, ana) and ante.first_in_paragraph
446 and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]): 446 and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]):
447 - return 1  
448 - return 0 447 + return 1.0
  448 + return 0.0
449 449
450 450
451 def adjsent_anapron_adjmen_personnumbermatch(ante, ana): 451 def adjsent_anapron_adjmen_personnumbermatch(ante, ana):
452 if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana) 452 if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana)
453 and ana.position_in_mentions - ante.position_in_mentions == 1 453 and ana.position_in_mentions - ante.position_in_mentions == 1
454 and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]): 454 and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]):
455 - return 1  
456 - return 0 455 + return 1.0
  456 + return 0.0
457 457
458 458
459 def adjsent_anapron_adjmen(ante, ana): 459 def adjsent_anapron_adjmen(ante, ana):
460 if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana) 460 if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana)
461 and ana.position_in_mentions - ante.position_in_mentions == 1): 461 and ana.position_in_mentions - ante.position_in_mentions == 1):
462 - return 1  
463 - return 0 462 + return 1.0
  463 + return 0.0
464 464
465 465
466 # supporting functions 466 # supporting functions
@@ -523,8 +523,8 @@ def check_one_way_acronym(acronym, expression): @@ -523,8 +523,8 @@ def check_one_way_acronym(acronym, expression):
523 if expr2: 523 if expr2:
524 initials += expr2[0].upper() 524 initials += expr2[0].upper()
525 if acronym == initials: 525 if acronym == initials:
526 - return 1  
527 - return 0 526 + return 1.0
  527 + return 0.0
528 528
529 529
530 def get_abbrev(mention): 530 def get_abbrev(mention):
corneferencer/resolvers/resolve.py
1 import numpy 1 import numpy
2 2
3 -from conf import NEURAL_MODEL  
4 -from corneferencer.resolvers import features  
5 -from corneferencer.resolvers.vectors import get_pair_features, get_pair_vector 3 +from corneferencer.resolvers import features, vectors
6 4
7 5
8 -def siamese(text, threshold): 6 +def siamese(text, threshold, neural_model):
9 last_set_id = 0 7 last_set_id = 0
10 for i, ana in enumerate(text.mentions): 8 for i, ana in enumerate(text.mentions):
11 if i > 0: 9 if i > 0:
12 for ante in reversed(text.mentions[:i]): 10 for ante in reversed(text.mentions[:i]):
13 if not features.pair_intersect(ante, ana): 11 if not features.pair_intersect(ante, ana):
14 - pair_features = get_pair_features(ante, ana) 12 + pair_features = vectors.get_pair_features(ante, ana)
15 13
16 ante_vec = [] 14 ante_vec = []
17 ante_vec.extend(ante.features) 15 ante_vec.extend(ante.features)
@@ -23,7 +21,7 @@ def siamese(text, threshold): @@ -23,7 +21,7 @@ def siamese(text, threshold):
23 ana_vec.extend(pair_features) 21 ana_vec.extend(pair_features)
24 ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32) 22 ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32)
25 23
26 - prediction = NEURAL_MODEL.predict([ante_sample, ana_sample])[0] 24 + prediction = neural_model.predict([ante_sample, ana_sample])[0]
27 25
28 if prediction < threshold: 26 if prediction < threshold:
29 if ante.set: 27 if ante.set:
@@ -37,7 +35,7 @@ def siamese(text, threshold): @@ -37,7 +35,7 @@ def siamese(text, threshold):
37 35
38 36
39 # incremental resolve algorithm 37 # incremental resolve algorithm
40 -def incremental(text, threshold): 38 +def incremental(text, threshold, neural_model):
41 last_set_id = 0 39 last_set_id = 0
42 for i, ana in enumerate(text.mentions): 40 for i, ana in enumerate(text.mentions):
43 if i > 0: 41 if i > 0:
@@ -45,9 +43,9 @@ def incremental(text, threshold): @@ -45,9 +43,9 @@ def incremental(text, threshold):
45 best_ante = None 43 best_ante = None
46 for ante in text.mentions[:i]: 44 for ante in text.mentions[:i]:
47 if not features.pair_intersect(ante, ana): 45 if not features.pair_intersect(ante, ana):
48 - pair_vec = get_pair_vector(ante, ana) 46 + pair_vec = vectors.get_pair_vector(ante, ana)
49 sample = numpy.asarray([pair_vec], dtype=numpy.float32) 47 sample = numpy.asarray([pair_vec], dtype=numpy.float32)
50 - prediction = NEURAL_MODEL.predict(sample)[0] 48 + prediction = neural_model.predict(sample)[0]
51 if prediction > threshold and prediction >= best_prediction: 49 if prediction > threshold and prediction >= best_prediction:
52 best_prediction = prediction 50 best_prediction = prediction
53 best_ante = ante 51 best_ante = ante
@@ -62,86 +60,21 @@ def incremental(text, threshold): @@ -62,86 +60,21 @@ def incremental(text, threshold):
62 60
63 61
64 # all2all resolve algorithm 62 # all2all resolve algorithm
65 -def all2all_debug(text, threshold):  
66 - last_set_id = 0  
67 - for pos1, mnt1 in enumerate(text.mentions):  
68 - best_prediction = 0.0  
69 - best_link = None  
70 - for pos2, mnt2 in enumerate(text.mentions):  
71 - if (mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2):  
72 - ante = mnt1  
73 - ana = mnt2  
74 - if pos2 < pos1:  
75 - ante = mnt2  
76 - ana = mnt1  
77 - pair_vec = get_pair_vector(ante, ana)  
78 - sample = numpy.asarray([pair_vec], dtype=numpy.float32)  
79 - prediction = NEURAL_MODEL.predict(sample)[0]  
80 - if prediction > threshold and prediction > best_prediction:  
81 - best_prediction = prediction  
82 - best_link = mnt2  
83 - if best_link is not None:  
84 - if best_link.set and not mnt1.set:  
85 - mnt1.set = best_link.set  
86 - elif best_link.set and mnt1.set:  
87 - text.merge_sets(best_link.set, mnt1.set)  
88 - elif not best_link.set and not mnt1.set:  
89 - str_set_id = 'set_%d' % last_set_id  
90 - best_link.set = str_set_id  
91 - mnt1.set = str_set_id  
92 - last_set_id += 1  
93 -  
94 -  
95 -def all2all_v1(text, threshold):  
96 - last_set_id = 0  
97 - for pos1, mnt1 in enumerate(text.mentions):  
98 - best_prediction = 0.0  
99 - best_link = None  
100 - for pos2, mnt2 in enumerate(text.mentions):  
101 - if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set)  
102 - and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):  
103 - ante = mnt1  
104 - ana = mnt2  
105 - if pos2 < pos1:  
106 - ante = mnt2  
107 - ana = mnt1  
108 - pair_vec = get_pair_vector(ante, ana)  
109 - sample = numpy.asarray([pair_vec], dtype=numpy.float32)  
110 - prediction = NEURAL_MODEL.predict(sample)[0]  
111 - if prediction > threshold and prediction > best_prediction:  
112 - best_prediction = prediction  
113 - best_link = mnt2  
114 - if best_link is not None:  
115 - if best_link.set and not mnt1.set:  
116 - mnt1.set = best_link.set  
117 - elif not best_link.set and mnt1.set:  
118 - best_link.set = mnt1.set  
119 - elif best_link.set and mnt1.set:  
120 - text.merge_sets(best_link.set, mnt1.set)  
121 - elif not best_link.set and not mnt1.set:  
122 - str_set_id = 'set_%d' % last_set_id  
123 - best_link.set = str_set_id  
124 - mnt1.set = str_set_id  
125 - last_set_id += 1  
126 -  
127 -  
128 -def all2all(text, threshold): 63 +def all2all(text, threshold, neural_model):
129 last_set_id = 0 64 last_set_id = 0
130 sets = text.get_sets() 65 sets = text.get_sets()
131 for pos1, mnt1 in enumerate(text.mentions): 66 for pos1, mnt1 in enumerate(text.mentions):
132 best_prediction = 0.0 67 best_prediction = 0.0
133 best_link = None 68 best_link = None
134 for pos2, mnt2 in enumerate(text.mentions): 69 for pos2, mnt2 in enumerate(text.mentions):
135 - if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set)  
136 - and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): 70 + if (pos2 > pos1 and
  71 + (mnt1.set != mnt2.set or not mnt1.set or not mnt2.set)
  72 + and not features.pair_intersect(mnt1, mnt2)):
137 ante = mnt1 73 ante = mnt1
138 ana = mnt2 74 ana = mnt2
139 - if pos2 < pos1:  
140 - ante = mnt2  
141 - ana = mnt1  
142 - pair_vec = get_pair_vector(ante, ana) 75 + pair_vec = vectors.get_pair_vector(ante, ana)
143 sample = numpy.asarray([pair_vec], dtype=numpy.float32) 76 sample = numpy.asarray([pair_vec], dtype=numpy.float32)
144 - prediction = NEURAL_MODEL.predict(sample)[0] 77 + prediction = neural_model.predict(sample)[0]
145 if prediction > threshold and prediction > best_prediction: 78 if prediction > threshold and prediction > best_prediction:
146 best_prediction = prediction 79 best_prediction = prediction
147 best_link = mnt2 80 best_link = mnt2
@@ -163,12 +96,12 @@ def all2all(text, threshold): @@ -163,12 +96,12 @@ def all2all(text, threshold):
163 96
164 97
165 # entity based resolve algorithm 98 # entity based resolve algorithm
166 -def entity_based(text, threshold): 99 +def entity_based(text, threshold, neural_model):
167 sets = [] 100 sets = []
168 last_set_id = 0 101 last_set_id = 0
169 for i, ana in enumerate(text.mentions): 102 for i, ana in enumerate(text.mentions):
170 if i > 0: 103 if i > 0:
171 - best_fit = get_best_set(sets, ana, threshold) 104 + best_fit = get_best_set(sets, ana, threshold, neural_model)
172 if best_fit is not None: 105 if best_fit is not None:
173 ana.set = best_fit['set_id'] 106 ana.set = best_fit['set_id']
174 best_fit['mentions'].append(ana) 107 best_fit['mentions'].append(ana)
@@ -188,25 +121,25 @@ def entity_based(text, threshold): @@ -188,25 +121,25 @@ def entity_based(text, threshold):
188 remove_singletons(sets) 121 remove_singletons(sets)
189 122
190 123
191 -def get_best_set(sets, ana, threshold): 124 +def get_best_set(sets, ana, threshold, neural_model):
192 best_prediction = 0.0 125 best_prediction = 0.0
193 best_set = None 126 best_set = None
194 for s in sets: 127 for s in sets:
195 - accuracy = predict_set(s['mentions'], ana) 128 + accuracy = predict_set(s['mentions'], ana, neural_model)
196 if accuracy > threshold and accuracy >= best_prediction: 129 if accuracy > threshold and accuracy >= best_prediction:
197 best_prediction = accuracy 130 best_prediction = accuracy
198 best_set = s 131 best_set = s
199 return best_set 132 return best_set
200 133
201 134
202 -def predict_set(mentions, ana): 135 +def predict_set(mentions, ana, neural_model):
203 prediction_sum = 0.0 136 prediction_sum = 0.0
204 for mnt in mentions: 137 for mnt in mentions:
205 prediction = 0.0 138 prediction = 0.0
206 if not features.pair_intersect(mnt, ana): 139 if not features.pair_intersect(mnt, ana):
207 - pair_vec = get_pair_vector(mnt, ana) 140 + pair_vec = vectors.get_pair_vector(mnt, ana)
208 sample = numpy.asarray([pair_vec], dtype=numpy.float32) 141 sample = numpy.asarray([pair_vec], dtype=numpy.float32)
209 - prediction = NEURAL_MODEL.predict(sample)[0] 142 + prediction = neural_model.predict(sample)[0]
210 prediction_sum += prediction 143 prediction_sum += prediction
211 return prediction_sum / float(len(mentions)) 144 return prediction_sum / float(len(mentions))
212 145
@@ -218,15 +151,15 @@ def remove_singletons(sets): @@ -218,15 +151,15 @@ def remove_singletons(sets):
218 151
219 152
220 # closest resolve algorithm 153 # closest resolve algorithm
221 -def closest(text, threshold): 154 +def closest(text, threshold, neural_model):
222 last_set_id = 0 155 last_set_id = 0
223 for i, ana in enumerate(text.mentions): 156 for i, ana in enumerate(text.mentions):
224 if i > 0: 157 if i > 0:
225 for ante in reversed(text.mentions[:i]): 158 for ante in reversed(text.mentions[:i]):
226 if not features.pair_intersect(ante, ana): 159 if not features.pair_intersect(ante, ana):
227 - pair_vec = get_pair_vector(ante, ana) 160 + pair_vec = vectors.get_pair_vector(ante, ana)
228 sample = numpy.asarray([pair_vec], dtype=numpy.float32) 161 sample = numpy.asarray([pair_vec], dtype=numpy.float32)
229 - prediction = NEURAL_MODEL.predict(sample)[0] 162 + prediction = neural_model.predict(sample)[0]
230 if prediction > threshold: 163 if prediction > threshold:
231 if ante.set: 164 if ante.set:
232 ana.set = ante.set 165 ana.set = ante.set
corneferencer/utils.py
@@ -7,7 +7,6 @@ import javaobj @@ -7,7 +7,6 @@ import javaobj
7 7
8 from keras.models import Sequential, Model 8 from keras.models import Sequential, Model
9 from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization, Lambda 9 from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization, Lambda
10 -from keras.optimizers import RMSprop, Adam  
11 from keras import backend as K 10 from keras import backend as K
12 11
13 12