Commit 04c45e2d8290995034f03db5126026ea08041da0
1 parent
c2871e0d
Added new features to feature vector.
Showing
1 changed file
with
179 additions
and
19 deletions
preparator.py
... | ... | @@ -14,10 +14,11 @@ from gensim.models.word2vec import Word2Vec |
14 | 14 | |
15 | 15 | TEST_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'test-prepared')) |
16 | 16 | TRAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'train-prepared')) |
17 | +FREQ_300M_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'freq', 'base.lst')) | |
17 | 18 | |
18 | 19 | ANNO_PATH = TEST_PATH |
19 | 20 | OUT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', |
20 | - 'test.csv')) | |
21 | + 'test-20170627.csv')) | |
21 | 22 | EACH_TEXT_SEPARATELLY = False |
22 | 23 | |
23 | 24 | CONTEXT = 5 |
... | ... | @@ -25,7 +26,12 @@ W2V_SIZE = 50 |
25 | 26 | MODEL = os.path.abspath(os.path.join(os.path.dirname(__file__), 'models', |
26 | 27 | '%d' % W2V_SIZE, |
27 | 28 | 'w2v_allwiki_nkjpfull_%d.model' % W2V_SIZE)) |
29 | + | |
30 | +NOUN_TAGS = ['subst', 'ger', 'depr'] | |
31 | +PPRON_TAGS = ['ppron12', 'ppron3'] | |
32 | +ZERO_TAGS = ['fin', 'praet', 'bedzie', 'impt', 'winien', 'aglt'] | |
28 | 33 | POSSIBLE_HEADS = [u'§', u'%', u'*', u'"', u'„', u'&', u'-'] |
34 | + | |
29 | 35 | NEG_PROPORTION = 1 |
30 | 36 | RANDOM_VECTORS = True |
31 | 37 | |
... | ... | @@ -38,8 +44,9 @@ UNKNONW_WORDS = 0 |
38 | 44 | |
39 | 45 | def main(): |
40 | 46 | model = Word2Vec.load(MODEL) |
47 | + freq_list = load_freq_list(FREQ_300M_PATH) | |
41 | 48 | try: |
42 | - create_data_vectors(model) | |
49 | + create_data_vectors(model, freq_list) | |
43 | 50 | finally: |
44 | 51 | print 'Unknown words: ', UNKNONW_WORDS |
45 | 52 | print 'All words: ', ALL_WORDS |
... | ... | @@ -47,7 +54,20 @@ def main(): |
47 | 54 | print 'Negatives: ', NEG_COUNT |
48 | 55 | |
49 | 56 | |
50 | -def create_data_vectors(model): | |
57 | +def load_freq_list(freq_path): | |
58 | + freq_list = {} | |
59 | + with codecs.open(freq_path, 'r', 'utf-8') as freq_file: | |
60 | + lines = freq_file.readlines() | |
61 | + for line in lines: | |
62 | + line_parts = line.split() | |
63 | + freq = int(line_parts[0]) | |
64 | + base = line_parts[1] | |
65 | + if base not in freq_list: | |
66 | + freq_list[base] = freq | |
67 | + return freq_list | |
68 | + | |
69 | + | |
70 | +def create_data_vectors(model, freq_list): | |
51 | 71 | features_file = None |
52 | 72 | if not EACH_TEXT_SEPARATELLY: |
53 | 73 | features_file = codecs.open(OUT_PATH, 'wt', 'utf-8') |
... | ... | @@ -72,7 +92,7 @@ def create_data_vectors(model): |
72 | 92 | print len(negatives) |
73 | 93 | |
74 | 94 | words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname) |
75 | - mentions_dict = markables_level_2_dict(mentions_path, words_path) | |
95 | + mentions_dict = markables_level_2_dict(mentions_path, words_path, freq_list) | |
76 | 96 | |
77 | 97 | if EACH_TEXT_SEPARATELLY: |
78 | 98 | text_features_path = os.path.join(OUT_PATH, '%s.csv' % textname) |
... | ... | @@ -185,8 +205,8 @@ def get_mention_features(mention_span, mentions_dict, model): |
185 | 205 | mention = get_mention_by_attr(mentions_dict, 'span', mention_span) |
186 | 206 | |
187 | 207 | if DEBUG: |
188 | - features.append(mention['head_base']) | |
189 | - head_vec = get_wv(model, mention['head_base']) | |
208 | + features.append(mention['head']['base']) | |
209 | + head_vec = get_wv(model, mention['head']['base']) | |
190 | 210 | features.extend(list(head_vec)) |
191 | 211 | |
192 | 212 | if DEBUG: |
... | ... | @@ -257,9 +277,25 @@ def get_mention_features(mention_span, mentions_dict, model): |
257 | 277 | sentence_vec = get_context_vec(mention['sentence'], model) |
258 | 278 | features.extend(list(sentence_vec)) |
259 | 279 | |
280 | + # cechy uzupelniajace | |
281 | + features.extend(mention_type(mention)) | |
282 | + | |
260 | 283 | return features |
261 | 284 | |
262 | 285 | |
286 | +def mention_type(mention): | |
287 | + type_vec = [0] * 4 | |
288 | + if mention['head']['ctag'] in NOUN_TAGS: | |
289 | + type_vec[0] = 1 | |
290 | + elif mention['head']['ctag'] in PPRON_TAGS: | |
291 | + type_vec[1] = 1 | |
292 | + elif mention['head']['ctag'] in ZERO_TAGS: | |
293 | + type_vec[2] = 1 | |
294 | + else: | |
295 | + type_vec[3] = 1 | |
296 | + return type_vec | |
297 | + | |
298 | + | |
263 | 299 | def get_wv(model, lemma, random=True): |
264 | 300 | global ALL_WORDS |
265 | 301 | global UNKNONW_WORDS |
... | ... | @@ -330,10 +366,14 @@ def get_pair_features(pair, mentions_dict): |
330 | 366 | features.append(exact_match(ante, ana)) |
331 | 367 | features.append(base_match(ante, ana)) |
332 | 368 | |
333 | - if len(mentions_dict) > 100: | |
334 | - features.append(1) | |
335 | - else: | |
336 | - features.append(0) | |
369 | + # cechy uzupelniajace | |
370 | + features.append(ante_contains_rarest_from_ana(ante, ana)) | |
371 | + features.extend(agreement(ante, ana, 'gender')) | |
372 | + features.extend(agreement(ante, ana, 'number')) | |
373 | + features.extend(agreement(ante, ana, 'person')) | |
374 | + features.append(is_acronym(ante, ana)) | |
375 | + features.append(same_sentence(ante, ana)) | |
376 | + features.append(same_paragraph(ante, ana)) | |
337 | 377 | |
338 | 378 | return features |
339 | 379 | |
... | ... | @@ -382,7 +422,58 @@ def base_match(ante, ana): |
382 | 422 | return 0 |
383 | 423 | |
384 | 424 | |
385 | -def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/NameSpaces/mention'): | |
425 | +def ante_contains_rarest_from_ana(ante, ana): | |
426 | + ana_rarest = ana['rarest'] | |
427 | + for word in ante['words']: | |
428 | + if word['base'] == ana_rarest['base']: | |
429 | + return 1 | |
430 | + return 0 | |
431 | + | |
432 | + | |
433 | +def agreement(ante, ana, tag_name): | |
434 | + agr_vec = [0] * 3 | |
435 | + if ante['head'][tag_name] == 'unk' or ana['head'][tag_name] == 'unk': | |
436 | + agr_vec[2] = 1 | |
437 | + elif ante['head'][tag_name] == ana['head'][tag_name]: | |
438 | + agr_vec[0] = 1 | |
439 | + else: | |
440 | + agr_vec[1] = 1 | |
441 | + return agr_vec | |
442 | + | |
443 | + | |
444 | +def is_acronym(ante, ana): | |
445 | + if ana['text'].upper() == ana['text']: | |
446 | + return check_one_way_acronym(ana['text'], ante['text']) | |
447 | + if ante['text'].upper() == ante['text']: | |
448 | + return check_one_way_acronym(ante['text'], ana['text']); | |
449 | + return 0; | |
450 | + | |
451 | + | |
452 | +def check_one_way_acronym(acronym, expression): | |
453 | + initials = u'' | |
454 | + for expr1 in expression.split('-'): | |
455 | + for expr2 in expr1.split(): | |
456 | + expr2 = expr2.strip() | |
457 | + if expr2: | |
458 | + initials += unicode(expr2[0]).upper() | |
459 | + if acronym == initials: | |
460 | + return 1; | |
461 | + return 0; | |
462 | + | |
463 | + | |
464 | +def same_sentence(ante, ana): | |
465 | + if ante['sentence_id'] == ana['sentence_id']: | |
466 | + return 1 | |
467 | + return 0 | |
468 | + | |
469 | + | |
470 | +def same_paragraph(ante, ana): | |
471 | + if ante['paragraph_id'] == ana['paragraph_id']: | |
472 | + return 1 | |
473 | + return 0 | |
474 | + | |
475 | + | |
476 | +def markables_level_2_dict(markables_path, words_path, freq_list, namespace='www.eml.org/NameSpaces/mention'): | |
386 | 477 | markables_dicts = [] |
387 | 478 | markables_tree = etree.parse(markables_path) |
388 | 479 | markables = markables_tree.xpath("//ns:markable", namespaces={'ns': namespace}) |
... | ... | @@ -401,9 +492,9 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na |
401 | 492 | if head_orth not in POSSIBLE_HEADS: |
402 | 493 | mention_words = span_to_words(span, words) |
403 | 494 | |
404 | - prec_context, follow_context, sentence, mnt_start_position, mnt_end_position = get_context(mention_words, words) | |
495 | + prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, paragraph_id, sentence_id = get_context(mention_words, words) | |
405 | 496 | |
406 | - head_base = get_head_base(head_orth, mention_words) | |
497 | + head = get_head(head_orth, mention_words) | |
407 | 498 | markables_dicts.append({'id': markable.attrib['id'], |
408 | 499 | 'set': markable.attrib['mention_group'], |
409 | 500 | 'text': span_to_text(span, words, 'orth'), |
... | ... | @@ -411,7 +502,7 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na |
411 | 502 | 'words': mention_words, |
412 | 503 | 'span': span, |
413 | 504 | 'head_orth': head_orth, |
414 | - 'head_base': head_base, | |
505 | + 'head': head, | |
415 | 506 | 'dominant': dominant, |
416 | 507 | 'node': markable, |
417 | 508 | 'prec_context': prec_context, |
... | ... | @@ -419,7 +510,10 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na |
419 | 510 | 'sentence': sentence, |
420 | 511 | 'position_in_mentions': idx, |
421 | 512 | 'start_in_words': mnt_start_position, |
422 | - 'end_in_words': mnt_end_position}) | |
513 | + 'end_in_words': mnt_end_position, | |
514 | + 'rarest': get_rarest_word(mention_words, freq_list), | |
515 | + 'paragraph_id': paragraph_id, | |
516 | + 'sentence_id': sentence_id}) | |
423 | 517 | else: |
424 | 518 | print 'Zduplikowana wzmianka: %s' % span |
425 | 519 | |
... | ... | @@ -427,6 +521,8 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na |
427 | 521 | |
428 | 522 | |
429 | 523 | def get_context(mention_words, words): |
524 | + paragraph_id = 0 | |
525 | + sentence_id = 0 | |
430 | 526 | prec_context = [] |
431 | 527 | follow_context = [] |
432 | 528 | sentence = [] |
... | ... | @@ -442,7 +538,11 @@ def get_context(mention_words, words): |
442 | 538 | sentence = get_sentence(idx, words) |
443 | 539 | mnt_end_position = get_mention_end(last_word, words) |
444 | 540 | break |
445 | - return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position | |
541 | + if word['lastinsent']: | |
542 | + sentence_id += 1 | |
543 | + if word['lastinpar']: | |
544 | + paragraph_id += 1 | |
545 | + return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, paragraph_id, sentence_id | |
446 | 546 | |
447 | 547 | |
448 | 548 | def get_prec_context(mention_start, words): |
... | ... | @@ -514,10 +614,10 @@ def get_sentence_end(words, word_idx): |
514 | 614 | return len(words) - 1 |
515 | 615 | |
516 | 616 | |
517 | -def get_head_base(head_orth, words): | |
617 | +def get_head(head_orth, words): | |
518 | 618 | for word in words: |
519 | 619 | if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth: |
520 | - return word['base'] | |
620 | + return word | |
521 | 621 | return None |
522 | 622 | |
523 | 623 | |
... | ... | @@ -531,15 +631,61 @@ def get_words(filepath): |
531 | 631 | lastinsent = False |
532 | 632 | if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': |
533 | 633 | lastinsent = True |
634 | + lastinpar = False | |
635 | + if 'lastinpar' in word.attrib and word.attrib['lastinpar'] == 'true': | |
636 | + lastinpar = True | |
534 | 637 | words.append({'id': word.attrib['id'], |
535 | 638 | 'orth': word.text, |
536 | 639 | 'base': word.attrib['base'], |
537 | 640 | 'hasnps': hasnps, |
538 | 641 | 'lastinsent': lastinsent, |
539 | - 'ctag': word.attrib['ctag']}) | |
642 | + 'lastinpar': lastinpar, | |
643 | + 'ctag': word.attrib['ctag'], | |
644 | + 'msd': word.attrib['msd'], | |
645 | + 'gender': get_gender(word.attrib['msd']), | |
646 | + 'person': get_person(word.attrib['msd']), | |
647 | + 'number': get_number(word.attrib['msd'])}) | |
540 | 648 | return words |
541 | 649 | |
542 | 650 | |
651 | +def get_gender(msd): | |
652 | + tags = msd.split(':') | |
653 | + if 'm1' in tags: | |
654 | + return 'm1' | |
655 | + elif 'm2' in tags: | |
656 | + return 'm2' | |
657 | + elif 'm3' in tags: | |
658 | + return 'm3' | |
659 | + elif 'f' in tags: | |
660 | + return 'f' | |
661 | + elif 'n' in tags: | |
662 | + return 'n' | |
663 | + else: | |
664 | + return 'unk' | |
665 | + | |
666 | + | |
667 | +def get_person(msd): | |
668 | + tags = msd.split(':') | |
669 | + if 'pri' in tags: | |
670 | + return 'pri' | |
671 | + elif 'sec' in tags: | |
672 | + return 'sec' | |
673 | + elif 'ter' in tags: | |
674 | + return 'ter' | |
675 | + else: | |
676 | + return 'unk' | |
677 | + | |
678 | + | |
679 | +def get_number(msd): | |
680 | + tags = msd.split(':') | |
681 | + if 'sg' in tags: | |
682 | + return 'sg' | |
683 | + elif 'pl' in tags: | |
684 | + return 'pl' | |
685 | + else: | |
686 | + return 'unk' | |
687 | + | |
688 | + | |
543 | 689 | def get_mention_by_attr(mentions, attr_name, value): |
544 | 690 | for mention in mentions: |
545 | 691 | if mention[attr_name] == value: |
... | ... | @@ -652,5 +798,19 @@ def word_to_ignore(word): |
652 | 798 | return False |
653 | 799 | |
654 | 800 | |
801 | +def get_rarest_word(words, freq_list): | |
802 | + min_freq = 0 | |
803 | + rarest_word = words[0] | |
804 | + for i, word in enumerate(words): | |
805 | + word_freq = 0 | |
806 | + if word['base'] in freq_list: | |
807 | + word_freq = freq_list[word['base']] | |
808 | + | |
809 | + if i == 0 or word_freq < min_freq: | |
810 | + min_freq = word_freq | |
811 | + rarest_word = word | |
812 | + return rarest_word | |
813 | + | |
814 | + | |
655 | 815 | if __name__ == '__main__': |
656 | 816 | main() |
... | ... |