Commit c29c36cde4fa9ea17a625fa2b0e8ea5c1192ec25
1 parent
db88d6e4
Add prepare data script and other minor improvements.
Showing
9 changed files
with
321 additions
and
239 deletions
conf.py
... | ... | @@ -30,7 +30,6 @@ W2V_MODEL_PATH = os.path.join(MAIN_PATH, 'models', W2V_MODEL_NAME) |
30 | 30 | W2V_MODEL = Word2Vec.load(W2V_MODEL_PATH) |
31 | 31 | |
32 | 32 | NEURAL_MODEL_PATH = os.path.join(MAIN_PATH, 'models', NEURAL_MODEL_NAME) |
33 | -NEURAL_MODEL = utils.initialize_neural_model(NEURAL_MODEL_ARCHITECTURE, NUMBER_OF_FEATURES, NEURAL_MODEL_PATH) | |
34 | 33 | |
35 | 34 | FREQ_LIST_PATH = os.path.join(MAIN_PATH, 'freq', FREQ_LIST_NAME) |
36 | 35 | FREQ_LIST = utils.load_freq_list(FREQ_LIST_PATH) |
... | ... |
corneferencer/entities.py
1 | -from corneferencer.resolvers.vectors import get_mention_features | |
1 | +from corneferencer.resolvers import vectors | |
2 | 2 | |
3 | 3 | |
4 | 4 | class Text: |
... | ... | @@ -19,6 +19,9 @@ class Text: |
19 | 19 | return mnt |
20 | 20 | return None |
21 | 21 | |
22 | + def get_mentions(self): | |
23 | + return self.mentions | |
24 | + | |
22 | 25 | def get_sets(self): |
23 | 26 | sets = {} |
24 | 27 | for mnt in self.mentions: |
... | ... | @@ -62,4 +65,4 @@ class Mention: |
62 | 65 | self.sentence_id = sentence_id |
63 | 66 | self.first_in_sentence = first_in_sentence |
64 | 67 | self.first_in_paragraph = first_in_paragraph |
65 | - self.features = get_mention_features(self) | |
68 | + self.features = vectors.get_mention_features(self) | |
... | ... |
corneferencer/inout/mmax.py
... | ... | @@ -3,11 +3,11 @@ import shutil |
3 | 3 | |
4 | 4 | from lxml import etree |
5 | 5 | |
6 | -from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST | |
6 | +import conf | |
7 | 7 | from corneferencer.entities import Mention, Text |
8 | 8 | |
9 | 9 | |
10 | -def read(inpath): | |
10 | +def read(inpath, clear_mentions=conf.CLEAR_INPUT): | |
11 | 11 | textname = os.path.splitext(os.path.basename(inpath))[0] |
12 | 12 | textdir = os.path.dirname(inpath) |
13 | 13 | |
... | ... | @@ -15,11 +15,11 @@ def read(inpath): |
15 | 15 | words_path = os.path.join(textdir, '%s_words.xml' % textname) |
16 | 16 | |
17 | 17 | text = Text(textname) |
18 | - text.mentions = read_mentions(mentions_path, words_path) | |
18 | + text.mentions = read_mentions(mentions_path, words_path, clear_mentions) | |
19 | 19 | return text |
20 | 20 | |
21 | 21 | |
22 | -def read_mentions(mentions_path, words_path): | |
22 | +def read_mentions(mentions_path, words_path, clear_mentions=conf.CLEAR_INPUT): | |
23 | 23 | mentions = [] |
24 | 24 | mentions_tree = etree.parse(mentions_path) |
25 | 25 | markables = mentions_tree.xpath("//ns:markable", |
... | ... | @@ -43,7 +43,7 @@ def read_mentions(mentions_path, words_path): |
43 | 43 | |
44 | 44 | head = get_head(head_orth, mention_words) |
45 | 45 | mention_group = '' |
46 | - if markable.attrib['mention_group'] != 'empty' and not CLEAR_INPUT: | |
46 | + if markable.attrib['mention_group'] != 'empty' and not clear_mentions: | |
47 | 47 | mention_group = markable.attrib['mention_group'] |
48 | 48 | mention = Mention(mnt_id=markable.attrib['id'], |
49 | 49 | text=span_to_text(span, words, 'orth'), |
... | ... | @@ -189,7 +189,7 @@ def get_prec_context(mention_start, words): |
189 | 189 | while context_start >= 0: |
190 | 190 | if not word_to_ignore(words[context_start]): |
191 | 191 | context.append(words[context_start]) |
192 | - if len(context) == CONTEXT: | |
192 | + if len(context) == conf.CONTEXT: | |
193 | 193 | break |
194 | 194 | context_start -= 1 |
195 | 195 | context.reverse() |
... | ... | @@ -222,7 +222,7 @@ def get_follow_context(mention_end, words): |
222 | 222 | while context_end < len(words): |
223 | 223 | if not word_to_ignore(words[context_end]): |
224 | 224 | context.append(words[context_end]) |
225 | - if len(context) == CONTEXT: | |
225 | + if len(context) == conf.CONTEXT: | |
226 | 226 | break |
227 | 227 | context_end += 1 |
228 | 228 | return context |
... | ... | @@ -349,9 +349,8 @@ def get_rarest_word(words): |
349 | 349 | rarest_word = words[0] |
350 | 350 | for i, word in enumerate(words): |
351 | 351 | word_freq = 0 |
352 | - if word['base'] in FREQ_LIST: | |
353 | - word_freq = FREQ_LIST[word['base']] | |
354 | - | |
352 | + if word['base'] in conf.FREQ_LIST: | |
353 | + word_freq = conf.FREQ_LIST[word['base']] | |
355 | 354 | if i == 0 or word_freq < min_freq: |
356 | 355 | min_freq = word_freq |
357 | 356 | rarest_word = word |
... | ... |
corneferencer/inout/tei.py
... | ... | @@ -4,7 +4,7 @@ import shutil |
4 | 4 | |
5 | 5 | from lxml import etree |
6 | 6 | |
7 | -from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST | |
7 | +import conf | |
8 | 8 | from corneferencer.entities import Mention, Text |
9 | 9 | from corneferencer.utils import eprint |
10 | 10 | |
... | ... | @@ -18,7 +18,7 @@ NSMAP = {None: TEI_NS, |
18 | 18 | 'xi': XI_NS} |
19 | 19 | |
20 | 20 | |
21 | -def read(inpath): | |
21 | +def read(inpath, clear_mentions=conf.CLEAR_INPUT): | |
22 | 22 | textname = os.path.basename(inpath) |
23 | 23 | |
24 | 24 | text = Text(textname) |
... | ... | @@ -49,7 +49,7 @@ def read(inpath): |
49 | 49 | eprint("Error: missing mentions layer for text %s!" % textname) |
50 | 50 | return None |
51 | 51 | |
52 | - if os.path.exists(ann_coreference) and not CLEAR_INPUT: | |
52 | + if os.path.exists(ann_coreference) and not clear_mentions: | |
53 | 53 | add_coreference_layer(ann_coreference, text) |
54 | 54 | |
55 | 55 | return text |
... | ... | @@ -215,6 +215,9 @@ def get_mention(mention, mnt_id, segments, segments_ids, paragraph_id, sentence_ |
215 | 215 | semh_id = get_fval(f).split('#')[-1] |
216 | 216 | semh = segments[semh_id] |
217 | 217 | |
218 | + if len(mnt_segments) == 0: | |
219 | + mnt_segments.append(semh) | |
220 | + | |
218 | 221 | (sent_segments, prec_context, follow_context, |
219 | 222 | first_in_sentence, first_in_paragraph) = get_context(mnt_segments, segments, segments_ids) |
220 | 223 | |
... | ... | @@ -272,7 +275,7 @@ def get_prec_context(mention_start, segments, segments_ids): |
272 | 275 | while context_start >= 0: |
273 | 276 | if not word_to_ignore(segments[segments_ids[context_start]]): |
274 | 277 | context.append(segments[segments_ids[context_start]]) |
275 | - if len(context) == CONTEXT: | |
278 | + if len(context) == conf.CONTEXT: | |
276 | 279 | break |
277 | 280 | context_start -= 1 |
278 | 281 | context.reverse() |
... | ... | @@ -285,7 +288,7 @@ def get_follow_context(mention_end, segments, segments_ids): |
285 | 288 | while context_end < len(segments): |
286 | 289 | if not word_to_ignore(segments[segments_ids[context_end]]): |
287 | 290 | context.append(segments[segments_ids[context_end]]) |
288 | - if len(context) == CONTEXT: | |
291 | + if len(context) == conf.CONTEXT: | |
289 | 292 | break |
290 | 293 | context_end += 1 |
291 | 294 | return context |
... | ... | @@ -341,8 +344,8 @@ def get_rarest_word(words): |
341 | 344 | rarest_word = words[0] |
342 | 345 | for i, word in enumerate(words): |
343 | 346 | word_freq = 0 |
344 | - if word['base'] in FREQ_LIST: | |
345 | - word_freq = FREQ_LIST[word['base']] | |
347 | + if word['base'] in conf.FREQ_LIST: | |
348 | + word_freq = conf.FREQ_LIST[word['base']] | |
346 | 349 | |
347 | 350 | if i == 0 or word_freq < min_freq: |
348 | 351 | min_freq = word_freq |
... | ... |
corneferencer/main.py
... | ... | @@ -4,9 +4,11 @@ import sys |
4 | 4 | from argparse import ArgumentParser |
5 | 5 | from natsort import natsorted |
6 | 6 | |
7 | -sys.path.append(os.path.abspath(os.path.join('..'))) | |
7 | +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
8 | + | |
8 | 9 | |
9 | 10 | import conf |
11 | +import utils | |
10 | 12 | from inout import mmax, tei |
11 | 13 | from inout.constants import INPUT_FORMATS |
12 | 14 | from resolvers import resolve |
... | ... | @@ -27,22 +29,25 @@ def main(): |
27 | 29 | if conf.NEURAL_MODEL_ARCHITECTURE == 'siamese': |
28 | 30 | resolver = conf.NEURAL_MODEL_ARCHITECTURE |
29 | 31 | eprint("Warning: Using %s resolver because of selected neural model architecture!" % |
30 | - conf.NEURAL_MODEL_ARCHITECTURE) | |
31 | - process_texts(args.input, args.output, args.format, resolver, args.threshold) | |
32 | + conf.NEURAL_MODEL_ARCHITECTURE) | |
33 | + process_texts(args.input, args.output, args.format, resolver, args.threshold, args.model) | |
32 | 34 | |
33 | 35 | |
34 | 36 | def parse_arguments(): |
35 | 37 | parser = ArgumentParser(description='Corneferencer: coreference resolver using neural nets.') |
38 | + parser.add_argument('-f', '--format', type=str, action='store', | |
39 | + dest='format', default=INPUT_FORMATS[0], | |
40 | + help='input format; default: %s; possibilities: %s' | |
41 | + % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS))) | |
36 | 42 | parser.add_argument('-i', '--input', type=str, action='store', |
37 | 43 | dest='input', default='', |
38 | 44 | help='input file or dir path') |
45 | + parser.add_argument('-m', '--model', type=str, action='store', | |
46 | + dest='model', default='', | |
47 | + help='neural model path; default: %s' % conf.NEURAL_MODEL_PATH) | |
39 | 48 | parser.add_argument('-o', '--output', type=str, action='store', |
40 | 49 | dest='output', default='', |
41 | 50 | help='output path; if not specified writes output to standard output') |
42 | - parser.add_argument('-f', '--format', type=str, action='store', | |
43 | - dest='format', default=INPUT_FORMATS[0], | |
44 | - help='input format; default: %s; possibilities: %s' | |
45 | - % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS))) | |
46 | 51 | parser.add_argument('-r', '--resolver', type=str, action='store', |
47 | 52 | dest='resolver', default=RESOLVERS[0], |
48 | 53 | help='resolve algorithm; default: %s; possibilities: %s' |
... | ... | @@ -55,16 +60,17 @@ def parse_arguments(): |
55 | 60 | return args |
56 | 61 | |
57 | 62 | |
58 | -def process_texts(inpath, outpath, informat, resolver, threshold): | |
63 | +def process_texts(inpath, outpath, informat, resolver, threshold, model_path): | |
64 | + model = utils.initialize_neural_model(conf.NEURAL_MODEL_ARCHITECTURE, conf.NUMBER_OF_FEATURES, model_path) | |
59 | 65 | if os.path.isdir(inpath): |
60 | - process_directory(inpath, outpath, informat, resolver, threshold) | |
66 | + process_directory(inpath, outpath, informat, resolver, threshold, model) | |
61 | 67 | elif os.path.isfile(inpath): |
62 | - process_text(inpath, outpath, informat, resolver, threshold) | |
68 | + process_text(inpath, outpath, informat, resolver, threshold, model) | |
63 | 69 | else: |
64 | 70 | eprint("Error: Specified input does not exist!") |
65 | 71 | |
66 | 72 | |
67 | -def process_directory(inpath, outpath, informat, resolver, threshold): | |
73 | +def process_directory(inpath, outpath, informat, resolver, threshold, model): | |
68 | 74 | inpath = os.path.abspath(inpath) |
69 | 75 | outpath = os.path.abspath(outpath) |
70 | 76 | |
... | ... | @@ -75,38 +81,38 @@ def process_directory(inpath, outpath, informat, resolver, threshold): |
75 | 81 | textname = os.path.splitext(os.path.basename(filename))[0] |
76 | 82 | textoutput = os.path.join(outpath, textname) |
77 | 83 | textinput = os.path.join(inpath, filename) |
78 | - process_text(textinput, textoutput, informat, resolver, threshold) | |
84 | + process_text(textinput, textoutput, informat, resolver, threshold, model) | |
79 | 85 | |
80 | 86 | |
81 | -def process_text(inpath, outpath, informat, resolver, threshold): | |
87 | +def process_text(inpath, outpath, informat, resolver, threshold, model): | |
82 | 88 | basename = os.path.basename(inpath) |
83 | 89 | if informat == 'mmax' and basename.endswith('.mmax'): |
84 | 90 | print (basename) |
85 | 91 | text = mmax.read(inpath) |
86 | 92 | if resolver == 'incremental': |
87 | - resolve.incremental(text, threshold) | |
93 | + resolve.incremental(text, threshold, model) | |
88 | 94 | elif resolver == 'entity_based': |
89 | - resolve.entity_based(text, threshold) | |
95 | + resolve.entity_based(text, threshold, model) | |
90 | 96 | elif resolver == 'closest': |
91 | - resolve.closest(text, threshold) | |
97 | + resolve.closest(text, threshold, model) | |
92 | 98 | elif resolver == 'siamese': |
93 | - resolve.siamese(text, threshold) | |
99 | + resolve.siamese(text, threshold, model) | |
94 | 100 | elif resolver == 'all2all': |
95 | - resolve.all2all(text, threshold) | |
101 | + resolve.all2all(text, threshold, model) | |
96 | 102 | mmax.write(inpath, outpath, text) |
97 | 103 | elif informat == 'tei': |
98 | 104 | print (basename) |
99 | 105 | text = tei.read(inpath) |
100 | 106 | if resolver == 'incremental': |
101 | - resolve.incremental(text, threshold) | |
107 | + resolve.incremental(text, threshold, model) | |
102 | 108 | elif resolver == 'entity_based': |
103 | - resolve.entity_based(text, threshold) | |
109 | + resolve.entity_based(text, threshold, model) | |
104 | 110 | elif resolver == 'closest': |
105 | - resolve.closest(text, threshold) | |
111 | + resolve.closest(text, threshold, model) | |
106 | 112 | elif resolver == 'siamese': |
107 | - resolve.siamese(text, threshold) | |
113 | + resolve.siamese(text, threshold, model) | |
108 | 114 | elif resolver == 'all2all': |
109 | - resolve.all2all(text, threshold) | |
115 | + resolve.all2all(text, threshold, model) | |
110 | 116 | tei.write(inpath, outpath, text) |
111 | 117 | |
112 | 118 | |
... | ... |
corneferencer/prepare_data.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | |
2 | + | |
3 | +import codecs | |
4 | +import os | |
5 | +import random | |
6 | +import sys | |
7 | + | |
8 | +from itertools import combinations | |
9 | +from argparse import ArgumentParser | |
10 | +from natsort import natsorted | |
11 | + | |
12 | +sys.path.append(os.path.abspath(os.path.join('..'))) | |
13 | + | |
14 | +from inout import mmax, tei | |
15 | +from inout.constants import INPUT_FORMATS | |
16 | +from utils import eprint | |
17 | +from corneferencer.resolvers import vectors | |
18 | + | |
19 | + | |
20 | +POS_COUNT = 0 | |
21 | +NEG_COUNT = 0 | |
22 | + | |
23 | + | |
24 | +def main(): | |
25 | + args = parse_arguments() | |
26 | + if not args.input: | |
27 | + eprint("Error: Input file(s) not specified!") | |
28 | + elif args.format not in INPUT_FORMATS: | |
29 | + eprint("Error: Unknown input file format!") | |
30 | + else: | |
31 | + process_texts(args.input, args.output, args.format, args.proportion) | |
32 | + | |
33 | + | |
34 | +def parse_arguments(): | |
35 | + parser = ArgumentParser(description='Corneferencer: data preparator for neural nets training.') | |
36 | + parser.add_argument('-i', '--input', type=str, action='store', | |
37 | + dest='input', default='', | |
38 | + help='input dir path') | |
39 | + parser.add_argument('-o', '--output', type=str, action='store', | |
40 | + dest='output', default='', | |
41 | + help='output path; if not specified writes output to standard output') | |
42 | + parser.add_argument('-f', '--format', type=str, action='store', | |
43 | + dest='format', default=INPUT_FORMATS[0], | |
44 | + help='input format; default: %s; possibilities: %s' | |
45 | + % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS))) | |
46 | + parser.add_argument('-p', '--proportion', type=int, action='store', | |
47 | + dest='proportion', default=5, | |
48 | + help='negative examples proportion; default: 5') | |
49 | + args = parser.parse_args() | |
50 | + return args | |
51 | + | |
52 | + | |
53 | +def process_texts(inpath, outpath, informat, proportion): | |
54 | + if os.path.isdir(inpath): | |
55 | + process_directory(inpath, outpath, informat, proportion) | |
56 | + else: | |
57 | + eprint("Error: Specified input does not exist or is not a directory!") | |
58 | + | |
59 | + | |
60 | +def process_directory(inpath, outpath, informat, proportion): | |
61 | + inpath = os.path.abspath(inpath) | |
62 | + outpath = os.path.abspath(outpath) | |
63 | + | |
64 | + try: | |
65 | + create_data_vectors(inpath, outpath, informat, proportion) | |
66 | + finally: | |
67 | + print ('Positives: ', POS_COUNT) | |
68 | + print ('Negatives: ', NEG_COUNT) | |
69 | + | |
70 | + | |
71 | +def create_data_vectors(inpath, outpath, informat, proportion): | |
72 | + features_file = codecs.open(outpath, 'w', 'utf-8') | |
73 | + | |
74 | + files = os.listdir(inpath) | |
75 | + files = natsorted(files) | |
76 | + | |
77 | + for filename in files: | |
78 | + textname = os.path.splitext(os.path.basename(filename))[0] | |
79 | + textinput = os.path.join(inpath, filename) | |
80 | + | |
81 | + print ('Processing text: %s' % textname) | |
82 | + text = None | |
83 | + if informat == 'mmax' and filename.endswith('.mmax'): | |
84 | + text = mmax.read(textinput, False) | |
85 | + elif informat == 'tei': | |
86 | + text = tei.read(textinput, False) | |
87 | + | |
88 | + positives, negatives = diff_mentions(text, proportion) | |
89 | + write_features(features_file, positives, negatives) | |
90 | + | |
91 | + | |
92 | +def diff_mentions(text, proportion): | |
93 | + sets = text.get_sets() | |
94 | + all_mentions = text.get_mentions() | |
95 | + positives = get_positives(sets) | |
96 | + positives, negatives = get_negatives_and_update_positives(all_mentions, positives, proportion) | |
97 | + return positives, negatives | |
98 | + | |
99 | + | |
100 | +def get_positives(sets): | |
101 | + positives = [] | |
102 | + for set_id in sets: | |
103 | + coref_set = sets[set_id] | |
104 | + positives.extend(list(combinations(coref_set, 2))) | |
105 | + return positives | |
106 | + | |
107 | + | |
108 | +def get_negatives_and_update_positives(all_mentions, positives, proportion): | |
109 | + all_pairs = list(combinations(all_mentions, 2)) | |
110 | + | |
111 | + all_pairs = set(all_pairs) | |
112 | + negatives = [pair for pair in all_pairs if pair not in positives] | |
113 | + samples_count = proportion * len(positives) | |
114 | + if samples_count > len(negatives): | |
115 | + samples_count = len(negatives) | |
116 | + if proportion == 1: | |
117 | + positives = random.sample(set(positives), samples_count) | |
118 | + print (u'Więcej przypadków pozytywnych niż negatywnych!') | |
119 | + negatives = random.sample(set(negatives), samples_count) | |
120 | + return positives, negatives | |
121 | + | |
122 | + | |
123 | +def write_features(features_file, positives, negatives): | |
124 | + global POS_COUNT | |
125 | + POS_COUNT += len(positives) | |
126 | + for pair in positives: | |
127 | + vector = vectors.get_pair_vector(pair[0], pair[1]) | |
128 | + vector.append(1.0) | |
129 | + features_file.write(u'%s\n' % u'\t'.join([str(feature) for feature in vector])) | |
130 | + | |
131 | + global NEG_COUNT | |
132 | + NEG_COUNT += len(negatives) | |
133 | + for pair in negatives: | |
134 | + vector = vectors.get_pair_vector(pair[0], pair[1]) | |
135 | + vector.append(0.0) | |
136 | + features_file.write(u'%s\n' % u'\t'.join([str(feature) for feature in vector])) | |
137 | + | |
138 | + | |
139 | +if __name__ == '__main__': | |
140 | + main() | |
... | ... |
corneferencer/resolvers/features.py
... | ... | @@ -72,97 +72,97 @@ def sentence_vec(mention): |
72 | 72 | |
73 | 73 | |
74 | 74 | def mention_type(mention): |
75 | - type_vec = [0] * 4 | |
75 | + type_vec = [0.0] * 4 | |
76 | 76 | if mention.head is None: |
77 | - type_vec[3] = 1 | |
77 | + type_vec[3] = 1.0 | |
78 | 78 | elif mention.head['ctag'] in constants.NOUN_TAGS: |
79 | - type_vec[0] = 1 | |
79 | + type_vec[0] = 1.0 | |
80 | 80 | elif mention.head['ctag'] in constants.PPRON_TAGS: |
81 | - type_vec[1] = 1 | |
81 | + type_vec[1] = 1.0 | |
82 | 82 | elif mention.head['ctag'] in constants.ZERO_TAGS: |
83 | - type_vec[2] = 1 | |
83 | + type_vec[2] = 1.0 | |
84 | 84 | else: |
85 | - type_vec[3] = 1 | |
85 | + type_vec[3] = 1.0 | |
86 | 86 | return type_vec |
87 | 87 | |
88 | 88 | |
89 | 89 | def is_first_second_person(mention): |
90 | 90 | if mention.head is None: |
91 | - return 0 | |
91 | + return 0.0 | |
92 | 92 | if mention.head['person'] in constants.FIRST_SECOND_PERSON: |
93 | - return 1 | |
94 | - return 0 | |
93 | + return 1.0 | |
94 | + return 0.0 | |
95 | 95 | |
96 | 96 | |
97 | 97 | def is_demonstrative(mention): |
98 | 98 | if mention.words[0]['base'].lower() in constants.INDICATIVE_PRONS_BASES: |
99 | - return 1 | |
100 | - return 0 | |
99 | + return 1.0 | |
100 | + return 0.0 | |
101 | 101 | |
102 | 102 | |
103 | 103 | def is_demonstrative_nominal(mention): |
104 | 104 | if mention.head is None: |
105 | - return 0 | |
105 | + return 0.0 | |
106 | 106 | if is_demonstrative(mention) and mention.head['ctag'] in constants.NOUN_TAGS: |
107 | - return 1 | |
108 | - return 0 | |
107 | + return 1.0 | |
108 | + return 0.0 | |
109 | 109 | |
110 | 110 | |
111 | 111 | def is_demonstrative_pronoun(mention): |
112 | 112 | if mention.head is None: |
113 | - return 0 | |
113 | + return 0.0 | |
114 | 114 | if (is_demonstrative(mention) and |
115 | 115 | (mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS)): |
116 | - return 1 | |
117 | - return 0 | |
116 | + return 1.0 | |
117 | + return 0.0 | |
118 | 118 | |
119 | 119 | |
120 | 120 | def is_refl_pronoun(mention): |
121 | 121 | if mention.head is None: |
122 | - return 0 | |
122 | + return 0.0 | |
123 | 123 | if mention.head['ctag'] in constants.SIEBIE_TAGS: |
124 | - return 1 | |
125 | - return 0 | |
124 | + return 1.0 | |
125 | + return 0.0 | |
126 | 126 | |
127 | 127 | |
128 | 128 | def is_first_in_sentence(mention): |
129 | 129 | if mention.first_in_sentence: |
130 | - return 1 | |
131 | - return 0 | |
130 | + return 1.0 | |
131 | + return 0.0 | |
132 | 132 | |
133 | 133 | |
134 | 134 | def is_zero_or_pronoun(mention): |
135 | 135 | if mention.head is None: |
136 | - return 0 | |
136 | + return 0.0 | |
137 | 137 | if mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS: |
138 | - return 1 | |
139 | - return 0 | |
138 | + return 1.0 | |
139 | + return 0.0 | |
140 | 140 | |
141 | 141 | |
142 | 142 | def head_contains_digit(mention): |
143 | 143 | _digits = re.compile('\d') |
144 | 144 | if _digits.search(mention.head_orth): |
145 | - return 1 | |
146 | - return 0 | |
145 | + return 1.0 | |
146 | + return 0.0 | |
147 | 147 | |
148 | 148 | |
149 | 149 | def mention_contains_digit(mention): |
150 | 150 | _digits = re.compile('\d') |
151 | 151 | if _digits.search(mention.text): |
152 | - return 1 | |
153 | - return 0 | |
152 | + return 1.0 | |
153 | + return 0.0 | |
154 | 154 | |
155 | 155 | |
156 | 156 | def contains_letter(mention): |
157 | 157 | if any(c.isalpha() for c in mention.text): |
158 | - return 1 | |
159 | - return 0 | |
158 | + return 1.0 | |
159 | + return 0.0 | |
160 | 160 | |
161 | 161 | |
162 | 162 | def post_modified(mention): |
163 | 163 | if mention.head_orth != mention.words[-1]['orth']: |
164 | - return 1 | |
165 | - return 0 | |
164 | + return 1.0 | |
165 | + return 0.0 | |
166 | 166 | |
167 | 167 | |
168 | 168 | # pair features |
... | ... | @@ -171,20 +171,20 @@ def distances_vec(ante, ana): |
171 | 171 | |
172 | 172 | mnts_intersect = pair_intersect(ante, ana) |
173 | 173 | |
174 | - words_dist = [0] * 11 | |
174 | + words_dist = [0.0] * 11 | |
175 | 175 | words_bucket = 0 |
176 | - if mnts_intersect != 1: | |
176 | + if mnts_intersect != 1.0: | |
177 | 177 | words_bucket = get_distance_bucket(ana.start_in_words - ante.end_in_words) |
178 | - words_dist[words_bucket] = 1 | |
178 | + words_dist[words_bucket] = 1.0 | |
179 | 179 | vec.extend(words_dist) |
180 | 180 | |
181 | - mentions_dist = [0] * 11 | |
181 | + mentions_dist = [0.0] * 11 | |
182 | 182 | mentions_bucket = 0 |
183 | - if mnts_intersect != 1: | |
183 | + if mnts_intersect != 1.0: | |
184 | 184 | mentions_bucket = get_distance_bucket(ana.position_in_mentions - ante.position_in_mentions) |
185 | 185 | if words_bucket == 10: |
186 | 186 | mentions_bucket = 10 |
187 | - mentions_dist[mentions_bucket] = 1 | |
187 | + mentions_dist[mentions_bucket] = 1.0 | |
188 | 188 | vec.extend(mentions_dist) |
189 | 189 | |
190 | 190 | vec.append(mnts_intersect) |
... | ... | @@ -196,45 +196,45 @@ def pair_intersect(ante, ana): |
196 | 196 | for ante_word in ante.words: |
197 | 197 | for ana_word in ana.words: |
198 | 198 | if ana_word['id'] == ante_word['id']: |
199 | - return 1 | |
200 | - return 0 | |
199 | + return 1.0 | |
200 | + return 0.0 | |
201 | 201 | |
202 | 202 | |
203 | 203 | def head_match(ante, ana): |
204 | 204 | if ante.head_orth.lower() == ana.head_orth.lower(): |
205 | - return 1 | |
206 | - return 0 | |
205 | + return 1.0 | |
206 | + return 0.0 | |
207 | 207 | |
208 | 208 | |
209 | 209 | def exact_match(ante, ana): |
210 | 210 | if ante.text.lower() == ana.text.lower(): |
211 | - return 1 | |
212 | - return 0 | |
211 | + return 1.0 | |
212 | + return 0.0 | |
213 | 213 | |
214 | 214 | |
215 | 215 | def base_match(ante, ana): |
216 | 216 | if ante.lemmatized_text.lower() == ana.lemmatized_text.lower(): |
217 | - return 1 | |
218 | - return 0 | |
217 | + return 1.0 | |
218 | + return 0.0 | |
219 | 219 | |
220 | 220 | |
221 | 221 | def ante_contains_rarest_from_ana(ante, ana): |
222 | 222 | ana_rarest = ana.rarest |
223 | 223 | for word in ante.words: |
224 | 224 | if word['base'] == ana_rarest['base']: |
225 | - return 1 | |
226 | - return 0 | |
225 | + return 1.0 | |
226 | + return 0.0 | |
227 | 227 | |
228 | 228 | |
229 | 229 | def agreement(ante, ana, tag_name): |
230 | - agr_vec = [0] * 3 | |
230 | + agr_vec = [0.0] * 3 | |
231 | 231 | if (ante.head is None or ana.head is None or |
232 | 232 | ante.head[tag_name] == 'unk' or ana.head[tag_name] == 'unk'): |
233 | - agr_vec[2] = 1 | |
233 | + agr_vec[2] = 1.0 | |
234 | 234 | elif ante.head[tag_name] == ana.head[tag_name]: |
235 | - agr_vec[0] = 1 | |
235 | + agr_vec[0] = 1.0 | |
236 | 236 | else: |
237 | - agr_vec[1] = 1 | |
237 | + agr_vec[1] = 1.0 | |
238 | 238 | return agr_vec |
239 | 239 | |
240 | 240 | |
... | ... | @@ -243,72 +243,72 @@ def is_acronym(ante, ana): |
243 | 243 | return check_one_way_acronym(ana.text, ante.text) |
244 | 244 | if ante.text.upper() == ante.text: |
245 | 245 | return check_one_way_acronym(ante.text, ana.text) |
246 | - return 0 | |
246 | + return 0.0 | |
247 | 247 | |
248 | 248 | |
249 | 249 | def same_sentence(ante, ana): |
250 | 250 | if ante.sentence_id == ana.sentence_id: |
251 | - return 1 | |
252 | - return 0 | |
251 | + return 1.0 | |
252 | + return 0.0 | |
253 | 253 | |
254 | 254 | |
255 | 255 | def neighbouring_sentence(ante, ana): |
256 | 256 | if ana.sentence_id - ante.sentence_id == 1: |
257 | - return 1 | |
258 | - return 0 | |
257 | + return 1.0 | |
258 | + return 0.0 | |
259 | 259 | |
260 | 260 | |
261 | 261 | def cousin_sentence(ante, ana): |
262 | 262 | if ana.sentence_id - ante.sentence_id == 2: |
263 | - return 1 | |
264 | - return 0 | |
263 | + return 1.0 | |
264 | + return 0.0 | |
265 | 265 | |
266 | 266 | |
267 | 267 | def distant_sentence(ante, ana): |
268 | 268 | if ana.sentence_id - ante.sentence_id > 2: |
269 | - return 1 | |
270 | - return 0 | |
269 | + return 1.0 | |
270 | + return 0.0 | |
271 | 271 | |
272 | 272 | |
273 | 273 | def same_paragraph(ante, ana): |
274 | 274 | if ante.paragraph_id == ana.paragraph_id: |
275 | - return 1 | |
276 | - return 0 | |
275 | + return 1.0 | |
276 | + return 0.0 | |
277 | 277 | |
278 | 278 | |
279 | 279 | def flat_gender_agreement(ante, ana): |
280 | - agr_vec = [0] * 3 | |
280 | + agr_vec = [0.0] * 3 | |
281 | 281 | if (ante.head is None or ana.head is None or |
282 | 282 | ante.head['gender'] == 'unk' or ana.head['gender'] == 'unk'): |
283 | - agr_vec[2] = 1 | |
283 | + agr_vec[2] = 1.0 | |
284 | 284 | elif (ante.head['gender'] == ana.head['gender'] or |
285 | 285 | (ante.head['gender'] in constants.MASCULINE_TAGS and ana.head['gender'] in constants.MASCULINE_TAGS)): |
286 | - agr_vec[0] = 1 | |
286 | + agr_vec[0] = 1.0 | |
287 | 287 | else: |
288 | - agr_vec[1] = 1 | |
288 | + agr_vec[1] = 1.0 | |
289 | 289 | return agr_vec |
290 | 290 | |
291 | 291 | |
292 | 292 | def left_match(ante, ana): |
293 | 293 | if (ante.text.lower().startswith(ana.text.lower()) or |
294 | 294 | ana.text.lower().startswith(ante.text.lower())): |
295 | - return 1 | |
296 | - return 0 | |
295 | + return 1.0 | |
296 | + return 0.0 | |
297 | 297 | |
298 | 298 | |
299 | 299 | def right_match(ante, ana): |
300 | 300 | if (ante.text.lower().endswith(ana.text.lower()) or |
301 | 301 | ana.text.lower().endswith(ante.text.lower())): |
302 | - return 1 | |
303 | - return 0 | |
302 | + return 1.0 | |
303 | + return 0.0 | |
304 | 304 | |
305 | 305 | |
306 | 306 | def abbrev2(ante, ana): |
307 | 307 | ante_abbrev = get_abbrev(ante) |
308 | 308 | ana_abbrev = get_abbrev(ana) |
309 | 309 | if ante.head_orth == ana_abbrev or ana.head_orth == ante_abbrev: |
310 | - return 1 | |
311 | - return 0 | |
310 | + return 1.0 | |
311 | + return 0.0 | |
312 | 312 | |
313 | 313 | |
314 | 314 | def string_kernel(ante, ana): |
... | ... | @@ -326,7 +326,7 @@ def head_string_kernel(ante, ana): |
326 | 326 | def wordnet_synonyms(ante, ana): |
327 | 327 | ante_synonyms = set() |
328 | 328 | if ante.head is None or ana.head is None: |
329 | - return 0 | |
329 | + return 0.0 | |
330 | 330 | |
331 | 331 | if ante.head['base'] in conf.LEMMA2SYNONYMS: |
332 | 332 | ante_synonyms = conf.LEMMA2SYNONYMS[ante.head['base']] |
... | ... | @@ -336,13 +336,13 @@ def wordnet_synonyms(ante, ana): |
336 | 336 | ana_synonyms = conf.LEMMA2SYNONYMS[ana.head['base']] |
337 | 337 | |
338 | 338 | if ana.head['base'] in ante_synonyms or ante.head['base'] in ana_synonyms: |
339 | - return 1 | |
340 | - return 0 | |
339 | + return 1.0 | |
340 | + return 0.0 | |
341 | 341 | |
342 | 342 | |
343 | 343 | def wordnet_ana_is_hypernym(ante, ana): |
344 | 344 | if ante.head is None or ana.head is None: |
345 | - return 0 | |
345 | + return 0.0 | |
346 | 346 | |
347 | 347 | ante_hypernyms = set() |
348 | 348 | if ante.head['base'] in conf.LEMMA2HYPERNYMS: |
... | ... | @@ -353,16 +353,16 @@ def wordnet_ana_is_hypernym(ante, ana): |
353 | 353 | ana_hypernyms = conf.LEMMA2HYPERNYMS[ana.head['base']] |
354 | 354 | |
355 | 355 | if not ante_hypernyms or not ana_hypernyms: |
356 | - return 0 | |
356 | + return 0.0 | |
357 | 357 | |
358 | 358 | if ana.head['base'] in ante_hypernyms: |
359 | - return 1 | |
360 | - return 0 | |
359 | + return 1.0 | |
360 | + return 0.0 | |
361 | 361 | |
362 | 362 | |
363 | 363 | def wordnet_ante_is_hypernym(ante, ana): |
364 | 364 | if ante.head is None or ana.head is None: |
365 | - return 0 | |
365 | + return 0.0 | |
366 | 366 | |
367 | 367 | ana_hypernyms = set() |
368 | 368 | if ana.head['base'] in conf.LEMMA2HYPERNYMS: |
... | ... | @@ -373,18 +373,18 @@ def wordnet_ante_is_hypernym(ante, ana): |
373 | 373 | ante_hypernyms = conf.LEMMA2HYPERNYMS[ante.head['base']] |
374 | 374 | |
375 | 375 | if not ante_hypernyms or not ana_hypernyms: |
376 | - return 0 | |
376 | + return 0.0 | |
377 | 377 | |
378 | 378 | if ante.head['base'] in ana_hypernyms: |
379 | - return 1 | |
380 | - return 0 | |
379 | + return 1.0 | |
380 | + return 0.0 | |
381 | 381 | |
382 | 382 | |
383 | 383 | def wikipedia_link(ante, ana): |
384 | 384 | ante_base = ante.lemmatized_text.lower() |
385 | 385 | ana_base = ana.lemmatized_text.lower() |
386 | 386 | if ante_base == ana_base: |
387 | - return 1 | |
387 | + return 1.0 | |
388 | 388 | |
389 | 389 | ante_links = set() |
390 | 390 | if ante_base in conf.TITLE2LINKS: |
... | ... | @@ -395,16 +395,16 @@ def wikipedia_link(ante, ana): |
395 | 395 | ana_links = conf.TITLE2LINKS[ana_base] |
396 | 396 | |
397 | 397 | if ana_base in ante_links or ante_base in ana_links: |
398 | - return 1 | |
398 | + return 1.0 | |
399 | 399 | |
400 | - return 0 | |
400 | + return 0.0 | |
401 | 401 | |
402 | 402 | |
403 | 403 | def wikipedia_mutual_link(ante, ana): |
404 | 404 | ante_base = ante.lemmatized_text.lower() |
405 | 405 | ana_base = ana.lemmatized_text.lower() |
406 | 406 | if ante_base == ana_base: |
407 | - return 1 | |
407 | + return 1.0 | |
408 | 408 | |
409 | 409 | ante_links = set() |
410 | 410 | if ante_base in conf.TITLE2LINKS: |
... | ... | @@ -415,52 +415,52 @@ def wikipedia_mutual_link(ante, ana): |
415 | 415 | ana_links = conf.TITLE2LINKS[ana_base] |
416 | 416 | |
417 | 417 | if ana_base in ante_links and ante_base in ana_links: |
418 | - return 1 | |
418 | + return 1.0 | |
419 | 419 | |
420 | - return 0 | |
420 | + return 0.0 | |
421 | 421 | |
422 | 422 | |
423 | 423 | def wikipedia_redirect(ante, ana): |
424 | 424 | ante_base = ante.lemmatized_text.lower() |
425 | 425 | ana_base = ana.lemmatized_text.lower() |
426 | 426 | if ante_base == ana_base: |
427 | - return 1 | |
427 | + return 1.0 | |
428 | 428 | |
429 | 429 | if ante_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ante_base] == ana_base: |
430 | - return 1 | |
430 | + return 1.0 | |
431 | 431 | |
432 | 432 | if ana_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ana_base] == ante_base: |
433 | - return 1 | |
433 | + return 1.0 | |
434 | 434 | |
435 | - return 0 | |
435 | + return 0.0 | |
436 | 436 | |
437 | 437 | |
438 | 438 | def samesent_anapron_antefirstinpar(ante, ana): |
439 | 439 | if same_sentence(ante, ana) and is_zero_or_pronoun(ana) and ante.first_in_paragraph: |
440 | - return 1 | |
441 | - return 0 | |
440 | + return 1.0 | |
441 | + return 0.0 | |
442 | 442 | |
443 | 443 | |
444 | 444 | def samesent_antefirstinpar_personnumbermatch(ante, ana): |
445 | 445 | if (same_sentence(ante, ana) and ante.first_in_paragraph |
446 | 446 | and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]): |
447 | - return 1 | |
448 | - return 0 | |
447 | + return 1.0 | |
448 | + return 0.0 | |
449 | 449 | |
450 | 450 | |
451 | 451 | def adjsent_anapron_adjmen_personnumbermatch(ante, ana): |
452 | 452 | if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana) |
453 | 453 | and ana.position_in_mentions - ante.position_in_mentions == 1 |
454 | 454 | and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]): |
455 | - return 1 | |
456 | - return 0 | |
455 | + return 1.0 | |
456 | + return 0.0 | |
457 | 457 | |
458 | 458 | |
459 | 459 | def adjsent_anapron_adjmen(ante, ana): |
460 | 460 | if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana) |
461 | 461 | and ana.position_in_mentions - ante.position_in_mentions == 1): |
462 | - return 1 | |
463 | - return 0 | |
462 | + return 1.0 | |
463 | + return 0.0 | |
464 | 464 | |
465 | 465 | |
466 | 466 | # supporting functions |
... | ... | @@ -523,8 +523,8 @@ def check_one_way_acronym(acronym, expression): |
523 | 523 | if expr2: |
524 | 524 | initials += expr2[0].upper() |
525 | 525 | if acronym == initials: |
526 | - return 1 | |
527 | - return 0 | |
526 | + return 1.0 | |
527 | + return 0.0 | |
528 | 528 | |
529 | 529 | |
530 | 530 | def get_abbrev(mention): |
... | ... |
corneferencer/resolvers/resolve.py
1 | 1 | import numpy |
2 | 2 | |
3 | -from conf import NEURAL_MODEL | |
4 | -from corneferencer.resolvers import features | |
5 | -from corneferencer.resolvers.vectors import get_pair_features, get_pair_vector | |
3 | +from corneferencer.resolvers import features, vectors | |
6 | 4 | |
7 | 5 | |
8 | -def siamese(text, threshold): | |
6 | +def siamese(text, threshold, neural_model): | |
9 | 7 | last_set_id = 0 |
10 | 8 | for i, ana in enumerate(text.mentions): |
11 | 9 | if i > 0: |
12 | 10 | for ante in reversed(text.mentions[:i]): |
13 | 11 | if not features.pair_intersect(ante, ana): |
14 | - pair_features = get_pair_features(ante, ana) | |
12 | + pair_features = vectors.get_pair_features(ante, ana) | |
15 | 13 | |
16 | 14 | ante_vec = [] |
17 | 15 | ante_vec.extend(ante.features) |
... | ... | @@ -23,7 +21,7 @@ def siamese(text, threshold): |
23 | 21 | ana_vec.extend(pair_features) |
24 | 22 | ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32) |
25 | 23 | |
26 | - prediction = NEURAL_MODEL.predict([ante_sample, ana_sample])[0] | |
24 | + prediction = neural_model.predict([ante_sample, ana_sample])[0] | |
27 | 25 | |
28 | 26 | if prediction < threshold: |
29 | 27 | if ante.set: |
... | ... | @@ -37,7 +35,7 @@ def siamese(text, threshold): |
37 | 35 | |
38 | 36 | |
39 | 37 | # incremental resolve algorithm |
40 | -def incremental(text, threshold): | |
38 | +def incremental(text, threshold, neural_model): | |
41 | 39 | last_set_id = 0 |
42 | 40 | for i, ana in enumerate(text.mentions): |
43 | 41 | if i > 0: |
... | ... | @@ -45,9 +43,9 @@ def incremental(text, threshold): |
45 | 43 | best_ante = None |
46 | 44 | for ante in text.mentions[:i]: |
47 | 45 | if not features.pair_intersect(ante, ana): |
48 | - pair_vec = get_pair_vector(ante, ana) | |
46 | + pair_vec = vectors.get_pair_vector(ante, ana) | |
49 | 47 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) |
50 | - prediction = NEURAL_MODEL.predict(sample)[0] | |
48 | + prediction = neural_model.predict(sample)[0] | |
51 | 49 | if prediction > threshold and prediction >= best_prediction: |
52 | 50 | best_prediction = prediction |
53 | 51 | best_ante = ante |
... | ... | @@ -62,86 +60,21 @@ def incremental(text, threshold): |
62 | 60 | |
63 | 61 | |
64 | 62 | # all2all resolve algorithm |
65 | -def all2all_debug(text, threshold): | |
66 | - last_set_id = 0 | |
67 | - for pos1, mnt1 in enumerate(text.mentions): | |
68 | - best_prediction = 0.0 | |
69 | - best_link = None | |
70 | - for pos2, mnt2 in enumerate(text.mentions): | |
71 | - if (mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2): | |
72 | - ante = mnt1 | |
73 | - ana = mnt2 | |
74 | - if pos2 < pos1: | |
75 | - ante = mnt2 | |
76 | - ana = mnt1 | |
77 | - pair_vec = get_pair_vector(ante, ana) | |
78 | - sample = numpy.asarray([pair_vec], dtype=numpy.float32) | |
79 | - prediction = NEURAL_MODEL.predict(sample)[0] | |
80 | - if prediction > threshold and prediction > best_prediction: | |
81 | - best_prediction = prediction | |
82 | - best_link = mnt2 | |
83 | - if best_link is not None: | |
84 | - if best_link.set and not mnt1.set: | |
85 | - mnt1.set = best_link.set | |
86 | - elif best_link.set and mnt1.set: | |
87 | - text.merge_sets(best_link.set, mnt1.set) | |
88 | - elif not best_link.set and not mnt1.set: | |
89 | - str_set_id = 'set_%d' % last_set_id | |
90 | - best_link.set = str_set_id | |
91 | - mnt1.set = str_set_id | |
92 | - last_set_id += 1 | |
93 | - | |
94 | - | |
95 | -def all2all_v1(text, threshold): | |
96 | - last_set_id = 0 | |
97 | - for pos1, mnt1 in enumerate(text.mentions): | |
98 | - best_prediction = 0.0 | |
99 | - best_link = None | |
100 | - for pos2, mnt2 in enumerate(text.mentions): | |
101 | - if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set) | |
102 | - and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): | |
103 | - ante = mnt1 | |
104 | - ana = mnt2 | |
105 | - if pos2 < pos1: | |
106 | - ante = mnt2 | |
107 | - ana = mnt1 | |
108 | - pair_vec = get_pair_vector(ante, ana) | |
109 | - sample = numpy.asarray([pair_vec], dtype=numpy.float32) | |
110 | - prediction = NEURAL_MODEL.predict(sample)[0] | |
111 | - if prediction > threshold and prediction > best_prediction: | |
112 | - best_prediction = prediction | |
113 | - best_link = mnt2 | |
114 | - if best_link is not None: | |
115 | - if best_link.set and not mnt1.set: | |
116 | - mnt1.set = best_link.set | |
117 | - elif not best_link.set and mnt1.set: | |
118 | - best_link.set = mnt1.set | |
119 | - elif best_link.set and mnt1.set: | |
120 | - text.merge_sets(best_link.set, mnt1.set) | |
121 | - elif not best_link.set and not mnt1.set: | |
122 | - str_set_id = 'set_%d' % last_set_id | |
123 | - best_link.set = str_set_id | |
124 | - mnt1.set = str_set_id | |
125 | - last_set_id += 1 | |
126 | - | |
127 | - | |
128 | -def all2all(text, threshold): | |
63 | +def all2all(text, threshold, neural_model): | |
129 | 64 | last_set_id = 0 |
130 | 65 | sets = text.get_sets() |
131 | 66 | for pos1, mnt1 in enumerate(text.mentions): |
132 | 67 | best_prediction = 0.0 |
133 | 68 | best_link = None |
134 | 69 | for pos2, mnt2 in enumerate(text.mentions): |
135 | - if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set) | |
136 | - and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): | |
70 | + if (pos2 > pos1 and | |
71 | + (mnt1.set != mnt2.set or not mnt1.set or not mnt2.set) | |
72 | + and not features.pair_intersect(mnt1, mnt2)): | |
137 | 73 | ante = mnt1 |
138 | 74 | ana = mnt2 |
139 | - if pos2 < pos1: | |
140 | - ante = mnt2 | |
141 | - ana = mnt1 | |
142 | - pair_vec = get_pair_vector(ante, ana) | |
75 | + pair_vec = vectors.get_pair_vector(ante, ana) | |
143 | 76 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) |
144 | - prediction = NEURAL_MODEL.predict(sample)[0] | |
77 | + prediction = neural_model.predict(sample)[0] | |
145 | 78 | if prediction > threshold and prediction > best_prediction: |
146 | 79 | best_prediction = prediction |
147 | 80 | best_link = mnt2 |
... | ... | @@ -163,12 +96,12 @@ def all2all(text, threshold): |
163 | 96 | |
164 | 97 | |
165 | 98 | # entity based resolve algorithm |
166 | -def entity_based(text, threshold): | |
99 | +def entity_based(text, threshold, neural_model): | |
167 | 100 | sets = [] |
168 | 101 | last_set_id = 0 |
169 | 102 | for i, ana in enumerate(text.mentions): |
170 | 103 | if i > 0: |
171 | - best_fit = get_best_set(sets, ana, threshold) | |
104 | + best_fit = get_best_set(sets, ana, threshold, neural_model) | |
172 | 105 | if best_fit is not None: |
173 | 106 | ana.set = best_fit['set_id'] |
174 | 107 | best_fit['mentions'].append(ana) |
... | ... | @@ -188,25 +121,25 @@ def entity_based(text, threshold): |
188 | 121 | remove_singletons(sets) |
189 | 122 | |
190 | 123 | |
191 | -def get_best_set(sets, ana, threshold): | |
124 | +def get_best_set(sets, ana, threshold, neural_model): | |
192 | 125 | best_prediction = 0.0 |
193 | 126 | best_set = None |
194 | 127 | for s in sets: |
195 | - accuracy = predict_set(s['mentions'], ana) | |
128 | + accuracy = predict_set(s['mentions'], ana, neural_model) | |
196 | 129 | if accuracy > threshold and accuracy >= best_prediction: |
197 | 130 | best_prediction = accuracy |
198 | 131 | best_set = s |
199 | 132 | return best_set |
200 | 133 | |
201 | 134 | |
202 | -def predict_set(mentions, ana): | |
135 | +def predict_set(mentions, ana, neural_model): | |
203 | 136 | prediction_sum = 0.0 |
204 | 137 | for mnt in mentions: |
205 | 138 | prediction = 0.0 |
206 | 139 | if not features.pair_intersect(mnt, ana): |
207 | - pair_vec = get_pair_vector(mnt, ana) | |
140 | + pair_vec = vectors.get_pair_vector(mnt, ana) | |
208 | 141 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) |
209 | - prediction = NEURAL_MODEL.predict(sample)[0] | |
142 | + prediction = neural_model.predict(sample)[0] | |
210 | 143 | prediction_sum += prediction |
211 | 144 | return prediction_sum / float(len(mentions)) |
212 | 145 | |
... | ... | @@ -218,15 +151,15 @@ def remove_singletons(sets): |
218 | 151 | |
219 | 152 | |
220 | 153 | # closest resolve algorithm |
221 | -def closest(text, threshold): | |
154 | +def closest(text, threshold, neural_model): | |
222 | 155 | last_set_id = 0 |
223 | 156 | for i, ana in enumerate(text.mentions): |
224 | 157 | if i > 0: |
225 | 158 | for ante in reversed(text.mentions[:i]): |
226 | 159 | if not features.pair_intersect(ante, ana): |
227 | - pair_vec = get_pair_vector(ante, ana) | |
160 | + pair_vec = vectors.get_pair_vector(ante, ana) | |
228 | 161 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) |
229 | - prediction = NEURAL_MODEL.predict(sample)[0] | |
162 | + prediction = neural_model.predict(sample)[0] | |
230 | 163 | if prediction > threshold: |
231 | 164 | if ante.set: |
232 | 165 | ana.set = ante.set |
... | ... |
corneferencer/utils.py