Commit 04c45e2d8290995034f03db5126026ea08041da0
1 parent
c2871e0d
Added new features to feature vector.
Showing
1 changed file
with
179 additions
and
19 deletions
preparator.py
@@ -14,10 +14,11 @@ from gensim.models.word2vec import Word2Vec | @@ -14,10 +14,11 @@ from gensim.models.word2vec import Word2Vec | ||
14 | 14 | ||
15 | TEST_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'test-prepared')) | 15 | TEST_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'test-prepared')) |
16 | TRAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'train-prepared')) | 16 | TRAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'train-prepared')) |
17 | +FREQ_300M_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'freq', 'base.lst')) | ||
17 | 18 | ||
18 | ANNO_PATH = TEST_PATH | 19 | ANNO_PATH = TEST_PATH |
19 | OUT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', | 20 | OUT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', |
20 | - 'test.csv')) | 21 | + 'test-20170627.csv')) |
21 | EACH_TEXT_SEPARATELLY = False | 22 | EACH_TEXT_SEPARATELLY = False |
22 | 23 | ||
23 | CONTEXT = 5 | 24 | CONTEXT = 5 |
@@ -25,7 +26,12 @@ W2V_SIZE = 50 | @@ -25,7 +26,12 @@ W2V_SIZE = 50 | ||
25 | MODEL = os.path.abspath(os.path.join(os.path.dirname(__file__), 'models', | 26 | MODEL = os.path.abspath(os.path.join(os.path.dirname(__file__), 'models', |
26 | '%d' % W2V_SIZE, | 27 | '%d' % W2V_SIZE, |
27 | 'w2v_allwiki_nkjpfull_%d.model' % W2V_SIZE)) | 28 | 'w2v_allwiki_nkjpfull_%d.model' % W2V_SIZE)) |
29 | + | ||
30 | +NOUN_TAGS = ['subst', 'ger', 'depr'] | ||
31 | +PPRON_TAGS = ['ppron12', 'ppron3'] | ||
32 | +ZERO_TAGS = ['fin', 'praet', 'bedzie', 'impt', 'winien', 'aglt'] | ||
28 | POSSIBLE_HEADS = [u'§', u'%', u'*', u'"', u'„', u'&', u'-'] | 33 | POSSIBLE_HEADS = [u'§', u'%', u'*', u'"', u'„', u'&', u'-'] |
34 | + | ||
29 | NEG_PROPORTION = 1 | 35 | NEG_PROPORTION = 1 |
30 | RANDOM_VECTORS = True | 36 | RANDOM_VECTORS = True |
31 | 37 | ||
@@ -38,8 +44,9 @@ UNKNONW_WORDS = 0 | @@ -38,8 +44,9 @@ UNKNONW_WORDS = 0 | ||
38 | 44 | ||
39 | def main(): | 45 | def main(): |
40 | model = Word2Vec.load(MODEL) | 46 | model = Word2Vec.load(MODEL) |
47 | + freq_list = load_freq_list(FREQ_300M_PATH) | ||
41 | try: | 48 | try: |
42 | - create_data_vectors(model) | 49 | + create_data_vectors(model, freq_list) |
43 | finally: | 50 | finally: |
44 | print 'Unknown words: ', UNKNONW_WORDS | 51 | print 'Unknown words: ', UNKNONW_WORDS |
45 | print 'All words: ', ALL_WORDS | 52 | print 'All words: ', ALL_WORDS |
@@ -47,7 +54,20 @@ def main(): | @@ -47,7 +54,20 @@ def main(): | ||
47 | print 'Negatives: ', NEG_COUNT | 54 | print 'Negatives: ', NEG_COUNT |
48 | 55 | ||
49 | 56 | ||
50 | -def create_data_vectors(model): | 57 | +def load_freq_list(freq_path): |
58 | + freq_list = {} | ||
59 | + with codecs.open(freq_path, 'r', 'utf-8') as freq_file: | ||
60 | + lines = freq_file.readlines() | ||
61 | + for line in lines: | ||
62 | + line_parts = line.split() | ||
63 | + freq = int(line_parts[0]) | ||
64 | + base = line_parts[1] | ||
65 | + if base not in freq_list: | ||
66 | + freq_list[base] = freq | ||
67 | + return freq_list | ||
68 | + | ||
69 | + | ||
70 | +def create_data_vectors(model, freq_list): | ||
51 | features_file = None | 71 | features_file = None |
52 | if not EACH_TEXT_SEPARATELLY: | 72 | if not EACH_TEXT_SEPARATELLY: |
53 | features_file = codecs.open(OUT_PATH, 'wt', 'utf-8') | 73 | features_file = codecs.open(OUT_PATH, 'wt', 'utf-8') |
@@ -72,7 +92,7 @@ def create_data_vectors(model): | @@ -72,7 +92,7 @@ def create_data_vectors(model): | ||
72 | print len(negatives) | 92 | print len(negatives) |
73 | 93 | ||
74 | words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname) | 94 | words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname) |
75 | - mentions_dict = markables_level_2_dict(mentions_path, words_path) | 95 | + mentions_dict = markables_level_2_dict(mentions_path, words_path, freq_list) |
76 | 96 | ||
77 | if EACH_TEXT_SEPARATELLY: | 97 | if EACH_TEXT_SEPARATELLY: |
78 | text_features_path = os.path.join(OUT_PATH, '%s.csv' % textname) | 98 | text_features_path = os.path.join(OUT_PATH, '%s.csv' % textname) |
@@ -185,8 +205,8 @@ def get_mention_features(mention_span, mentions_dict, model): | @@ -185,8 +205,8 @@ def get_mention_features(mention_span, mentions_dict, model): | ||
185 | mention = get_mention_by_attr(mentions_dict, 'span', mention_span) | 205 | mention = get_mention_by_attr(mentions_dict, 'span', mention_span) |
186 | 206 | ||
187 | if DEBUG: | 207 | if DEBUG: |
188 | - features.append(mention['head_base']) | ||
189 | - head_vec = get_wv(model, mention['head_base']) | 208 | + features.append(mention['head']['base']) |
209 | + head_vec = get_wv(model, mention['head']['base']) | ||
190 | features.extend(list(head_vec)) | 210 | features.extend(list(head_vec)) |
191 | 211 | ||
192 | if DEBUG: | 212 | if DEBUG: |
@@ -257,9 +277,25 @@ def get_mention_features(mention_span, mentions_dict, model): | @@ -257,9 +277,25 @@ def get_mention_features(mention_span, mentions_dict, model): | ||
257 | sentence_vec = get_context_vec(mention['sentence'], model) | 277 | sentence_vec = get_context_vec(mention['sentence'], model) |
258 | features.extend(list(sentence_vec)) | 278 | features.extend(list(sentence_vec)) |
259 | 279 | ||
280 | + # cechy uzupelniajace | ||
281 | + features.extend(mention_type(mention)) | ||
282 | + | ||
260 | return features | 283 | return features |
261 | 284 | ||
262 | 285 | ||
286 | +def mention_type(mention): | ||
287 | + type_vec = [0] * 4 | ||
288 | + if mention['head']['ctag'] in NOUN_TAGS: | ||
289 | + type_vec[0] = 1 | ||
290 | + elif mention['head']['ctag'] in PPRON_TAGS: | ||
291 | + type_vec[1] = 1 | ||
292 | + elif mention['head']['ctag'] in ZERO_TAGS: | ||
293 | + type_vec[2] = 1 | ||
294 | + else: | ||
295 | + type_vec[3] = 1 | ||
296 | + return type_vec | ||
297 | + | ||
298 | + | ||
263 | def get_wv(model, lemma, random=True): | 299 | def get_wv(model, lemma, random=True): |
264 | global ALL_WORDS | 300 | global ALL_WORDS |
265 | global UNKNONW_WORDS | 301 | global UNKNONW_WORDS |
@@ -330,10 +366,14 @@ def get_pair_features(pair, mentions_dict): | @@ -330,10 +366,14 @@ def get_pair_features(pair, mentions_dict): | ||
330 | features.append(exact_match(ante, ana)) | 366 | features.append(exact_match(ante, ana)) |
331 | features.append(base_match(ante, ana)) | 367 | features.append(base_match(ante, ana)) |
332 | 368 | ||
333 | - if len(mentions_dict) > 100: | ||
334 | - features.append(1) | ||
335 | - else: | ||
336 | - features.append(0) | 369 | + # cechy uzupelniajace |
370 | + features.append(ante_contains_rarest_from_ana(ante, ana)) | ||
371 | + features.extend(agreement(ante, ana, 'gender')) | ||
372 | + features.extend(agreement(ante, ana, 'number')) | ||
373 | + features.extend(agreement(ante, ana, 'person')) | ||
374 | + features.append(is_acronym(ante, ana)) | ||
375 | + features.append(same_sentence(ante, ana)) | ||
376 | + features.append(same_paragraph(ante, ana)) | ||
337 | 377 | ||
338 | return features | 378 | return features |
339 | 379 | ||
@@ -382,7 +422,58 @@ def base_match(ante, ana): | @@ -382,7 +422,58 @@ def base_match(ante, ana): | ||
382 | return 0 | 422 | return 0 |
383 | 423 | ||
384 | 424 | ||
385 | -def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/NameSpaces/mention'): | 425 | +def ante_contains_rarest_from_ana(ante, ana): |
426 | + ana_rarest = ana['rarest'] | ||
427 | + for word in ante['words']: | ||
428 | + if word['base'] == ana_rarest['base']: | ||
429 | + return 1 | ||
430 | + return 0 | ||
431 | + | ||
432 | + | ||
433 | +def agreement(ante, ana, tag_name): | ||
434 | + agr_vec = [0] * 3 | ||
435 | + if ante['head'][tag_name] == 'unk' or ana['head'][tag_name] == 'unk': | ||
436 | + agr_vec[2] = 1 | ||
437 | + elif ante['head'][tag_name] == ana['head'][tag_name]: | ||
438 | + agr_vec[0] = 1 | ||
439 | + else: | ||
440 | + agr_vec[1] = 1 | ||
441 | + return agr_vec | ||
442 | + | ||
443 | + | ||
444 | +def is_acronym(ante, ana): | ||
445 | + if ana['text'].upper() == ana['text']: | ||
446 | + return check_one_way_acronym(ana['text'], ante['text']) | ||
447 | + if ante['text'].upper() == ante['text']: | ||
448 | + return check_one_way_acronym(ante['text'], ana['text']); | ||
449 | + return 0; | ||
450 | + | ||
451 | + | ||
452 | +def check_one_way_acronym(acronym, expression): | ||
453 | + initials = u'' | ||
454 | + for expr1 in expression.split('-'): | ||
455 | + for expr2 in expr1.split(): | ||
456 | + expr2 = expr2.strip() | ||
457 | + if expr2: | ||
458 | + initials += unicode(expr2[0]).upper() | ||
459 | + if acronym == initials: | ||
460 | + return 1; | ||
461 | + return 0; | ||
462 | + | ||
463 | + | ||
464 | +def same_sentence(ante, ana): | ||
465 | + if ante['sentence_id'] == ana['sentence_id']: | ||
466 | + return 1 | ||
467 | + return 0 | ||
468 | + | ||
469 | + | ||
470 | +def same_paragraph(ante, ana): | ||
471 | + if ante['paragraph_id'] == ana['paragraph_id']: | ||
472 | + return 1 | ||
473 | + return 0 | ||
474 | + | ||
475 | + | ||
476 | +def markables_level_2_dict(markables_path, words_path, freq_list, namespace='www.eml.org/NameSpaces/mention'): | ||
386 | markables_dicts = [] | 477 | markables_dicts = [] |
387 | markables_tree = etree.parse(markables_path) | 478 | markables_tree = etree.parse(markables_path) |
388 | markables = markables_tree.xpath("//ns:markable", namespaces={'ns': namespace}) | 479 | markables = markables_tree.xpath("//ns:markable", namespaces={'ns': namespace}) |
@@ -401,9 +492,9 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na | @@ -401,9 +492,9 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na | ||
401 | if head_orth not in POSSIBLE_HEADS: | 492 | if head_orth not in POSSIBLE_HEADS: |
402 | mention_words = span_to_words(span, words) | 493 | mention_words = span_to_words(span, words) |
403 | 494 | ||
404 | - prec_context, follow_context, sentence, mnt_start_position, mnt_end_position = get_context(mention_words, words) | 495 | + prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, paragraph_id, sentence_id = get_context(mention_words, words) |
405 | 496 | ||
406 | - head_base = get_head_base(head_orth, mention_words) | 497 | + head = get_head(head_orth, mention_words) |
407 | markables_dicts.append({'id': markable.attrib['id'], | 498 | markables_dicts.append({'id': markable.attrib['id'], |
408 | 'set': markable.attrib['mention_group'], | 499 | 'set': markable.attrib['mention_group'], |
409 | 'text': span_to_text(span, words, 'orth'), | 500 | 'text': span_to_text(span, words, 'orth'), |
@@ -411,7 +502,7 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na | @@ -411,7 +502,7 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na | ||
411 | 'words': mention_words, | 502 | 'words': mention_words, |
412 | 'span': span, | 503 | 'span': span, |
413 | 'head_orth': head_orth, | 504 | 'head_orth': head_orth, |
414 | - 'head_base': head_base, | 505 | + 'head': head, |
415 | 'dominant': dominant, | 506 | 'dominant': dominant, |
416 | 'node': markable, | 507 | 'node': markable, |
417 | 'prec_context': prec_context, | 508 | 'prec_context': prec_context, |
@@ -419,7 +510,10 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na | @@ -419,7 +510,10 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na | ||
419 | 'sentence': sentence, | 510 | 'sentence': sentence, |
420 | 'position_in_mentions': idx, | 511 | 'position_in_mentions': idx, |
421 | 'start_in_words': mnt_start_position, | 512 | 'start_in_words': mnt_start_position, |
422 | - 'end_in_words': mnt_end_position}) | 513 | + 'end_in_words': mnt_end_position, |
514 | + 'rarest': get_rarest_word(mention_words, freq_list), | ||
515 | + 'paragraph_id': paragraph_id, | ||
516 | + 'sentence_id': sentence_id}) | ||
423 | else: | 517 | else: |
424 | print 'Zduplikowana wzmianka: %s' % span | 518 | print 'Zduplikowana wzmianka: %s' % span |
425 | 519 | ||
@@ -427,6 +521,8 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na | @@ -427,6 +521,8 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na | ||
427 | 521 | ||
428 | 522 | ||
429 | def get_context(mention_words, words): | 523 | def get_context(mention_words, words): |
524 | + paragraph_id = 0 | ||
525 | + sentence_id = 0 | ||
430 | prec_context = [] | 526 | prec_context = [] |
431 | follow_context = [] | 527 | follow_context = [] |
432 | sentence = [] | 528 | sentence = [] |
@@ -442,7 +538,11 @@ def get_context(mention_words, words): | @@ -442,7 +538,11 @@ def get_context(mention_words, words): | ||
442 | sentence = get_sentence(idx, words) | 538 | sentence = get_sentence(idx, words) |
443 | mnt_end_position = get_mention_end(last_word, words) | 539 | mnt_end_position = get_mention_end(last_word, words) |
444 | break | 540 | break |
445 | - return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position | 541 | + if word['lastinsent']: |
542 | + sentence_id += 1 | ||
543 | + if word['lastinpar']: | ||
544 | + paragraph_id += 1 | ||
545 | + return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, paragraph_id, sentence_id | ||
446 | 546 | ||
447 | 547 | ||
448 | def get_prec_context(mention_start, words): | 548 | def get_prec_context(mention_start, words): |
@@ -514,10 +614,10 @@ def get_sentence_end(words, word_idx): | @@ -514,10 +614,10 @@ def get_sentence_end(words, word_idx): | ||
514 | return len(words) - 1 | 614 | return len(words) - 1 |
515 | 615 | ||
516 | 616 | ||
517 | -def get_head_base(head_orth, words): | 617 | +def get_head(head_orth, words): |
518 | for word in words: | 618 | for word in words: |
519 | if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth: | 619 | if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth: |
520 | - return word['base'] | 620 | + return word |
521 | return None | 621 | return None |
522 | 622 | ||
523 | 623 | ||
@@ -531,15 +631,61 @@ def get_words(filepath): | @@ -531,15 +631,61 @@ def get_words(filepath): | ||
531 | lastinsent = False | 631 | lastinsent = False |
532 | if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': | 632 | if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': |
533 | lastinsent = True | 633 | lastinsent = True |
634 | + lastinpar = False | ||
635 | + if 'lastinpar' in word.attrib and word.attrib['lastinpar'] == 'true': | ||
636 | + lastinpar = True | ||
534 | words.append({'id': word.attrib['id'], | 637 | words.append({'id': word.attrib['id'], |
535 | 'orth': word.text, | 638 | 'orth': word.text, |
536 | 'base': word.attrib['base'], | 639 | 'base': word.attrib['base'], |
537 | 'hasnps': hasnps, | 640 | 'hasnps': hasnps, |
538 | 'lastinsent': lastinsent, | 641 | 'lastinsent': lastinsent, |
539 | - 'ctag': word.attrib['ctag']}) | 642 | + 'lastinpar': lastinpar, |
643 | + 'ctag': word.attrib['ctag'], | ||
644 | + 'msd': word.attrib['msd'], | ||
645 | + 'gender': get_gender(word.attrib['msd']), | ||
646 | + 'person': get_person(word.attrib['msd']), | ||
647 | + 'number': get_number(word.attrib['msd'])}) | ||
540 | return words | 648 | return words |
541 | 649 | ||
542 | 650 | ||
651 | +def get_gender(msd): | ||
652 | + tags = msd.split(':') | ||
653 | + if 'm1' in tags: | ||
654 | + return 'm1' | ||
655 | + elif 'm2' in tags: | ||
656 | + return 'm2' | ||
657 | + elif 'm3' in tags: | ||
658 | + return 'm3' | ||
659 | + elif 'f' in tags: | ||
660 | + return 'f' | ||
661 | + elif 'n' in tags: | ||
662 | + return 'n' | ||
663 | + else: | ||
664 | + return 'unk' | ||
665 | + | ||
666 | + | ||
667 | +def get_person(msd): | ||
668 | + tags = msd.split(':') | ||
669 | + if 'pri' in tags: | ||
670 | + return 'pri' | ||
671 | + elif 'sec' in tags: | ||
672 | + return 'sec' | ||
673 | + elif 'ter' in tags: | ||
674 | + return 'ter' | ||
675 | + else: | ||
676 | + return 'unk' | ||
677 | + | ||
678 | + | ||
679 | +def get_number(msd): | ||
680 | + tags = msd.split(':') | ||
681 | + if 'sg' in tags: | ||
682 | + return 'sg' | ||
683 | + elif 'pl' in tags: | ||
684 | + return 'pl' | ||
685 | + else: | ||
686 | + return 'unk' | ||
687 | + | ||
688 | + | ||
543 | def get_mention_by_attr(mentions, attr_name, value): | 689 | def get_mention_by_attr(mentions, attr_name, value): |
544 | for mention in mentions: | 690 | for mention in mentions: |
545 | if mention[attr_name] == value: | 691 | if mention[attr_name] == value: |
@@ -652,5 +798,19 @@ def word_to_ignore(word): | @@ -652,5 +798,19 @@ def word_to_ignore(word): | ||
652 | return False | 798 | return False |
653 | 799 | ||
654 | 800 | ||
801 | +def get_rarest_word(words, freq_list): | ||
802 | + min_freq = 0 | ||
803 | + rarest_word = words[0] | ||
804 | + for i, word in enumerate(words): | ||
805 | + word_freq = 0 | ||
806 | + if word['base'] in freq_list: | ||
807 | + word_freq = freq_list[word['base']] | ||
808 | + | ||
809 | + if i == 0 or word_freq < min_freq: | ||
810 | + min_freq = word_freq | ||
811 | + rarest_word = word | ||
812 | + return rarest_word | ||
813 | + | ||
814 | + | ||
655 | if __name__ == '__main__': | 815 | if __name__ == '__main__': |
656 | main() | 816 | main() |