Commit c29c36cde4fa9ea17a625fa2b0e8ea5c1192ec25
1 parent
db88d6e4
Add prepare data script and other minor improvements.
Showing
9 changed files
with
321 additions
and
239 deletions
conf.py
@@ -30,7 +30,6 @@ W2V_MODEL_PATH = os.path.join(MAIN_PATH, 'models', W2V_MODEL_NAME) | @@ -30,7 +30,6 @@ W2V_MODEL_PATH = os.path.join(MAIN_PATH, 'models', W2V_MODEL_NAME) | ||
30 | W2V_MODEL = Word2Vec.load(W2V_MODEL_PATH) | 30 | W2V_MODEL = Word2Vec.load(W2V_MODEL_PATH) |
31 | 31 | ||
32 | NEURAL_MODEL_PATH = os.path.join(MAIN_PATH, 'models', NEURAL_MODEL_NAME) | 32 | NEURAL_MODEL_PATH = os.path.join(MAIN_PATH, 'models', NEURAL_MODEL_NAME) |
33 | -NEURAL_MODEL = utils.initialize_neural_model(NEURAL_MODEL_ARCHITECTURE, NUMBER_OF_FEATURES, NEURAL_MODEL_PATH) | ||
34 | 33 | ||
35 | FREQ_LIST_PATH = os.path.join(MAIN_PATH, 'freq', FREQ_LIST_NAME) | 34 | FREQ_LIST_PATH = os.path.join(MAIN_PATH, 'freq', FREQ_LIST_NAME) |
36 | FREQ_LIST = utils.load_freq_list(FREQ_LIST_PATH) | 35 | FREQ_LIST = utils.load_freq_list(FREQ_LIST_PATH) |
corneferencer/entities.py
1 | -from corneferencer.resolvers.vectors import get_mention_features | 1 | +from corneferencer.resolvers import vectors |
2 | 2 | ||
3 | 3 | ||
4 | class Text: | 4 | class Text: |
@@ -19,6 +19,9 @@ class Text: | @@ -19,6 +19,9 @@ class Text: | ||
19 | return mnt | 19 | return mnt |
20 | return None | 20 | return None |
21 | 21 | ||
22 | + def get_mentions(self): | ||
23 | + return self.mentions | ||
24 | + | ||
22 | def get_sets(self): | 25 | def get_sets(self): |
23 | sets = {} | 26 | sets = {} |
24 | for mnt in self.mentions: | 27 | for mnt in self.mentions: |
@@ -62,4 +65,4 @@ class Mention: | @@ -62,4 +65,4 @@ class Mention: | ||
62 | self.sentence_id = sentence_id | 65 | self.sentence_id = sentence_id |
63 | self.first_in_sentence = first_in_sentence | 66 | self.first_in_sentence = first_in_sentence |
64 | self.first_in_paragraph = first_in_paragraph | 67 | self.first_in_paragraph = first_in_paragraph |
65 | - self.features = get_mention_features(self) | 68 | + self.features = vectors.get_mention_features(self) |
corneferencer/inout/mmax.py
@@ -3,11 +3,11 @@ import shutil | @@ -3,11 +3,11 @@ import shutil | ||
3 | 3 | ||
4 | from lxml import etree | 4 | from lxml import etree |
5 | 5 | ||
6 | -from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST | 6 | +import conf |
7 | from corneferencer.entities import Mention, Text | 7 | from corneferencer.entities import Mention, Text |
8 | 8 | ||
9 | 9 | ||
10 | -def read(inpath): | 10 | +def read(inpath, clear_mentions=conf.CLEAR_INPUT): |
11 | textname = os.path.splitext(os.path.basename(inpath))[0] | 11 | textname = os.path.splitext(os.path.basename(inpath))[0] |
12 | textdir = os.path.dirname(inpath) | 12 | textdir = os.path.dirname(inpath) |
13 | 13 | ||
@@ -15,11 +15,11 @@ def read(inpath): | @@ -15,11 +15,11 @@ def read(inpath): | ||
15 | words_path = os.path.join(textdir, '%s_words.xml' % textname) | 15 | words_path = os.path.join(textdir, '%s_words.xml' % textname) |
16 | 16 | ||
17 | text = Text(textname) | 17 | text = Text(textname) |
18 | - text.mentions = read_mentions(mentions_path, words_path) | 18 | + text.mentions = read_mentions(mentions_path, words_path, clear_mentions) |
19 | return text | 19 | return text |
20 | 20 | ||
21 | 21 | ||
22 | -def read_mentions(mentions_path, words_path): | 22 | +def read_mentions(mentions_path, words_path, clear_mentions=conf.CLEAR_INPUT): |
23 | mentions = [] | 23 | mentions = [] |
24 | mentions_tree = etree.parse(mentions_path) | 24 | mentions_tree = etree.parse(mentions_path) |
25 | markables = mentions_tree.xpath("//ns:markable", | 25 | markables = mentions_tree.xpath("//ns:markable", |
@@ -43,7 +43,7 @@ def read_mentions(mentions_path, words_path): | @@ -43,7 +43,7 @@ def read_mentions(mentions_path, words_path): | ||
43 | 43 | ||
44 | head = get_head(head_orth, mention_words) | 44 | head = get_head(head_orth, mention_words) |
45 | mention_group = '' | 45 | mention_group = '' |
46 | - if markable.attrib['mention_group'] != 'empty' and not CLEAR_INPUT: | 46 | + if markable.attrib['mention_group'] != 'empty' and not clear_mentions: |
47 | mention_group = markable.attrib['mention_group'] | 47 | mention_group = markable.attrib['mention_group'] |
48 | mention = Mention(mnt_id=markable.attrib['id'], | 48 | mention = Mention(mnt_id=markable.attrib['id'], |
49 | text=span_to_text(span, words, 'orth'), | 49 | text=span_to_text(span, words, 'orth'), |
@@ -189,7 +189,7 @@ def get_prec_context(mention_start, words): | @@ -189,7 +189,7 @@ def get_prec_context(mention_start, words): | ||
189 | while context_start >= 0: | 189 | while context_start >= 0: |
190 | if not word_to_ignore(words[context_start]): | 190 | if not word_to_ignore(words[context_start]): |
191 | context.append(words[context_start]) | 191 | context.append(words[context_start]) |
192 | - if len(context) == CONTEXT: | 192 | + if len(context) == conf.CONTEXT: |
193 | break | 193 | break |
194 | context_start -= 1 | 194 | context_start -= 1 |
195 | context.reverse() | 195 | context.reverse() |
@@ -222,7 +222,7 @@ def get_follow_context(mention_end, words): | @@ -222,7 +222,7 @@ def get_follow_context(mention_end, words): | ||
222 | while context_end < len(words): | 222 | while context_end < len(words): |
223 | if not word_to_ignore(words[context_end]): | 223 | if not word_to_ignore(words[context_end]): |
224 | context.append(words[context_end]) | 224 | context.append(words[context_end]) |
225 | - if len(context) == CONTEXT: | 225 | + if len(context) == conf.CONTEXT: |
226 | break | 226 | break |
227 | context_end += 1 | 227 | context_end += 1 |
228 | return context | 228 | return context |
@@ -349,9 +349,8 @@ def get_rarest_word(words): | @@ -349,9 +349,8 @@ def get_rarest_word(words): | ||
349 | rarest_word = words[0] | 349 | rarest_word = words[0] |
350 | for i, word in enumerate(words): | 350 | for i, word in enumerate(words): |
351 | word_freq = 0 | 351 | word_freq = 0 |
352 | - if word['base'] in FREQ_LIST: | ||
353 | - word_freq = FREQ_LIST[word['base']] | ||
354 | - | 352 | + if word['base'] in conf.FREQ_LIST: |
353 | + word_freq = conf.FREQ_LIST[word['base']] | ||
355 | if i == 0 or word_freq < min_freq: | 354 | if i == 0 or word_freq < min_freq: |
356 | min_freq = word_freq | 355 | min_freq = word_freq |
357 | rarest_word = word | 356 | rarest_word = word |
corneferencer/inout/tei.py
@@ -4,7 +4,7 @@ import shutil | @@ -4,7 +4,7 @@ import shutil | ||
4 | 4 | ||
5 | from lxml import etree | 5 | from lxml import etree |
6 | 6 | ||
7 | -from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST | 7 | +import conf |
8 | from corneferencer.entities import Mention, Text | 8 | from corneferencer.entities import Mention, Text |
9 | from corneferencer.utils import eprint | 9 | from corneferencer.utils import eprint |
10 | 10 | ||
@@ -18,7 +18,7 @@ NSMAP = {None: TEI_NS, | @@ -18,7 +18,7 @@ NSMAP = {None: TEI_NS, | ||
18 | 'xi': XI_NS} | 18 | 'xi': XI_NS} |
19 | 19 | ||
20 | 20 | ||
21 | -def read(inpath): | 21 | +def read(inpath, clear_mentions=conf.CLEAR_INPUT): |
22 | textname = os.path.basename(inpath) | 22 | textname = os.path.basename(inpath) |
23 | 23 | ||
24 | text = Text(textname) | 24 | text = Text(textname) |
@@ -49,7 +49,7 @@ def read(inpath): | @@ -49,7 +49,7 @@ def read(inpath): | ||
49 | eprint("Error: missing mentions layer for text %s!" % textname) | 49 | eprint("Error: missing mentions layer for text %s!" % textname) |
50 | return None | 50 | return None |
51 | 51 | ||
52 | - if os.path.exists(ann_coreference) and not CLEAR_INPUT: | 52 | + if os.path.exists(ann_coreference) and not clear_mentions: |
53 | add_coreference_layer(ann_coreference, text) | 53 | add_coreference_layer(ann_coreference, text) |
54 | 54 | ||
55 | return text | 55 | return text |
@@ -215,6 +215,9 @@ def get_mention(mention, mnt_id, segments, segments_ids, paragraph_id, sentence_ | @@ -215,6 +215,9 @@ def get_mention(mention, mnt_id, segments, segments_ids, paragraph_id, sentence_ | ||
215 | semh_id = get_fval(f).split('#')[-1] | 215 | semh_id = get_fval(f).split('#')[-1] |
216 | semh = segments[semh_id] | 216 | semh = segments[semh_id] |
217 | 217 | ||
218 | + if len(mnt_segments) == 0: | ||
219 | + mnt_segments.append(semh) | ||
220 | + | ||
218 | (sent_segments, prec_context, follow_context, | 221 | (sent_segments, prec_context, follow_context, |
219 | first_in_sentence, first_in_paragraph) = get_context(mnt_segments, segments, segments_ids) | 222 | first_in_sentence, first_in_paragraph) = get_context(mnt_segments, segments, segments_ids) |
220 | 223 | ||
@@ -272,7 +275,7 @@ def get_prec_context(mention_start, segments, segments_ids): | @@ -272,7 +275,7 @@ def get_prec_context(mention_start, segments, segments_ids): | ||
272 | while context_start >= 0: | 275 | while context_start >= 0: |
273 | if not word_to_ignore(segments[segments_ids[context_start]]): | 276 | if not word_to_ignore(segments[segments_ids[context_start]]): |
274 | context.append(segments[segments_ids[context_start]]) | 277 | context.append(segments[segments_ids[context_start]]) |
275 | - if len(context) == CONTEXT: | 278 | + if len(context) == conf.CONTEXT: |
276 | break | 279 | break |
277 | context_start -= 1 | 280 | context_start -= 1 |
278 | context.reverse() | 281 | context.reverse() |
@@ -285,7 +288,7 @@ def get_follow_context(mention_end, segments, segments_ids): | @@ -285,7 +288,7 @@ def get_follow_context(mention_end, segments, segments_ids): | ||
285 | while context_end < len(segments): | 288 | while context_end < len(segments): |
286 | if not word_to_ignore(segments[segments_ids[context_end]]): | 289 | if not word_to_ignore(segments[segments_ids[context_end]]): |
287 | context.append(segments[segments_ids[context_end]]) | 290 | context.append(segments[segments_ids[context_end]]) |
288 | - if len(context) == CONTEXT: | 291 | + if len(context) == conf.CONTEXT: |
289 | break | 292 | break |
290 | context_end += 1 | 293 | context_end += 1 |
291 | return context | 294 | return context |
@@ -341,8 +344,8 @@ def get_rarest_word(words): | @@ -341,8 +344,8 @@ def get_rarest_word(words): | ||
341 | rarest_word = words[0] | 344 | rarest_word = words[0] |
342 | for i, word in enumerate(words): | 345 | for i, word in enumerate(words): |
343 | word_freq = 0 | 346 | word_freq = 0 |
344 | - if word['base'] in FREQ_LIST: | ||
345 | - word_freq = FREQ_LIST[word['base']] | 347 | + if word['base'] in conf.FREQ_LIST: |
348 | + word_freq = conf.FREQ_LIST[word['base']] | ||
346 | 349 | ||
347 | if i == 0 or word_freq < min_freq: | 350 | if i == 0 or word_freq < min_freq: |
348 | min_freq = word_freq | 351 | min_freq = word_freq |
corneferencer/main.py
@@ -4,9 +4,11 @@ import sys | @@ -4,9 +4,11 @@ import sys | ||
4 | from argparse import ArgumentParser | 4 | from argparse import ArgumentParser |
5 | from natsort import natsorted | 5 | from natsort import natsorted |
6 | 6 | ||
7 | -sys.path.append(os.path.abspath(os.path.join('..'))) | 7 | +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
8 | + | ||
8 | 9 | ||
9 | import conf | 10 | import conf |
11 | +import utils | ||
10 | from inout import mmax, tei | 12 | from inout import mmax, tei |
11 | from inout.constants import INPUT_FORMATS | 13 | from inout.constants import INPUT_FORMATS |
12 | from resolvers import resolve | 14 | from resolvers import resolve |
@@ -27,22 +29,25 @@ def main(): | @@ -27,22 +29,25 @@ def main(): | ||
27 | if conf.NEURAL_MODEL_ARCHITECTURE == 'siamese': | 29 | if conf.NEURAL_MODEL_ARCHITECTURE == 'siamese': |
28 | resolver = conf.NEURAL_MODEL_ARCHITECTURE | 30 | resolver = conf.NEURAL_MODEL_ARCHITECTURE |
29 | eprint("Warning: Using %s resolver because of selected neural model architecture!" % | 31 | eprint("Warning: Using %s resolver because of selected neural model architecture!" % |
30 | - conf.NEURAL_MODEL_ARCHITECTURE) | ||
31 | - process_texts(args.input, args.output, args.format, resolver, args.threshold) | 32 | + conf.NEURAL_MODEL_ARCHITECTURE) |
33 | + process_texts(args.input, args.output, args.format, resolver, args.threshold, args.model) | ||
32 | 34 | ||
33 | 35 | ||
34 | def parse_arguments(): | 36 | def parse_arguments(): |
35 | parser = ArgumentParser(description='Corneferencer: coreference resolver using neural nets.') | 37 | parser = ArgumentParser(description='Corneferencer: coreference resolver using neural nets.') |
38 | + parser.add_argument('-f', '--format', type=str, action='store', | ||
39 | + dest='format', default=INPUT_FORMATS[0], | ||
40 | + help='input format; default: %s; possibilities: %s' | ||
41 | + % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS))) | ||
36 | parser.add_argument('-i', '--input', type=str, action='store', | 42 | parser.add_argument('-i', '--input', type=str, action='store', |
37 | dest='input', default='', | 43 | dest='input', default='', |
38 | help='input file or dir path') | 44 | help='input file or dir path') |
45 | + parser.add_argument('-m', '--model', type=str, action='store', | ||
46 | + dest='model', default='', | ||
47 | + help='neural model path; default: %s' % conf.NEURAL_MODEL_PATH) | ||
39 | parser.add_argument('-o', '--output', type=str, action='store', | 48 | parser.add_argument('-o', '--output', type=str, action='store', |
40 | dest='output', default='', | 49 | dest='output', default='', |
41 | help='output path; if not specified writes output to standard output') | 50 | help='output path; if not specified writes output to standard output') |
42 | - parser.add_argument('-f', '--format', type=str, action='store', | ||
43 | - dest='format', default=INPUT_FORMATS[0], | ||
44 | - help='input format; default: %s; possibilities: %s' | ||
45 | - % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS))) | ||
46 | parser.add_argument('-r', '--resolver', type=str, action='store', | 51 | parser.add_argument('-r', '--resolver', type=str, action='store', |
47 | dest='resolver', default=RESOLVERS[0], | 52 | dest='resolver', default=RESOLVERS[0], |
48 | help='resolve algorithm; default: %s; possibilities: %s' | 53 | help='resolve algorithm; default: %s; possibilities: %s' |
@@ -55,16 +60,17 @@ def parse_arguments(): | @@ -55,16 +60,17 @@ def parse_arguments(): | ||
55 | return args | 60 | return args |
56 | 61 | ||
57 | 62 | ||
58 | -def process_texts(inpath, outpath, informat, resolver, threshold): | 63 | +def process_texts(inpath, outpath, informat, resolver, threshold, model_path): |
64 | + model = utils.initialize_neural_model(conf.NEURAL_MODEL_ARCHITECTURE, conf.NUMBER_OF_FEATURES, model_path) | ||
59 | if os.path.isdir(inpath): | 65 | if os.path.isdir(inpath): |
60 | - process_directory(inpath, outpath, informat, resolver, threshold) | 66 | + process_directory(inpath, outpath, informat, resolver, threshold, model) |
61 | elif os.path.isfile(inpath): | 67 | elif os.path.isfile(inpath): |
62 | - process_text(inpath, outpath, informat, resolver, threshold) | 68 | + process_text(inpath, outpath, informat, resolver, threshold, model) |
63 | else: | 69 | else: |
64 | eprint("Error: Specified input does not exist!") | 70 | eprint("Error: Specified input does not exist!") |
65 | 71 | ||
66 | 72 | ||
67 | -def process_directory(inpath, outpath, informat, resolver, threshold): | 73 | +def process_directory(inpath, outpath, informat, resolver, threshold, model): |
68 | inpath = os.path.abspath(inpath) | 74 | inpath = os.path.abspath(inpath) |
69 | outpath = os.path.abspath(outpath) | 75 | outpath = os.path.abspath(outpath) |
70 | 76 | ||
@@ -75,38 +81,38 @@ def process_directory(inpath, outpath, informat, resolver, threshold): | @@ -75,38 +81,38 @@ def process_directory(inpath, outpath, informat, resolver, threshold): | ||
75 | textname = os.path.splitext(os.path.basename(filename))[0] | 81 | textname = os.path.splitext(os.path.basename(filename))[0] |
76 | textoutput = os.path.join(outpath, textname) | 82 | textoutput = os.path.join(outpath, textname) |
77 | textinput = os.path.join(inpath, filename) | 83 | textinput = os.path.join(inpath, filename) |
78 | - process_text(textinput, textoutput, informat, resolver, threshold) | 84 | + process_text(textinput, textoutput, informat, resolver, threshold, model) |
79 | 85 | ||
80 | 86 | ||
81 | -def process_text(inpath, outpath, informat, resolver, threshold): | 87 | +def process_text(inpath, outpath, informat, resolver, threshold, model): |
82 | basename = os.path.basename(inpath) | 88 | basename = os.path.basename(inpath) |
83 | if informat == 'mmax' and basename.endswith('.mmax'): | 89 | if informat == 'mmax' and basename.endswith('.mmax'): |
84 | print (basename) | 90 | print (basename) |
85 | text = mmax.read(inpath) | 91 | text = mmax.read(inpath) |
86 | if resolver == 'incremental': | 92 | if resolver == 'incremental': |
87 | - resolve.incremental(text, threshold) | 93 | + resolve.incremental(text, threshold, model) |
88 | elif resolver == 'entity_based': | 94 | elif resolver == 'entity_based': |
89 | - resolve.entity_based(text, threshold) | 95 | + resolve.entity_based(text, threshold, model) |
90 | elif resolver == 'closest': | 96 | elif resolver == 'closest': |
91 | - resolve.closest(text, threshold) | 97 | + resolve.closest(text, threshold, model) |
92 | elif resolver == 'siamese': | 98 | elif resolver == 'siamese': |
93 | - resolve.siamese(text, threshold) | 99 | + resolve.siamese(text, threshold, model) |
94 | elif resolver == 'all2all': | 100 | elif resolver == 'all2all': |
95 | - resolve.all2all(text, threshold) | 101 | + resolve.all2all(text, threshold, model) |
96 | mmax.write(inpath, outpath, text) | 102 | mmax.write(inpath, outpath, text) |
97 | elif informat == 'tei': | 103 | elif informat == 'tei': |
98 | print (basename) | 104 | print (basename) |
99 | text = tei.read(inpath) | 105 | text = tei.read(inpath) |
100 | if resolver == 'incremental': | 106 | if resolver == 'incremental': |
101 | - resolve.incremental(text, threshold) | 107 | + resolve.incremental(text, threshold, model) |
102 | elif resolver == 'entity_based': | 108 | elif resolver == 'entity_based': |
103 | - resolve.entity_based(text, threshold) | 109 | + resolve.entity_based(text, threshold, model) |
104 | elif resolver == 'closest': | 110 | elif resolver == 'closest': |
105 | - resolve.closest(text, threshold) | 111 | + resolve.closest(text, threshold, model) |
106 | elif resolver == 'siamese': | 112 | elif resolver == 'siamese': |
107 | - resolve.siamese(text, threshold) | 113 | + resolve.siamese(text, threshold, model) |
108 | elif resolver == 'all2all': | 114 | elif resolver == 'all2all': |
109 | - resolve.all2all(text, threshold) | 115 | + resolve.all2all(text, threshold, model) |
110 | tei.write(inpath, outpath, text) | 116 | tei.write(inpath, outpath, text) |
111 | 117 | ||
112 | 118 |
corneferencer/prepare_data.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | + | ||
3 | +import codecs | ||
4 | +import os | ||
5 | +import random | ||
6 | +import sys | ||
7 | + | ||
8 | +from itertools import combinations | ||
9 | +from argparse import ArgumentParser | ||
10 | +from natsort import natsorted | ||
11 | + | ||
12 | +sys.path.append(os.path.abspath(os.path.join('..'))) | ||
13 | + | ||
14 | +from inout import mmax, tei | ||
15 | +from inout.constants import INPUT_FORMATS | ||
16 | +from utils import eprint | ||
17 | +from corneferencer.resolvers import vectors | ||
18 | + | ||
19 | + | ||
20 | +POS_COUNT = 0 | ||
21 | +NEG_COUNT = 0 | ||
22 | + | ||
23 | + | ||
24 | +def main(): | ||
25 | + args = parse_arguments() | ||
26 | + if not args.input: | ||
27 | + eprint("Error: Input file(s) not specified!") | ||
28 | + elif args.format not in INPUT_FORMATS: | ||
29 | + eprint("Error: Unknown input file format!") | ||
30 | + else: | ||
31 | + process_texts(args.input, args.output, args.format, args.proportion) | ||
32 | + | ||
33 | + | ||
34 | +def parse_arguments(): | ||
35 | + parser = ArgumentParser(description='Corneferencer: data preparator for neural nets training.') | ||
36 | + parser.add_argument('-i', '--input', type=str, action='store', | ||
37 | + dest='input', default='', | ||
38 | + help='input dir path') | ||
39 | + parser.add_argument('-o', '--output', type=str, action='store', | ||
40 | + dest='output', default='', | ||
41 | + help='output path; if not specified writes output to standard output') | ||
42 | + parser.add_argument('-f', '--format', type=str, action='store', | ||
43 | + dest='format', default=INPUT_FORMATS[0], | ||
44 | + help='input format; default: %s; possibilities: %s' | ||
45 | + % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS))) | ||
46 | + parser.add_argument('-p', '--proportion', type=int, action='store', | ||
47 | + dest='proportion', default=5, | ||
48 | + help='negative examples proportion; default: 5') | ||
49 | + args = parser.parse_args() | ||
50 | + return args | ||
51 | + | ||
52 | + | ||
53 | +def process_texts(inpath, outpath, informat, proportion): | ||
54 | + if os.path.isdir(inpath): | ||
55 | + process_directory(inpath, outpath, informat, proportion) | ||
56 | + else: | ||
57 | + eprint("Error: Specified input does not exist or is not a directory!") | ||
58 | + | ||
59 | + | ||
60 | +def process_directory(inpath, outpath, informat, proportion): | ||
61 | + inpath = os.path.abspath(inpath) | ||
62 | + outpath = os.path.abspath(outpath) | ||
63 | + | ||
64 | + try: | ||
65 | + create_data_vectors(inpath, outpath, informat, proportion) | ||
66 | + finally: | ||
67 | + print ('Positives: ', POS_COUNT) | ||
68 | + print ('Negatives: ', NEG_COUNT) | ||
69 | + | ||
70 | + | ||
71 | +def create_data_vectors(inpath, outpath, informat, proportion): | ||
72 | + features_file = codecs.open(outpath, 'w', 'utf-8') | ||
73 | + | ||
74 | + files = os.listdir(inpath) | ||
75 | + files = natsorted(files) | ||
76 | + | ||
77 | + for filename in files: | ||
78 | + textname = os.path.splitext(os.path.basename(filename))[0] | ||
79 | + textinput = os.path.join(inpath, filename) | ||
80 | + | ||
81 | + print ('Processing text: %s' % textname) | ||
82 | + text = None | ||
83 | + if informat == 'mmax' and filename.endswith('.mmax'): | ||
84 | + text = mmax.read(textinput, False) | ||
85 | + elif informat == 'tei': | ||
86 | + text = tei.read(textinput, False) | ||
87 | + | ||
88 | + positives, negatives = diff_mentions(text, proportion) | ||
89 | + write_features(features_file, positives, negatives) | ||
90 | + | ||
91 | + | ||
92 | +def diff_mentions(text, proportion): | ||
93 | + sets = text.get_sets() | ||
94 | + all_mentions = text.get_mentions() | ||
95 | + positives = get_positives(sets) | ||
96 | + positives, negatives = get_negatives_and_update_positives(all_mentions, positives, proportion) | ||
97 | + return positives, negatives | ||
98 | + | ||
99 | + | ||
100 | +def get_positives(sets): | ||
101 | + positives = [] | ||
102 | + for set_id in sets: | ||
103 | + coref_set = sets[set_id] | ||
104 | + positives.extend(list(combinations(coref_set, 2))) | ||
105 | + return positives | ||
106 | + | ||
107 | + | ||
108 | +def get_negatives_and_update_positives(all_mentions, positives, proportion): | ||
109 | + all_pairs = list(combinations(all_mentions, 2)) | ||
110 | + | ||
111 | + all_pairs = set(all_pairs) | ||
112 | + negatives = [pair for pair in all_pairs if pair not in positives] | ||
113 | + samples_count = proportion * len(positives) | ||
114 | + if samples_count > len(negatives): | ||
115 | + samples_count = len(negatives) | ||
116 | + if proportion == 1: | ||
117 | + positives = random.sample(set(positives), samples_count) | ||
118 | + print (u'Więcej przypadków pozytywnych niż negatywnych!') | ||
119 | + negatives = random.sample(set(negatives), samples_count) | ||
120 | + return positives, negatives | ||
121 | + | ||
122 | + | ||
123 | +def write_features(features_file, positives, negatives): | ||
124 | + global POS_COUNT | ||
125 | + POS_COUNT += len(positives) | ||
126 | + for pair in positives: | ||
127 | + vector = vectors.get_pair_vector(pair[0], pair[1]) | ||
128 | + vector.append(1.0) | ||
129 | + features_file.write(u'%s\n' % u'\t'.join([str(feature) for feature in vector])) | ||
130 | + | ||
131 | + global NEG_COUNT | ||
132 | + NEG_COUNT += len(negatives) | ||
133 | + for pair in negatives: | ||
134 | + vector = vectors.get_pair_vector(pair[0], pair[1]) | ||
135 | + vector.append(0.0) | ||
136 | + features_file.write(u'%s\n' % u'\t'.join([str(feature) for feature in vector])) | ||
137 | + | ||
138 | + | ||
139 | +if __name__ == '__main__': | ||
140 | + main() |
corneferencer/resolvers/features.py
@@ -72,97 +72,97 @@ def sentence_vec(mention): | @@ -72,97 +72,97 @@ def sentence_vec(mention): | ||
72 | 72 | ||
73 | 73 | ||
74 | def mention_type(mention): | 74 | def mention_type(mention): |
75 | - type_vec = [0] * 4 | 75 | + type_vec = [0.0] * 4 |
76 | if mention.head is None: | 76 | if mention.head is None: |
77 | - type_vec[3] = 1 | 77 | + type_vec[3] = 1.0 |
78 | elif mention.head['ctag'] in constants.NOUN_TAGS: | 78 | elif mention.head['ctag'] in constants.NOUN_TAGS: |
79 | - type_vec[0] = 1 | 79 | + type_vec[0] = 1.0 |
80 | elif mention.head['ctag'] in constants.PPRON_TAGS: | 80 | elif mention.head['ctag'] in constants.PPRON_TAGS: |
81 | - type_vec[1] = 1 | 81 | + type_vec[1] = 1.0 |
82 | elif mention.head['ctag'] in constants.ZERO_TAGS: | 82 | elif mention.head['ctag'] in constants.ZERO_TAGS: |
83 | - type_vec[2] = 1 | 83 | + type_vec[2] = 1.0 |
84 | else: | 84 | else: |
85 | - type_vec[3] = 1 | 85 | + type_vec[3] = 1.0 |
86 | return type_vec | 86 | return type_vec |
87 | 87 | ||
88 | 88 | ||
89 | def is_first_second_person(mention): | 89 | def is_first_second_person(mention): |
90 | if mention.head is None: | 90 | if mention.head is None: |
91 | - return 0 | 91 | + return 0.0 |
92 | if mention.head['person'] in constants.FIRST_SECOND_PERSON: | 92 | if mention.head['person'] in constants.FIRST_SECOND_PERSON: |
93 | - return 1 | ||
94 | - return 0 | 93 | + return 1.0 |
94 | + return 0.0 | ||
95 | 95 | ||
96 | 96 | ||
97 | def is_demonstrative(mention): | 97 | def is_demonstrative(mention): |
98 | if mention.words[0]['base'].lower() in constants.INDICATIVE_PRONS_BASES: | 98 | if mention.words[0]['base'].lower() in constants.INDICATIVE_PRONS_BASES: |
99 | - return 1 | ||
100 | - return 0 | 99 | + return 1.0 |
100 | + return 0.0 | ||
101 | 101 | ||
102 | 102 | ||
103 | def is_demonstrative_nominal(mention): | 103 | def is_demonstrative_nominal(mention): |
104 | if mention.head is None: | 104 | if mention.head is None: |
105 | - return 0 | 105 | + return 0.0 |
106 | if is_demonstrative(mention) and mention.head['ctag'] in constants.NOUN_TAGS: | 106 | if is_demonstrative(mention) and mention.head['ctag'] in constants.NOUN_TAGS: |
107 | - return 1 | ||
108 | - return 0 | 107 | + return 1.0 |
108 | + return 0.0 | ||
109 | 109 | ||
110 | 110 | ||
111 | def is_demonstrative_pronoun(mention): | 111 | def is_demonstrative_pronoun(mention): |
112 | if mention.head is None: | 112 | if mention.head is None: |
113 | - return 0 | 113 | + return 0.0 |
114 | if (is_demonstrative(mention) and | 114 | if (is_demonstrative(mention) and |
115 | (mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS)): | 115 | (mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS)): |
116 | - return 1 | ||
117 | - return 0 | 116 | + return 1.0 |
117 | + return 0.0 | ||
118 | 118 | ||
119 | 119 | ||
120 | def is_refl_pronoun(mention): | 120 | def is_refl_pronoun(mention): |
121 | if mention.head is None: | 121 | if mention.head is None: |
122 | - return 0 | 122 | + return 0.0 |
123 | if mention.head['ctag'] in constants.SIEBIE_TAGS: | 123 | if mention.head['ctag'] in constants.SIEBIE_TAGS: |
124 | - return 1 | ||
125 | - return 0 | 124 | + return 1.0 |
125 | + return 0.0 | ||
126 | 126 | ||
127 | 127 | ||
128 | def is_first_in_sentence(mention): | 128 | def is_first_in_sentence(mention): |
129 | if mention.first_in_sentence: | 129 | if mention.first_in_sentence: |
130 | - return 1 | ||
131 | - return 0 | 130 | + return 1.0 |
131 | + return 0.0 | ||
132 | 132 | ||
133 | 133 | ||
134 | def is_zero_or_pronoun(mention): | 134 | def is_zero_or_pronoun(mention): |
135 | if mention.head is None: | 135 | if mention.head is None: |
136 | - return 0 | 136 | + return 0.0 |
137 | if mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS: | 137 | if mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS: |
138 | - return 1 | ||
139 | - return 0 | 138 | + return 1.0 |
139 | + return 0.0 | ||
140 | 140 | ||
141 | 141 | ||
142 | def head_contains_digit(mention): | 142 | def head_contains_digit(mention): |
143 | _digits = re.compile('\d') | 143 | _digits = re.compile('\d') |
144 | if _digits.search(mention.head_orth): | 144 | if _digits.search(mention.head_orth): |
145 | - return 1 | ||
146 | - return 0 | 145 | + return 1.0 |
146 | + return 0.0 | ||
147 | 147 | ||
148 | 148 | ||
149 | def mention_contains_digit(mention): | 149 | def mention_contains_digit(mention): |
150 | _digits = re.compile('\d') | 150 | _digits = re.compile('\d') |
151 | if _digits.search(mention.text): | 151 | if _digits.search(mention.text): |
152 | - return 1 | ||
153 | - return 0 | 152 | + return 1.0 |
153 | + return 0.0 | ||
154 | 154 | ||
155 | 155 | ||
156 | def contains_letter(mention): | 156 | def contains_letter(mention): |
157 | if any(c.isalpha() for c in mention.text): | 157 | if any(c.isalpha() for c in mention.text): |
158 | - return 1 | ||
159 | - return 0 | 158 | + return 1.0 |
159 | + return 0.0 | ||
160 | 160 | ||
161 | 161 | ||
162 | def post_modified(mention): | 162 | def post_modified(mention): |
163 | if mention.head_orth != mention.words[-1]['orth']: | 163 | if mention.head_orth != mention.words[-1]['orth']: |
164 | - return 1 | ||
165 | - return 0 | 164 | + return 1.0 |
165 | + return 0.0 | ||
166 | 166 | ||
167 | 167 | ||
168 | # pair features | 168 | # pair features |
@@ -171,20 +171,20 @@ def distances_vec(ante, ana): | @@ -171,20 +171,20 @@ def distances_vec(ante, ana): | ||
171 | 171 | ||
172 | mnts_intersect = pair_intersect(ante, ana) | 172 | mnts_intersect = pair_intersect(ante, ana) |
173 | 173 | ||
174 | - words_dist = [0] * 11 | 174 | + words_dist = [0.0] * 11 |
175 | words_bucket = 0 | 175 | words_bucket = 0 |
176 | - if mnts_intersect != 1: | 176 | + if mnts_intersect != 1.0: |
177 | words_bucket = get_distance_bucket(ana.start_in_words - ante.end_in_words) | 177 | words_bucket = get_distance_bucket(ana.start_in_words - ante.end_in_words) |
178 | - words_dist[words_bucket] = 1 | 178 | + words_dist[words_bucket] = 1.0 |
179 | vec.extend(words_dist) | 179 | vec.extend(words_dist) |
180 | 180 | ||
181 | - mentions_dist = [0] * 11 | 181 | + mentions_dist = [0.0] * 11 |
182 | mentions_bucket = 0 | 182 | mentions_bucket = 0 |
183 | - if mnts_intersect != 1: | 183 | + if mnts_intersect != 1.0: |
184 | mentions_bucket = get_distance_bucket(ana.position_in_mentions - ante.position_in_mentions) | 184 | mentions_bucket = get_distance_bucket(ana.position_in_mentions - ante.position_in_mentions) |
185 | if words_bucket == 10: | 185 | if words_bucket == 10: |
186 | mentions_bucket = 10 | 186 | mentions_bucket = 10 |
187 | - mentions_dist[mentions_bucket] = 1 | 187 | + mentions_dist[mentions_bucket] = 1.0 |
188 | vec.extend(mentions_dist) | 188 | vec.extend(mentions_dist) |
189 | 189 | ||
190 | vec.append(mnts_intersect) | 190 | vec.append(mnts_intersect) |
@@ -196,45 +196,45 @@ def pair_intersect(ante, ana): | @@ -196,45 +196,45 @@ def pair_intersect(ante, ana): | ||
196 | for ante_word in ante.words: | 196 | for ante_word in ante.words: |
197 | for ana_word in ana.words: | 197 | for ana_word in ana.words: |
198 | if ana_word['id'] == ante_word['id']: | 198 | if ana_word['id'] == ante_word['id']: |
199 | - return 1 | ||
200 | - return 0 | 199 | + return 1.0 |
200 | + return 0.0 | ||
201 | 201 | ||
202 | 202 | ||
203 | def head_match(ante, ana): | 203 | def head_match(ante, ana): |
204 | if ante.head_orth.lower() == ana.head_orth.lower(): | 204 | if ante.head_orth.lower() == ana.head_orth.lower(): |
205 | - return 1 | ||
206 | - return 0 | 205 | + return 1.0 |
206 | + return 0.0 | ||
207 | 207 | ||
208 | 208 | ||
209 | def exact_match(ante, ana): | 209 | def exact_match(ante, ana): |
210 | if ante.text.lower() == ana.text.lower(): | 210 | if ante.text.lower() == ana.text.lower(): |
211 | - return 1 | ||
212 | - return 0 | 211 | + return 1.0 |
212 | + return 0.0 | ||
213 | 213 | ||
214 | 214 | ||
215 | def base_match(ante, ana): | 215 | def base_match(ante, ana): |
216 | if ante.lemmatized_text.lower() == ana.lemmatized_text.lower(): | 216 | if ante.lemmatized_text.lower() == ana.lemmatized_text.lower(): |
217 | - return 1 | ||
218 | - return 0 | 217 | + return 1.0 |
218 | + return 0.0 | ||
219 | 219 | ||
220 | 220 | ||
221 | def ante_contains_rarest_from_ana(ante, ana): | 221 | def ante_contains_rarest_from_ana(ante, ana): |
222 | ana_rarest = ana.rarest | 222 | ana_rarest = ana.rarest |
223 | for word in ante.words: | 223 | for word in ante.words: |
224 | if word['base'] == ana_rarest['base']: | 224 | if word['base'] == ana_rarest['base']: |
225 | - return 1 | ||
226 | - return 0 | 225 | + return 1.0 |
226 | + return 0.0 | ||
227 | 227 | ||
228 | 228 | ||
229 | def agreement(ante, ana, tag_name): | 229 | def agreement(ante, ana, tag_name): |
230 | - agr_vec = [0] * 3 | 230 | + agr_vec = [0.0] * 3 |
231 | if (ante.head is None or ana.head is None or | 231 | if (ante.head is None or ana.head is None or |
232 | ante.head[tag_name] == 'unk' or ana.head[tag_name] == 'unk'): | 232 | ante.head[tag_name] == 'unk' or ana.head[tag_name] == 'unk'): |
233 | - agr_vec[2] = 1 | 233 | + agr_vec[2] = 1.0 |
234 | elif ante.head[tag_name] == ana.head[tag_name]: | 234 | elif ante.head[tag_name] == ana.head[tag_name]: |
235 | - agr_vec[0] = 1 | 235 | + agr_vec[0] = 1.0 |
236 | else: | 236 | else: |
237 | - agr_vec[1] = 1 | 237 | + agr_vec[1] = 1.0 |
238 | return agr_vec | 238 | return agr_vec |
239 | 239 | ||
240 | 240 | ||
@@ -243,72 +243,72 @@ def is_acronym(ante, ana): | @@ -243,72 +243,72 @@ def is_acronym(ante, ana): | ||
243 | return check_one_way_acronym(ana.text, ante.text) | 243 | return check_one_way_acronym(ana.text, ante.text) |
244 | if ante.text.upper() == ante.text: | 244 | if ante.text.upper() == ante.text: |
245 | return check_one_way_acronym(ante.text, ana.text) | 245 | return check_one_way_acronym(ante.text, ana.text) |
246 | - return 0 | 246 | + return 0.0 |
247 | 247 | ||
248 | 248 | ||
249 | def same_sentence(ante, ana): | 249 | def same_sentence(ante, ana): |
250 | if ante.sentence_id == ana.sentence_id: | 250 | if ante.sentence_id == ana.sentence_id: |
251 | - return 1 | ||
252 | - return 0 | 251 | + return 1.0 |
252 | + return 0.0 | ||
253 | 253 | ||
254 | 254 | ||
255 | def neighbouring_sentence(ante, ana): | 255 | def neighbouring_sentence(ante, ana): |
256 | if ana.sentence_id - ante.sentence_id == 1: | 256 | if ana.sentence_id - ante.sentence_id == 1: |
257 | - return 1 | ||
258 | - return 0 | 257 | + return 1.0 |
258 | + return 0.0 | ||
259 | 259 | ||
260 | 260 | ||
261 | def cousin_sentence(ante, ana): | 261 | def cousin_sentence(ante, ana): |
262 | if ana.sentence_id - ante.sentence_id == 2: | 262 | if ana.sentence_id - ante.sentence_id == 2: |
263 | - return 1 | ||
264 | - return 0 | 263 | + return 1.0 |
264 | + return 0.0 | ||
265 | 265 | ||
266 | 266 | ||
267 | def distant_sentence(ante, ana): | 267 | def distant_sentence(ante, ana): |
268 | if ana.sentence_id - ante.sentence_id > 2: | 268 | if ana.sentence_id - ante.sentence_id > 2: |
269 | - return 1 | ||
270 | - return 0 | 269 | + return 1.0 |
270 | + return 0.0 | ||
271 | 271 | ||
272 | 272 | ||
273 | def same_paragraph(ante, ana): | 273 | def same_paragraph(ante, ana): |
274 | if ante.paragraph_id == ana.paragraph_id: | 274 | if ante.paragraph_id == ana.paragraph_id: |
275 | - return 1 | ||
276 | - return 0 | 275 | + return 1.0 |
276 | + return 0.0 | ||
277 | 277 | ||
278 | 278 | ||
279 | def flat_gender_agreement(ante, ana): | 279 | def flat_gender_agreement(ante, ana): |
280 | - agr_vec = [0] * 3 | 280 | + agr_vec = [0.0] * 3 |
281 | if (ante.head is None or ana.head is None or | 281 | if (ante.head is None or ana.head is None or |
282 | ante.head['gender'] == 'unk' or ana.head['gender'] == 'unk'): | 282 | ante.head['gender'] == 'unk' or ana.head['gender'] == 'unk'): |
283 | - agr_vec[2] = 1 | 283 | + agr_vec[2] = 1.0 |
284 | elif (ante.head['gender'] == ana.head['gender'] or | 284 | elif (ante.head['gender'] == ana.head['gender'] or |
285 | (ante.head['gender'] in constants.MASCULINE_TAGS and ana.head['gender'] in constants.MASCULINE_TAGS)): | 285 | (ante.head['gender'] in constants.MASCULINE_TAGS and ana.head['gender'] in constants.MASCULINE_TAGS)): |
286 | - agr_vec[0] = 1 | 286 | + agr_vec[0] = 1.0 |
287 | else: | 287 | else: |
288 | - agr_vec[1] = 1 | 288 | + agr_vec[1] = 1.0 |
289 | return agr_vec | 289 | return agr_vec |
290 | 290 | ||
291 | 291 | ||
292 | def left_match(ante, ana): | 292 | def left_match(ante, ana): |
293 | if (ante.text.lower().startswith(ana.text.lower()) or | 293 | if (ante.text.lower().startswith(ana.text.lower()) or |
294 | ana.text.lower().startswith(ante.text.lower())): | 294 | ana.text.lower().startswith(ante.text.lower())): |
295 | - return 1 | ||
296 | - return 0 | 295 | + return 1.0 |
296 | + return 0.0 | ||
297 | 297 | ||
298 | 298 | ||
299 | def right_match(ante, ana): | 299 | def right_match(ante, ana): |
300 | if (ante.text.lower().endswith(ana.text.lower()) or | 300 | if (ante.text.lower().endswith(ana.text.lower()) or |
301 | ana.text.lower().endswith(ante.text.lower())): | 301 | ana.text.lower().endswith(ante.text.lower())): |
302 | - return 1 | ||
303 | - return 0 | 302 | + return 1.0 |
303 | + return 0.0 | ||
304 | 304 | ||
305 | 305 | ||
306 | def abbrev2(ante, ana): | 306 | def abbrev2(ante, ana): |
307 | ante_abbrev = get_abbrev(ante) | 307 | ante_abbrev = get_abbrev(ante) |
308 | ana_abbrev = get_abbrev(ana) | 308 | ana_abbrev = get_abbrev(ana) |
309 | if ante.head_orth == ana_abbrev or ana.head_orth == ante_abbrev: | 309 | if ante.head_orth == ana_abbrev or ana.head_orth == ante_abbrev: |
310 | - return 1 | ||
311 | - return 0 | 310 | + return 1.0 |
311 | + return 0.0 | ||
312 | 312 | ||
313 | 313 | ||
314 | def string_kernel(ante, ana): | 314 | def string_kernel(ante, ana): |
@@ -326,7 +326,7 @@ def head_string_kernel(ante, ana): | @@ -326,7 +326,7 @@ def head_string_kernel(ante, ana): | ||
326 | def wordnet_synonyms(ante, ana): | 326 | def wordnet_synonyms(ante, ana): |
327 | ante_synonyms = set() | 327 | ante_synonyms = set() |
328 | if ante.head is None or ana.head is None: | 328 | if ante.head is None or ana.head is None: |
329 | - return 0 | 329 | + return 0.0 |
330 | 330 | ||
331 | if ante.head['base'] in conf.LEMMA2SYNONYMS: | 331 | if ante.head['base'] in conf.LEMMA2SYNONYMS: |
332 | ante_synonyms = conf.LEMMA2SYNONYMS[ante.head['base']] | 332 | ante_synonyms = conf.LEMMA2SYNONYMS[ante.head['base']] |
@@ -336,13 +336,13 @@ def wordnet_synonyms(ante, ana): | @@ -336,13 +336,13 @@ def wordnet_synonyms(ante, ana): | ||
336 | ana_synonyms = conf.LEMMA2SYNONYMS[ana.head['base']] | 336 | ana_synonyms = conf.LEMMA2SYNONYMS[ana.head['base']] |
337 | 337 | ||
338 | if ana.head['base'] in ante_synonyms or ante.head['base'] in ana_synonyms: | 338 | if ana.head['base'] in ante_synonyms or ante.head['base'] in ana_synonyms: |
339 | - return 1 | ||
340 | - return 0 | 339 | + return 1.0 |
340 | + return 0.0 | ||
341 | 341 | ||
342 | 342 | ||
343 | def wordnet_ana_is_hypernym(ante, ana): | 343 | def wordnet_ana_is_hypernym(ante, ana): |
344 | if ante.head is None or ana.head is None: | 344 | if ante.head is None or ana.head is None: |
345 | - return 0 | 345 | + return 0.0 |
346 | 346 | ||
347 | ante_hypernyms = set() | 347 | ante_hypernyms = set() |
348 | if ante.head['base'] in conf.LEMMA2HYPERNYMS: | 348 | if ante.head['base'] in conf.LEMMA2HYPERNYMS: |
@@ -353,16 +353,16 @@ def wordnet_ana_is_hypernym(ante, ana): | @@ -353,16 +353,16 @@ def wordnet_ana_is_hypernym(ante, ana): | ||
353 | ana_hypernyms = conf.LEMMA2HYPERNYMS[ana.head['base']] | 353 | ana_hypernyms = conf.LEMMA2HYPERNYMS[ana.head['base']] |
354 | 354 | ||
355 | if not ante_hypernyms or not ana_hypernyms: | 355 | if not ante_hypernyms or not ana_hypernyms: |
356 | - return 0 | 356 | + return 0.0 |
357 | 357 | ||
358 | if ana.head['base'] in ante_hypernyms: | 358 | if ana.head['base'] in ante_hypernyms: |
359 | - return 1 | ||
360 | - return 0 | 359 | + return 1.0 |
360 | + return 0.0 | ||
361 | 361 | ||
362 | 362 | ||
363 | def wordnet_ante_is_hypernym(ante, ana): | 363 | def wordnet_ante_is_hypernym(ante, ana): |
364 | if ante.head is None or ana.head is None: | 364 | if ante.head is None or ana.head is None: |
365 | - return 0 | 365 | + return 0.0 |
366 | 366 | ||
367 | ana_hypernyms = set() | 367 | ana_hypernyms = set() |
368 | if ana.head['base'] in conf.LEMMA2HYPERNYMS: | 368 | if ana.head['base'] in conf.LEMMA2HYPERNYMS: |
@@ -373,18 +373,18 @@ def wordnet_ante_is_hypernym(ante, ana): | @@ -373,18 +373,18 @@ def wordnet_ante_is_hypernym(ante, ana): | ||
373 | ante_hypernyms = conf.LEMMA2HYPERNYMS[ante.head['base']] | 373 | ante_hypernyms = conf.LEMMA2HYPERNYMS[ante.head['base']] |
374 | 374 | ||
375 | if not ante_hypernyms or not ana_hypernyms: | 375 | if not ante_hypernyms or not ana_hypernyms: |
376 | - return 0 | 376 | + return 0.0 |
377 | 377 | ||
378 | if ante.head['base'] in ana_hypernyms: | 378 | if ante.head['base'] in ana_hypernyms: |
379 | - return 1 | ||
380 | - return 0 | 379 | + return 1.0 |
380 | + return 0.0 | ||
381 | 381 | ||
382 | 382 | ||
383 | def wikipedia_link(ante, ana): | 383 | def wikipedia_link(ante, ana): |
384 | ante_base = ante.lemmatized_text.lower() | 384 | ante_base = ante.lemmatized_text.lower() |
385 | ana_base = ana.lemmatized_text.lower() | 385 | ana_base = ana.lemmatized_text.lower() |
386 | if ante_base == ana_base: | 386 | if ante_base == ana_base: |
387 | - return 1 | 387 | + return 1.0 |
388 | 388 | ||
389 | ante_links = set() | 389 | ante_links = set() |
390 | if ante_base in conf.TITLE2LINKS: | 390 | if ante_base in conf.TITLE2LINKS: |
@@ -395,16 +395,16 @@ def wikipedia_link(ante, ana): | @@ -395,16 +395,16 @@ def wikipedia_link(ante, ana): | ||
395 | ana_links = conf.TITLE2LINKS[ana_base] | 395 | ana_links = conf.TITLE2LINKS[ana_base] |
396 | 396 | ||
397 | if ana_base in ante_links or ante_base in ana_links: | 397 | if ana_base in ante_links or ante_base in ana_links: |
398 | - return 1 | 398 | + return 1.0 |
399 | 399 | ||
400 | - return 0 | 400 | + return 0.0 |
401 | 401 | ||
402 | 402 | ||
403 | def wikipedia_mutual_link(ante, ana): | 403 | def wikipedia_mutual_link(ante, ana): |
404 | ante_base = ante.lemmatized_text.lower() | 404 | ante_base = ante.lemmatized_text.lower() |
405 | ana_base = ana.lemmatized_text.lower() | 405 | ana_base = ana.lemmatized_text.lower() |
406 | if ante_base == ana_base: | 406 | if ante_base == ana_base: |
407 | - return 1 | 407 | + return 1.0 |
408 | 408 | ||
409 | ante_links = set() | 409 | ante_links = set() |
410 | if ante_base in conf.TITLE2LINKS: | 410 | if ante_base in conf.TITLE2LINKS: |
@@ -415,52 +415,52 @@ def wikipedia_mutual_link(ante, ana): | @@ -415,52 +415,52 @@ def wikipedia_mutual_link(ante, ana): | ||
415 | ana_links = conf.TITLE2LINKS[ana_base] | 415 | ana_links = conf.TITLE2LINKS[ana_base] |
416 | 416 | ||
417 | if ana_base in ante_links and ante_base in ana_links: | 417 | if ana_base in ante_links and ante_base in ana_links: |
418 | - return 1 | 418 | + return 1.0 |
419 | 419 | ||
420 | - return 0 | 420 | + return 0.0 |
421 | 421 | ||
422 | 422 | ||
423 | def wikipedia_redirect(ante, ana): | 423 | def wikipedia_redirect(ante, ana): |
424 | ante_base = ante.lemmatized_text.lower() | 424 | ante_base = ante.lemmatized_text.lower() |
425 | ana_base = ana.lemmatized_text.lower() | 425 | ana_base = ana.lemmatized_text.lower() |
426 | if ante_base == ana_base: | 426 | if ante_base == ana_base: |
427 | - return 1 | 427 | + return 1.0 |
428 | 428 | ||
429 | if ante_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ante_base] == ana_base: | 429 | if ante_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ante_base] == ana_base: |
430 | - return 1 | 430 | + return 1.0 |
431 | 431 | ||
432 | if ana_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ana_base] == ante_base: | 432 | if ana_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ana_base] == ante_base: |
433 | - return 1 | 433 | + return 1.0 |
434 | 434 | ||
435 | - return 0 | 435 | + return 0.0 |
436 | 436 | ||
437 | 437 | ||
438 | def samesent_anapron_antefirstinpar(ante, ana): | 438 | def samesent_anapron_antefirstinpar(ante, ana): |
439 | if same_sentence(ante, ana) and is_zero_or_pronoun(ana) and ante.first_in_paragraph: | 439 | if same_sentence(ante, ana) and is_zero_or_pronoun(ana) and ante.first_in_paragraph: |
440 | - return 1 | ||
441 | - return 0 | 440 | + return 1.0 |
441 | + return 0.0 | ||
442 | 442 | ||
443 | 443 | ||
444 | def samesent_antefirstinpar_personnumbermatch(ante, ana): | 444 | def samesent_antefirstinpar_personnumbermatch(ante, ana): |
445 | if (same_sentence(ante, ana) and ante.first_in_paragraph | 445 | if (same_sentence(ante, ana) and ante.first_in_paragraph |
446 | and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]): | 446 | and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]): |
447 | - return 1 | ||
448 | - return 0 | 447 | + return 1.0 |
448 | + return 0.0 | ||
449 | 449 | ||
450 | 450 | ||
451 | def adjsent_anapron_adjmen_personnumbermatch(ante, ana): | 451 | def adjsent_anapron_adjmen_personnumbermatch(ante, ana): |
452 | if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana) | 452 | if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana) |
453 | and ana.position_in_mentions - ante.position_in_mentions == 1 | 453 | and ana.position_in_mentions - ante.position_in_mentions == 1 |
454 | and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]): | 454 | and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]): |
455 | - return 1 | ||
456 | - return 0 | 455 | + return 1.0 |
456 | + return 0.0 | ||
457 | 457 | ||
458 | 458 | ||
459 | def adjsent_anapron_adjmen(ante, ana): | 459 | def adjsent_anapron_adjmen(ante, ana): |
460 | if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana) | 460 | if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana) |
461 | and ana.position_in_mentions - ante.position_in_mentions == 1): | 461 | and ana.position_in_mentions - ante.position_in_mentions == 1): |
462 | - return 1 | ||
463 | - return 0 | 462 | + return 1.0 |
463 | + return 0.0 | ||
464 | 464 | ||
465 | 465 | ||
466 | # supporting functions | 466 | # supporting functions |
@@ -523,8 +523,8 @@ def check_one_way_acronym(acronym, expression): | @@ -523,8 +523,8 @@ def check_one_way_acronym(acronym, expression): | ||
523 | if expr2: | 523 | if expr2: |
524 | initials += expr2[0].upper() | 524 | initials += expr2[0].upper() |
525 | if acronym == initials: | 525 | if acronym == initials: |
526 | - return 1 | ||
527 | - return 0 | 526 | + return 1.0 |
527 | + return 0.0 | ||
528 | 528 | ||
529 | 529 | ||
530 | def get_abbrev(mention): | 530 | def get_abbrev(mention): |
corneferencer/resolvers/resolve.py
1 | import numpy | 1 | import numpy |
2 | 2 | ||
3 | -from conf import NEURAL_MODEL | ||
4 | -from corneferencer.resolvers import features | ||
5 | -from corneferencer.resolvers.vectors import get_pair_features, get_pair_vector | 3 | +from corneferencer.resolvers import features, vectors |
6 | 4 | ||
7 | 5 | ||
8 | -def siamese(text, threshold): | 6 | +def siamese(text, threshold, neural_model): |
9 | last_set_id = 0 | 7 | last_set_id = 0 |
10 | for i, ana in enumerate(text.mentions): | 8 | for i, ana in enumerate(text.mentions): |
11 | if i > 0: | 9 | if i > 0: |
12 | for ante in reversed(text.mentions[:i]): | 10 | for ante in reversed(text.mentions[:i]): |
13 | if not features.pair_intersect(ante, ana): | 11 | if not features.pair_intersect(ante, ana): |
14 | - pair_features = get_pair_features(ante, ana) | 12 | + pair_features = vectors.get_pair_features(ante, ana) |
15 | 13 | ||
16 | ante_vec = [] | 14 | ante_vec = [] |
17 | ante_vec.extend(ante.features) | 15 | ante_vec.extend(ante.features) |
@@ -23,7 +21,7 @@ def siamese(text, threshold): | @@ -23,7 +21,7 @@ def siamese(text, threshold): | ||
23 | ana_vec.extend(pair_features) | 21 | ana_vec.extend(pair_features) |
24 | ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32) | 22 | ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32) |
25 | 23 | ||
26 | - prediction = NEURAL_MODEL.predict([ante_sample, ana_sample])[0] | 24 | + prediction = neural_model.predict([ante_sample, ana_sample])[0] |
27 | 25 | ||
28 | if prediction < threshold: | 26 | if prediction < threshold: |
29 | if ante.set: | 27 | if ante.set: |
@@ -37,7 +35,7 @@ def siamese(text, threshold): | @@ -37,7 +35,7 @@ def siamese(text, threshold): | ||
37 | 35 | ||
38 | 36 | ||
39 | # incremental resolve algorithm | 37 | # incremental resolve algorithm |
40 | -def incremental(text, threshold): | 38 | +def incremental(text, threshold, neural_model): |
41 | last_set_id = 0 | 39 | last_set_id = 0 |
42 | for i, ana in enumerate(text.mentions): | 40 | for i, ana in enumerate(text.mentions): |
43 | if i > 0: | 41 | if i > 0: |
@@ -45,9 +43,9 @@ def incremental(text, threshold): | @@ -45,9 +43,9 @@ def incremental(text, threshold): | ||
45 | best_ante = None | 43 | best_ante = None |
46 | for ante in text.mentions[:i]: | 44 | for ante in text.mentions[:i]: |
47 | if not features.pair_intersect(ante, ana): | 45 | if not features.pair_intersect(ante, ana): |
48 | - pair_vec = get_pair_vector(ante, ana) | 46 | + pair_vec = vectors.get_pair_vector(ante, ana) |
49 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) | 47 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) |
50 | - prediction = NEURAL_MODEL.predict(sample)[0] | 48 | + prediction = neural_model.predict(sample)[0] |
51 | if prediction > threshold and prediction >= best_prediction: | 49 | if prediction > threshold and prediction >= best_prediction: |
52 | best_prediction = prediction | 50 | best_prediction = prediction |
53 | best_ante = ante | 51 | best_ante = ante |
@@ -62,86 +60,21 @@ def incremental(text, threshold): | @@ -62,86 +60,21 @@ def incremental(text, threshold): | ||
62 | 60 | ||
63 | 61 | ||
64 | # all2all resolve algorithm | 62 | # all2all resolve algorithm |
65 | -def all2all_debug(text, threshold): | ||
66 | - last_set_id = 0 | ||
67 | - for pos1, mnt1 in enumerate(text.mentions): | ||
68 | - best_prediction = 0.0 | ||
69 | - best_link = None | ||
70 | - for pos2, mnt2 in enumerate(text.mentions): | ||
71 | - if (mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2): | ||
72 | - ante = mnt1 | ||
73 | - ana = mnt2 | ||
74 | - if pos2 < pos1: | ||
75 | - ante = mnt2 | ||
76 | - ana = mnt1 | ||
77 | - pair_vec = get_pair_vector(ante, ana) | ||
78 | - sample = numpy.asarray([pair_vec], dtype=numpy.float32) | ||
79 | - prediction = NEURAL_MODEL.predict(sample)[0] | ||
80 | - if prediction > threshold and prediction > best_prediction: | ||
81 | - best_prediction = prediction | ||
82 | - best_link = mnt2 | ||
83 | - if best_link is not None: | ||
84 | - if best_link.set and not mnt1.set: | ||
85 | - mnt1.set = best_link.set | ||
86 | - elif best_link.set and mnt1.set: | ||
87 | - text.merge_sets(best_link.set, mnt1.set) | ||
88 | - elif not best_link.set and not mnt1.set: | ||
89 | - str_set_id = 'set_%d' % last_set_id | ||
90 | - best_link.set = str_set_id | ||
91 | - mnt1.set = str_set_id | ||
92 | - last_set_id += 1 | ||
93 | - | ||
94 | - | ||
95 | -def all2all_v1(text, threshold): | ||
96 | - last_set_id = 0 | ||
97 | - for pos1, mnt1 in enumerate(text.mentions): | ||
98 | - best_prediction = 0.0 | ||
99 | - best_link = None | ||
100 | - for pos2, mnt2 in enumerate(text.mentions): | ||
101 | - if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set) | ||
102 | - and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): | ||
103 | - ante = mnt1 | ||
104 | - ana = mnt2 | ||
105 | - if pos2 < pos1: | ||
106 | - ante = mnt2 | ||
107 | - ana = mnt1 | ||
108 | - pair_vec = get_pair_vector(ante, ana) | ||
109 | - sample = numpy.asarray([pair_vec], dtype=numpy.float32) | ||
110 | - prediction = NEURAL_MODEL.predict(sample)[0] | ||
111 | - if prediction > threshold and prediction > best_prediction: | ||
112 | - best_prediction = prediction | ||
113 | - best_link = mnt2 | ||
114 | - if best_link is not None: | ||
115 | - if best_link.set and not mnt1.set: | ||
116 | - mnt1.set = best_link.set | ||
117 | - elif not best_link.set and mnt1.set: | ||
118 | - best_link.set = mnt1.set | ||
119 | - elif best_link.set and mnt1.set: | ||
120 | - text.merge_sets(best_link.set, mnt1.set) | ||
121 | - elif not best_link.set and not mnt1.set: | ||
122 | - str_set_id = 'set_%d' % last_set_id | ||
123 | - best_link.set = str_set_id | ||
124 | - mnt1.set = str_set_id | ||
125 | - last_set_id += 1 | ||
126 | - | ||
127 | - | ||
128 | -def all2all(text, threshold): | 63 | +def all2all(text, threshold, neural_model): |
129 | last_set_id = 0 | 64 | last_set_id = 0 |
130 | sets = text.get_sets() | 65 | sets = text.get_sets() |
131 | for pos1, mnt1 in enumerate(text.mentions): | 66 | for pos1, mnt1 in enumerate(text.mentions): |
132 | best_prediction = 0.0 | 67 | best_prediction = 0.0 |
133 | best_link = None | 68 | best_link = None |
134 | for pos2, mnt2 in enumerate(text.mentions): | 69 | for pos2, mnt2 in enumerate(text.mentions): |
135 | - if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set) | ||
136 | - and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): | 70 | + if (pos2 > pos1 and |
71 | + (mnt1.set != mnt2.set or not mnt1.set or not mnt2.set) | ||
72 | + and not features.pair_intersect(mnt1, mnt2)): | ||
137 | ante = mnt1 | 73 | ante = mnt1 |
138 | ana = mnt2 | 74 | ana = mnt2 |
139 | - if pos2 < pos1: | ||
140 | - ante = mnt2 | ||
141 | - ana = mnt1 | ||
142 | - pair_vec = get_pair_vector(ante, ana) | 75 | + pair_vec = vectors.get_pair_vector(ante, ana) |
143 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) | 76 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) |
144 | - prediction = NEURAL_MODEL.predict(sample)[0] | 77 | + prediction = neural_model.predict(sample)[0] |
145 | if prediction > threshold and prediction > best_prediction: | 78 | if prediction > threshold and prediction > best_prediction: |
146 | best_prediction = prediction | 79 | best_prediction = prediction |
147 | best_link = mnt2 | 80 | best_link = mnt2 |
@@ -163,12 +96,12 @@ def all2all(text, threshold): | @@ -163,12 +96,12 @@ def all2all(text, threshold): | ||
163 | 96 | ||
164 | 97 | ||
165 | # entity based resolve algorithm | 98 | # entity based resolve algorithm |
166 | -def entity_based(text, threshold): | 99 | +def entity_based(text, threshold, neural_model): |
167 | sets = [] | 100 | sets = [] |
168 | last_set_id = 0 | 101 | last_set_id = 0 |
169 | for i, ana in enumerate(text.mentions): | 102 | for i, ana in enumerate(text.mentions): |
170 | if i > 0: | 103 | if i > 0: |
171 | - best_fit = get_best_set(sets, ana, threshold) | 104 | + best_fit = get_best_set(sets, ana, threshold, neural_model) |
172 | if best_fit is not None: | 105 | if best_fit is not None: |
173 | ana.set = best_fit['set_id'] | 106 | ana.set = best_fit['set_id'] |
174 | best_fit['mentions'].append(ana) | 107 | best_fit['mentions'].append(ana) |
@@ -188,25 +121,25 @@ def entity_based(text, threshold): | @@ -188,25 +121,25 @@ def entity_based(text, threshold): | ||
188 | remove_singletons(sets) | 121 | remove_singletons(sets) |
189 | 122 | ||
190 | 123 | ||
191 | -def get_best_set(sets, ana, threshold): | 124 | +def get_best_set(sets, ana, threshold, neural_model): |
192 | best_prediction = 0.0 | 125 | best_prediction = 0.0 |
193 | best_set = None | 126 | best_set = None |
194 | for s in sets: | 127 | for s in sets: |
195 | - accuracy = predict_set(s['mentions'], ana) | 128 | + accuracy = predict_set(s['mentions'], ana, neural_model) |
196 | if accuracy > threshold and accuracy >= best_prediction: | 129 | if accuracy > threshold and accuracy >= best_prediction: |
197 | best_prediction = accuracy | 130 | best_prediction = accuracy |
198 | best_set = s | 131 | best_set = s |
199 | return best_set | 132 | return best_set |
200 | 133 | ||
201 | 134 | ||
202 | -def predict_set(mentions, ana): | 135 | +def predict_set(mentions, ana, neural_model): |
203 | prediction_sum = 0.0 | 136 | prediction_sum = 0.0 |
204 | for mnt in mentions: | 137 | for mnt in mentions: |
205 | prediction = 0.0 | 138 | prediction = 0.0 |
206 | if not features.pair_intersect(mnt, ana): | 139 | if not features.pair_intersect(mnt, ana): |
207 | - pair_vec = get_pair_vector(mnt, ana) | 140 | + pair_vec = vectors.get_pair_vector(mnt, ana) |
208 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) | 141 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) |
209 | - prediction = NEURAL_MODEL.predict(sample)[0] | 142 | + prediction = neural_model.predict(sample)[0] |
210 | prediction_sum += prediction | 143 | prediction_sum += prediction |
211 | return prediction_sum / float(len(mentions)) | 144 | return prediction_sum / float(len(mentions)) |
212 | 145 | ||
@@ -218,15 +151,15 @@ def remove_singletons(sets): | @@ -218,15 +151,15 @@ def remove_singletons(sets): | ||
218 | 151 | ||
219 | 152 | ||
220 | # closest resolve algorithm | 153 | # closest resolve algorithm |
221 | -def closest(text, threshold): | 154 | +def closest(text, threshold, neural_model): |
222 | last_set_id = 0 | 155 | last_set_id = 0 |
223 | for i, ana in enumerate(text.mentions): | 156 | for i, ana in enumerate(text.mentions): |
224 | if i > 0: | 157 | if i > 0: |
225 | for ante in reversed(text.mentions[:i]): | 158 | for ante in reversed(text.mentions[:i]): |
226 | if not features.pair_intersect(ante, ana): | 159 | if not features.pair_intersect(ante, ana): |
227 | - pair_vec = get_pair_vector(ante, ana) | 160 | + pair_vec = vectors.get_pair_vector(ante, ana) |
228 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) | 161 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) |
229 | - prediction = NEURAL_MODEL.predict(sample)[0] | 162 | + prediction = neural_model.predict(sample)[0] |
230 | if prediction > threshold: | 163 | if prediction > threshold: |
231 | if ante.set: | 164 | if ante.set: |
232 | ana.set = ante.set | 165 | ana.set = ante.set |
corneferencer/utils.py
@@ -7,7 +7,6 @@ import javaobj | @@ -7,7 +7,6 @@ import javaobj | ||
7 | 7 | ||
8 | from keras.models import Sequential, Model | 8 | from keras.models import Sequential, Model |
9 | from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization, Lambda | 9 | from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization, Lambda |
10 | -from keras.optimizers import RMSprop, Adam | ||
11 | from keras import backend as K | 10 | from keras import backend as K |
12 | 11 | ||
13 | 12 |