Commit c29c36cde4fa9ea17a625fa2b0e8ea5c1192ec25

Authored by Bartłomiej Nitoń
1 parent db88d6e4

Add prepare data script and other minor improvements.

... ... @@ -30,7 +30,6 @@ W2V_MODEL_PATH = os.path.join(MAIN_PATH, 'models', W2V_MODEL_NAME)
30 30 W2V_MODEL = Word2Vec.load(W2V_MODEL_PATH)
31 31  
32 32 NEURAL_MODEL_PATH = os.path.join(MAIN_PATH, 'models', NEURAL_MODEL_NAME)
33   -NEURAL_MODEL = utils.initialize_neural_model(NEURAL_MODEL_ARCHITECTURE, NUMBER_OF_FEATURES, NEURAL_MODEL_PATH)
34 33  
35 34 FREQ_LIST_PATH = os.path.join(MAIN_PATH, 'freq', FREQ_LIST_NAME)
36 35 FREQ_LIST = utils.load_freq_list(FREQ_LIST_PATH)
... ...
corneferencer/entities.py
1   -from corneferencer.resolvers.vectors import get_mention_features
  1 +from corneferencer.resolvers import vectors
2 2  
3 3  
4 4 class Text:
... ... @@ -19,6 +19,9 @@ class Text:
19 19 return mnt
20 20 return None
21 21  
  22 + def get_mentions(self):
  23 + return self.mentions
  24 +
22 25 def get_sets(self):
23 26 sets = {}
24 27 for mnt in self.mentions:
... ... @@ -62,4 +65,4 @@ class Mention:
62 65 self.sentence_id = sentence_id
63 66 self.first_in_sentence = first_in_sentence
64 67 self.first_in_paragraph = first_in_paragraph
65   - self.features = get_mention_features(self)
  68 + self.features = vectors.get_mention_features(self)
... ...
corneferencer/inout/mmax.py
... ... @@ -3,11 +3,11 @@ import shutil
3 3  
4 4 from lxml import etree
5 5  
6   -from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST
  6 +import conf
7 7 from corneferencer.entities import Mention, Text
8 8  
9 9  
10   -def read(inpath):
  10 +def read(inpath, clear_mentions=conf.CLEAR_INPUT):
11 11 textname = os.path.splitext(os.path.basename(inpath))[0]
12 12 textdir = os.path.dirname(inpath)
13 13  
... ... @@ -15,11 +15,11 @@ def read(inpath):
15 15 words_path = os.path.join(textdir, '%s_words.xml' % textname)
16 16  
17 17 text = Text(textname)
18   - text.mentions = read_mentions(mentions_path, words_path)
  18 + text.mentions = read_mentions(mentions_path, words_path, clear_mentions)
19 19 return text
20 20  
21 21  
22   -def read_mentions(mentions_path, words_path):
  22 +def read_mentions(mentions_path, words_path, clear_mentions=conf.CLEAR_INPUT):
23 23 mentions = []
24 24 mentions_tree = etree.parse(mentions_path)
25 25 markables = mentions_tree.xpath("//ns:markable",
... ... @@ -43,7 +43,7 @@ def read_mentions(mentions_path, words_path):
43 43  
44 44 head = get_head(head_orth, mention_words)
45 45 mention_group = ''
46   - if markable.attrib['mention_group'] != 'empty' and not CLEAR_INPUT:
  46 + if markable.attrib['mention_group'] != 'empty' and not clear_mentions:
47 47 mention_group = markable.attrib['mention_group']
48 48 mention = Mention(mnt_id=markable.attrib['id'],
49 49 text=span_to_text(span, words, 'orth'),
... ... @@ -189,7 +189,7 @@ def get_prec_context(mention_start, words):
189 189 while context_start >= 0:
190 190 if not word_to_ignore(words[context_start]):
191 191 context.append(words[context_start])
192   - if len(context) == CONTEXT:
  192 + if len(context) == conf.CONTEXT:
193 193 break
194 194 context_start -= 1
195 195 context.reverse()
... ... @@ -222,7 +222,7 @@ def get_follow_context(mention_end, words):
222 222 while context_end < len(words):
223 223 if not word_to_ignore(words[context_end]):
224 224 context.append(words[context_end])
225   - if len(context) == CONTEXT:
  225 + if len(context) == conf.CONTEXT:
226 226 break
227 227 context_end += 1
228 228 return context
... ... @@ -349,9 +349,8 @@ def get_rarest_word(words):
349 349 rarest_word = words[0]
350 350 for i, word in enumerate(words):
351 351 word_freq = 0
352   - if word['base'] in FREQ_LIST:
353   - word_freq = FREQ_LIST[word['base']]
354   -
  352 + if word['base'] in conf.FREQ_LIST:
  353 + word_freq = conf.FREQ_LIST[word['base']]
355 354 if i == 0 or word_freq < min_freq:
356 355 min_freq = word_freq
357 356 rarest_word = word
... ...
corneferencer/inout/tei.py
... ... @@ -4,7 +4,7 @@ import shutil
4 4  
5 5 from lxml import etree
6 6  
7   -from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST
  7 +import conf
8 8 from corneferencer.entities import Mention, Text
9 9 from corneferencer.utils import eprint
10 10  
... ... @@ -18,7 +18,7 @@ NSMAP = {None: TEI_NS,
18 18 'xi': XI_NS}
19 19  
20 20  
21   -def read(inpath):
  21 +def read(inpath, clear_mentions=conf.CLEAR_INPUT):
22 22 textname = os.path.basename(inpath)
23 23  
24 24 text = Text(textname)
... ... @@ -49,7 +49,7 @@ def read(inpath):
49 49 eprint("Error: missing mentions layer for text %s!" % textname)
50 50 return None
51 51  
52   - if os.path.exists(ann_coreference) and not CLEAR_INPUT:
  52 + if os.path.exists(ann_coreference) and not clear_mentions:
53 53 add_coreference_layer(ann_coreference, text)
54 54  
55 55 return text
... ... @@ -215,6 +215,9 @@ def get_mention(mention, mnt_id, segments, segments_ids, paragraph_id, sentence_
215 215 semh_id = get_fval(f).split('#')[-1]
216 216 semh = segments[semh_id]
217 217  
  218 + if len(mnt_segments) == 0:
  219 + mnt_segments.append(semh)
  220 +
218 221 (sent_segments, prec_context, follow_context,
219 222 first_in_sentence, first_in_paragraph) = get_context(mnt_segments, segments, segments_ids)
220 223  
... ... @@ -272,7 +275,7 @@ def get_prec_context(mention_start, segments, segments_ids):
272 275 while context_start >= 0:
273 276 if not word_to_ignore(segments[segments_ids[context_start]]):
274 277 context.append(segments[segments_ids[context_start]])
275   - if len(context) == CONTEXT:
  278 + if len(context) == conf.CONTEXT:
276 279 break
277 280 context_start -= 1
278 281 context.reverse()
... ... @@ -285,7 +288,7 @@ def get_follow_context(mention_end, segments, segments_ids):
285 288 while context_end < len(segments):
286 289 if not word_to_ignore(segments[segments_ids[context_end]]):
287 290 context.append(segments[segments_ids[context_end]])
288   - if len(context) == CONTEXT:
  291 + if len(context) == conf.CONTEXT:
289 292 break
290 293 context_end += 1
291 294 return context
... ... @@ -341,8 +344,8 @@ def get_rarest_word(words):
341 344 rarest_word = words[0]
342 345 for i, word in enumerate(words):
343 346 word_freq = 0
344   - if word['base'] in FREQ_LIST:
345   - word_freq = FREQ_LIST[word['base']]
  347 + if word['base'] in conf.FREQ_LIST:
  348 + word_freq = conf.FREQ_LIST[word['base']]
346 349  
347 350 if i == 0 or word_freq < min_freq:
348 351 min_freq = word_freq
... ...
corneferencer/main.py
... ... @@ -4,9 +4,11 @@ import sys
4 4 from argparse import ArgumentParser
5 5 from natsort import natsorted
6 6  
7   -sys.path.append(os.path.abspath(os.path.join('..')))
  7 +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
  8 +
8 9  
9 10 import conf
  11 +import utils
10 12 from inout import mmax, tei
11 13 from inout.constants import INPUT_FORMATS
12 14 from resolvers import resolve
... ... @@ -27,22 +29,25 @@ def main():
27 29 if conf.NEURAL_MODEL_ARCHITECTURE == 'siamese':
28 30 resolver = conf.NEURAL_MODEL_ARCHITECTURE
29 31 eprint("Warning: Using %s resolver because of selected neural model architecture!" %
30   - conf.NEURAL_MODEL_ARCHITECTURE)
31   - process_texts(args.input, args.output, args.format, resolver, args.threshold)
  32 + conf.NEURAL_MODEL_ARCHITECTURE)
  33 + process_texts(args.input, args.output, args.format, resolver, args.threshold, args.model)
32 34  
33 35  
34 36 def parse_arguments():
35 37 parser = ArgumentParser(description='Corneferencer: coreference resolver using neural nets.')
  38 + parser.add_argument('-f', '--format', type=str, action='store',
  39 + dest='format', default=INPUT_FORMATS[0],
  40 + help='input format; default: %s; possibilities: %s'
  41 + % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS)))
36 42 parser.add_argument('-i', '--input', type=str, action='store',
37 43 dest='input', default='',
38 44 help='input file or dir path')
  45 + parser.add_argument('-m', '--model', type=str, action='store',
  46 + dest='model', default='',
  47 + help='neural model path; default: %s' % conf.NEURAL_MODEL_PATH)
39 48 parser.add_argument('-o', '--output', type=str, action='store',
40 49 dest='output', default='',
41 50 help='output path; if not specified writes output to standard output')
42   - parser.add_argument('-f', '--format', type=str, action='store',
43   - dest='format', default=INPUT_FORMATS[0],
44   - help='input format; default: %s; possibilities: %s'
45   - % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS)))
46 51 parser.add_argument('-r', '--resolver', type=str, action='store',
47 52 dest='resolver', default=RESOLVERS[0],
48 53 help='resolve algorithm; default: %s; possibilities: %s'
... ... @@ -55,16 +60,17 @@ def parse_arguments():
55 60 return args
56 61  
57 62  
58   -def process_texts(inpath, outpath, informat, resolver, threshold):
  63 +def process_texts(inpath, outpath, informat, resolver, threshold, model_path):
  64 + model = utils.initialize_neural_model(conf.NEURAL_MODEL_ARCHITECTURE, conf.NUMBER_OF_FEATURES, model_path)
59 65 if os.path.isdir(inpath):
60   - process_directory(inpath, outpath, informat, resolver, threshold)
  66 + process_directory(inpath, outpath, informat, resolver, threshold, model)
61 67 elif os.path.isfile(inpath):
62   - process_text(inpath, outpath, informat, resolver, threshold)
  68 + process_text(inpath, outpath, informat, resolver, threshold, model)
63 69 else:
64 70 eprint("Error: Specified input does not exist!")
65 71  
66 72  
67   -def process_directory(inpath, outpath, informat, resolver, threshold):
  73 +def process_directory(inpath, outpath, informat, resolver, threshold, model):
68 74 inpath = os.path.abspath(inpath)
69 75 outpath = os.path.abspath(outpath)
70 76  
... ... @@ -75,38 +81,38 @@ def process_directory(inpath, outpath, informat, resolver, threshold):
75 81 textname = os.path.splitext(os.path.basename(filename))[0]
76 82 textoutput = os.path.join(outpath, textname)
77 83 textinput = os.path.join(inpath, filename)
78   - process_text(textinput, textoutput, informat, resolver, threshold)
  84 + process_text(textinput, textoutput, informat, resolver, threshold, model)
79 85  
80 86  
81   -def process_text(inpath, outpath, informat, resolver, threshold):
  87 +def process_text(inpath, outpath, informat, resolver, threshold, model):
82 88 basename = os.path.basename(inpath)
83 89 if informat == 'mmax' and basename.endswith('.mmax'):
84 90 print (basename)
85 91 text = mmax.read(inpath)
86 92 if resolver == 'incremental':
87   - resolve.incremental(text, threshold)
  93 + resolve.incremental(text, threshold, model)
88 94 elif resolver == 'entity_based':
89   - resolve.entity_based(text, threshold)
  95 + resolve.entity_based(text, threshold, model)
90 96 elif resolver == 'closest':
91   - resolve.closest(text, threshold)
  97 + resolve.closest(text, threshold, model)
92 98 elif resolver == 'siamese':
93   - resolve.siamese(text, threshold)
  99 + resolve.siamese(text, threshold, model)
94 100 elif resolver == 'all2all':
95   - resolve.all2all(text, threshold)
  101 + resolve.all2all(text, threshold, model)
96 102 mmax.write(inpath, outpath, text)
97 103 elif informat == 'tei':
98 104 print (basename)
99 105 text = tei.read(inpath)
100 106 if resolver == 'incremental':
101   - resolve.incremental(text, threshold)
  107 + resolve.incremental(text, threshold, model)
102 108 elif resolver == 'entity_based':
103   - resolve.entity_based(text, threshold)
  109 + resolve.entity_based(text, threshold, model)
104 110 elif resolver == 'closest':
105   - resolve.closest(text, threshold)
  111 + resolve.closest(text, threshold, model)
106 112 elif resolver == 'siamese':
107   - resolve.siamese(text, threshold)
  113 + resolve.siamese(text, threshold, model)
108 114 elif resolver == 'all2all':
109   - resolve.all2all(text, threshold)
  115 + resolve.all2all(text, threshold, model)
110 116 tei.write(inpath, outpath, text)
111 117  
112 118  
... ...
corneferencer/prepare_data.py 0 → 100644
  1 +# -*- coding: utf-8 -*-
  2 +
  3 +import codecs
  4 +import os
  5 +import random
  6 +import sys
  7 +
  8 +from itertools import combinations
  9 +from argparse import ArgumentParser
  10 +from natsort import natsorted
  11 +
  12 +sys.path.append(os.path.abspath(os.path.join('..')))
  13 +
  14 +from inout import mmax, tei
  15 +from inout.constants import INPUT_FORMATS
  16 +from utils import eprint
  17 +from corneferencer.resolvers import vectors
  18 +
  19 +
  20 +POS_COUNT = 0
  21 +NEG_COUNT = 0
  22 +
  23 +
  24 +def main():
  25 + args = parse_arguments()
  26 + if not args.input:
  27 + eprint("Error: Input file(s) not specified!")
  28 + elif args.format not in INPUT_FORMATS:
  29 + eprint("Error: Unknown input file format!")
  30 + else:
  31 + process_texts(args.input, args.output, args.format, args.proportion)
  32 +
  33 +
  34 +def parse_arguments():
  35 + parser = ArgumentParser(description='Corneferencer: data preparator for neural nets training.')
  36 + parser.add_argument('-i', '--input', type=str, action='store',
  37 + dest='input', default='',
  38 + help='input dir path')
  39 + parser.add_argument('-o', '--output', type=str, action='store',
  40 + dest='output', default='',
  41 + help='output path; if not specified writes output to standard output')
  42 + parser.add_argument('-f', '--format', type=str, action='store',
  43 + dest='format', default=INPUT_FORMATS[0],
  44 + help='input format; default: %s; possibilities: %s'
  45 + % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS)))
  46 + parser.add_argument('-p', '--proportion', type=int, action='store',
  47 + dest='proportion', default=5,
  48 + help='negative examples proportion; default: 5')
  49 + args = parser.parse_args()
  50 + return args
  51 +
  52 +
  53 +def process_texts(inpath, outpath, informat, proportion):
  54 + if os.path.isdir(inpath):
  55 + process_directory(inpath, outpath, informat, proportion)
  56 + else:
  57 + eprint("Error: Specified input does not exist or is not a directory!")
  58 +
  59 +
  60 +def process_directory(inpath, outpath, informat, proportion):
  61 + inpath = os.path.abspath(inpath)
  62 + outpath = os.path.abspath(outpath)
  63 +
  64 + try:
  65 + create_data_vectors(inpath, outpath, informat, proportion)
  66 + finally:
  67 + print ('Positives: ', POS_COUNT)
  68 + print ('Negatives: ', NEG_COUNT)
  69 +
  70 +
  71 +def create_data_vectors(inpath, outpath, informat, proportion):
  72 + features_file = codecs.open(outpath, 'w', 'utf-8')
  73 +
  74 + files = os.listdir(inpath)
  75 + files = natsorted(files)
  76 +
  77 + for filename in files:
  78 + textname = os.path.splitext(os.path.basename(filename))[0]
  79 + textinput = os.path.join(inpath, filename)
  80 +
  81 + print ('Processing text: %s' % textname)
  82 + text = None
  83 + if informat == 'mmax' and filename.endswith('.mmax'):
  84 + text = mmax.read(textinput, False)
  85 + elif informat == 'tei':
  86 + text = tei.read(textinput, False)
  87 +
  88 + positives, negatives = diff_mentions(text, proportion)
  89 + write_features(features_file, positives, negatives)
  90 +
  91 +
  92 +def diff_mentions(text, proportion):
  93 + sets = text.get_sets()
  94 + all_mentions = text.get_mentions()
  95 + positives = get_positives(sets)
  96 + positives, negatives = get_negatives_and_update_positives(all_mentions, positives, proportion)
  97 + return positives, negatives
  98 +
  99 +
  100 +def get_positives(sets):
  101 + positives = []
  102 + for set_id in sets:
  103 + coref_set = sets[set_id]
  104 + positives.extend(list(combinations(coref_set, 2)))
  105 + return positives
  106 +
  107 +
  108 +def get_negatives_and_update_positives(all_mentions, positives, proportion):
  109 + all_pairs = list(combinations(all_mentions, 2))
  110 +
  111 + all_pairs = set(all_pairs)
  112 + negatives = [pair for pair in all_pairs if pair not in positives]
  113 + samples_count = proportion * len(positives)
  114 + if samples_count > len(negatives):
  115 + samples_count = len(negatives)
  116 + if proportion == 1:
  117 + positives = random.sample(set(positives), samples_count)
  118 + print (u'Więcej przypadków pozytywnych niż negatywnych!')
  119 + negatives = random.sample(set(negatives), samples_count)
  120 + return positives, negatives
  121 +
  122 +
  123 +def write_features(features_file, positives, negatives):
  124 + global POS_COUNT
  125 + POS_COUNT += len(positives)
  126 + for pair in positives:
  127 + vector = vectors.get_pair_vector(pair[0], pair[1])
  128 + vector.append(1.0)
  129 + features_file.write(u'%s\n' % u'\t'.join([str(feature) for feature in vector]))
  130 +
  131 + global NEG_COUNT
  132 + NEG_COUNT += len(negatives)
  133 + for pair in negatives:
  134 + vector = vectors.get_pair_vector(pair[0], pair[1])
  135 + vector.append(0.0)
  136 + features_file.write(u'%s\n' % u'\t'.join([str(feature) for feature in vector]))
  137 +
  138 +
  139 +if __name__ == '__main__':
  140 + main()
... ...
corneferencer/resolvers/features.py
... ... @@ -72,97 +72,97 @@ def sentence_vec(mention):
72 72  
73 73  
74 74 def mention_type(mention):
75   - type_vec = [0] * 4
  75 + type_vec = [0.0] * 4
76 76 if mention.head is None:
77   - type_vec[3] = 1
  77 + type_vec[3] = 1.0
78 78 elif mention.head['ctag'] in constants.NOUN_TAGS:
79   - type_vec[0] = 1
  79 + type_vec[0] = 1.0
80 80 elif mention.head['ctag'] in constants.PPRON_TAGS:
81   - type_vec[1] = 1
  81 + type_vec[1] = 1.0
82 82 elif mention.head['ctag'] in constants.ZERO_TAGS:
83   - type_vec[2] = 1
  83 + type_vec[2] = 1.0
84 84 else:
85   - type_vec[3] = 1
  85 + type_vec[3] = 1.0
86 86 return type_vec
87 87  
88 88  
89 89 def is_first_second_person(mention):
90 90 if mention.head is None:
91   - return 0
  91 + return 0.0
92 92 if mention.head['person'] in constants.FIRST_SECOND_PERSON:
93   - return 1
94   - return 0
  93 + return 1.0
  94 + return 0.0
95 95  
96 96  
97 97 def is_demonstrative(mention):
98 98 if mention.words[0]['base'].lower() in constants.INDICATIVE_PRONS_BASES:
99   - return 1
100   - return 0
  99 + return 1.0
  100 + return 0.0
101 101  
102 102  
103 103 def is_demonstrative_nominal(mention):
104 104 if mention.head is None:
105   - return 0
  105 + return 0.0
106 106 if is_demonstrative(mention) and mention.head['ctag'] in constants.NOUN_TAGS:
107   - return 1
108   - return 0
  107 + return 1.0
  108 + return 0.0
109 109  
110 110  
111 111 def is_demonstrative_pronoun(mention):
112 112 if mention.head is None:
113   - return 0
  113 + return 0.0
114 114 if (is_demonstrative(mention) and
115 115 (mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS)):
116   - return 1
117   - return 0
  116 + return 1.0
  117 + return 0.0
118 118  
119 119  
120 120 def is_refl_pronoun(mention):
121 121 if mention.head is None:
122   - return 0
  122 + return 0.0
123 123 if mention.head['ctag'] in constants.SIEBIE_TAGS:
124   - return 1
125   - return 0
  124 + return 1.0
  125 + return 0.0
126 126  
127 127  
128 128 def is_first_in_sentence(mention):
129 129 if mention.first_in_sentence:
130   - return 1
131   - return 0
  130 + return 1.0
  131 + return 0.0
132 132  
133 133  
134 134 def is_zero_or_pronoun(mention):
135 135 if mention.head is None:
136   - return 0
  136 + return 0.0
137 137 if mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS:
138   - return 1
139   - return 0
  138 + return 1.0
  139 + return 0.0
140 140  
141 141  
142 142 def head_contains_digit(mention):
143 143 _digits = re.compile('\d')
144 144 if _digits.search(mention.head_orth):
145   - return 1
146   - return 0
  145 + return 1.0
  146 + return 0.0
147 147  
148 148  
149 149 def mention_contains_digit(mention):
150 150 _digits = re.compile('\d')
151 151 if _digits.search(mention.text):
152   - return 1
153   - return 0
  152 + return 1.0
  153 + return 0.0
154 154  
155 155  
156 156 def contains_letter(mention):
157 157 if any(c.isalpha() for c in mention.text):
158   - return 1
159   - return 0
  158 + return 1.0
  159 + return 0.0
160 160  
161 161  
162 162 def post_modified(mention):
163 163 if mention.head_orth != mention.words[-1]['orth']:
164   - return 1
165   - return 0
  164 + return 1.0
  165 + return 0.0
166 166  
167 167  
168 168 # pair features
... ... @@ -171,20 +171,20 @@ def distances_vec(ante, ana):
171 171  
172 172 mnts_intersect = pair_intersect(ante, ana)
173 173  
174   - words_dist = [0] * 11
  174 + words_dist = [0.0] * 11
175 175 words_bucket = 0
176   - if mnts_intersect != 1:
  176 + if mnts_intersect != 1.0:
177 177 words_bucket = get_distance_bucket(ana.start_in_words - ante.end_in_words)
178   - words_dist[words_bucket] = 1
  178 + words_dist[words_bucket] = 1.0
179 179 vec.extend(words_dist)
180 180  
181   - mentions_dist = [0] * 11
  181 + mentions_dist = [0.0] * 11
182 182 mentions_bucket = 0
183   - if mnts_intersect != 1:
  183 + if mnts_intersect != 1.0:
184 184 mentions_bucket = get_distance_bucket(ana.position_in_mentions - ante.position_in_mentions)
185 185 if words_bucket == 10:
186 186 mentions_bucket = 10
187   - mentions_dist[mentions_bucket] = 1
  187 + mentions_dist[mentions_bucket] = 1.0
188 188 vec.extend(mentions_dist)
189 189  
190 190 vec.append(mnts_intersect)
... ... @@ -196,45 +196,45 @@ def pair_intersect(ante, ana):
196 196 for ante_word in ante.words:
197 197 for ana_word in ana.words:
198 198 if ana_word['id'] == ante_word['id']:
199   - return 1
200   - return 0
  199 + return 1.0
  200 + return 0.0
201 201  
202 202  
203 203 def head_match(ante, ana):
204 204 if ante.head_orth.lower() == ana.head_orth.lower():
205   - return 1
206   - return 0
  205 + return 1.0
  206 + return 0.0
207 207  
208 208  
209 209 def exact_match(ante, ana):
210 210 if ante.text.lower() == ana.text.lower():
211   - return 1
212   - return 0
  211 + return 1.0
  212 + return 0.0
213 213  
214 214  
215 215 def base_match(ante, ana):
216 216 if ante.lemmatized_text.lower() == ana.lemmatized_text.lower():
217   - return 1
218   - return 0
  217 + return 1.0
  218 + return 0.0
219 219  
220 220  
221 221 def ante_contains_rarest_from_ana(ante, ana):
222 222 ana_rarest = ana.rarest
223 223 for word in ante.words:
224 224 if word['base'] == ana_rarest['base']:
225   - return 1
226   - return 0
  225 + return 1.0
  226 + return 0.0
227 227  
228 228  
229 229 def agreement(ante, ana, tag_name):
230   - agr_vec = [0] * 3
  230 + agr_vec = [0.0] * 3
231 231 if (ante.head is None or ana.head is None or
232 232 ante.head[tag_name] == 'unk' or ana.head[tag_name] == 'unk'):
233   - agr_vec[2] = 1
  233 + agr_vec[2] = 1.0
234 234 elif ante.head[tag_name] == ana.head[tag_name]:
235   - agr_vec[0] = 1
  235 + agr_vec[0] = 1.0
236 236 else:
237   - agr_vec[1] = 1
  237 + agr_vec[1] = 1.0
238 238 return agr_vec
239 239  
240 240  
... ... @@ -243,72 +243,72 @@ def is_acronym(ante, ana):
243 243 return check_one_way_acronym(ana.text, ante.text)
244 244 if ante.text.upper() == ante.text:
245 245 return check_one_way_acronym(ante.text, ana.text)
246   - return 0
  246 + return 0.0
247 247  
248 248  
249 249 def same_sentence(ante, ana):
250 250 if ante.sentence_id == ana.sentence_id:
251   - return 1
252   - return 0
  251 + return 1.0
  252 + return 0.0
253 253  
254 254  
255 255 def neighbouring_sentence(ante, ana):
256 256 if ana.sentence_id - ante.sentence_id == 1:
257   - return 1
258   - return 0
  257 + return 1.0
  258 + return 0.0
259 259  
260 260  
261 261 def cousin_sentence(ante, ana):
262 262 if ana.sentence_id - ante.sentence_id == 2:
263   - return 1
264   - return 0
  263 + return 1.0
  264 + return 0.0
265 265  
266 266  
267 267 def distant_sentence(ante, ana):
268 268 if ana.sentence_id - ante.sentence_id > 2:
269   - return 1
270   - return 0
  269 + return 1.0
  270 + return 0.0
271 271  
272 272  
273 273 def same_paragraph(ante, ana):
274 274 if ante.paragraph_id == ana.paragraph_id:
275   - return 1
276   - return 0
  275 + return 1.0
  276 + return 0.0
277 277  
278 278  
279 279 def flat_gender_agreement(ante, ana):
280   - agr_vec = [0] * 3
  280 + agr_vec = [0.0] * 3
281 281 if (ante.head is None or ana.head is None or
282 282 ante.head['gender'] == 'unk' or ana.head['gender'] == 'unk'):
283   - agr_vec[2] = 1
  283 + agr_vec[2] = 1.0
284 284 elif (ante.head['gender'] == ana.head['gender'] or
285 285 (ante.head['gender'] in constants.MASCULINE_TAGS and ana.head['gender'] in constants.MASCULINE_TAGS)):
286   - agr_vec[0] = 1
  286 + agr_vec[0] = 1.0
287 287 else:
288   - agr_vec[1] = 1
  288 + agr_vec[1] = 1.0
289 289 return agr_vec
290 290  
291 291  
292 292 def left_match(ante, ana):
293 293 if (ante.text.lower().startswith(ana.text.lower()) or
294 294 ana.text.lower().startswith(ante.text.lower())):
295   - return 1
296   - return 0
  295 + return 1.0
  296 + return 0.0
297 297  
298 298  
299 299 def right_match(ante, ana):
300 300 if (ante.text.lower().endswith(ana.text.lower()) or
301 301 ana.text.lower().endswith(ante.text.lower())):
302   - return 1
303   - return 0
  302 + return 1.0
  303 + return 0.0
304 304  
305 305  
306 306 def abbrev2(ante, ana):
307 307 ante_abbrev = get_abbrev(ante)
308 308 ana_abbrev = get_abbrev(ana)
309 309 if ante.head_orth == ana_abbrev or ana.head_orth == ante_abbrev:
310   - return 1
311   - return 0
  310 + return 1.0
  311 + return 0.0
312 312  
313 313  
314 314 def string_kernel(ante, ana):
... ... @@ -326,7 +326,7 @@ def head_string_kernel(ante, ana):
326 326 def wordnet_synonyms(ante, ana):
327 327 ante_synonyms = set()
328 328 if ante.head is None or ana.head is None:
329   - return 0
  329 + return 0.0
330 330  
331 331 if ante.head['base'] in conf.LEMMA2SYNONYMS:
332 332 ante_synonyms = conf.LEMMA2SYNONYMS[ante.head['base']]
... ... @@ -336,13 +336,13 @@ def wordnet_synonyms(ante, ana):
336 336 ana_synonyms = conf.LEMMA2SYNONYMS[ana.head['base']]
337 337  
338 338 if ana.head['base'] in ante_synonyms or ante.head['base'] in ana_synonyms:
339   - return 1
340   - return 0
  339 + return 1.0
  340 + return 0.0
341 341  
342 342  
343 343 def wordnet_ana_is_hypernym(ante, ana):
344 344 if ante.head is None or ana.head is None:
345   - return 0
  345 + return 0.0
346 346  
347 347 ante_hypernyms = set()
348 348 if ante.head['base'] in conf.LEMMA2HYPERNYMS:
... ... @@ -353,16 +353,16 @@ def wordnet_ana_is_hypernym(ante, ana):
353 353 ana_hypernyms = conf.LEMMA2HYPERNYMS[ana.head['base']]
354 354  
355 355 if not ante_hypernyms or not ana_hypernyms:
356   - return 0
  356 + return 0.0
357 357  
358 358 if ana.head['base'] in ante_hypernyms:
359   - return 1
360   - return 0
  359 + return 1.0
  360 + return 0.0
361 361  
362 362  
363 363 def wordnet_ante_is_hypernym(ante, ana):
364 364 if ante.head is None or ana.head is None:
365   - return 0
  365 + return 0.0
366 366  
367 367 ana_hypernyms = set()
368 368 if ana.head['base'] in conf.LEMMA2HYPERNYMS:
... ... @@ -373,18 +373,18 @@ def wordnet_ante_is_hypernym(ante, ana):
373 373 ante_hypernyms = conf.LEMMA2HYPERNYMS[ante.head['base']]
374 374  
375 375 if not ante_hypernyms or not ana_hypernyms:
376   - return 0
  376 + return 0.0
377 377  
378 378 if ante.head['base'] in ana_hypernyms:
379   - return 1
380   - return 0
  379 + return 1.0
  380 + return 0.0
381 381  
382 382  
383 383 def wikipedia_link(ante, ana):
384 384 ante_base = ante.lemmatized_text.lower()
385 385 ana_base = ana.lemmatized_text.lower()
386 386 if ante_base == ana_base:
387   - return 1
  387 + return 1.0
388 388  
389 389 ante_links = set()
390 390 if ante_base in conf.TITLE2LINKS:
... ... @@ -395,16 +395,16 @@ def wikipedia_link(ante, ana):
395 395 ana_links = conf.TITLE2LINKS[ana_base]
396 396  
397 397 if ana_base in ante_links or ante_base in ana_links:
398   - return 1
  398 + return 1.0
399 399  
400   - return 0
  400 + return 0.0
401 401  
402 402  
403 403 def wikipedia_mutual_link(ante, ana):
404 404 ante_base = ante.lemmatized_text.lower()
405 405 ana_base = ana.lemmatized_text.lower()
406 406 if ante_base == ana_base:
407   - return 1
  407 + return 1.0
408 408  
409 409 ante_links = set()
410 410 if ante_base in conf.TITLE2LINKS:
... ... @@ -415,52 +415,52 @@ def wikipedia_mutual_link(ante, ana):
415 415 ana_links = conf.TITLE2LINKS[ana_base]
416 416  
417 417 if ana_base in ante_links and ante_base in ana_links:
418   - return 1
  418 + return 1.0
419 419  
420   - return 0
  420 + return 0.0
421 421  
422 422  
423 423 def wikipedia_redirect(ante, ana):
424 424 ante_base = ante.lemmatized_text.lower()
425 425 ana_base = ana.lemmatized_text.lower()
426 426 if ante_base == ana_base:
427   - return 1
  427 + return 1.0
428 428  
429 429 if ante_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ante_base] == ana_base:
430   - return 1
  430 + return 1.0
431 431  
432 432 if ana_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ana_base] == ante_base:
433   - return 1
  433 + return 1.0
434 434  
435   - return 0
  435 + return 0.0
436 436  
437 437  
438 438 def samesent_anapron_antefirstinpar(ante, ana):
439 439 if same_sentence(ante, ana) and is_zero_or_pronoun(ana) and ante.first_in_paragraph:
440   - return 1
441   - return 0
  440 + return 1.0
  441 + return 0.0
442 442  
443 443  
444 444 def samesent_antefirstinpar_personnumbermatch(ante, ana):
445 445 if (same_sentence(ante, ana) and ante.first_in_paragraph
446 446 and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]):
447   - return 1
448   - return 0
  447 + return 1.0
  448 + return 0.0
449 449  
450 450  
451 451 def adjsent_anapron_adjmen_personnumbermatch(ante, ana):
452 452 if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana)
453 453 and ana.position_in_mentions - ante.position_in_mentions == 1
454 454 and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]):
455   - return 1
456   - return 0
  455 + return 1.0
  456 + return 0.0
457 457  
458 458  
459 459 def adjsent_anapron_adjmen(ante, ana):
460 460 if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana)
461 461 and ana.position_in_mentions - ante.position_in_mentions == 1):
462   - return 1
463   - return 0
  462 + return 1.0
  463 + return 0.0
464 464  
465 465  
466 466 # supporting functions
... ... @@ -523,8 +523,8 @@ def check_one_way_acronym(acronym, expression):
523 523 if expr2:
524 524 initials += expr2[0].upper()
525 525 if acronym == initials:
526   - return 1
527   - return 0
  526 + return 1.0
  527 + return 0.0
528 528  
529 529  
530 530 def get_abbrev(mention):
... ...
corneferencer/resolvers/resolve.py
1 1 import numpy
2 2  
3   -from conf import NEURAL_MODEL
4   -from corneferencer.resolvers import features
5   -from corneferencer.resolvers.vectors import get_pair_features, get_pair_vector
  3 +from corneferencer.resolvers import features, vectors
6 4  
7 5  
8   -def siamese(text, threshold):
  6 +def siamese(text, threshold, neural_model):
9 7 last_set_id = 0
10 8 for i, ana in enumerate(text.mentions):
11 9 if i > 0:
12 10 for ante in reversed(text.mentions[:i]):
13 11 if not features.pair_intersect(ante, ana):
14   - pair_features = get_pair_features(ante, ana)
  12 + pair_features = vectors.get_pair_features(ante, ana)
15 13  
16 14 ante_vec = []
17 15 ante_vec.extend(ante.features)
... ... @@ -23,7 +21,7 @@ def siamese(text, threshold):
23 21 ana_vec.extend(pair_features)
24 22 ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32)
25 23  
26   - prediction = NEURAL_MODEL.predict([ante_sample, ana_sample])[0]
  24 + prediction = neural_model.predict([ante_sample, ana_sample])[0]
27 25  
28 26 if prediction < threshold:
29 27 if ante.set:
... ... @@ -37,7 +35,7 @@ def siamese(text, threshold):
37 35  
38 36  
39 37 # incremental resolve algorithm
40   -def incremental(text, threshold):
  38 +def incremental(text, threshold, neural_model):
41 39 last_set_id = 0
42 40 for i, ana in enumerate(text.mentions):
43 41 if i > 0:
... ... @@ -45,9 +43,9 @@ def incremental(text, threshold):
45 43 best_ante = None
46 44 for ante in text.mentions[:i]:
47 45 if not features.pair_intersect(ante, ana):
48   - pair_vec = get_pair_vector(ante, ana)
  46 + pair_vec = vectors.get_pair_vector(ante, ana)
49 47 sample = numpy.asarray([pair_vec], dtype=numpy.float32)
50   - prediction = NEURAL_MODEL.predict(sample)[0]
  48 + prediction = neural_model.predict(sample)[0]
51 49 if prediction > threshold and prediction >= best_prediction:
52 50 best_prediction = prediction
53 51 best_ante = ante
... ... @@ -62,86 +60,21 @@ def incremental(text, threshold):
62 60  
63 61  
64 62 # all2all resolve algorithm
65   -def all2all_debug(text, threshold):
66   - last_set_id = 0
67   - for pos1, mnt1 in enumerate(text.mentions):
68   - best_prediction = 0.0
69   - best_link = None
70   - for pos2, mnt2 in enumerate(text.mentions):
71   - if (mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2):
72   - ante = mnt1
73   - ana = mnt2
74   - if pos2 < pos1:
75   - ante = mnt2
76   - ana = mnt1
77   - pair_vec = get_pair_vector(ante, ana)
78   - sample = numpy.asarray([pair_vec], dtype=numpy.float32)
79   - prediction = NEURAL_MODEL.predict(sample)[0]
80   - if prediction > threshold and prediction > best_prediction:
81   - best_prediction = prediction
82   - best_link = mnt2
83   - if best_link is not None:
84   - if best_link.set and not mnt1.set:
85   - mnt1.set = best_link.set
86   - elif best_link.set and mnt1.set:
87   - text.merge_sets(best_link.set, mnt1.set)
88   - elif not best_link.set and not mnt1.set:
89   - str_set_id = 'set_%d' % last_set_id
90   - best_link.set = str_set_id
91   - mnt1.set = str_set_id
92   - last_set_id += 1
93   -
94   -
95   -def all2all_v1(text, threshold):
96   - last_set_id = 0
97   - for pos1, mnt1 in enumerate(text.mentions):
98   - best_prediction = 0.0
99   - best_link = None
100   - for pos2, mnt2 in enumerate(text.mentions):
101   - if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set)
102   - and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
103   - ante = mnt1
104   - ana = mnt2
105   - if pos2 < pos1:
106   - ante = mnt2
107   - ana = mnt1
108   - pair_vec = get_pair_vector(ante, ana)
109   - sample = numpy.asarray([pair_vec], dtype=numpy.float32)
110   - prediction = NEURAL_MODEL.predict(sample)[0]
111   - if prediction > threshold and prediction > best_prediction:
112   - best_prediction = prediction
113   - best_link = mnt2
114   - if best_link is not None:
115   - if best_link.set and not mnt1.set:
116   - mnt1.set = best_link.set
117   - elif not best_link.set and mnt1.set:
118   - best_link.set = mnt1.set
119   - elif best_link.set and mnt1.set:
120   - text.merge_sets(best_link.set, mnt1.set)
121   - elif not best_link.set and not mnt1.set:
122   - str_set_id = 'set_%d' % last_set_id
123   - best_link.set = str_set_id
124   - mnt1.set = str_set_id
125   - last_set_id += 1
126   -
127   -
128   -def all2all(text, threshold):
  63 +def all2all(text, threshold, neural_model):
129 64 last_set_id = 0
130 65 sets = text.get_sets()
131 66 for pos1, mnt1 in enumerate(text.mentions):
132 67 best_prediction = 0.0
133 68 best_link = None
134 69 for pos2, mnt2 in enumerate(text.mentions):
135   - if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set)
136   - and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
  70 + if (pos2 > pos1 and
  71 + (mnt1.set != mnt2.set or not mnt1.set or not mnt2.set)
  72 + and not features.pair_intersect(mnt1, mnt2)):
137 73 ante = mnt1
138 74 ana = mnt2
139   - if pos2 < pos1:
140   - ante = mnt2
141   - ana = mnt1
142   - pair_vec = get_pair_vector(ante, ana)
  75 + pair_vec = vectors.get_pair_vector(ante, ana)
143 76 sample = numpy.asarray([pair_vec], dtype=numpy.float32)
144   - prediction = NEURAL_MODEL.predict(sample)[0]
  77 + prediction = neural_model.predict(sample)[0]
145 78 if prediction > threshold and prediction > best_prediction:
146 79 best_prediction = prediction
147 80 best_link = mnt2
... ... @@ -163,12 +96,12 @@ def all2all(text, threshold):
163 96  
164 97  
165 98 # entity based resolve algorithm
166   -def entity_based(text, threshold):
  99 +def entity_based(text, threshold, neural_model):
167 100 sets = []
168 101 last_set_id = 0
169 102 for i, ana in enumerate(text.mentions):
170 103 if i > 0:
171   - best_fit = get_best_set(sets, ana, threshold)
  104 + best_fit = get_best_set(sets, ana, threshold, neural_model)
172 105 if best_fit is not None:
173 106 ana.set = best_fit['set_id']
174 107 best_fit['mentions'].append(ana)
... ... @@ -188,25 +121,25 @@ def entity_based(text, threshold):
188 121 remove_singletons(sets)
189 122  
190 123  
191   -def get_best_set(sets, ana, threshold):
  124 +def get_best_set(sets, ana, threshold, neural_model):
192 125 best_prediction = 0.0
193 126 best_set = None
194 127 for s in sets:
195   - accuracy = predict_set(s['mentions'], ana)
  128 + accuracy = predict_set(s['mentions'], ana, neural_model)
196 129 if accuracy > threshold and accuracy >= best_prediction:
197 130 best_prediction = accuracy
198 131 best_set = s
199 132 return best_set
200 133  
201 134  
202   -def predict_set(mentions, ana):
  135 +def predict_set(mentions, ana, neural_model):
203 136 prediction_sum = 0.0
204 137 for mnt in mentions:
205 138 prediction = 0.0
206 139 if not features.pair_intersect(mnt, ana):
207   - pair_vec = get_pair_vector(mnt, ana)
  140 + pair_vec = vectors.get_pair_vector(mnt, ana)
208 141 sample = numpy.asarray([pair_vec], dtype=numpy.float32)
209   - prediction = NEURAL_MODEL.predict(sample)[0]
  142 + prediction = neural_model.predict(sample)[0]
210 143 prediction_sum += prediction
211 144 return prediction_sum / float(len(mentions))
212 145  
... ... @@ -218,15 +151,15 @@ def remove_singletons(sets):
218 151  
219 152  
220 153 # closest resolve algorithm
221   -def closest(text, threshold):
  154 +def closest(text, threshold, neural_model):
222 155 last_set_id = 0
223 156 for i, ana in enumerate(text.mentions):
224 157 if i > 0:
225 158 for ante in reversed(text.mentions[:i]):
226 159 if not features.pair_intersect(ante, ana):
227   - pair_vec = get_pair_vector(ante, ana)
  160 + pair_vec = vectors.get_pair_vector(ante, ana)
228 161 sample = numpy.asarray([pair_vec], dtype=numpy.float32)
229   - prediction = NEURAL_MODEL.predict(sample)[0]
  162 + prediction = neural_model.predict(sample)[0]
230 163 if prediction > threshold:
231 164 if ante.set:
232 165 ana.set = ante.set
... ...
corneferencer/utils.py
... ... @@ -7,7 +7,6 @@ import javaobj
7 7  
8 8 from keras.models import Sequential, Model
9 9 from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization, Lambda
10   -from keras.optimizers import RMSprop, Adam
11 10 from keras import backend as K
12 11  
13 12  
... ...