Commit c2871e0ded5bfb23380ab7c041d4237e1a0c8481
1 parent
01a04337
Basic evaluation and data preparation scripts.
Showing
3 changed files
with
790 additions
and
0 deletions
counter.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | |
2 | + | |
3 | +import os | |
4 | + | |
5 | +from lxml import etree | |
6 | +from natsort import natsorted | |
7 | + | |
8 | +from preparator import ANNO_PATH | |
9 | + | |
10 | + | |
11 | +def count_words(): | |
12 | + anno_files = os.listdir(ANNO_PATH) | |
13 | + anno_files = natsorted(anno_files) | |
14 | + for filename in anno_files: | |
15 | + if filename.endswith('.mmax'): | |
16 | + words_count = 0 | |
17 | + textname = filename.replace('.mmax', '') | |
18 | + words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname) | |
19 | + tree = etree.parse(words_path) | |
20 | + for word in tree.xpath("//word"): | |
21 | + if word.attrib['ctag'] != 'interp': | |
22 | + words_count += 1 | |
23 | + print textname, words_count | |
24 | + | |
25 | + | |
26 | +def count_mentions(): | |
27 | + anno_files = os.listdir(ANNO_PATH) | |
28 | + anno_files = natsorted(anno_files) | |
29 | + for filename in anno_files: | |
30 | + if filename.endswith('.mmax'): | |
31 | + textname = filename.replace('.mmax', '') | |
32 | + | |
33 | + mentions_path = os.path.join(ANNO_PATH, '%s_mentions.xml' % textname) | |
34 | + tree = etree.parse(mentions_path) | |
35 | + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) | |
36 | + print textname, len(mentions) | |
... | ... |
preparator.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | |
2 | + | |
3 | +import codecs | |
4 | +import numpy | |
5 | +import os | |
6 | +import random | |
7 | + | |
8 | +from lxml import etree | |
9 | +from itertools import combinations | |
10 | +from natsort import natsorted | |
11 | + | |
12 | +from gensim.models.word2vec import Word2Vec | |
13 | + | |
14 | + | |
15 | +TEST_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'test-prepared')) | |
16 | +TRAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'train-prepared')) | |
17 | + | |
18 | +ANNO_PATH = TEST_PATH | |
19 | +OUT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', | |
20 | + 'test.csv')) | |
21 | +EACH_TEXT_SEPARATELLY = False | |
22 | + | |
23 | +CONTEXT = 5 | |
24 | +W2V_SIZE = 50 | |
25 | +MODEL = os.path.abspath(os.path.join(os.path.dirname(__file__), 'models', | |
26 | + '%d' % W2V_SIZE, | |
27 | + 'w2v_allwiki_nkjpfull_%d.model' % W2V_SIZE)) | |
28 | +POSSIBLE_HEADS = [u'§', u'%', u'*', u'"', u'„', u'&', u'-'] | |
29 | +NEG_PROPORTION = 1 | |
30 | +RANDOM_VECTORS = True | |
31 | + | |
32 | +DEBUG = False | |
33 | +POS_COUNT = 0 | |
34 | +NEG_COUNT = 0 | |
35 | +ALL_WORDS = 0 | |
36 | +UNKNONW_WORDS = 0 | |
37 | + | |
38 | + | |
39 | +def main(): | |
40 | + model = Word2Vec.load(MODEL) | |
41 | + try: | |
42 | + create_data_vectors(model) | |
43 | + finally: | |
44 | + print 'Unknown words: ', UNKNONW_WORDS | |
45 | + print 'All words: ', ALL_WORDS | |
46 | + print 'Positives: ', POS_COUNT | |
47 | + print 'Negatives: ', NEG_COUNT | |
48 | + | |
49 | + | |
50 | +def create_data_vectors(model): | |
51 | + features_file = None | |
52 | + if not EACH_TEXT_SEPARATELLY: | |
53 | + features_file = codecs.open(OUT_PATH, 'wt', 'utf-8') | |
54 | + | |
55 | + anno_files = os.listdir(ANNO_PATH) | |
56 | + anno_files = natsorted(anno_files) | |
57 | + for filename in anno_files: | |
58 | + if filename.endswith('.mmax'): | |
59 | + print '=======> ', filename | |
60 | + textname = filename.replace('.mmax', '') | |
61 | + | |
62 | + mentions_path = os.path.join(ANNO_PATH, '%s_mentions.xml' % textname) | |
63 | + tree = etree.parse(mentions_path) | |
64 | + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) | |
65 | + positives, negatives = diff_mentions(mentions) | |
66 | + | |
67 | + if DEBUG: | |
68 | + print 'Positives:' | |
69 | + print len(positives) | |
70 | + | |
71 | + print 'Negatives:' | |
72 | + print len(negatives) | |
73 | + | |
74 | + words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname) | |
75 | + mentions_dict = markables_level_2_dict(mentions_path, words_path) | |
76 | + | |
77 | + if EACH_TEXT_SEPARATELLY: | |
78 | + text_features_path = os.path.join(OUT_PATH, '%s.csv' % textname) | |
79 | + features_file = codecs.open(text_features_path, 'wt', 'utf-8') | |
80 | + write_features(features_file, positives, negatives, mentions_dict, model, textname) | |
81 | + | |
82 | + if not EACH_TEXT_SEPARATELLY: | |
83 | + features_file.close() | |
84 | + | |
85 | + | |
86 | +def diff_mentions(mentions): | |
87 | + sets, clustered_mensions = get_sets(mentions) | |
88 | + positives = get_positives(sets) | |
89 | + positives, negatives = get_negatives_and_update_positives(clustered_mensions, positives) | |
90 | + if len(negatives) != len(positives) and NEG_PROPORTION == 1: | |
91 | + print u'Niezgodna liczba przypadków pozytywnych i negatywnych!' | |
92 | + return positives, negatives | |
93 | + | |
94 | + | |
95 | +def get_sets(mentions): | |
96 | + sets = {} | |
97 | + clustered_mensions = [] | |
98 | + for mention in mentions: | |
99 | + set_id = mention.attrib['mention_group'] | |
100 | + if set_id == 'empty' or set_id == '' or mention.attrib['mention_head'] in POSSIBLE_HEADS: | |
101 | + pass | |
102 | + elif set_id not in sets: | |
103 | + sets[set_id] = [mention.attrib['span']] | |
104 | + clustered_mensions.append(mention.attrib['span']) | |
105 | + elif set_id in sets: | |
106 | + sets[set_id].append(mention.attrib['span']) | |
107 | + clustered_mensions.append(mention.attrib['span']) | |
108 | + else: | |
109 | + print u'Coś poszło nie tak przy wyszukiwaniu klastrów!' | |
110 | + | |
111 | + sets_to_remove = [] | |
112 | + for set_id in sets: | |
113 | + if len(sets[set_id]) < 2: | |
114 | + sets_to_remove.append(set_id) | |
115 | + if len(sets[set_id]) == 1: | |
116 | + print u'Removing clustered mention: ', sets[set_id][0] | |
117 | + clustered_mensions.remove(sets[set_id][0]) | |
118 | + | |
119 | + for set_id in sets_to_remove: | |
120 | + print u'Removing set: ', set_id | |
121 | + sets.pop(set_id) | |
122 | + | |
123 | + return sets, clustered_mensions | |
124 | + | |
125 | + | |
126 | +def get_positives(sets): | |
127 | + positives = [] | |
128 | + for set_id in sets: | |
129 | + coref_set = sets[set_id] | |
130 | + positives.extend(list(combinations(coref_set, 2))) | |
131 | + return positives | |
132 | + | |
133 | + | |
134 | +def get_negatives_and_update_positives(clustered_mensions, positives): | |
135 | + all_pairs = list(combinations(clustered_mensions, 2)) | |
136 | + all_pairs = set(all_pairs) | |
137 | + negatives = [pair for pair in all_pairs if pair not in positives] | |
138 | + samples_count = NEG_PROPORTION * len(positives) | |
139 | + if samples_count > len(negatives): | |
140 | + samples_count = len(negatives) | |
141 | + if NEG_PROPORTION == 1: | |
142 | + positives = random.sample(set(positives), samples_count) | |
143 | + print u'Więcej przypadków pozytywnych niż negatywnych!' | |
144 | + negatives = random.sample(set(negatives), samples_count) | |
145 | + return positives, negatives | |
146 | + | |
147 | + | |
148 | +def write_features(features_file, positives, negatives, mentions_dict, model, textname): | |
149 | + global POS_COUNT | |
150 | + POS_COUNT += len(positives) | |
151 | + for pair in positives: | |
152 | + pair_features = [] | |
153 | + if DEBUG: | |
154 | + pair_features = ['%s>%s:%s' % (textname, pair[0], pair[1])] | |
155 | + pair_features.extend(get_features(pair, mentions_dict, model)) | |
156 | + pair_features.append(1) | |
157 | + features_file.write(u'%s\n' % u'\t'.join([unicode(feature) for feature in pair_features])) | |
158 | + | |
159 | + global NEG_COUNT | |
160 | + NEG_COUNT += len(negatives) | |
161 | + for pair in negatives: | |
162 | + pair_features = [] | |
163 | + if DEBUG: | |
164 | + pair_features = ['%s>%s:%s' % (textname, pair[0], pair[1])] | |
165 | + pair_features.extend(get_features(pair, mentions_dict, model)) | |
166 | + pair_features.append(0) | |
167 | + features_file.write(u'%s\n' % u'\t'.join([unicode(feature) for feature in pair_features])) | |
168 | + | |
169 | + | |
170 | +def get_features(pair, mentions_dict, model): | |
171 | + features = [] | |
172 | + ante = pair[0] | |
173 | + ana = pair[1] | |
174 | + ante_features = get_mention_features(ante, mentions_dict, model) | |
175 | + features.extend(ante_features) | |
176 | + ana_features = get_mention_features(ana, mentions_dict, model) | |
177 | + features.extend(ana_features) | |
178 | + pair_features = get_pair_features(pair, mentions_dict) | |
179 | + features.extend(pair_features) | |
180 | + return features | |
181 | + | |
182 | + | |
183 | +def get_mention_features(mention_span, mentions_dict, model): | |
184 | + features = [] | |
185 | + mention = get_mention_by_attr(mentions_dict, 'span', mention_span) | |
186 | + | |
187 | + if DEBUG: | |
188 | + features.append(mention['head_base']) | |
189 | + head_vec = get_wv(model, mention['head_base']) | |
190 | + features.extend(list(head_vec)) | |
191 | + | |
192 | + if DEBUG: | |
193 | + features.append(mention['words'][0]['base']) | |
194 | + first_vec = get_wv(model, mention['words'][0]['base']) | |
195 | + features.extend(list(first_vec)) | |
196 | + | |
197 | + if DEBUG: | |
198 | + features.append(mention['words'][-1]['base']) | |
199 | + last_vec = get_wv(model, mention['words'][-1]['base']) | |
200 | + features.extend(list(last_vec)) | |
201 | + | |
202 | + if len(mention['follow_context']) > 0: | |
203 | + if DEBUG: | |
204 | + features.append(mention['follow_context'][0]['base']) | |
205 | + after_1_vec = get_wv(model, mention['follow_context'][0]['base']) | |
206 | + features.extend(list(after_1_vec)) | |
207 | + else: | |
208 | + if DEBUG: | |
209 | + features.append('None') | |
210 | + features.extend([0.0] * W2V_SIZE) | |
211 | + if len(mention['follow_context']) > 1: | |
212 | + if DEBUG: | |
213 | + features.append(mention['follow_context'][1]['base']) | |
214 | + after_2_vec = get_wv(model, mention['follow_context'][1]['base']) | |
215 | + features.extend(list(after_2_vec)) | |
216 | + else: | |
217 | + if DEBUG: | |
218 | + features.append('None') | |
219 | + features.extend([0.0] * W2V_SIZE) | |
220 | + | |
221 | + if len(mention['prec_context']) > 0: | |
222 | + if DEBUG: | |
223 | + features.append(mention['prec_context'][-1]['base']) | |
224 | + prec_1_vec = get_wv(model, mention['prec_context'][-1]['base']) | |
225 | + features.extend(list(prec_1_vec)) | |
226 | + else: | |
227 | + if DEBUG: | |
228 | + features.append('None') | |
229 | + features.extend([0.0] * W2V_SIZE) | |
230 | + if len(mention['prec_context']) > 1: | |
231 | + if DEBUG: | |
232 | + features.append(mention['prec_context'][-2]['base']) | |
233 | + prec_2_vec = get_wv(model, mention['prec_context'][-2]['base']) | |
234 | + features.extend(list(prec_2_vec)) | |
235 | + else: | |
236 | + if DEBUG: | |
237 | + features.append('None') | |
238 | + features.extend([0.0] * W2V_SIZE) | |
239 | + | |
240 | + if DEBUG: | |
241 | + features.append(u' '.join([word['orth'] for word in mention['prec_context']])) | |
242 | + prec_vec = get_context_vec(mention['prec_context'], model) | |
243 | + features.extend(list(prec_vec)) | |
244 | + | |
245 | + if DEBUG: | |
246 | + features.append(u' '.join([word['orth'] for word in mention['follow_context']])) | |
247 | + follow_vec = get_context_vec(mention['follow_context'], model) | |
248 | + features.extend(list(follow_vec)) | |
249 | + | |
250 | + if DEBUG: | |
251 | + features.append(u' '.join([word['orth'] for word in mention['words']])) | |
252 | + mention_vec = get_context_vec(mention['words'], model) | |
253 | + features.extend(list(mention_vec)) | |
254 | + | |
255 | + if DEBUG: | |
256 | + features.append(u' '.join([word['orth'] for word in mention['sentence']])) | |
257 | + sentence_vec = get_context_vec(mention['sentence'], model) | |
258 | + features.extend(list(sentence_vec)) | |
259 | + | |
260 | + return features | |
261 | + | |
262 | + | |
263 | +def get_wv(model, lemma, random=True): | |
264 | + global ALL_WORDS | |
265 | + global UNKNONW_WORDS | |
266 | + vec = None | |
267 | + if random: | |
268 | + vec = random_vec() | |
269 | + ALL_WORDS += 1 | |
270 | + try: | |
271 | + vec = model.wv[lemma] | |
272 | + except KeyError: | |
273 | + UNKNONW_WORDS += 1 | |
274 | + return vec | |
275 | + | |
276 | + | |
277 | +def random_vec(): | |
278 | + return numpy.asarray([random.uniform(-0.25, 0.25) for i in range(0, W2V_SIZE)], dtype=numpy.float32) | |
279 | + | |
280 | + | |
281 | +def get_context_vec(words, model): | |
282 | + vec = numpy.zeros(W2V_SIZE, dtype=numpy.float32) | |
283 | + unknown_count = 0 | |
284 | + if len(words) != 0: | |
285 | + for word in words: | |
286 | + word_vec = get_wv(model, word['base'], RANDOM_VECTORS) | |
287 | + if word_vec is None: | |
288 | + unknown_count += 1 | |
289 | + else: | |
290 | + vec += word_vec | |
291 | + significant_words = len(words) - unknown_count | |
292 | + if significant_words != 0: | |
293 | + vec = vec/float(significant_words) | |
294 | + else: | |
295 | + vec = random_vec() | |
296 | + return vec | |
297 | + | |
298 | + | |
299 | +def get_pair_features(pair, mentions_dict): | |
300 | + ante = get_mention_by_attr(mentions_dict, 'span', pair[0]) | |
301 | + ana = get_mention_by_attr(mentions_dict, 'span', pair[1]) | |
302 | + | |
303 | + features = [] | |
304 | + mnts_intersect = pair_intersect(ante, ana) | |
305 | + | |
306 | + words_dist = [0] * 11 | |
307 | + words_bucket = 0 | |
308 | + if mnts_intersect != 1: | |
309 | + words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words'] - 1) | |
310 | + if DEBUG: | |
311 | + features.append('Bucket %d' % words_bucket) | |
312 | + words_dist[words_bucket] = 1 | |
313 | + features.extend(words_dist) | |
314 | + | |
315 | + mentions_dist = [0] * 11 | |
316 | + mentions_bucket = 0 | |
317 | + if mnts_intersect != 1: | |
318 | + mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions'] - 1) | |
319 | + if words_bucket == 10: | |
320 | + mentions_bucket = 10 | |
321 | + if DEBUG: | |
322 | + features.append('Bucket %d' % mentions_bucket) | |
323 | + mentions_dist[mentions_bucket] = 1 | |
324 | + features.extend(mentions_dist) | |
325 | + | |
326 | + if DEBUG: | |
327 | + features.append('Other features') | |
328 | + features.append(mnts_intersect) | |
329 | + features.append(head_match(ante, ana)) | |
330 | + features.append(exact_match(ante, ana)) | |
331 | + features.append(base_match(ante, ana)) | |
332 | + | |
333 | + if len(mentions_dict) > 100: | |
334 | + features.append(1) | |
335 | + else: | |
336 | + features.append(0) | |
337 | + | |
338 | + return features | |
339 | + | |
340 | + | |
341 | +def get_distance_bucket(distance): | |
342 | + if distance >= 0 and distance <= 4: | |
343 | + return distance | |
344 | + elif distance >= 5 and distance <= 7: | |
345 | + return 5 | |
346 | + elif distance >= 8 and distance <= 15: | |
347 | + return 6 | |
348 | + elif distance >= 16 and distance <= 31: | |
349 | + return 7 | |
350 | + elif distance >= 32 and distance <= 63: | |
351 | + return 8 | |
352 | + elif distance >= 64: | |
353 | + return 9 | |
354 | + else: | |
355 | + print u'Coś poszło nie tak przy kubełkowaniu!!' | |
356 | + return 10 | |
357 | + | |
358 | + | |
359 | +def pair_intersect(ante, ana): | |
360 | + for ante_word in ante['words']: | |
361 | + for ana_word in ana['words']: | |
362 | + if ana_word['id'] == ante_word['id']: | |
363 | + return 1 | |
364 | + return 0 | |
365 | + | |
366 | + | |
367 | +def head_match(ante, ana): | |
368 | + if ante['head_orth'].lower() == ana['head_orth'].lower(): | |
369 | + return 1 | |
370 | + return 0 | |
371 | + | |
372 | + | |
373 | +def exact_match(ante, ana): | |
374 | + if ante['text'].lower() == ana['text'].lower(): | |
375 | + return 1 | |
376 | + return 0 | |
377 | + | |
378 | + | |
379 | +def base_match(ante, ana): | |
380 | + if ante['lemmatized_text'].lower() == ana['lemmatized_text'].lower(): | |
381 | + return 1 | |
382 | + return 0 | |
383 | + | |
384 | + | |
385 | +def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/NameSpaces/mention'): | |
386 | + markables_dicts = [] | |
387 | + markables_tree = etree.parse(markables_path) | |
388 | + markables = markables_tree.xpath("//ns:markable", namespaces={'ns': namespace}) | |
389 | + | |
390 | + words = get_words(words_path) | |
391 | + | |
392 | + for idx, markable in enumerate(markables): | |
393 | + span = markable.attrib['span'] | |
394 | + if not get_mention_by_attr(markables_dicts, 'span', span): | |
395 | + | |
396 | + dominant = '' | |
397 | + if 'dominant' in markable.attrib: | |
398 | + dominant = markable.attrib['dominant'] | |
399 | + | |
400 | + head_orth = markable.attrib['mention_head'] | |
401 | + if head_orth not in POSSIBLE_HEADS: | |
402 | + mention_words = span_to_words(span, words) | |
403 | + | |
404 | + prec_context, follow_context, sentence, mnt_start_position, mnt_end_position = get_context(mention_words, words) | |
405 | + | |
406 | + head_base = get_head_base(head_orth, mention_words) | |
407 | + markables_dicts.append({'id': markable.attrib['id'], | |
408 | + 'set': markable.attrib['mention_group'], | |
409 | + 'text': span_to_text(span, words, 'orth'), | |
410 | + 'lemmatized_text': span_to_text(span, words, 'base'), | |
411 | + 'words': mention_words, | |
412 | + 'span': span, | |
413 | + 'head_orth': head_orth, | |
414 | + 'head_base': head_base, | |
415 | + 'dominant': dominant, | |
416 | + 'node': markable, | |
417 | + 'prec_context': prec_context, | |
418 | + 'follow_context': follow_context, | |
419 | + 'sentence': sentence, | |
420 | + 'position_in_mentions': idx, | |
421 | + 'start_in_words': mnt_start_position, | |
422 | + 'end_in_words': mnt_end_position}) | |
423 | + else: | |
424 | + print 'Zduplikowana wzmianka: %s' % span | |
425 | + | |
426 | + return markables_dicts | |
427 | + | |
428 | + | |
429 | +def get_context(mention_words, words): | |
430 | + prec_context = [] | |
431 | + follow_context = [] | |
432 | + sentence = [] | |
433 | + mnt_start_position = -1 | |
434 | + first_word = mention_words[0] | |
435 | + last_word = mention_words[-1] | |
436 | + for idx, word in enumerate(words): | |
437 | + if word['id'] == first_word['id']: | |
438 | + prec_context = get_prec_context(idx, words) | |
439 | + mnt_start_position = get_mention_start(first_word, words) | |
440 | + if word['id'] == last_word['id']: | |
441 | + follow_context = get_follow_context(idx, words) | |
442 | + sentence = get_sentence(idx, words) | |
443 | + mnt_end_position = get_mention_end(last_word, words) | |
444 | + break | |
445 | + return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position | |
446 | + | |
447 | + | |
448 | +def get_prec_context(mention_start, words): | |
449 | + context = [] | |
450 | + context_start = mention_start - 1 | |
451 | + while context_start >= 0: | |
452 | + if not word_to_ignore(words[context_start]): | |
453 | + context.append(words[context_start]) | |
454 | + if len(context) == CONTEXT: | |
455 | + break | |
456 | + context_start -= 1 | |
457 | + context.reverse() | |
458 | + return context | |
459 | + | |
460 | + | |
461 | +def get_mention_start(first_word, words): | |
462 | + start = 0 | |
463 | + for word in words: | |
464 | + if not word_to_ignore(word): | |
465 | + start += 1 | |
466 | + if word['id'] == first_word['id']: | |
467 | + break | |
468 | + return start | |
469 | + | |
470 | + | |
471 | +def get_mention_end(last_word, words): | |
472 | + end = 0 | |
473 | + for word in words: | |
474 | + if not word_to_ignore(word): | |
475 | + end += 1 | |
476 | + if word['id'] == last_word['id']: | |
477 | + break | |
478 | + return end | |
479 | + | |
480 | + | |
481 | +def get_follow_context(mention_end, words): | |
482 | + context = [] | |
483 | + context_end = mention_end + 1 | |
484 | + while context_end < len(words): | |
485 | + if not word_to_ignore(words[context_end]): | |
486 | + context.append(words[context_end]) | |
487 | + if len(context) == CONTEXT: | |
488 | + break | |
489 | + context_end += 1 | |
490 | + return context | |
491 | + | |
492 | + | |
493 | +def get_sentence(word_idx, words): | |
494 | + sentence_start = get_sentence_start(words, word_idx) | |
495 | + sentence_end = get_sentence_end(words, word_idx) | |
496 | + sentence = [word for word in words[sentence_start:sentence_end+1] if not word_to_ignore(word)] | |
497 | + return sentence | |
498 | + | |
499 | + | |
500 | +def get_sentence_start(words, word_idx): | |
501 | + search_start = word_idx | |
502 | + while word_idx >= 0: | |
503 | + if words[word_idx]['lastinsent'] and search_start != word_idx: | |
504 | + return word_idx+1 | |
505 | + word_idx -= 1 | |
506 | + return 0 | |
507 | + | |
508 | + | |
509 | +def get_sentence_end(words, word_idx): | |
510 | + while word_idx < len(words): | |
511 | + if words[word_idx]['lastinsent']: | |
512 | + return word_idx | |
513 | + word_idx += 1 | |
514 | + return len(words) - 1 | |
515 | + | |
516 | + | |
517 | +def get_head_base(head_orth, words): | |
518 | + for word in words: | |
519 | + if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth: | |
520 | + return word['base'] | |
521 | + return None | |
522 | + | |
523 | + | |
524 | +def get_words(filepath): | |
525 | + tree = etree.parse(filepath) | |
526 | + words = [] | |
527 | + for word in tree.xpath("//word"): | |
528 | + hasnps = False | |
529 | + if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true': | |
530 | + hasnps = True | |
531 | + lastinsent = False | |
532 | + if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': | |
533 | + lastinsent = True | |
534 | + words.append({'id': word.attrib['id'], | |
535 | + 'orth': word.text, | |
536 | + 'base': word.attrib['base'], | |
537 | + 'hasnps': hasnps, | |
538 | + 'lastinsent': lastinsent, | |
539 | + 'ctag': word.attrib['ctag']}) | |
540 | + return words | |
541 | + | |
542 | + | |
543 | +def get_mention_by_attr(mentions, attr_name, value): | |
544 | + for mention in mentions: | |
545 | + if mention[attr_name] == value: | |
546 | + return mention | |
547 | + return None | |
548 | + | |
549 | + | |
550 | +def get_mention_index_by_attr(mentions, attr_name, value): | |
551 | + for idx, mention in enumerate(mentions): | |
552 | + if mention[attr_name] == value: | |
553 | + return idx | |
554 | + return None | |
555 | + | |
556 | + | |
557 | +def span_to_text(span, words, form): | |
558 | + fragments = span.split(',') | |
559 | + mention_parts = [] | |
560 | + for fragment in fragments: | |
561 | + mention_parts.append(fragment_to_text(fragment, words, form)) | |
562 | + return u' [...] '.join(mention_parts) | |
563 | + | |
564 | + | |
565 | +def fragment_to_text(fragment, words, form): | |
566 | + if '..' in fragment: | |
567 | + text = get_multiword_text(fragment, words, form) | |
568 | + else: | |
569 | + text = get_one_word_text(fragment, words, form) | |
570 | + return text | |
571 | + | |
572 | + | |
573 | +def get_multiword_text(fragment, words, form): | |
574 | + mention_parts = [] | |
575 | + boundaries = fragment.split('..') | |
576 | + start_id = boundaries[0] | |
577 | + end_id = boundaries[1] | |
578 | + in_string = False | |
579 | + for word in words: | |
580 | + if word['id'] == start_id: | |
581 | + in_string = True | |
582 | + if in_string and not word_to_ignore(word): | |
583 | + mention_parts.append(word) | |
584 | + if word['id'] == end_id: | |
585 | + break | |
586 | + return to_text(mention_parts, form) | |
587 | + | |
588 | + | |
589 | +def to_text(words, form): | |
590 | + text = '' | |
591 | + for idx, word in enumerate(words): | |
592 | + if word['hasnps'] or idx == 0: | |
593 | + text += word[form] | |
594 | + else: | |
595 | + text += u' %s' % word[form] | |
596 | + return text | |
597 | + | |
598 | + | |
599 | +def get_one_word_text(word_id, words, form): | |
600 | + this_word = (word for word in words if word['id'] == word_id).next() | |
601 | + if word_to_ignore(this_word): | |
602 | + print this_word | |
603 | + return this_word[form] | |
604 | + | |
605 | + | |
606 | +def span_to_words(span, words): | |
607 | + fragments = span.split(',') | |
608 | + mention_parts = [] | |
609 | + for fragment in fragments: | |
610 | + mention_parts.extend(fragment_to_words(fragment, words)) | |
611 | + return mention_parts | |
612 | + | |
613 | + | |
614 | +def fragment_to_words(fragment, words): | |
615 | + mention_parts = [] | |
616 | + if '..' in fragment: | |
617 | + mention_parts.extend(get_multiword(fragment, words)) | |
618 | + else: | |
619 | + mention_parts.extend(get_word(fragment, words)) | |
620 | + return mention_parts | |
621 | + | |
622 | + | |
623 | +def get_multiword(fragment, words): | |
624 | + mention_parts = [] | |
625 | + boundaries = fragment.split('..') | |
626 | + start_id = boundaries[0] | |
627 | + end_id = boundaries[1] | |
628 | + in_string = False | |
629 | + for word in words: | |
630 | + if word['id'] == start_id: | |
631 | + in_string = True | |
632 | + if in_string and not word_to_ignore(word): | |
633 | + mention_parts.append(word) | |
634 | + if word['id'] == end_id: | |
635 | + break | |
636 | + return mention_parts | |
637 | + | |
638 | + | |
639 | +def get_word(word_id, words): | |
640 | + for word in words: | |
641 | + if word['id'] == word_id: | |
642 | + if not word_to_ignore(word): | |
643 | + return [word] | |
644 | + else: | |
645 | + return [] | |
646 | + return [] | |
647 | + | |
648 | + | |
649 | +def word_to_ignore(word): | |
650 | + if word['ctag'] == 'interp': | |
651 | + return True | |
652 | + return False | |
653 | + | |
654 | + | |
655 | +if __name__ == '__main__': | |
656 | + main() | |
... | ... |
resolver.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | |
2 | + | |
3 | +import codecs | |
4 | +import os | |
5 | + | |
6 | +import numpy as np | |
7 | + | |
8 | +from natsort import natsorted | |
9 | + | |
10 | +from keras.models import Model | |
11 | +from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization | |
12 | +from keras.optimizers import SGD, Adam | |
13 | + | |
14 | +IN_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', | |
15 | + 'prepared_text_files')) | |
16 | +OUT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', | |
17 | + 'metrics.csv')) | |
18 | + | |
19 | +MODEL = os.path.abspath(os.path.join(os.path.dirname(__file__), 'weights_2017_05_10.h5')) | |
20 | + | |
21 | + | |
22 | +NUMBER_OF_FEATURES = 1126 | |
23 | + | |
24 | + | |
25 | +def main(): | |
26 | + resolve_files() | |
27 | + | |
28 | + | |
29 | +def resolve_files(): | |
30 | + metrics_file = codecs.open(OUT_PATH, 'w', 'utf-8') | |
31 | + write_labels(metrics_file) | |
32 | + | |
33 | + anno_files = os.listdir(IN_PATH) | |
34 | + anno_files = natsorted(anno_files) | |
35 | + for filename in anno_files: | |
36 | + print (filename) | |
37 | + textname = filename.replace('.csv', '') | |
38 | + text_data_path = os.path.join(IN_PATH, filename) | |
39 | + resolve(textname, text_data_path, metrics_file) | |
40 | + | |
41 | + metrics_file.close() | |
42 | + | |
43 | + | |
44 | +def write_labels(metrics_file): | |
45 | + metrics_file.write('Text\tAccuracy\tPrecision\tRecall\tF1\tPairs\n') | |
46 | + | |
47 | + | |
48 | +def resolve(textname, text_data_path, metrics_file): | |
49 | + raw_data = open(text_data_path, 'rt') | |
50 | + test_data = np.loadtxt(raw_data, delimiter='\t') | |
51 | + test_set = test_data[:, 0:NUMBER_OF_FEATURES] | |
52 | + test_labels = test_data[:, NUMBER_OF_FEATURES] # last column consists of labels | |
53 | + | |
54 | + inputs = Input(shape=(NUMBER_OF_FEATURES,)) | |
55 | + output_from_1st_layer = Dense(1000, activation='relu')(inputs) | |
56 | + output_from_1st_layer = Dropout(0.5)(output_from_1st_layer) | |
57 | + output_from_1st_layer = BatchNormalization()(output_from_1st_layer) | |
58 | + output_from_2nd_layer = Dense(500, activation='relu')(output_from_1st_layer) | |
59 | + output_from_2nd_layer = Dropout(0.5)(output_from_2nd_layer) | |
60 | + output_from_2nd_layer = BatchNormalization()(output_from_2nd_layer) | |
61 | + output = Dense(1, activation='sigmoid')(output_from_2nd_layer) | |
62 | + | |
63 | + model = Model(inputs, output) | |
64 | + model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy']) | |
65 | + model.load_weights(MODEL) | |
66 | + | |
67 | + predictions = model.predict(test_set) | |
68 | + | |
69 | + calc_metrics(textname, test_set, test_labels, predictions, metrics_file) | |
70 | + | |
71 | + | |
72 | +def calc_metrics(textname, test_set, test_labels, predictions, metrics_file): | |
73 | + true_positives = 0.0 | |
74 | + false_positives = 0.0 | |
75 | + true_negatives = 0.0 | |
76 | + false_negatives = 0.0 | |
77 | + | |
78 | + for i in range(len(test_set)): | |
79 | + if (predictions[i] < 0.5 and test_labels[i] == 0): true_negatives += 1 | |
80 | + if (predictions[i] < 0.5 and test_labels[i] == 1): false_negatives += 1 | |
81 | + if (predictions[i] >= 0.5 and test_labels[i] == 1): true_positives += 1 | |
82 | + if (predictions[i] >= 0.5 and test_labels[i] == 0): false_positives += 1 | |
83 | + | |
84 | + accuracy = (true_positives + true_negatives) / len(test_set) | |
85 | + precision = true_positives / (true_positives + false_positives) | |
86 | + recall = true_positives / (true_positives + false_negatives) | |
87 | + f1 = 2 * (precision * recall) / (precision + recall) | |
88 | + | |
89 | + metrics_file.write('%s\t%s\t%s\t%s\t%s\t%s\n' % (textname, | |
90 | + repr(accuracy), | |
91 | + repr(precision), | |
92 | + repr(recall), | |
93 | + repr(f1), | |
94 | + repr(len(test_set)))) | |
95 | + | |
96 | + | |
97 | +if __name__ == '__main__': | |
98 | + main() | |
... | ... |