Commit 4a097c4f123cc34ee07d870361bad7a817a3faeb
1 parent
a5beabe2
Corneferencer alpha version
Showing
15 changed files
with
782 additions
and
0 deletions
conf.py
0 → 100644
1 | +import os | |
2 | + | |
3 | +from gensim.models.word2vec import Word2Vec | |
4 | + | |
5 | +from corneferencer.utils import initialize_neural_model | |
6 | + | |
7 | + | |
8 | +CONTEXT = 5 | |
9 | +THRESHOLD = 0.5 | |
10 | +RANDOM_WORD_VECTORS = True | |
11 | +W2V_SIZE = 50 | |
12 | +W2V_MODEL_NAME = 'w2v_allwiki_nkjpfull_50.model' | |
13 | + | |
14 | +NUMBER_OF_FEATURES = 1126 | |
15 | +NEURAL_MODEL_NAME = 'weights_2017_05_10.h5' | |
16 | + | |
17 | + | |
18 | +# do not change that | |
19 | +W2V_MODEL_PATH = os.path.join(os.path.dirname(__file__), 'models', W2V_MODEL_NAME) | |
20 | +W2V_MODEL = Word2Vec.load(W2V_MODEL_PATH) | |
21 | + | |
22 | +NEURAL_MODEL_PATH = os.path.join(os.path.dirname(__file__), 'models', NEURAL_MODEL_NAME) | |
23 | +NEURAL_MODEL = initialize_neural_model(NUMBER_OF_FEATURES) | |
... | ... |
corneferencer/entities.py
0 → 100644
1 | +from corneferencer.resolvers.vectors import get_mention_features | |
2 | + | |
3 | + | |
4 | +class Text: | |
5 | + | |
6 | + def __init__(self, text_id): | |
7 | + self.__id = text_id | |
8 | + self.mentions = [] | |
9 | + | |
10 | + def get_mention_set(self, mnt_id): | |
11 | + for mnt in self.mentions: | |
12 | + if mnt.id == mnt_id: | |
13 | + return mnt.set | |
14 | + return None | |
15 | + | |
16 | + | |
17 | +class Mention: | |
18 | + | |
19 | + def __init__(self, mnt_id, text, lemmatized_text, words, span, | |
20 | + head_orth, head_base, dominant, node, prec_context, | |
21 | + follow_context, sentence, position_in_mentions, | |
22 | + start_in_words, end_in_words): | |
23 | + self.id = mnt_id | |
24 | + self.set = '' | |
25 | + self.text = text | |
26 | + self.lemmatized_text = lemmatized_text | |
27 | + self.words = words | |
28 | + self.span = span | |
29 | + self.head_orth = head_orth | |
30 | + self.head_base = head_base | |
31 | + self.dominant = dominant | |
32 | + self.node = node | |
33 | + self.prec_context = prec_context | |
34 | + self.follow_context = follow_context | |
35 | + self.sentence = sentence | |
36 | + self.position_in_mentions = position_in_mentions | |
37 | + self.start_in_words = start_in_words | |
38 | + self.end_in_words = end_in_words | |
39 | + self.features = get_mention_features(self) | |
... | ... |
corneferencer/core.py renamed to corneferencer/inout/__init__.py
corneferencer/inout/constants.py
0 → 100644
1 | +INPUT_FORMATS = ['mmax'] | |
... | ... |
corneferencer/inout/mmax.py
0 → 100644
1 | +import os | |
2 | +import shutil | |
3 | + | |
4 | +from lxml import etree | |
5 | + | |
6 | +from conf import CONTEXT | |
7 | +from corneferencer.entities import Mention, Text | |
8 | + | |
9 | + | |
10 | +def read(inpath): | |
11 | + textname = os.path.splitext(os.path.basename(inpath))[0] | |
12 | + textdir = os.path.dirname(inpath) | |
13 | + | |
14 | + mentions_path = os.path.join(textdir, '%s_mentions.xml' % textname) | |
15 | + words_path = os.path.join(textdir, '%s_words.xml' % textname) | |
16 | + | |
17 | + text = Text(textname) | |
18 | + mentions = read_mentions(mentions_path, words_path) | |
19 | + text.mentions = mentions | |
20 | + return text | |
21 | + | |
22 | + | |
23 | +def read_mentions(mentions_path, words_path): | |
24 | + mentions = [] | |
25 | + mentions_tree = etree.parse(mentions_path) | |
26 | + markables = mentions_tree.xpath("//ns:markable", | |
27 | + namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) | |
28 | + words = get_words(words_path) | |
29 | + | |
30 | + for idx, markable in enumerate(markables): | |
31 | + span = markable.attrib['span'] | |
32 | + | |
33 | + dominant = '' | |
34 | + if 'dominant' in markable.attrib: | |
35 | + dominant = markable.attrib['dominant'] | |
36 | + | |
37 | + head_orth = markable.attrib['mention_head'] | |
38 | + mention_words = span_to_words(span, words) | |
39 | + | |
40 | + (prec_context, follow_context, sentence, | |
41 | + mnt_start_position, mnt_end_position) = get_context(mention_words, words) | |
42 | + | |
43 | + head_base = get_head_base(head_orth, mention_words) | |
44 | + mention = Mention(mnt_id=markable.attrib['id'], | |
45 | + text=span_to_text(span, words, 'orth'), | |
46 | + lemmatized_text=span_to_text(span, words, 'base'), | |
47 | + words=mention_words, | |
48 | + span=span, | |
49 | + head_orth=head_orth, | |
50 | + head_base=head_base, | |
51 | + dominant=dominant, | |
52 | + node=markable, | |
53 | + prec_context=prec_context, | |
54 | + follow_context=follow_context, | |
55 | + sentence=sentence, | |
56 | + position_in_mentions=idx, | |
57 | + start_in_words=mnt_start_position, | |
58 | + end_in_words=mnt_end_position) | |
59 | + mentions.append(mention) | |
60 | + | |
61 | + return mentions | |
62 | + | |
63 | + | |
64 | +def get_words(filepath): | |
65 | + tree = etree.parse(filepath) | |
66 | + words = [] | |
67 | + for word in tree.xpath("//word"): | |
68 | + hasnps = False | |
69 | + if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true': | |
70 | + hasnps = True | |
71 | + lastinsent = False | |
72 | + if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': | |
73 | + lastinsent = True | |
74 | + words.append({'id': word.attrib['id'], | |
75 | + 'orth': word.text, | |
76 | + 'base': word.attrib['base'], | |
77 | + 'hasnps': hasnps, | |
78 | + 'lastinsent': lastinsent, | |
79 | + 'ctag': word.attrib['ctag']}) | |
80 | + return words | |
81 | + | |
82 | + | |
83 | +def span_to_words(span, words): | |
84 | + fragments = span.split(',') | |
85 | + mention_parts = [] | |
86 | + for fragment in fragments: | |
87 | + mention_parts.extend(fragment_to_words(fragment, words)) | |
88 | + return mention_parts | |
89 | + | |
90 | + | |
91 | +def fragment_to_words(fragment, words): | |
92 | + mention_parts = [] | |
93 | + if '..' in fragment: | |
94 | + mention_parts.extend(get_multiword(fragment, words)) | |
95 | + else: | |
96 | + mention_parts.extend(get_word(fragment, words)) | |
97 | + return mention_parts | |
98 | + | |
99 | + | |
100 | +def get_multiword(fragment, words): | |
101 | + mention_parts = [] | |
102 | + boundaries = fragment.split('..') | |
103 | + start_id = boundaries[0] | |
104 | + end_id = boundaries[1] | |
105 | + in_string = False | |
106 | + for word in words: | |
107 | + if word['id'] == start_id: | |
108 | + in_string = True | |
109 | + if in_string and not word_to_ignore(word): | |
110 | + mention_parts.append(word) | |
111 | + if word['id'] == end_id: | |
112 | + break | |
113 | + return mention_parts | |
114 | + | |
115 | + | |
116 | +def get_word(word_id, words): | |
117 | + for word in words: | |
118 | + if word['id'] == word_id: | |
119 | + if not word_to_ignore(word): | |
120 | + return [word] | |
121 | + else: | |
122 | + return [] | |
123 | + return [] | |
124 | + | |
125 | + | |
126 | +def word_to_ignore(word): | |
127 | + if word['ctag'] == 'interp': | |
128 | + return True | |
129 | + return False | |
130 | + | |
131 | + | |
132 | +def get_context(mention_words, words): | |
133 | + prec_context = [] | |
134 | + follow_context = [] | |
135 | + sentence = [] | |
136 | + mnt_start_position = -1 | |
137 | + mnt_end_position = -1 | |
138 | + first_word = mention_words[0] | |
139 | + last_word = mention_words[-1] | |
140 | + for idx, word in enumerate(words): | |
141 | + if word['id'] == first_word['id']: | |
142 | + prec_context = get_prec_context(idx, words) | |
143 | + mnt_start_position = get_mention_start(first_word, words) | |
144 | + if word['id'] == last_word['id']: | |
145 | + follow_context = get_follow_context(idx, words) | |
146 | + sentence = get_sentence(idx, words) | |
147 | + mnt_end_position = get_mention_end(last_word, words) | |
148 | + break | |
149 | + return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position | |
150 | + | |
151 | + | |
152 | +def get_prec_context(mention_start, words): | |
153 | + context = [] | |
154 | + context_start = mention_start - 1 | |
155 | + while context_start >= 0: | |
156 | + if not word_to_ignore(words[context_start]): | |
157 | + context.append(words[context_start]) | |
158 | + if len(context) == CONTEXT: | |
159 | + break | |
160 | + context_start -= 1 | |
161 | + context.reverse() | |
162 | + return context | |
163 | + | |
164 | + | |
165 | +def get_mention_start(first_word, words): | |
166 | + start = 0 | |
167 | + for word in words: | |
168 | + if not word_to_ignore(word): | |
169 | + start += 1 | |
170 | + if word['id'] == first_word['id']: | |
171 | + break | |
172 | + return start | |
173 | + | |
174 | + | |
175 | +def get_mention_end(last_word, words): | |
176 | + end = 0 | |
177 | + for word in words: | |
178 | + if not word_to_ignore(word): | |
179 | + end += 1 | |
180 | + if word['id'] == last_word['id']: | |
181 | + break | |
182 | + return end | |
183 | + | |
184 | + | |
185 | +def get_follow_context(mention_end, words): | |
186 | + context = [] | |
187 | + context_end = mention_end + 1 | |
188 | + while context_end < len(words): | |
189 | + if not word_to_ignore(words[context_end]): | |
190 | + context.append(words[context_end]) | |
191 | + if len(context) == CONTEXT: | |
192 | + break | |
193 | + context_end += 1 | |
194 | + return context | |
195 | + | |
196 | + | |
197 | +def get_sentence(word_idx, words): | |
198 | + sentence_start = get_sentence_start(words, word_idx) | |
199 | + sentence_end = get_sentence_end(words, word_idx) | |
200 | + sentence = [word for word in words[sentence_start:sentence_end + 1] if not word_to_ignore(word)] | |
201 | + return sentence | |
202 | + | |
203 | + | |
204 | +def get_sentence_start(words, word_idx): | |
205 | + search_start = word_idx | |
206 | + while word_idx >= 0: | |
207 | + if words[word_idx]['lastinsent'] and search_start != word_idx: | |
208 | + return word_idx + 1 | |
209 | + word_idx -= 1 | |
210 | + return 0 | |
211 | + | |
212 | + | |
213 | +def get_sentence_end(words, word_idx): | |
214 | + while word_idx < len(words): | |
215 | + if words[word_idx]['lastinsent']: | |
216 | + return word_idx | |
217 | + word_idx += 1 | |
218 | + return len(words) - 1 | |
219 | + | |
220 | + | |
221 | +def get_head_base(head_orth, words): | |
222 | + for word in words: | |
223 | + if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth: | |
224 | + return word['base'] | |
225 | + return None | |
226 | + | |
227 | + | |
228 | +def span_to_text(span, words, form): | |
229 | + fragments = span.split(',') | |
230 | + mention_parts = [] | |
231 | + for fragment in fragments: | |
232 | + mention_parts.append(fragment_to_text(fragment, words, form)) | |
233 | + return u' [...] '.join(mention_parts) | |
234 | + | |
235 | + | |
236 | +def fragment_to_text(fragment, words, form): | |
237 | + if '..' in fragment: | |
238 | + text = get_multiword_text(fragment, words, form) | |
239 | + else: | |
240 | + text = get_one_word_text(fragment, words, form) | |
241 | + return text | |
242 | + | |
243 | + | |
244 | +def get_multiword_text(fragment, words, form): | |
245 | + mention_parts = [] | |
246 | + boundaries = fragment.split('..') | |
247 | + start_id = boundaries[0] | |
248 | + end_id = boundaries[1] | |
249 | + in_string = False | |
250 | + for word in words: | |
251 | + if word['id'] == start_id: | |
252 | + in_string = True | |
253 | + if in_string and not word_to_ignore(word): | |
254 | + mention_parts.append(word) | |
255 | + if word['id'] == end_id: | |
256 | + break | |
257 | + return to_text(mention_parts, form) | |
258 | + | |
259 | + | |
260 | +def to_text(words, form): | |
261 | + text = '' | |
262 | + for idx, word in enumerate(words): | |
263 | + if word['hasnps'] or idx == 0: | |
264 | + text += word[form] | |
265 | + else: | |
266 | + text += u' %s' % word[form] | |
267 | + return text | |
268 | + | |
269 | + | |
270 | +def get_one_word_text(word_id, words, form): | |
271 | + this_word = next(word for word in words if word['id'] == word_id) | |
272 | + return this_word[form] | |
273 | + | |
274 | + | |
275 | +def write(inpath, outpath, text): | |
276 | + textname = os.path.splitext(os.path.basename(inpath))[0] | |
277 | + intextdir = os.path.dirname(inpath) | |
278 | + outtextdir = os.path.dirname(outpath) | |
279 | + | |
280 | + in_mmax_path = os.path.join(intextdir, '%s.mmax' % textname) | |
281 | + out_mmax_path = os.path.join(outtextdir, '%s.mmax' % textname) | |
282 | + copy_mmax(in_mmax_path, out_mmax_path) | |
283 | + | |
284 | + in_words_path = os.path.join(intextdir, '%s_words.xml' % textname) | |
285 | + out_words_path = os.path.join(outtextdir, '%s_words.xml' % textname) | |
286 | + copy_words(in_words_path, out_words_path) | |
287 | + | |
288 | + in_mentions_path = os.path.join(intextdir, '%s_mentions.xml' % textname) | |
289 | + out_mentions_path = os.path.join(outtextdir, '%s_mentions.xml' % textname) | |
290 | + write_mentions(in_mentions_path, out_mentions_path, text) | |
291 | + | |
292 | + | |
293 | +def copy_mmax(src, dest): | |
294 | + shutil.copyfile(src, dest) | |
295 | + | |
296 | + | |
297 | +def copy_words(src, dest): | |
298 | + shutil.copyfile(src, dest) | |
299 | + | |
300 | + | |
301 | +def write_mentions(inpath, outpath, text): | |
302 | + tree = etree.parse(inpath) | |
303 | + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) | |
304 | + | |
305 | + for mnt in mentions: | |
306 | + mnt_set = text.get_mention_set(mnt.attrib['id']) | |
307 | + if mnt_set: | |
308 | + mnt.attrib['mention_group'] = mnt_set | |
309 | + else: | |
310 | + mnt.attrib['mention_group'] = 'empty' | |
311 | + | |
312 | + with open(outpath, 'wb') as output_file: | |
313 | + output_file.write(etree.tostring(tree, pretty_print=True, | |
314 | + xml_declaration=True, encoding='UTF-8', | |
315 | + doctype=u'<!DOCTYPE markables SYSTEM "markables.dtd">')) | |
... | ... |
corneferencer/main.py
0 → 100644
1 | +import os | |
2 | +import sys | |
3 | + | |
4 | +from argparse import ArgumentParser | |
5 | +from natsort import natsorted | |
6 | + | |
7 | +sys.path.append(os.path.abspath(os.path.join('..'))) | |
8 | + | |
9 | +from inout import mmax | |
10 | +from inout.constants import INPUT_FORMATS | |
11 | +from resolvers import resolve | |
12 | +from resolvers.constants import RESOLVERS | |
13 | +from utils import eprint | |
14 | + | |
15 | + | |
16 | +def main(): | |
17 | + args = parse_arguments() | |
18 | + if not args.input: | |
19 | + eprint("Error: Input file(s) not specified!") | |
20 | + elif args.resolver not in RESOLVERS: | |
21 | + eprint("Error: Unknown resolve algorithm!") | |
22 | + elif args.format not in INPUT_FORMATS: | |
23 | + eprint("Error: Unknown input file format!") | |
24 | + else: | |
25 | + process_texts(args.input, args.output, args.format, args.resolver) | |
26 | + | |
27 | + | |
28 | +def parse_arguments(): | |
29 | + parser = ArgumentParser(description='Corneferencer: coreference resolver using neural nets.') | |
30 | + parser.add_argument('-i', '--input', type=str, action='store', | |
31 | + dest='input', default='', | |
32 | + help='input file or dir path') | |
33 | + parser.add_argument('-o', '--output', type=str, action='store', | |
34 | + dest='output', default='', | |
35 | + help='output path; if not specified writes output to standard output') | |
36 | + parser.add_argument('-f', '--format', type=str, action='store', | |
37 | + dest='format', default='mmax', | |
38 | + help='input format; default: mmax') | |
39 | + parser.add_argument('-r', '--resolver', type=str, action='store', | |
40 | + dest='resolver', default='incremental', | |
41 | + help='resolve algorithm; default: incremental; possibilities: %s' | |
42 | + % ', '.join(RESOLVERS)) | |
43 | + | |
44 | + args = parser.parse_args() | |
45 | + return args | |
46 | + | |
47 | + | |
48 | +def process_texts(inpath, outpath, informat, resolver): | |
49 | + if os.path.isdir(inpath): | |
50 | + process_directory(inpath, outpath, informat, resolver) | |
51 | + elif os.path.isfile(inpath): | |
52 | + process_file(inpath, outpath, informat, resolver) | |
53 | + else: | |
54 | + eprint("Error: Specified input does not exist!") | |
55 | + | |
56 | + | |
57 | +def process_directory(inpath, outpath, informat, resolver): | |
58 | + inpath = os.path.abspath(inpath) | |
59 | + outpath = os.path.abspath(outpath) | |
60 | + | |
61 | + files = os.listdir(inpath) | |
62 | + files = natsorted(files) | |
63 | + | |
64 | + for filename in files: | |
65 | + textname = os.path.splitext(os.path.basename(filename))[0] | |
66 | + textoutput = os.path.join(outpath, textname) | |
67 | + textinput = os.path.join(inpath, filename) | |
68 | + process_file(textinput, textoutput, informat, resolver) | |
69 | + | |
70 | + | |
71 | +def process_file(inpath, outpath, informat, resolver): | |
72 | + basename = os.path.basename(inpath) | |
73 | + if informat == 'mmax' and basename.endswith('.mmax'): | |
74 | + print (basename) | |
75 | + text = mmax.read(inpath) | |
76 | + if resolver == 'incremental': | |
77 | + resolve.incremental(text) | |
78 | + elif resolver == 'entity_based': | |
79 | + resolve.entity_based(text) | |
80 | + mmax.write(inpath, outpath, text) | |
81 | + | |
82 | + | |
83 | +if __name__ == '__main__': | |
84 | + main() | |
... | ... |
corneferencer/readers/__init__.py deleted
corneferencer/entities/__init__.py renamed to corneferencer/resolvers/__init__.py
corneferencer/resolvers/constants.py
0 → 100644
1 | +RESOLVERS = ['entity_based', 'incremental'] | |
... | ... |
corneferencer/resolvers/features.py
0 → 100644
1 | +import numpy | |
2 | +import random | |
3 | + | |
4 | +from conf import RANDOM_WORD_VECTORS, W2V_MODEL, W2V_SIZE | |
5 | + | |
6 | + | |
7 | +# mention features | |
8 | +def head_vec(mention): | |
9 | + return list(get_wv(W2V_MODEL, mention.head_base)) | |
10 | + | |
11 | + | |
12 | +def first_word_vec(mention): | |
13 | + return list(get_wv(W2V_MODEL, mention.words[0]['base'])) | |
14 | + | |
15 | + | |
16 | +def last_word_vec(mention): | |
17 | + return list(get_wv(W2V_MODEL, mention.words[-1]['base'])) | |
18 | + | |
19 | + | |
20 | +def first_after_vec(mention): | |
21 | + if len(mention.follow_context) > 0: | |
22 | + vec = list(get_wv(W2V_MODEL, mention.follow_context[0]['base'])) | |
23 | + else: | |
24 | + vec = [0.0] * W2V_SIZE | |
25 | + return vec | |
26 | + | |
27 | + | |
28 | +def second_after_vec(mention): | |
29 | + if len(mention.follow_context) > 1: | |
30 | + vec = list(get_wv(W2V_MODEL, mention.follow_context[1]['base'])) | |
31 | + else: | |
32 | + vec = [0.0] * W2V_SIZE | |
33 | + return vec | |
34 | + | |
35 | + | |
36 | +def first_before_vec(mention): | |
37 | + if len(mention.prec_context) > 0: | |
38 | + vec = list(get_wv(W2V_MODEL, mention.prec_context[-1]['base'])) | |
39 | + else: | |
40 | + vec = [0.0] * W2V_SIZE | |
41 | + return vec | |
42 | + | |
43 | + | |
44 | +def second_before_vec(mention): | |
45 | + if len(mention.prec_context) > 1: | |
46 | + vec = list(get_wv(W2V_MODEL, mention.prec_context[-2]['base'])) | |
47 | + else: | |
48 | + vec = [0.0] * W2V_SIZE | |
49 | + return vec | |
50 | + | |
51 | + | |
52 | +def preceding_context_vec(mention): | |
53 | + return list(get_context_vec(mention.prec_context, W2V_MODEL)) | |
54 | + | |
55 | + | |
56 | +def following_context_vec(mention): | |
57 | + return list(get_context_vec(mention.follow_context, W2V_MODEL)) | |
58 | + | |
59 | + | |
60 | +def mention_vec(mention): | |
61 | + return list(get_context_vec(mention.words, W2V_MODEL)) | |
62 | + | |
63 | + | |
64 | +def sentence_vec(mention): | |
65 | + return list(get_context_vec(mention.sentence, W2V_MODEL)) | |
66 | + | |
67 | + | |
68 | +# pair features | |
69 | +def distances_vec(ante, ana): | |
70 | + vec = [] | |
71 | + | |
72 | + mnts_intersect = pair_intersect(ante, ana) | |
73 | + | |
74 | + words_dist = [0] * 11 | |
75 | + words_bucket = 0 | |
76 | + if mnts_intersect != 1: | |
77 | + words_bucket = get_distance_bucket(ana.start_in_words - ante.end_in_words - 1) | |
78 | + words_dist[words_bucket] = 1 | |
79 | + vec.extend(words_dist) | |
80 | + | |
81 | + mentions_dist = [0] * 11 | |
82 | + mentions_bucket = 0 | |
83 | + if mnts_intersect != 1: | |
84 | + mentions_bucket = get_distance_bucket(ana.position_in_mentions - ante.position_in_mentions - 1) | |
85 | + if words_bucket == 10: | |
86 | + mentions_bucket = 10 | |
87 | + mentions_dist[mentions_bucket] = 1 | |
88 | + vec.extend(mentions_dist) | |
89 | + | |
90 | + vec.append(mnts_intersect) | |
91 | + | |
92 | + return vec | |
93 | + | |
94 | + | |
95 | +def pair_intersect(ante, ana): | |
96 | + for ante_word in ante.words: | |
97 | + for ana_word in ana.words: | |
98 | + if ana_word['id'] == ante_word['id']: | |
99 | + return 1 | |
100 | + return 0 | |
101 | + | |
102 | + | |
103 | +def head_match(ante, ana): | |
104 | + if ante.head_orth.lower() == ana.head_orth.lower(): | |
105 | + return 1 | |
106 | + return 0 | |
107 | + | |
108 | + | |
109 | +def exact_match(ante, ana): | |
110 | + if ante.text.lower() == ana.text.lower(): | |
111 | + return 1 | |
112 | + return 0 | |
113 | + | |
114 | + | |
115 | +def base_match(ante, ana): | |
116 | + if ante.lemmatized_text.lower() == ana.lemmatized_text.lower(): | |
117 | + return 1 | |
118 | + return 0 | |
119 | + | |
120 | + | |
121 | +# supporting functions | |
122 | +def get_wv(model, lemma, use_random_vec=True): | |
123 | + vec = None | |
124 | + if use_random_vec: | |
125 | + vec = random_vec() | |
126 | + try: | |
127 | + vec = model.wv[lemma] | |
128 | + except KeyError: | |
129 | + pass | |
130 | + except TypeError: | |
131 | + pass | |
132 | + return vec | |
133 | + | |
134 | + | |
135 | +def random_vec(): | |
136 | + return numpy.asarray([random.uniform(-0.25, 0.25) for i in range(0, W2V_SIZE)], dtype=numpy.float32) | |
137 | + | |
138 | + | |
139 | +def get_context_vec(words, model): | |
140 | + vec = numpy.zeros(W2V_SIZE, dtype=numpy.float32) | |
141 | + unknown_count = 0 | |
142 | + if len(words) != 0: | |
143 | + for word in words: | |
144 | + word_vec = get_wv(model, word['base'], RANDOM_WORD_VECTORS) | |
145 | + if word_vec is None: | |
146 | + unknown_count += 1 | |
147 | + else: | |
148 | + vec += word_vec | |
149 | + significant_words = len(words) - unknown_count | |
150 | + if significant_words != 0: | |
151 | + vec = vec / float(significant_words) | |
152 | + else: | |
153 | + vec = random_vec() | |
154 | + return vec | |
155 | + | |
156 | + | |
157 | +def get_distance_bucket(distance): | |
158 | + if 0 <= distance <= 4: | |
159 | + return distance | |
160 | + elif 5 <= distance <= 7: | |
161 | + return 5 | |
162 | + elif 8 <= distance <= 15: | |
163 | + return 6 | |
164 | + elif 16 <= distance <= 31: | |
165 | + return 7 | |
166 | + elif 32 <= distance <= 63: | |
167 | + return 8 | |
168 | + elif distance >= 64: | |
169 | + return 9 | |
170 | + return 10 | |
... | ... |
corneferencer/resolvers/resolve.py
0 → 100644
1 | +from conf import NEURAL_MODEL, THRESHOLD | |
2 | +from corneferencer.resolvers.vectors import create_pair_vector | |
3 | + | |
4 | + | |
5 | +# incremental resolve algorithm | |
6 | +def incremental(text): | |
7 | + last_set_id = 1 | |
8 | + for i, ana in enumerate(text.mentions): | |
9 | + if i > 0: | |
10 | + best_prediction = 0.0 | |
11 | + best_ante = None | |
12 | + for ante in text.mentions[:i:-1]: | |
13 | + pair_vec = create_pair_vector(ante, ana) | |
14 | + prediction = NEURAL_MODEL.predict(pair_vec) | |
15 | + accuracy = prediction[0] | |
16 | + if accuracy > THRESHOLD and accuracy > best_prediction: | |
17 | + best_prediction = accuracy | |
18 | + best_ante = ante | |
19 | + if best_ante is not None: | |
20 | + if best_ante.set: | |
21 | + ana.set = best_ante.set | |
22 | + else: | |
23 | + str_set_id = 'set_%d' % last_set_id | |
24 | + best_ante.set = str_set_id | |
25 | + ana.set = str_set_id | |
26 | + last_set_id += 1 | |
27 | + | |
28 | + | |
29 | +# entity based resolve algorithm | |
30 | +def entity_based(text): | |
31 | + sets = [] | |
32 | + last_set_id = 1 | |
33 | + for i, ana in enumerate(text.mentions): | |
34 | + if i > 0: | |
35 | + best_fit = get_best_set(sets, ana) | |
36 | + if best_fit is not None: | |
37 | + ana.set = best_fit['set_id'] | |
38 | + best_fit['mentions'].append(ana) | |
39 | + else: | |
40 | + str_set_id = 'set_%d' % last_set_id | |
41 | + sets.append({'set_id': str_set_id, | |
42 | + 'mentions': [ana]}) | |
43 | + ana.set = str_set_id | |
44 | + last_set_id += 1 | |
45 | + else: | |
46 | + str_set_id = 'set_%d' % last_set_id | |
47 | + sets.append({'set_id': str_set_id, | |
48 | + 'mentions': [ana]}) | |
49 | + ana.set = str_set_id | |
50 | + last_set_id += 1 | |
51 | + | |
52 | + remove_singletons(sets) | |
53 | + | |
54 | + | |
55 | +def get_best_set(sets, ana): | |
56 | + best_prediction = 0.0 | |
57 | + best_set = None | |
58 | + for s in sets: | |
59 | + accuracy = predict_set(s['mentions'], ana) | |
60 | + if accuracy > THRESHOLD and accuracy >= best_prediction: | |
61 | + best_prediction = accuracy | |
62 | + best_set = s | |
63 | + return best_set | |
64 | + | |
65 | + | |
66 | +def predict_set(mentions, ana): | |
67 | + accuracy_sum = 0.0 | |
68 | + for mnt in mentions: | |
69 | + pair_vec = create_pair_vector(mnt, ana) | |
70 | + prediction = NEURAL_MODEL.predict(pair_vec) | |
71 | + accuracy = prediction[0] | |
72 | + accuracy_sum += accuracy | |
73 | + return accuracy_sum / float(len(mentions)) | |
74 | + | |
75 | + | |
76 | +def remove_singletons(sets): | |
77 | + for s in sets: | |
78 | + if len(s['mentions']) == 1: | |
79 | + s['mentions'][0].set = '' | |
... | ... |
corneferencer/resolvers/vectors.py
0 → 100644
1 | +import numpy | |
2 | + | |
3 | +from corneferencer.resolvers import features | |
4 | + | |
5 | +# input_1 to have shape (None, 1126) but got array with shape (1126, 1) | |
6 | +def create_pair_vector(ante, ana): | |
7 | + vec = [] | |
8 | + # ante_features = get_mention_features(ante) | |
9 | + # vec.extend(ante_features) | |
10 | + # ana_features = get_mention_features(ana) | |
11 | + # vec.extend(ana_features) | |
12 | + vec.extend(ante.features) | |
13 | + vec.extend(ana.features) | |
14 | + pair_features = get_pair_features(ante, ana) | |
15 | + vec.extend(pair_features) | |
16 | + return numpy.asarray([vec], dtype=numpy.float32) | |
17 | + | |
18 | + | |
19 | +def get_mention_features(mention): | |
20 | + vec = [] | |
21 | + vec.extend(features.head_vec(mention)) | |
22 | + vec.extend(features.first_word_vec(mention)) | |
23 | + vec.extend(features.last_word_vec(mention)) | |
24 | + vec.extend(features.first_after_vec(mention)) | |
25 | + vec.extend(features.second_after_vec(mention)) | |
26 | + vec.extend(features.first_before_vec(mention)) | |
27 | + vec.extend(features.second_before_vec(mention)) | |
28 | + vec.extend(features.preceding_context_vec(mention)) | |
29 | + vec.extend(features.following_context_vec(mention)) | |
30 | + vec.extend(features.mention_vec(mention)) | |
31 | + vec.extend(features.sentence_vec(mention)) | |
32 | + return vec | |
33 | + | |
34 | + | |
35 | +def get_pair_features(ante, ana): | |
36 | + vec = [] | |
37 | + vec.extend(features.distances_vec(ante, ana)) | |
38 | + vec.append(features.head_match(ante, ana)) | |
39 | + vec.append(features.exact_match(ante, ana)) | |
40 | + vec.append(features.base_match(ante, ana)) | |
41 | + return vec | |
... | ... |
corneferencer/utils.py
0 → 100644
1 | +from __future__ import print_function | |
2 | + | |
3 | +import sys | |
4 | + | |
5 | +from keras.models import Model | |
6 | +from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization | |
7 | + | |
8 | + | |
9 | +def eprint(*args, **kwargs): | |
10 | + print(*args, file=sys.stderr, **kwargs) | |
11 | + | |
12 | + | |
13 | +def initialize_neural_model(number_of_features): | |
14 | + inputs = Input(shape=(number_of_features,)) | |
15 | + output_from_1st_layer = Dense(1000, activation='relu')(inputs) | |
16 | + output_from_1st_layer = Dropout(0.5)(output_from_1st_layer) | |
17 | + output_from_1st_layer = BatchNormalization()(output_from_1st_layer) | |
18 | + output_from_2nd_layer = Dense(500, activation='relu')(output_from_1st_layer) | |
19 | + output_from_2nd_layer = Dropout(0.5)(output_from_2nd_layer) | |
20 | + output_from_2nd_layer = BatchNormalization()(output_from_2nd_layer) | |
21 | + output = Dense(1, activation='sigmoid')(output_from_2nd_layer) | |
22 | + | |
23 | + model = Model(inputs, output) | |
24 | + model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy']) | |
25 | + return model | |
... | ... |
requirements.txt
setup.py deleted