Commit c2871e0ded5bfb23380ab7c041d4237e1a0c8481
1 parent
01a04337
Basic evaluation and data preparation scripts.
Showing
3 changed files
with
790 additions
and
0 deletions
counter.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | + | ||
3 | +import os | ||
4 | + | ||
5 | +from lxml import etree | ||
6 | +from natsort import natsorted | ||
7 | + | ||
8 | +from preparator import ANNO_PATH | ||
9 | + | ||
10 | + | ||
11 | +def count_words(): | ||
12 | + anno_files = os.listdir(ANNO_PATH) | ||
13 | + anno_files = natsorted(anno_files) | ||
14 | + for filename in anno_files: | ||
15 | + if filename.endswith('.mmax'): | ||
16 | + words_count = 0 | ||
17 | + textname = filename.replace('.mmax', '') | ||
18 | + words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname) | ||
19 | + tree = etree.parse(words_path) | ||
20 | + for word in tree.xpath("//word"): | ||
21 | + if word.attrib['ctag'] != 'interp': | ||
22 | + words_count += 1 | ||
23 | + print textname, words_count | ||
24 | + | ||
25 | + | ||
26 | +def count_mentions(): | ||
27 | + anno_files = os.listdir(ANNO_PATH) | ||
28 | + anno_files = natsorted(anno_files) | ||
29 | + for filename in anno_files: | ||
30 | + if filename.endswith('.mmax'): | ||
31 | + textname = filename.replace('.mmax', '') | ||
32 | + | ||
33 | + mentions_path = os.path.join(ANNO_PATH, '%s_mentions.xml' % textname) | ||
34 | + tree = etree.parse(mentions_path) | ||
35 | + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) | ||
36 | + print textname, len(mentions) |
preparator.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | + | ||
3 | +import codecs | ||
4 | +import numpy | ||
5 | +import os | ||
6 | +import random | ||
7 | + | ||
8 | +from lxml import etree | ||
9 | +from itertools import combinations | ||
10 | +from natsort import natsorted | ||
11 | + | ||
12 | +from gensim.models.word2vec import Word2Vec | ||
13 | + | ||
14 | + | ||
15 | +TEST_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'test-prepared')) | ||
16 | +TRAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'train-prepared')) | ||
17 | + | ||
18 | +ANNO_PATH = TEST_PATH | ||
19 | +OUT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', | ||
20 | + 'test.csv')) | ||
21 | +EACH_TEXT_SEPARATELLY = False | ||
22 | + | ||
23 | +CONTEXT = 5 | ||
24 | +W2V_SIZE = 50 | ||
25 | +MODEL = os.path.abspath(os.path.join(os.path.dirname(__file__), 'models', | ||
26 | + '%d' % W2V_SIZE, | ||
27 | + 'w2v_allwiki_nkjpfull_%d.model' % W2V_SIZE)) | ||
28 | +POSSIBLE_HEADS = [u'§', u'%', u'*', u'"', u'„', u'&', u'-'] | ||
29 | +NEG_PROPORTION = 1 | ||
30 | +RANDOM_VECTORS = True | ||
31 | + | ||
32 | +DEBUG = False | ||
33 | +POS_COUNT = 0 | ||
34 | +NEG_COUNT = 0 | ||
35 | +ALL_WORDS = 0 | ||
36 | +UNKNONW_WORDS = 0 | ||
37 | + | ||
38 | + | ||
39 | +def main(): | ||
40 | + model = Word2Vec.load(MODEL) | ||
41 | + try: | ||
42 | + create_data_vectors(model) | ||
43 | + finally: | ||
44 | + print 'Unknown words: ', UNKNONW_WORDS | ||
45 | + print 'All words: ', ALL_WORDS | ||
46 | + print 'Positives: ', POS_COUNT | ||
47 | + print 'Negatives: ', NEG_COUNT | ||
48 | + | ||
49 | + | ||
50 | +def create_data_vectors(model): | ||
51 | + features_file = None | ||
52 | + if not EACH_TEXT_SEPARATELLY: | ||
53 | + features_file = codecs.open(OUT_PATH, 'wt', 'utf-8') | ||
54 | + | ||
55 | + anno_files = os.listdir(ANNO_PATH) | ||
56 | + anno_files = natsorted(anno_files) | ||
57 | + for filename in anno_files: | ||
58 | + if filename.endswith('.mmax'): | ||
59 | + print '=======> ', filename | ||
60 | + textname = filename.replace('.mmax', '') | ||
61 | + | ||
62 | + mentions_path = os.path.join(ANNO_PATH, '%s_mentions.xml' % textname) | ||
63 | + tree = etree.parse(mentions_path) | ||
64 | + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) | ||
65 | + positives, negatives = diff_mentions(mentions) | ||
66 | + | ||
67 | + if DEBUG: | ||
68 | + print 'Positives:' | ||
69 | + print len(positives) | ||
70 | + | ||
71 | + print 'Negatives:' | ||
72 | + print len(negatives) | ||
73 | + | ||
74 | + words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname) | ||
75 | + mentions_dict = markables_level_2_dict(mentions_path, words_path) | ||
76 | + | ||
77 | + if EACH_TEXT_SEPARATELLY: | ||
78 | + text_features_path = os.path.join(OUT_PATH, '%s.csv' % textname) | ||
79 | + features_file = codecs.open(text_features_path, 'wt', 'utf-8') | ||
80 | + write_features(features_file, positives, negatives, mentions_dict, model, textname) | ||
81 | + | ||
82 | + if not EACH_TEXT_SEPARATELLY: | ||
83 | + features_file.close() | ||
84 | + | ||
85 | + | ||
86 | +def diff_mentions(mentions): | ||
87 | + sets, clustered_mensions = get_sets(mentions) | ||
88 | + positives = get_positives(sets) | ||
89 | + positives, negatives = get_negatives_and_update_positives(clustered_mensions, positives) | ||
90 | + if len(negatives) != len(positives) and NEG_PROPORTION == 1: | ||
91 | + print u'Niezgodna liczba przypadków pozytywnych i negatywnych!' | ||
92 | + return positives, negatives | ||
93 | + | ||
94 | + | ||
95 | +def get_sets(mentions): | ||
96 | + sets = {} | ||
97 | + clustered_mensions = [] | ||
98 | + for mention in mentions: | ||
99 | + set_id = mention.attrib['mention_group'] | ||
100 | + if set_id == 'empty' or set_id == '' or mention.attrib['mention_head'] in POSSIBLE_HEADS: | ||
101 | + pass | ||
102 | + elif set_id not in sets: | ||
103 | + sets[set_id] = [mention.attrib['span']] | ||
104 | + clustered_mensions.append(mention.attrib['span']) | ||
105 | + elif set_id in sets: | ||
106 | + sets[set_id].append(mention.attrib['span']) | ||
107 | + clustered_mensions.append(mention.attrib['span']) | ||
108 | + else: | ||
109 | + print u'Coś poszło nie tak przy wyszukiwaniu klastrów!' | ||
110 | + | ||
111 | + sets_to_remove = [] | ||
112 | + for set_id in sets: | ||
113 | + if len(sets[set_id]) < 2: | ||
114 | + sets_to_remove.append(set_id) | ||
115 | + if len(sets[set_id]) == 1: | ||
116 | + print u'Removing clustered mention: ', sets[set_id][0] | ||
117 | + clustered_mensions.remove(sets[set_id][0]) | ||
118 | + | ||
119 | + for set_id in sets_to_remove: | ||
120 | + print u'Removing set: ', set_id | ||
121 | + sets.pop(set_id) | ||
122 | + | ||
123 | + return sets, clustered_mensions | ||
124 | + | ||
125 | + | ||
126 | +def get_positives(sets): | ||
127 | + positives = [] | ||
128 | + for set_id in sets: | ||
129 | + coref_set = sets[set_id] | ||
130 | + positives.extend(list(combinations(coref_set, 2))) | ||
131 | + return positives | ||
132 | + | ||
133 | + | ||
134 | +def get_negatives_and_update_positives(clustered_mensions, positives): | ||
135 | + all_pairs = list(combinations(clustered_mensions, 2)) | ||
136 | + all_pairs = set(all_pairs) | ||
137 | + negatives = [pair for pair in all_pairs if pair not in positives] | ||
138 | + samples_count = NEG_PROPORTION * len(positives) | ||
139 | + if samples_count > len(negatives): | ||
140 | + samples_count = len(negatives) | ||
141 | + if NEG_PROPORTION == 1: | ||
142 | + positives = random.sample(set(positives), samples_count) | ||
143 | + print u'Więcej przypadków pozytywnych niż negatywnych!' | ||
144 | + negatives = random.sample(set(negatives), samples_count) | ||
145 | + return positives, negatives | ||
146 | + | ||
147 | + | ||
148 | +def write_features(features_file, positives, negatives, mentions_dict, model, textname): | ||
149 | + global POS_COUNT | ||
150 | + POS_COUNT += len(positives) | ||
151 | + for pair in positives: | ||
152 | + pair_features = [] | ||
153 | + if DEBUG: | ||
154 | + pair_features = ['%s>%s:%s' % (textname, pair[0], pair[1])] | ||
155 | + pair_features.extend(get_features(pair, mentions_dict, model)) | ||
156 | + pair_features.append(1) | ||
157 | + features_file.write(u'%s\n' % u'\t'.join([unicode(feature) for feature in pair_features])) | ||
158 | + | ||
159 | + global NEG_COUNT | ||
160 | + NEG_COUNT += len(negatives) | ||
161 | + for pair in negatives: | ||
162 | + pair_features = [] | ||
163 | + if DEBUG: | ||
164 | + pair_features = ['%s>%s:%s' % (textname, pair[0], pair[1])] | ||
165 | + pair_features.extend(get_features(pair, mentions_dict, model)) | ||
166 | + pair_features.append(0) | ||
167 | + features_file.write(u'%s\n' % u'\t'.join([unicode(feature) for feature in pair_features])) | ||
168 | + | ||
169 | + | ||
170 | +def get_features(pair, mentions_dict, model): | ||
171 | + features = [] | ||
172 | + ante = pair[0] | ||
173 | + ana = pair[1] | ||
174 | + ante_features = get_mention_features(ante, mentions_dict, model) | ||
175 | + features.extend(ante_features) | ||
176 | + ana_features = get_mention_features(ana, mentions_dict, model) | ||
177 | + features.extend(ana_features) | ||
178 | + pair_features = get_pair_features(pair, mentions_dict) | ||
179 | + features.extend(pair_features) | ||
180 | + return features | ||
181 | + | ||
182 | + | ||
183 | +def get_mention_features(mention_span, mentions_dict, model): | ||
184 | + features = [] | ||
185 | + mention = get_mention_by_attr(mentions_dict, 'span', mention_span) | ||
186 | + | ||
187 | + if DEBUG: | ||
188 | + features.append(mention['head_base']) | ||
189 | + head_vec = get_wv(model, mention['head_base']) | ||
190 | + features.extend(list(head_vec)) | ||
191 | + | ||
192 | + if DEBUG: | ||
193 | + features.append(mention['words'][0]['base']) | ||
194 | + first_vec = get_wv(model, mention['words'][0]['base']) | ||
195 | + features.extend(list(first_vec)) | ||
196 | + | ||
197 | + if DEBUG: | ||
198 | + features.append(mention['words'][-1]['base']) | ||
199 | + last_vec = get_wv(model, mention['words'][-1]['base']) | ||
200 | + features.extend(list(last_vec)) | ||
201 | + | ||
202 | + if len(mention['follow_context']) > 0: | ||
203 | + if DEBUG: | ||
204 | + features.append(mention['follow_context'][0]['base']) | ||
205 | + after_1_vec = get_wv(model, mention['follow_context'][0]['base']) | ||
206 | + features.extend(list(after_1_vec)) | ||
207 | + else: | ||
208 | + if DEBUG: | ||
209 | + features.append('None') | ||
210 | + features.extend([0.0] * W2V_SIZE) | ||
211 | + if len(mention['follow_context']) > 1: | ||
212 | + if DEBUG: | ||
213 | + features.append(mention['follow_context'][1]['base']) | ||
214 | + after_2_vec = get_wv(model, mention['follow_context'][1]['base']) | ||
215 | + features.extend(list(after_2_vec)) | ||
216 | + else: | ||
217 | + if DEBUG: | ||
218 | + features.append('None') | ||
219 | + features.extend([0.0] * W2V_SIZE) | ||
220 | + | ||
221 | + if len(mention['prec_context']) > 0: | ||
222 | + if DEBUG: | ||
223 | + features.append(mention['prec_context'][-1]['base']) | ||
224 | + prec_1_vec = get_wv(model, mention['prec_context'][-1]['base']) | ||
225 | + features.extend(list(prec_1_vec)) | ||
226 | + else: | ||
227 | + if DEBUG: | ||
228 | + features.append('None') | ||
229 | + features.extend([0.0] * W2V_SIZE) | ||
230 | + if len(mention['prec_context']) > 1: | ||
231 | + if DEBUG: | ||
232 | + features.append(mention['prec_context'][-2]['base']) | ||
233 | + prec_2_vec = get_wv(model, mention['prec_context'][-2]['base']) | ||
234 | + features.extend(list(prec_2_vec)) | ||
235 | + else: | ||
236 | + if DEBUG: | ||
237 | + features.append('None') | ||
238 | + features.extend([0.0] * W2V_SIZE) | ||
239 | + | ||
240 | + if DEBUG: | ||
241 | + features.append(u' '.join([word['orth'] for word in mention['prec_context']])) | ||
242 | + prec_vec = get_context_vec(mention['prec_context'], model) | ||
243 | + features.extend(list(prec_vec)) | ||
244 | + | ||
245 | + if DEBUG: | ||
246 | + features.append(u' '.join([word['orth'] for word in mention['follow_context']])) | ||
247 | + follow_vec = get_context_vec(mention['follow_context'], model) | ||
248 | + features.extend(list(follow_vec)) | ||
249 | + | ||
250 | + if DEBUG: | ||
251 | + features.append(u' '.join([word['orth'] for word in mention['words']])) | ||
252 | + mention_vec = get_context_vec(mention['words'], model) | ||
253 | + features.extend(list(mention_vec)) | ||
254 | + | ||
255 | + if DEBUG: | ||
256 | + features.append(u' '.join([word['orth'] for word in mention['sentence']])) | ||
257 | + sentence_vec = get_context_vec(mention['sentence'], model) | ||
258 | + features.extend(list(sentence_vec)) | ||
259 | + | ||
260 | + return features | ||
261 | + | ||
262 | + | ||
263 | +def get_wv(model, lemma, random=True): | ||
264 | + global ALL_WORDS | ||
265 | + global UNKNONW_WORDS | ||
266 | + vec = None | ||
267 | + if random: | ||
268 | + vec = random_vec() | ||
269 | + ALL_WORDS += 1 | ||
270 | + try: | ||
271 | + vec = model.wv[lemma] | ||
272 | + except KeyError: | ||
273 | + UNKNONW_WORDS += 1 | ||
274 | + return vec | ||
275 | + | ||
276 | + | ||
277 | +def random_vec(): | ||
278 | + return numpy.asarray([random.uniform(-0.25, 0.25) for i in range(0, W2V_SIZE)], dtype=numpy.float32) | ||
279 | + | ||
280 | + | ||
281 | +def get_context_vec(words, model): | ||
282 | + vec = numpy.zeros(W2V_SIZE, dtype=numpy.float32) | ||
283 | + unknown_count = 0 | ||
284 | + if len(words) != 0: | ||
285 | + for word in words: | ||
286 | + word_vec = get_wv(model, word['base'], RANDOM_VECTORS) | ||
287 | + if word_vec is None: | ||
288 | + unknown_count += 1 | ||
289 | + else: | ||
290 | + vec += word_vec | ||
291 | + significant_words = len(words) - unknown_count | ||
292 | + if significant_words != 0: | ||
293 | + vec = vec/float(significant_words) | ||
294 | + else: | ||
295 | + vec = random_vec() | ||
296 | + return vec | ||
297 | + | ||
298 | + | ||
299 | +def get_pair_features(pair, mentions_dict): | ||
300 | + ante = get_mention_by_attr(mentions_dict, 'span', pair[0]) | ||
301 | + ana = get_mention_by_attr(mentions_dict, 'span', pair[1]) | ||
302 | + | ||
303 | + features = [] | ||
304 | + mnts_intersect = pair_intersect(ante, ana) | ||
305 | + | ||
306 | + words_dist = [0] * 11 | ||
307 | + words_bucket = 0 | ||
308 | + if mnts_intersect != 1: | ||
309 | + words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words'] - 1) | ||
310 | + if DEBUG: | ||
311 | + features.append('Bucket %d' % words_bucket) | ||
312 | + words_dist[words_bucket] = 1 | ||
313 | + features.extend(words_dist) | ||
314 | + | ||
315 | + mentions_dist = [0] * 11 | ||
316 | + mentions_bucket = 0 | ||
317 | + if mnts_intersect != 1: | ||
318 | + mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions'] - 1) | ||
319 | + if words_bucket == 10: | ||
320 | + mentions_bucket = 10 | ||
321 | + if DEBUG: | ||
322 | + features.append('Bucket %d' % mentions_bucket) | ||
323 | + mentions_dist[mentions_bucket] = 1 | ||
324 | + features.extend(mentions_dist) | ||
325 | + | ||
326 | + if DEBUG: | ||
327 | + features.append('Other features') | ||
328 | + features.append(mnts_intersect) | ||
329 | + features.append(head_match(ante, ana)) | ||
330 | + features.append(exact_match(ante, ana)) | ||
331 | + features.append(base_match(ante, ana)) | ||
332 | + | ||
333 | + if len(mentions_dict) > 100: | ||
334 | + features.append(1) | ||
335 | + else: | ||
336 | + features.append(0) | ||
337 | + | ||
338 | + return features | ||
339 | + | ||
340 | + | ||
341 | +def get_distance_bucket(distance): | ||
342 | + if distance >= 0 and distance <= 4: | ||
343 | + return distance | ||
344 | + elif distance >= 5 and distance <= 7: | ||
345 | + return 5 | ||
346 | + elif distance >= 8 and distance <= 15: | ||
347 | + return 6 | ||
348 | + elif distance >= 16 and distance <= 31: | ||
349 | + return 7 | ||
350 | + elif distance >= 32 and distance <= 63: | ||
351 | + return 8 | ||
352 | + elif distance >= 64: | ||
353 | + return 9 | ||
354 | + else: | ||
355 | + print u'Coś poszło nie tak przy kubełkowaniu!!' | ||
356 | + return 10 | ||
357 | + | ||
358 | + | ||
359 | +def pair_intersect(ante, ana): | ||
360 | + for ante_word in ante['words']: | ||
361 | + for ana_word in ana['words']: | ||
362 | + if ana_word['id'] == ante_word['id']: | ||
363 | + return 1 | ||
364 | + return 0 | ||
365 | + | ||
366 | + | ||
367 | +def head_match(ante, ana): | ||
368 | + if ante['head_orth'].lower() == ana['head_orth'].lower(): | ||
369 | + return 1 | ||
370 | + return 0 | ||
371 | + | ||
372 | + | ||
373 | +def exact_match(ante, ana): | ||
374 | + if ante['text'].lower() == ana['text'].lower(): | ||
375 | + return 1 | ||
376 | + return 0 | ||
377 | + | ||
378 | + | ||
379 | +def base_match(ante, ana): | ||
380 | + if ante['lemmatized_text'].lower() == ana['lemmatized_text'].lower(): | ||
381 | + return 1 | ||
382 | + return 0 | ||
383 | + | ||
384 | + | ||
385 | +def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/NameSpaces/mention'): | ||
386 | + markables_dicts = [] | ||
387 | + markables_tree = etree.parse(markables_path) | ||
388 | + markables = markables_tree.xpath("//ns:markable", namespaces={'ns': namespace}) | ||
389 | + | ||
390 | + words = get_words(words_path) | ||
391 | + | ||
392 | + for idx, markable in enumerate(markables): | ||
393 | + span = markable.attrib['span'] | ||
394 | + if not get_mention_by_attr(markables_dicts, 'span', span): | ||
395 | + | ||
396 | + dominant = '' | ||
397 | + if 'dominant' in markable.attrib: | ||
398 | + dominant = markable.attrib['dominant'] | ||
399 | + | ||
400 | + head_orth = markable.attrib['mention_head'] | ||
401 | + if head_orth not in POSSIBLE_HEADS: | ||
402 | + mention_words = span_to_words(span, words) | ||
403 | + | ||
404 | + prec_context, follow_context, sentence, mnt_start_position, mnt_end_position = get_context(mention_words, words) | ||
405 | + | ||
406 | + head_base = get_head_base(head_orth, mention_words) | ||
407 | + markables_dicts.append({'id': markable.attrib['id'], | ||
408 | + 'set': markable.attrib['mention_group'], | ||
409 | + 'text': span_to_text(span, words, 'orth'), | ||
410 | + 'lemmatized_text': span_to_text(span, words, 'base'), | ||
411 | + 'words': mention_words, | ||
412 | + 'span': span, | ||
413 | + 'head_orth': head_orth, | ||
414 | + 'head_base': head_base, | ||
415 | + 'dominant': dominant, | ||
416 | + 'node': markable, | ||
417 | + 'prec_context': prec_context, | ||
418 | + 'follow_context': follow_context, | ||
419 | + 'sentence': sentence, | ||
420 | + 'position_in_mentions': idx, | ||
421 | + 'start_in_words': mnt_start_position, | ||
422 | + 'end_in_words': mnt_end_position}) | ||
423 | + else: | ||
424 | + print 'Zduplikowana wzmianka: %s' % span | ||
425 | + | ||
426 | + return markables_dicts | ||
427 | + | ||
428 | + | ||
429 | +def get_context(mention_words, words): | ||
430 | + prec_context = [] | ||
431 | + follow_context = [] | ||
432 | + sentence = [] | ||
433 | + mnt_start_position = -1 | ||
434 | + first_word = mention_words[0] | ||
435 | + last_word = mention_words[-1] | ||
436 | + for idx, word in enumerate(words): | ||
437 | + if word['id'] == first_word['id']: | ||
438 | + prec_context = get_prec_context(idx, words) | ||
439 | + mnt_start_position = get_mention_start(first_word, words) | ||
440 | + if word['id'] == last_word['id']: | ||
441 | + follow_context = get_follow_context(idx, words) | ||
442 | + sentence = get_sentence(idx, words) | ||
443 | + mnt_end_position = get_mention_end(last_word, words) | ||
444 | + break | ||
445 | + return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position | ||
446 | + | ||
447 | + | ||
448 | +def get_prec_context(mention_start, words): | ||
449 | + context = [] | ||
450 | + context_start = mention_start - 1 | ||
451 | + while context_start >= 0: | ||
452 | + if not word_to_ignore(words[context_start]): | ||
453 | + context.append(words[context_start]) | ||
454 | + if len(context) == CONTEXT: | ||
455 | + break | ||
456 | + context_start -= 1 | ||
457 | + context.reverse() | ||
458 | + return context | ||
459 | + | ||
460 | + | ||
461 | +def get_mention_start(first_word, words): | ||
462 | + start = 0 | ||
463 | + for word in words: | ||
464 | + if not word_to_ignore(word): | ||
465 | + start += 1 | ||
466 | + if word['id'] == first_word['id']: | ||
467 | + break | ||
468 | + return start | ||
469 | + | ||
470 | + | ||
471 | +def get_mention_end(last_word, words): | ||
472 | + end = 0 | ||
473 | + for word in words: | ||
474 | + if not word_to_ignore(word): | ||
475 | + end += 1 | ||
476 | + if word['id'] == last_word['id']: | ||
477 | + break | ||
478 | + return end | ||
479 | + | ||
480 | + | ||
481 | +def get_follow_context(mention_end, words): | ||
482 | + context = [] | ||
483 | + context_end = mention_end + 1 | ||
484 | + while context_end < len(words): | ||
485 | + if not word_to_ignore(words[context_end]): | ||
486 | + context.append(words[context_end]) | ||
487 | + if len(context) == CONTEXT: | ||
488 | + break | ||
489 | + context_end += 1 | ||
490 | + return context | ||
491 | + | ||
492 | + | ||
493 | +def get_sentence(word_idx, words): | ||
494 | + sentence_start = get_sentence_start(words, word_idx) | ||
495 | + sentence_end = get_sentence_end(words, word_idx) | ||
496 | + sentence = [word for word in words[sentence_start:sentence_end+1] if not word_to_ignore(word)] | ||
497 | + return sentence | ||
498 | + | ||
499 | + | ||
500 | +def get_sentence_start(words, word_idx): | ||
501 | + search_start = word_idx | ||
502 | + while word_idx >= 0: | ||
503 | + if words[word_idx]['lastinsent'] and search_start != word_idx: | ||
504 | + return word_idx+1 | ||
505 | + word_idx -= 1 | ||
506 | + return 0 | ||
507 | + | ||
508 | + | ||
509 | +def get_sentence_end(words, word_idx): | ||
510 | + while word_idx < len(words): | ||
511 | + if words[word_idx]['lastinsent']: | ||
512 | + return word_idx | ||
513 | + word_idx += 1 | ||
514 | + return len(words) - 1 | ||
515 | + | ||
516 | + | ||
517 | +def get_head_base(head_orth, words): | ||
518 | + for word in words: | ||
519 | + if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth: | ||
520 | + return word['base'] | ||
521 | + return None | ||
522 | + | ||
523 | + | ||
524 | +def get_words(filepath): | ||
525 | + tree = etree.parse(filepath) | ||
526 | + words = [] | ||
527 | + for word in tree.xpath("//word"): | ||
528 | + hasnps = False | ||
529 | + if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true': | ||
530 | + hasnps = True | ||
531 | + lastinsent = False | ||
532 | + if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': | ||
533 | + lastinsent = True | ||
534 | + words.append({'id': word.attrib['id'], | ||
535 | + 'orth': word.text, | ||
536 | + 'base': word.attrib['base'], | ||
537 | + 'hasnps': hasnps, | ||
538 | + 'lastinsent': lastinsent, | ||
539 | + 'ctag': word.attrib['ctag']}) | ||
540 | + return words | ||
541 | + | ||
542 | + | ||
543 | +def get_mention_by_attr(mentions, attr_name, value): | ||
544 | + for mention in mentions: | ||
545 | + if mention[attr_name] == value: | ||
546 | + return mention | ||
547 | + return None | ||
548 | + | ||
549 | + | ||
550 | +def get_mention_index_by_attr(mentions, attr_name, value): | ||
551 | + for idx, mention in enumerate(mentions): | ||
552 | + if mention[attr_name] == value: | ||
553 | + return idx | ||
554 | + return None | ||
555 | + | ||
556 | + | ||
557 | +def span_to_text(span, words, form): | ||
558 | + fragments = span.split(',') | ||
559 | + mention_parts = [] | ||
560 | + for fragment in fragments: | ||
561 | + mention_parts.append(fragment_to_text(fragment, words, form)) | ||
562 | + return u' [...] '.join(mention_parts) | ||
563 | + | ||
564 | + | ||
565 | +def fragment_to_text(fragment, words, form): | ||
566 | + if '..' in fragment: | ||
567 | + text = get_multiword_text(fragment, words, form) | ||
568 | + else: | ||
569 | + text = get_one_word_text(fragment, words, form) | ||
570 | + return text | ||
571 | + | ||
572 | + | ||
573 | +def get_multiword_text(fragment, words, form): | ||
574 | + mention_parts = [] | ||
575 | + boundaries = fragment.split('..') | ||
576 | + start_id = boundaries[0] | ||
577 | + end_id = boundaries[1] | ||
578 | + in_string = False | ||
579 | + for word in words: | ||
580 | + if word['id'] == start_id: | ||
581 | + in_string = True | ||
582 | + if in_string and not word_to_ignore(word): | ||
583 | + mention_parts.append(word) | ||
584 | + if word['id'] == end_id: | ||
585 | + break | ||
586 | + return to_text(mention_parts, form) | ||
587 | + | ||
588 | + | ||
589 | +def to_text(words, form): | ||
590 | + text = '' | ||
591 | + for idx, word in enumerate(words): | ||
592 | + if word['hasnps'] or idx == 0: | ||
593 | + text += word[form] | ||
594 | + else: | ||
595 | + text += u' %s' % word[form] | ||
596 | + return text | ||
597 | + | ||
598 | + | ||
599 | +def get_one_word_text(word_id, words, form): | ||
600 | + this_word = (word for word in words if word['id'] == word_id).next() | ||
601 | + if word_to_ignore(this_word): | ||
602 | + print this_word | ||
603 | + return this_word[form] | ||
604 | + | ||
605 | + | ||
606 | +def span_to_words(span, words): | ||
607 | + fragments = span.split(',') | ||
608 | + mention_parts = [] | ||
609 | + for fragment in fragments: | ||
610 | + mention_parts.extend(fragment_to_words(fragment, words)) | ||
611 | + return mention_parts | ||
612 | + | ||
613 | + | ||
614 | +def fragment_to_words(fragment, words): | ||
615 | + mention_parts = [] | ||
616 | + if '..' in fragment: | ||
617 | + mention_parts.extend(get_multiword(fragment, words)) | ||
618 | + else: | ||
619 | + mention_parts.extend(get_word(fragment, words)) | ||
620 | + return mention_parts | ||
621 | + | ||
622 | + | ||
623 | +def get_multiword(fragment, words): | ||
624 | + mention_parts = [] | ||
625 | + boundaries = fragment.split('..') | ||
626 | + start_id = boundaries[0] | ||
627 | + end_id = boundaries[1] | ||
628 | + in_string = False | ||
629 | + for word in words: | ||
630 | + if word['id'] == start_id: | ||
631 | + in_string = True | ||
632 | + if in_string and not word_to_ignore(word): | ||
633 | + mention_parts.append(word) | ||
634 | + if word['id'] == end_id: | ||
635 | + break | ||
636 | + return mention_parts | ||
637 | + | ||
638 | + | ||
639 | +def get_word(word_id, words): | ||
640 | + for word in words: | ||
641 | + if word['id'] == word_id: | ||
642 | + if not word_to_ignore(word): | ||
643 | + return [word] | ||
644 | + else: | ||
645 | + return [] | ||
646 | + return [] | ||
647 | + | ||
648 | + | ||
649 | +def word_to_ignore(word): | ||
650 | + if word['ctag'] == 'interp': | ||
651 | + return True | ||
652 | + return False | ||
653 | + | ||
654 | + | ||
655 | +if __name__ == '__main__': | ||
656 | + main() |
resolver.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | + | ||
3 | +import codecs | ||
4 | +import os | ||
5 | + | ||
6 | +import numpy as np | ||
7 | + | ||
8 | +from natsort import natsorted | ||
9 | + | ||
10 | +from keras.models import Model | ||
11 | +from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization | ||
12 | +from keras.optimizers import SGD, Adam | ||
13 | + | ||
14 | +IN_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', | ||
15 | + 'prepared_text_files')) | ||
16 | +OUT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', | ||
17 | + 'metrics.csv')) | ||
18 | + | ||
19 | +MODEL = os.path.abspath(os.path.join(os.path.dirname(__file__), 'weights_2017_05_10.h5')) | ||
20 | + | ||
21 | + | ||
22 | +NUMBER_OF_FEATURES = 1126 | ||
23 | + | ||
24 | + | ||
25 | +def main(): | ||
26 | + resolve_files() | ||
27 | + | ||
28 | + | ||
29 | +def resolve_files(): | ||
30 | + metrics_file = codecs.open(OUT_PATH, 'w', 'utf-8') | ||
31 | + write_labels(metrics_file) | ||
32 | + | ||
33 | + anno_files = os.listdir(IN_PATH) | ||
34 | + anno_files = natsorted(anno_files) | ||
35 | + for filename in anno_files: | ||
36 | + print (filename) | ||
37 | + textname = filename.replace('.csv', '') | ||
38 | + text_data_path = os.path.join(IN_PATH, filename) | ||
39 | + resolve(textname, text_data_path, metrics_file) | ||
40 | + | ||
41 | + metrics_file.close() | ||
42 | + | ||
43 | + | ||
44 | +def write_labels(metrics_file): | ||
45 | + metrics_file.write('Text\tAccuracy\tPrecision\tRecall\tF1\tPairs\n') | ||
46 | + | ||
47 | + | ||
48 | +def resolve(textname, text_data_path, metrics_file): | ||
49 | + raw_data = open(text_data_path, 'rt') | ||
50 | + test_data = np.loadtxt(raw_data, delimiter='\t') | ||
51 | + test_set = test_data[:, 0:NUMBER_OF_FEATURES] | ||
52 | + test_labels = test_data[:, NUMBER_OF_FEATURES] # last column consists of labels | ||
53 | + | ||
54 | + inputs = Input(shape=(NUMBER_OF_FEATURES,)) | ||
55 | + output_from_1st_layer = Dense(1000, activation='relu')(inputs) | ||
56 | + output_from_1st_layer = Dropout(0.5)(output_from_1st_layer) | ||
57 | + output_from_1st_layer = BatchNormalization()(output_from_1st_layer) | ||
58 | + output_from_2nd_layer = Dense(500, activation='relu')(output_from_1st_layer) | ||
59 | + output_from_2nd_layer = Dropout(0.5)(output_from_2nd_layer) | ||
60 | + output_from_2nd_layer = BatchNormalization()(output_from_2nd_layer) | ||
61 | + output = Dense(1, activation='sigmoid')(output_from_2nd_layer) | ||
62 | + | ||
63 | + model = Model(inputs, output) | ||
64 | + model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy']) | ||
65 | + model.load_weights(MODEL) | ||
66 | + | ||
67 | + predictions = model.predict(test_set) | ||
68 | + | ||
69 | + calc_metrics(textname, test_set, test_labels, predictions, metrics_file) | ||
70 | + | ||
71 | + | ||
72 | +def calc_metrics(textname, test_set, test_labels, predictions, metrics_file): | ||
73 | + true_positives = 0.0 | ||
74 | + false_positives = 0.0 | ||
75 | + true_negatives = 0.0 | ||
76 | + false_negatives = 0.0 | ||
77 | + | ||
78 | + for i in range(len(test_set)): | ||
79 | + if (predictions[i] < 0.5 and test_labels[i] == 0): true_negatives += 1 | ||
80 | + if (predictions[i] < 0.5 and test_labels[i] == 1): false_negatives += 1 | ||
81 | + if (predictions[i] >= 0.5 and test_labels[i] == 1): true_positives += 1 | ||
82 | + if (predictions[i] >= 0.5 and test_labels[i] == 0): false_positives += 1 | ||
83 | + | ||
84 | + accuracy = (true_positives + true_negatives) / len(test_set) | ||
85 | + precision = true_positives / (true_positives + false_positives) | ||
86 | + recall = true_positives / (true_positives + false_negatives) | ||
87 | + f1 = 2 * (precision * recall) / (precision + recall) | ||
88 | + | ||
89 | + metrics_file.write('%s\t%s\t%s\t%s\t%s\t%s\n' % (textname, | ||
90 | + repr(accuracy), | ||
91 | + repr(precision), | ||
92 | + repr(recall), | ||
93 | + repr(f1), | ||
94 | + repr(len(test_set)))) | ||
95 | + | ||
96 | + | ||
97 | +if __name__ == '__main__': | ||
98 | + main() |