Commit 4a097c4f123cc34ee07d870361bad7a817a3faeb
1 parent
a5beabe2
Corneferencer alpha version
Showing
15 changed files
with
782 additions
and
0 deletions
conf.py
0 → 100644
1 | +import os | ||
2 | + | ||
3 | +from gensim.models.word2vec import Word2Vec | ||
4 | + | ||
5 | +from corneferencer.utils import initialize_neural_model | ||
6 | + | ||
7 | + | ||
8 | +CONTEXT = 5 | ||
9 | +THRESHOLD = 0.5 | ||
10 | +RANDOM_WORD_VECTORS = True | ||
11 | +W2V_SIZE = 50 | ||
12 | +W2V_MODEL_NAME = 'w2v_allwiki_nkjpfull_50.model' | ||
13 | + | ||
14 | +NUMBER_OF_FEATURES = 1126 | ||
15 | +NEURAL_MODEL_NAME = 'weights_2017_05_10.h5' | ||
16 | + | ||
17 | + | ||
18 | +# do not change that | ||
19 | +W2V_MODEL_PATH = os.path.join(os.path.dirname(__file__), 'models', W2V_MODEL_NAME) | ||
20 | +W2V_MODEL = Word2Vec.load(W2V_MODEL_PATH) | ||
21 | + | ||
22 | +NEURAL_MODEL_PATH = os.path.join(os.path.dirname(__file__), 'models', NEURAL_MODEL_NAME) | ||
23 | +NEURAL_MODEL = initialize_neural_model(NUMBER_OF_FEATURES) |
corneferencer/entities.py
0 → 100644
1 | +from corneferencer.resolvers.vectors import get_mention_features | ||
2 | + | ||
3 | + | ||
4 | +class Text: | ||
5 | + | ||
6 | + def __init__(self, text_id): | ||
7 | + self.__id = text_id | ||
8 | + self.mentions = [] | ||
9 | + | ||
10 | + def get_mention_set(self, mnt_id): | ||
11 | + for mnt in self.mentions: | ||
12 | + if mnt.id == mnt_id: | ||
13 | + return mnt.set | ||
14 | + return None | ||
15 | + | ||
16 | + | ||
17 | +class Mention: | ||
18 | + | ||
19 | + def __init__(self, mnt_id, text, lemmatized_text, words, span, | ||
20 | + head_orth, head_base, dominant, node, prec_context, | ||
21 | + follow_context, sentence, position_in_mentions, | ||
22 | + start_in_words, end_in_words): | ||
23 | + self.id = mnt_id | ||
24 | + self.set = '' | ||
25 | + self.text = text | ||
26 | + self.lemmatized_text = lemmatized_text | ||
27 | + self.words = words | ||
28 | + self.span = span | ||
29 | + self.head_orth = head_orth | ||
30 | + self.head_base = head_base | ||
31 | + self.dominant = dominant | ||
32 | + self.node = node | ||
33 | + self.prec_context = prec_context | ||
34 | + self.follow_context = follow_context | ||
35 | + self.sentence = sentence | ||
36 | + self.position_in_mentions = position_in_mentions | ||
37 | + self.start_in_words = start_in_words | ||
38 | + self.end_in_words = end_in_words | ||
39 | + self.features = get_mention_features(self) |
corneferencer/core.py renamed to corneferencer/inout/__init__.py
corneferencer/inout/constants.py
0 → 100644
1 | +INPUT_FORMATS = ['mmax'] |
corneferencer/inout/mmax.py
0 → 100644
1 | +import os | ||
2 | +import shutil | ||
3 | + | ||
4 | +from lxml import etree | ||
5 | + | ||
6 | +from conf import CONTEXT | ||
7 | +from corneferencer.entities import Mention, Text | ||
8 | + | ||
9 | + | ||
10 | +def read(inpath): | ||
11 | + textname = os.path.splitext(os.path.basename(inpath))[0] | ||
12 | + textdir = os.path.dirname(inpath) | ||
13 | + | ||
14 | + mentions_path = os.path.join(textdir, '%s_mentions.xml' % textname) | ||
15 | + words_path = os.path.join(textdir, '%s_words.xml' % textname) | ||
16 | + | ||
17 | + text = Text(textname) | ||
18 | + mentions = read_mentions(mentions_path, words_path) | ||
19 | + text.mentions = mentions | ||
20 | + return text | ||
21 | + | ||
22 | + | ||
23 | +def read_mentions(mentions_path, words_path): | ||
24 | + mentions = [] | ||
25 | + mentions_tree = etree.parse(mentions_path) | ||
26 | + markables = mentions_tree.xpath("//ns:markable", | ||
27 | + namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) | ||
28 | + words = get_words(words_path) | ||
29 | + | ||
30 | + for idx, markable in enumerate(markables): | ||
31 | + span = markable.attrib['span'] | ||
32 | + | ||
33 | + dominant = '' | ||
34 | + if 'dominant' in markable.attrib: | ||
35 | + dominant = markable.attrib['dominant'] | ||
36 | + | ||
37 | + head_orth = markable.attrib['mention_head'] | ||
38 | + mention_words = span_to_words(span, words) | ||
39 | + | ||
40 | + (prec_context, follow_context, sentence, | ||
41 | + mnt_start_position, mnt_end_position) = get_context(mention_words, words) | ||
42 | + | ||
43 | + head_base = get_head_base(head_orth, mention_words) | ||
44 | + mention = Mention(mnt_id=markable.attrib['id'], | ||
45 | + text=span_to_text(span, words, 'orth'), | ||
46 | + lemmatized_text=span_to_text(span, words, 'base'), | ||
47 | + words=mention_words, | ||
48 | + span=span, | ||
49 | + head_orth=head_orth, | ||
50 | + head_base=head_base, | ||
51 | + dominant=dominant, | ||
52 | + node=markable, | ||
53 | + prec_context=prec_context, | ||
54 | + follow_context=follow_context, | ||
55 | + sentence=sentence, | ||
56 | + position_in_mentions=idx, | ||
57 | + start_in_words=mnt_start_position, | ||
58 | + end_in_words=mnt_end_position) | ||
59 | + mentions.append(mention) | ||
60 | + | ||
61 | + return mentions | ||
62 | + | ||
63 | + | ||
64 | +def get_words(filepath): | ||
65 | + tree = etree.parse(filepath) | ||
66 | + words = [] | ||
67 | + for word in tree.xpath("//word"): | ||
68 | + hasnps = False | ||
69 | + if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true': | ||
70 | + hasnps = True | ||
71 | + lastinsent = False | ||
72 | + if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': | ||
73 | + lastinsent = True | ||
74 | + words.append({'id': word.attrib['id'], | ||
75 | + 'orth': word.text, | ||
76 | + 'base': word.attrib['base'], | ||
77 | + 'hasnps': hasnps, | ||
78 | + 'lastinsent': lastinsent, | ||
79 | + 'ctag': word.attrib['ctag']}) | ||
80 | + return words | ||
81 | + | ||
82 | + | ||
83 | +def span_to_words(span, words): | ||
84 | + fragments = span.split(',') | ||
85 | + mention_parts = [] | ||
86 | + for fragment in fragments: | ||
87 | + mention_parts.extend(fragment_to_words(fragment, words)) | ||
88 | + return mention_parts | ||
89 | + | ||
90 | + | ||
91 | +def fragment_to_words(fragment, words): | ||
92 | + mention_parts = [] | ||
93 | + if '..' in fragment: | ||
94 | + mention_parts.extend(get_multiword(fragment, words)) | ||
95 | + else: | ||
96 | + mention_parts.extend(get_word(fragment, words)) | ||
97 | + return mention_parts | ||
98 | + | ||
99 | + | ||
100 | +def get_multiword(fragment, words): | ||
101 | + mention_parts = [] | ||
102 | + boundaries = fragment.split('..') | ||
103 | + start_id = boundaries[0] | ||
104 | + end_id = boundaries[1] | ||
105 | + in_string = False | ||
106 | + for word in words: | ||
107 | + if word['id'] == start_id: | ||
108 | + in_string = True | ||
109 | + if in_string and not word_to_ignore(word): | ||
110 | + mention_parts.append(word) | ||
111 | + if word['id'] == end_id: | ||
112 | + break | ||
113 | + return mention_parts | ||
114 | + | ||
115 | + | ||
116 | +def get_word(word_id, words): | ||
117 | + for word in words: | ||
118 | + if word['id'] == word_id: | ||
119 | + if not word_to_ignore(word): | ||
120 | + return [word] | ||
121 | + else: | ||
122 | + return [] | ||
123 | + return [] | ||
124 | + | ||
125 | + | ||
126 | +def word_to_ignore(word): | ||
127 | + if word['ctag'] == 'interp': | ||
128 | + return True | ||
129 | + return False | ||
130 | + | ||
131 | + | ||
132 | +def get_context(mention_words, words): | ||
133 | + prec_context = [] | ||
134 | + follow_context = [] | ||
135 | + sentence = [] | ||
136 | + mnt_start_position = -1 | ||
137 | + mnt_end_position = -1 | ||
138 | + first_word = mention_words[0] | ||
139 | + last_word = mention_words[-1] | ||
140 | + for idx, word in enumerate(words): | ||
141 | + if word['id'] == first_word['id']: | ||
142 | + prec_context = get_prec_context(idx, words) | ||
143 | + mnt_start_position = get_mention_start(first_word, words) | ||
144 | + if word['id'] == last_word['id']: | ||
145 | + follow_context = get_follow_context(idx, words) | ||
146 | + sentence = get_sentence(idx, words) | ||
147 | + mnt_end_position = get_mention_end(last_word, words) | ||
148 | + break | ||
149 | + return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position | ||
150 | + | ||
151 | + | ||
152 | +def get_prec_context(mention_start, words): | ||
153 | + context = [] | ||
154 | + context_start = mention_start - 1 | ||
155 | + while context_start >= 0: | ||
156 | + if not word_to_ignore(words[context_start]): | ||
157 | + context.append(words[context_start]) | ||
158 | + if len(context) == CONTEXT: | ||
159 | + break | ||
160 | + context_start -= 1 | ||
161 | + context.reverse() | ||
162 | + return context | ||
163 | + | ||
164 | + | ||
165 | +def get_mention_start(first_word, words): | ||
166 | + start = 0 | ||
167 | + for word in words: | ||
168 | + if not word_to_ignore(word): | ||
169 | + start += 1 | ||
170 | + if word['id'] == first_word['id']: | ||
171 | + break | ||
172 | + return start | ||
173 | + | ||
174 | + | ||
175 | +def get_mention_end(last_word, words): | ||
176 | + end = 0 | ||
177 | + for word in words: | ||
178 | + if not word_to_ignore(word): | ||
179 | + end += 1 | ||
180 | + if word['id'] == last_word['id']: | ||
181 | + break | ||
182 | + return end | ||
183 | + | ||
184 | + | ||
185 | +def get_follow_context(mention_end, words): | ||
186 | + context = [] | ||
187 | + context_end = mention_end + 1 | ||
188 | + while context_end < len(words): | ||
189 | + if not word_to_ignore(words[context_end]): | ||
190 | + context.append(words[context_end]) | ||
191 | + if len(context) == CONTEXT: | ||
192 | + break | ||
193 | + context_end += 1 | ||
194 | + return context | ||
195 | + | ||
196 | + | ||
197 | +def get_sentence(word_idx, words): | ||
198 | + sentence_start = get_sentence_start(words, word_idx) | ||
199 | + sentence_end = get_sentence_end(words, word_idx) | ||
200 | + sentence = [word for word in words[sentence_start:sentence_end + 1] if not word_to_ignore(word)] | ||
201 | + return sentence | ||
202 | + | ||
203 | + | ||
204 | +def get_sentence_start(words, word_idx): | ||
205 | + search_start = word_idx | ||
206 | + while word_idx >= 0: | ||
207 | + if words[word_idx]['lastinsent'] and search_start != word_idx: | ||
208 | + return word_idx + 1 | ||
209 | + word_idx -= 1 | ||
210 | + return 0 | ||
211 | + | ||
212 | + | ||
213 | +def get_sentence_end(words, word_idx): | ||
214 | + while word_idx < len(words): | ||
215 | + if words[word_idx]['lastinsent']: | ||
216 | + return word_idx | ||
217 | + word_idx += 1 | ||
218 | + return len(words) - 1 | ||
219 | + | ||
220 | + | ||
221 | +def get_head_base(head_orth, words): | ||
222 | + for word in words: | ||
223 | + if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth: | ||
224 | + return word['base'] | ||
225 | + return None | ||
226 | + | ||
227 | + | ||
228 | +def span_to_text(span, words, form): | ||
229 | + fragments = span.split(',') | ||
230 | + mention_parts = [] | ||
231 | + for fragment in fragments: | ||
232 | + mention_parts.append(fragment_to_text(fragment, words, form)) | ||
233 | + return u' [...] '.join(mention_parts) | ||
234 | + | ||
235 | + | ||
236 | +def fragment_to_text(fragment, words, form): | ||
237 | + if '..' in fragment: | ||
238 | + text = get_multiword_text(fragment, words, form) | ||
239 | + else: | ||
240 | + text = get_one_word_text(fragment, words, form) | ||
241 | + return text | ||
242 | + | ||
243 | + | ||
244 | +def get_multiword_text(fragment, words, form): | ||
245 | + mention_parts = [] | ||
246 | + boundaries = fragment.split('..') | ||
247 | + start_id = boundaries[0] | ||
248 | + end_id = boundaries[1] | ||
249 | + in_string = False | ||
250 | + for word in words: | ||
251 | + if word['id'] == start_id: | ||
252 | + in_string = True | ||
253 | + if in_string and not word_to_ignore(word): | ||
254 | + mention_parts.append(word) | ||
255 | + if word['id'] == end_id: | ||
256 | + break | ||
257 | + return to_text(mention_parts, form) | ||
258 | + | ||
259 | + | ||
260 | +def to_text(words, form): | ||
261 | + text = '' | ||
262 | + for idx, word in enumerate(words): | ||
263 | + if word['hasnps'] or idx == 0: | ||
264 | + text += word[form] | ||
265 | + else: | ||
266 | + text += u' %s' % word[form] | ||
267 | + return text | ||
268 | + | ||
269 | + | ||
270 | +def get_one_word_text(word_id, words, form): | ||
271 | + this_word = next(word for word in words if word['id'] == word_id) | ||
272 | + return this_word[form] | ||
273 | + | ||
274 | + | ||
275 | +def write(inpath, outpath, text): | ||
276 | + textname = os.path.splitext(os.path.basename(inpath))[0] | ||
277 | + intextdir = os.path.dirname(inpath) | ||
278 | + outtextdir = os.path.dirname(outpath) | ||
279 | + | ||
280 | + in_mmax_path = os.path.join(intextdir, '%s.mmax' % textname) | ||
281 | + out_mmax_path = os.path.join(outtextdir, '%s.mmax' % textname) | ||
282 | + copy_mmax(in_mmax_path, out_mmax_path) | ||
283 | + | ||
284 | + in_words_path = os.path.join(intextdir, '%s_words.xml' % textname) | ||
285 | + out_words_path = os.path.join(outtextdir, '%s_words.xml' % textname) | ||
286 | + copy_words(in_words_path, out_words_path) | ||
287 | + | ||
288 | + in_mentions_path = os.path.join(intextdir, '%s_mentions.xml' % textname) | ||
289 | + out_mentions_path = os.path.join(outtextdir, '%s_mentions.xml' % textname) | ||
290 | + write_mentions(in_mentions_path, out_mentions_path, text) | ||
291 | + | ||
292 | + | ||
293 | +def copy_mmax(src, dest): | ||
294 | + shutil.copyfile(src, dest) | ||
295 | + | ||
296 | + | ||
297 | +def copy_words(src, dest): | ||
298 | + shutil.copyfile(src, dest) | ||
299 | + | ||
300 | + | ||
301 | +def write_mentions(inpath, outpath, text): | ||
302 | + tree = etree.parse(inpath) | ||
303 | + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) | ||
304 | + | ||
305 | + for mnt in mentions: | ||
306 | + mnt_set = text.get_mention_set(mnt.attrib['id']) | ||
307 | + if mnt_set: | ||
308 | + mnt.attrib['mention_group'] = mnt_set | ||
309 | + else: | ||
310 | + mnt.attrib['mention_group'] = 'empty' | ||
311 | + | ||
312 | + with open(outpath, 'wb') as output_file: | ||
313 | + output_file.write(etree.tostring(tree, pretty_print=True, | ||
314 | + xml_declaration=True, encoding='UTF-8', | ||
315 | + doctype=u'<!DOCTYPE markables SYSTEM "markables.dtd">')) |
corneferencer/main.py
0 → 100644
1 | +import os | ||
2 | +import sys | ||
3 | + | ||
4 | +from argparse import ArgumentParser | ||
5 | +from natsort import natsorted | ||
6 | + | ||
7 | +sys.path.append(os.path.abspath(os.path.join('..'))) | ||
8 | + | ||
9 | +from inout import mmax | ||
10 | +from inout.constants import INPUT_FORMATS | ||
11 | +from resolvers import resolve | ||
12 | +from resolvers.constants import RESOLVERS | ||
13 | +from utils import eprint | ||
14 | + | ||
15 | + | ||
16 | +def main(): | ||
17 | + args = parse_arguments() | ||
18 | + if not args.input: | ||
19 | + eprint("Error: Input file(s) not specified!") | ||
20 | + elif args.resolver not in RESOLVERS: | ||
21 | + eprint("Error: Unknown resolve algorithm!") | ||
22 | + elif args.format not in INPUT_FORMATS: | ||
23 | + eprint("Error: Unknown input file format!") | ||
24 | + else: | ||
25 | + process_texts(args.input, args.output, args.format, args.resolver) | ||
26 | + | ||
27 | + | ||
28 | +def parse_arguments(): | ||
29 | + parser = ArgumentParser(description='Corneferencer: coreference resolver using neural nets.') | ||
30 | + parser.add_argument('-i', '--input', type=str, action='store', | ||
31 | + dest='input', default='', | ||
32 | + help='input file or dir path') | ||
33 | + parser.add_argument('-o', '--output', type=str, action='store', | ||
34 | + dest='output', default='', | ||
35 | + help='output path; if not specified writes output to standard output') | ||
36 | + parser.add_argument('-f', '--format', type=str, action='store', | ||
37 | + dest='format', default='mmax', | ||
38 | + help='input format; default: mmax') | ||
39 | + parser.add_argument('-r', '--resolver', type=str, action='store', | ||
40 | + dest='resolver', default='incremental', | ||
41 | + help='resolve algorithm; default: incremental; possibilities: %s' | ||
42 | + % ', '.join(RESOLVERS)) | ||
43 | + | ||
44 | + args = parser.parse_args() | ||
45 | + return args | ||
46 | + | ||
47 | + | ||
48 | +def process_texts(inpath, outpath, informat, resolver): | ||
49 | + if os.path.isdir(inpath): | ||
50 | + process_directory(inpath, outpath, informat, resolver) | ||
51 | + elif os.path.isfile(inpath): | ||
52 | + process_file(inpath, outpath, informat, resolver) | ||
53 | + else: | ||
54 | + eprint("Error: Specified input does not exist!") | ||
55 | + | ||
56 | + | ||
57 | +def process_directory(inpath, outpath, informat, resolver): | ||
58 | + inpath = os.path.abspath(inpath) | ||
59 | + outpath = os.path.abspath(outpath) | ||
60 | + | ||
61 | + files = os.listdir(inpath) | ||
62 | + files = natsorted(files) | ||
63 | + | ||
64 | + for filename in files: | ||
65 | + textname = os.path.splitext(os.path.basename(filename))[0] | ||
66 | + textoutput = os.path.join(outpath, textname) | ||
67 | + textinput = os.path.join(inpath, filename) | ||
68 | + process_file(textinput, textoutput, informat, resolver) | ||
69 | + | ||
70 | + | ||
71 | +def process_file(inpath, outpath, informat, resolver): | ||
72 | + basename = os.path.basename(inpath) | ||
73 | + if informat == 'mmax' and basename.endswith('.mmax'): | ||
74 | + print (basename) | ||
75 | + text = mmax.read(inpath) | ||
76 | + if resolver == 'incremental': | ||
77 | + resolve.incremental(text) | ||
78 | + elif resolver == 'entity_based': | ||
79 | + resolve.entity_based(text) | ||
80 | + mmax.write(inpath, outpath, text) | ||
81 | + | ||
82 | + | ||
83 | +if __name__ == '__main__': | ||
84 | + main() |
corneferencer/readers/__init__.py deleted
corneferencer/entities/__init__.py renamed to corneferencer/resolvers/__init__.py
corneferencer/resolvers/constants.py
0 → 100644
1 | +RESOLVERS = ['entity_based', 'incremental'] |
corneferencer/resolvers/features.py
0 → 100644
1 | +import numpy | ||
2 | +import random | ||
3 | + | ||
4 | +from conf import RANDOM_WORD_VECTORS, W2V_MODEL, W2V_SIZE | ||
5 | + | ||
6 | + | ||
7 | +# mention features | ||
8 | +def head_vec(mention): | ||
9 | + return list(get_wv(W2V_MODEL, mention.head_base)) | ||
10 | + | ||
11 | + | ||
12 | +def first_word_vec(mention): | ||
13 | + return list(get_wv(W2V_MODEL, mention.words[0]['base'])) | ||
14 | + | ||
15 | + | ||
16 | +def last_word_vec(mention): | ||
17 | + return list(get_wv(W2V_MODEL, mention.words[-1]['base'])) | ||
18 | + | ||
19 | + | ||
20 | +def first_after_vec(mention): | ||
21 | + if len(mention.follow_context) > 0: | ||
22 | + vec = list(get_wv(W2V_MODEL, mention.follow_context[0]['base'])) | ||
23 | + else: | ||
24 | + vec = [0.0] * W2V_SIZE | ||
25 | + return vec | ||
26 | + | ||
27 | + | ||
28 | +def second_after_vec(mention): | ||
29 | + if len(mention.follow_context) > 1: | ||
30 | + vec = list(get_wv(W2V_MODEL, mention.follow_context[1]['base'])) | ||
31 | + else: | ||
32 | + vec = [0.0] * W2V_SIZE | ||
33 | + return vec | ||
34 | + | ||
35 | + | ||
36 | +def first_before_vec(mention): | ||
37 | + if len(mention.prec_context) > 0: | ||
38 | + vec = list(get_wv(W2V_MODEL, mention.prec_context[-1]['base'])) | ||
39 | + else: | ||
40 | + vec = [0.0] * W2V_SIZE | ||
41 | + return vec | ||
42 | + | ||
43 | + | ||
44 | +def second_before_vec(mention): | ||
45 | + if len(mention.prec_context) > 1: | ||
46 | + vec = list(get_wv(W2V_MODEL, mention.prec_context[-2]['base'])) | ||
47 | + else: | ||
48 | + vec = [0.0] * W2V_SIZE | ||
49 | + return vec | ||
50 | + | ||
51 | + | ||
52 | +def preceding_context_vec(mention): | ||
53 | + return list(get_context_vec(mention.prec_context, W2V_MODEL)) | ||
54 | + | ||
55 | + | ||
56 | +def following_context_vec(mention): | ||
57 | + return list(get_context_vec(mention.follow_context, W2V_MODEL)) | ||
58 | + | ||
59 | + | ||
60 | +def mention_vec(mention): | ||
61 | + return list(get_context_vec(mention.words, W2V_MODEL)) | ||
62 | + | ||
63 | + | ||
64 | +def sentence_vec(mention): | ||
65 | + return list(get_context_vec(mention.sentence, W2V_MODEL)) | ||
66 | + | ||
67 | + | ||
68 | +# pair features | ||
69 | +def distances_vec(ante, ana): | ||
70 | + vec = [] | ||
71 | + | ||
72 | + mnts_intersect = pair_intersect(ante, ana) | ||
73 | + | ||
74 | + words_dist = [0] * 11 | ||
75 | + words_bucket = 0 | ||
76 | + if mnts_intersect != 1: | ||
77 | + words_bucket = get_distance_bucket(ana.start_in_words - ante.end_in_words - 1) | ||
78 | + words_dist[words_bucket] = 1 | ||
79 | + vec.extend(words_dist) | ||
80 | + | ||
81 | + mentions_dist = [0] * 11 | ||
82 | + mentions_bucket = 0 | ||
83 | + if mnts_intersect != 1: | ||
84 | + mentions_bucket = get_distance_bucket(ana.position_in_mentions - ante.position_in_mentions - 1) | ||
85 | + if words_bucket == 10: | ||
86 | + mentions_bucket = 10 | ||
87 | + mentions_dist[mentions_bucket] = 1 | ||
88 | + vec.extend(mentions_dist) | ||
89 | + | ||
90 | + vec.append(mnts_intersect) | ||
91 | + | ||
92 | + return vec | ||
93 | + | ||
94 | + | ||
95 | +def pair_intersect(ante, ana): | ||
96 | + for ante_word in ante.words: | ||
97 | + for ana_word in ana.words: | ||
98 | + if ana_word['id'] == ante_word['id']: | ||
99 | + return 1 | ||
100 | + return 0 | ||
101 | + | ||
102 | + | ||
103 | +def head_match(ante, ana): | ||
104 | + if ante.head_orth.lower() == ana.head_orth.lower(): | ||
105 | + return 1 | ||
106 | + return 0 | ||
107 | + | ||
108 | + | ||
109 | +def exact_match(ante, ana): | ||
110 | + if ante.text.lower() == ana.text.lower(): | ||
111 | + return 1 | ||
112 | + return 0 | ||
113 | + | ||
114 | + | ||
115 | +def base_match(ante, ana): | ||
116 | + if ante.lemmatized_text.lower() == ana.lemmatized_text.lower(): | ||
117 | + return 1 | ||
118 | + return 0 | ||
119 | + | ||
120 | + | ||
121 | +# supporting functions | ||
122 | +def get_wv(model, lemma, use_random_vec=True): | ||
123 | + vec = None | ||
124 | + if use_random_vec: | ||
125 | + vec = random_vec() | ||
126 | + try: | ||
127 | + vec = model.wv[lemma] | ||
128 | + except KeyError: | ||
129 | + pass | ||
130 | + except TypeError: | ||
131 | + pass | ||
132 | + return vec | ||
133 | + | ||
134 | + | ||
135 | +def random_vec(): | ||
136 | + return numpy.asarray([random.uniform(-0.25, 0.25) for i in range(0, W2V_SIZE)], dtype=numpy.float32) | ||
137 | + | ||
138 | + | ||
139 | +def get_context_vec(words, model): | ||
140 | + vec = numpy.zeros(W2V_SIZE, dtype=numpy.float32) | ||
141 | + unknown_count = 0 | ||
142 | + if len(words) != 0: | ||
143 | + for word in words: | ||
144 | + word_vec = get_wv(model, word['base'], RANDOM_WORD_VECTORS) | ||
145 | + if word_vec is None: | ||
146 | + unknown_count += 1 | ||
147 | + else: | ||
148 | + vec += word_vec | ||
149 | + significant_words = len(words) - unknown_count | ||
150 | + if significant_words != 0: | ||
151 | + vec = vec / float(significant_words) | ||
152 | + else: | ||
153 | + vec = random_vec() | ||
154 | + return vec | ||
155 | + | ||
156 | + | ||
157 | +def get_distance_bucket(distance): | ||
158 | + if 0 <= distance <= 4: | ||
159 | + return distance | ||
160 | + elif 5 <= distance <= 7: | ||
161 | + return 5 | ||
162 | + elif 8 <= distance <= 15: | ||
163 | + return 6 | ||
164 | + elif 16 <= distance <= 31: | ||
165 | + return 7 | ||
166 | + elif 32 <= distance <= 63: | ||
167 | + return 8 | ||
168 | + elif distance >= 64: | ||
169 | + return 9 | ||
170 | + return 10 |
corneferencer/resolvers/resolve.py
0 → 100644
1 | +from conf import NEURAL_MODEL, THRESHOLD | ||
2 | +from corneferencer.resolvers.vectors import create_pair_vector | ||
3 | + | ||
4 | + | ||
5 | +# incremental resolve algorithm | ||
6 | +def incremental(text): | ||
7 | + last_set_id = 1 | ||
8 | + for i, ana in enumerate(text.mentions): | ||
9 | + if i > 0: | ||
10 | + best_prediction = 0.0 | ||
11 | + best_ante = None | ||
12 | + for ante in text.mentions[:i:-1]: | ||
13 | + pair_vec = create_pair_vector(ante, ana) | ||
14 | + prediction = NEURAL_MODEL.predict(pair_vec) | ||
15 | + accuracy = prediction[0] | ||
16 | + if accuracy > THRESHOLD and accuracy > best_prediction: | ||
17 | + best_prediction = accuracy | ||
18 | + best_ante = ante | ||
19 | + if best_ante is not None: | ||
20 | + if best_ante.set: | ||
21 | + ana.set = best_ante.set | ||
22 | + else: | ||
23 | + str_set_id = 'set_%d' % last_set_id | ||
24 | + best_ante.set = str_set_id | ||
25 | + ana.set = str_set_id | ||
26 | + last_set_id += 1 | ||
27 | + | ||
28 | + | ||
29 | +# entity based resolve algorithm | ||
30 | +def entity_based(text): | ||
31 | + sets = [] | ||
32 | + last_set_id = 1 | ||
33 | + for i, ana in enumerate(text.mentions): | ||
34 | + if i > 0: | ||
35 | + best_fit = get_best_set(sets, ana) | ||
36 | + if best_fit is not None: | ||
37 | + ana.set = best_fit['set_id'] | ||
38 | + best_fit['mentions'].append(ana) | ||
39 | + else: | ||
40 | + str_set_id = 'set_%d' % last_set_id | ||
41 | + sets.append({'set_id': str_set_id, | ||
42 | + 'mentions': [ana]}) | ||
43 | + ana.set = str_set_id | ||
44 | + last_set_id += 1 | ||
45 | + else: | ||
46 | + str_set_id = 'set_%d' % last_set_id | ||
47 | + sets.append({'set_id': str_set_id, | ||
48 | + 'mentions': [ana]}) | ||
49 | + ana.set = str_set_id | ||
50 | + last_set_id += 1 | ||
51 | + | ||
52 | + remove_singletons(sets) | ||
53 | + | ||
54 | + | ||
55 | +def get_best_set(sets, ana): | ||
56 | + best_prediction = 0.0 | ||
57 | + best_set = None | ||
58 | + for s in sets: | ||
59 | + accuracy = predict_set(s['mentions'], ana) | ||
60 | + if accuracy > THRESHOLD and accuracy >= best_prediction: | ||
61 | + best_prediction = accuracy | ||
62 | + best_set = s | ||
63 | + return best_set | ||
64 | + | ||
65 | + | ||
66 | +def predict_set(mentions, ana): | ||
67 | + accuracy_sum = 0.0 | ||
68 | + for mnt in mentions: | ||
69 | + pair_vec = create_pair_vector(mnt, ana) | ||
70 | + prediction = NEURAL_MODEL.predict(pair_vec) | ||
71 | + accuracy = prediction[0] | ||
72 | + accuracy_sum += accuracy | ||
73 | + return accuracy_sum / float(len(mentions)) | ||
74 | + | ||
75 | + | ||
76 | +def remove_singletons(sets): | ||
77 | + for s in sets: | ||
78 | + if len(s['mentions']) == 1: | ||
79 | + s['mentions'][0].set = '' |
corneferencer/resolvers/vectors.py
0 → 100644
1 | +import numpy | ||
2 | + | ||
3 | +from corneferencer.resolvers import features | ||
4 | + | ||
5 | +# input_1 to have shape (None, 1126) but got array with shape (1126, 1) | ||
6 | +def create_pair_vector(ante, ana): | ||
7 | + vec = [] | ||
8 | + # ante_features = get_mention_features(ante) | ||
9 | + # vec.extend(ante_features) | ||
10 | + # ana_features = get_mention_features(ana) | ||
11 | + # vec.extend(ana_features) | ||
12 | + vec.extend(ante.features) | ||
13 | + vec.extend(ana.features) | ||
14 | + pair_features = get_pair_features(ante, ana) | ||
15 | + vec.extend(pair_features) | ||
16 | + return numpy.asarray([vec], dtype=numpy.float32) | ||
17 | + | ||
18 | + | ||
19 | +def get_mention_features(mention): | ||
20 | + vec = [] | ||
21 | + vec.extend(features.head_vec(mention)) | ||
22 | + vec.extend(features.first_word_vec(mention)) | ||
23 | + vec.extend(features.last_word_vec(mention)) | ||
24 | + vec.extend(features.first_after_vec(mention)) | ||
25 | + vec.extend(features.second_after_vec(mention)) | ||
26 | + vec.extend(features.first_before_vec(mention)) | ||
27 | + vec.extend(features.second_before_vec(mention)) | ||
28 | + vec.extend(features.preceding_context_vec(mention)) | ||
29 | + vec.extend(features.following_context_vec(mention)) | ||
30 | + vec.extend(features.mention_vec(mention)) | ||
31 | + vec.extend(features.sentence_vec(mention)) | ||
32 | + return vec | ||
33 | + | ||
34 | + | ||
35 | +def get_pair_features(ante, ana): | ||
36 | + vec = [] | ||
37 | + vec.extend(features.distances_vec(ante, ana)) | ||
38 | + vec.append(features.head_match(ante, ana)) | ||
39 | + vec.append(features.exact_match(ante, ana)) | ||
40 | + vec.append(features.base_match(ante, ana)) | ||
41 | + return vec |
corneferencer/utils.py
0 → 100644
1 | +from __future__ import print_function | ||
2 | + | ||
3 | +import sys | ||
4 | + | ||
5 | +from keras.models import Model | ||
6 | +from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization | ||
7 | + | ||
8 | + | ||
9 | +def eprint(*args, **kwargs): | ||
10 | + print(*args, file=sys.stderr, **kwargs) | ||
11 | + | ||
12 | + | ||
13 | +def initialize_neural_model(number_of_features): | ||
14 | + inputs = Input(shape=(number_of_features,)) | ||
15 | + output_from_1st_layer = Dense(1000, activation='relu')(inputs) | ||
16 | + output_from_1st_layer = Dropout(0.5)(output_from_1st_layer) | ||
17 | + output_from_1st_layer = BatchNormalization()(output_from_1st_layer) | ||
18 | + output_from_2nd_layer = Dense(500, activation='relu')(output_from_1st_layer) | ||
19 | + output_from_2nd_layer = Dropout(0.5)(output_from_2nd_layer) | ||
20 | + output_from_2nd_layer = BatchNormalization()(output_from_2nd_layer) | ||
21 | + output = Dense(1, activation='sigmoid')(output_from_2nd_layer) | ||
22 | + | ||
23 | + model = Model(inputs, output) | ||
24 | + model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy']) | ||
25 | + return model |
requirements.txt
setup.py deleted