Commit 04c45e2d8290995034f03db5126026ea08041da0

Authored by Bartłomiej Nitoń
1 parent c2871e0d

Added new features to feature vector.

Showing 1 changed file with 179 additions and 19 deletions
preparator.py
... ... @@ -14,10 +14,11 @@ from gensim.models.word2vec import Word2Vec
14 14  
15 15 TEST_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'test-prepared'))
16 16 TRAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'train-prepared'))
  17 +FREQ_300M_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'freq', 'base.lst'))
17 18  
18 19 ANNO_PATH = TEST_PATH
19 20 OUT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data',
20   - 'test.csv'))
  21 + 'test-20170627.csv'))
21 22 EACH_TEXT_SEPARATELLY = False
22 23  
23 24 CONTEXT = 5
... ... @@ -25,7 +26,12 @@ W2V_SIZE = 50
25 26 MODEL = os.path.abspath(os.path.join(os.path.dirname(__file__), 'models',
26 27 '%d' % W2V_SIZE,
27 28 'w2v_allwiki_nkjpfull_%d.model' % W2V_SIZE))
  29 +
  30 +NOUN_TAGS = ['subst', 'ger', 'depr']
  31 +PPRON_TAGS = ['ppron12', 'ppron3']
  32 +ZERO_TAGS = ['fin', 'praet', 'bedzie', 'impt', 'winien', 'aglt']
28 33 POSSIBLE_HEADS = [u'§', u'%', u'*', u'"', u'„', u'&', u'-']
  34 +
29 35 NEG_PROPORTION = 1
30 36 RANDOM_VECTORS = True
31 37  
... ... @@ -38,8 +44,9 @@ UNKNONW_WORDS = 0
38 44  
39 45 def main():
40 46 model = Word2Vec.load(MODEL)
  47 + freq_list = load_freq_list(FREQ_300M_PATH)
41 48 try:
42   - create_data_vectors(model)
  49 + create_data_vectors(model, freq_list)
43 50 finally:
44 51 print 'Unknown words: ', UNKNONW_WORDS
45 52 print 'All words: ', ALL_WORDS
... ... @@ -47,7 +54,20 @@ def main():
47 54 print 'Negatives: ', NEG_COUNT
48 55  
49 56  
50   -def create_data_vectors(model):
  57 +def load_freq_list(freq_path):
  58 + freq_list = {}
  59 + with codecs.open(freq_path, 'r', 'utf-8') as freq_file:
  60 + lines = freq_file.readlines()
  61 + for line in lines:
  62 + line_parts = line.split()
  63 + freq = int(line_parts[0])
  64 + base = line_parts[1]
  65 + if base not in freq_list:
  66 + freq_list[base] = freq
  67 + return freq_list
  68 +
  69 +
  70 +def create_data_vectors(model, freq_list):
51 71 features_file = None
52 72 if not EACH_TEXT_SEPARATELLY:
53 73 features_file = codecs.open(OUT_PATH, 'wt', 'utf-8')
... ... @@ -72,7 +92,7 @@ def create_data_vectors(model):
72 92 print len(negatives)
73 93  
74 94 words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname)
75   - mentions_dict = markables_level_2_dict(mentions_path, words_path)
  95 + mentions_dict = markables_level_2_dict(mentions_path, words_path, freq_list)
76 96  
77 97 if EACH_TEXT_SEPARATELLY:
78 98 text_features_path = os.path.join(OUT_PATH, '%s.csv' % textname)
... ... @@ -185,8 +205,8 @@ def get_mention_features(mention_span, mentions_dict, model):
185 205 mention = get_mention_by_attr(mentions_dict, 'span', mention_span)
186 206  
187 207 if DEBUG:
188   - features.append(mention['head_base'])
189   - head_vec = get_wv(model, mention['head_base'])
  208 + features.append(mention['head']['base'])
  209 + head_vec = get_wv(model, mention['head']['base'])
190 210 features.extend(list(head_vec))
191 211  
192 212 if DEBUG:
... ... @@ -257,9 +277,25 @@ def get_mention_features(mention_span, mentions_dict, model):
257 277 sentence_vec = get_context_vec(mention['sentence'], model)
258 278 features.extend(list(sentence_vec))
259 279  
  280 + # cechy uzupelniajace
  281 + features.extend(mention_type(mention))
  282 +
260 283 return features
261 284  
262 285  
  286 +def mention_type(mention):
  287 + type_vec = [0] * 4
  288 + if mention['head']['ctag'] in NOUN_TAGS:
  289 + type_vec[0] = 1
  290 + elif mention['head']['ctag'] in PPRON_TAGS:
  291 + type_vec[1] = 1
  292 + elif mention['head']['ctag'] in ZERO_TAGS:
  293 + type_vec[2] = 1
  294 + else:
  295 + type_vec[3] = 1
  296 + return type_vec
  297 +
  298 +
263 299 def get_wv(model, lemma, random=True):
264 300 global ALL_WORDS
265 301 global UNKNONW_WORDS
... ... @@ -330,10 +366,14 @@ def get_pair_features(pair, mentions_dict):
330 366 features.append(exact_match(ante, ana))
331 367 features.append(base_match(ante, ana))
332 368  
333   - if len(mentions_dict) > 100:
334   - features.append(1)
335   - else:
336   - features.append(0)
  369 + # cechy uzupelniajace
  370 + features.append(ante_contains_rarest_from_ana(ante, ana))
  371 + features.extend(agreement(ante, ana, 'gender'))
  372 + features.extend(agreement(ante, ana, 'number'))
  373 + features.extend(agreement(ante, ana, 'person'))
  374 + features.append(is_acronym(ante, ana))
  375 + features.append(same_sentence(ante, ana))
  376 + features.append(same_paragraph(ante, ana))
337 377  
338 378 return features
339 379  
... ... @@ -382,7 +422,58 @@ def base_match(ante, ana):
382 422 return 0
383 423  
384 424  
385   -def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/NameSpaces/mention'):
  425 +def ante_contains_rarest_from_ana(ante, ana):
  426 + ana_rarest = ana['rarest']
  427 + for word in ante['words']:
  428 + if word['base'] == ana_rarest['base']:
  429 + return 1
  430 + return 0
  431 +
  432 +
  433 +def agreement(ante, ana, tag_name):
  434 + agr_vec = [0] * 3
  435 + if ante['head'][tag_name] == 'unk' or ana['head'][tag_name] == 'unk':
  436 + agr_vec[2] = 1
  437 + elif ante['head'][tag_name] == ana['head'][tag_name]:
  438 + agr_vec[0] = 1
  439 + else:
  440 + agr_vec[1] = 1
  441 + return agr_vec
  442 +
  443 +
  444 +def is_acronym(ante, ana):
  445 + if ana['text'].upper() == ana['text']:
  446 + return check_one_way_acronym(ana['text'], ante['text'])
  447 + if ante['text'].upper() == ante['text']:
  448 + return check_one_way_acronym(ante['text'], ana['text']);
  449 + return 0;
  450 +
  451 +
  452 +def check_one_way_acronym(acronym, expression):
  453 + initials = u''
  454 + for expr1 in expression.split('-'):
  455 + for expr2 in expr1.split():
  456 + expr2 = expr2.strip()
  457 + if expr2:
  458 + initials += unicode(expr2[0]).upper()
  459 + if acronym == initials:
  460 + return 1;
  461 + return 0;
  462 +
  463 +
  464 +def same_sentence(ante, ana):
  465 + if ante['sentence_id'] == ana['sentence_id']:
  466 + return 1
  467 + return 0
  468 +
  469 +
  470 +def same_paragraph(ante, ana):
  471 + if ante['paragraph_id'] == ana['paragraph_id']:
  472 + return 1
  473 + return 0
  474 +
  475 +
  476 +def markables_level_2_dict(markables_path, words_path, freq_list, namespace='www.eml.org/NameSpaces/mention'):
386 477 markables_dicts = []
387 478 markables_tree = etree.parse(markables_path)
388 479 markables = markables_tree.xpath("//ns:markable", namespaces={'ns': namespace})
... ... @@ -401,9 +492,9 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na
401 492 if head_orth not in POSSIBLE_HEADS:
402 493 mention_words = span_to_words(span, words)
403 494  
404   - prec_context, follow_context, sentence, mnt_start_position, mnt_end_position = get_context(mention_words, words)
  495 + prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, paragraph_id, sentence_id = get_context(mention_words, words)
405 496  
406   - head_base = get_head_base(head_orth, mention_words)
  497 + head = get_head(head_orth, mention_words)
407 498 markables_dicts.append({'id': markable.attrib['id'],
408 499 'set': markable.attrib['mention_group'],
409 500 'text': span_to_text(span, words, 'orth'),
... ... @@ -411,7 +502,7 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na
411 502 'words': mention_words,
412 503 'span': span,
413 504 'head_orth': head_orth,
414   - 'head_base': head_base,
  505 + 'head': head,
415 506 'dominant': dominant,
416 507 'node': markable,
417 508 'prec_context': prec_context,
... ... @@ -419,7 +510,10 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na
419 510 'sentence': sentence,
420 511 'position_in_mentions': idx,
421 512 'start_in_words': mnt_start_position,
422   - 'end_in_words': mnt_end_position})
  513 + 'end_in_words': mnt_end_position,
  514 + 'rarest': get_rarest_word(mention_words, freq_list),
  515 + 'paragraph_id': paragraph_id,
  516 + 'sentence_id': sentence_id})
423 517 else:
424 518 print 'Zduplikowana wzmianka: %s' % span
425 519  
... ... @@ -427,6 +521,8 @@ def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/Na
427 521  
428 522  
429 523 def get_context(mention_words, words):
  524 + paragraph_id = 0
  525 + sentence_id = 0
430 526 prec_context = []
431 527 follow_context = []
432 528 sentence = []
... ... @@ -442,7 +538,11 @@ def get_context(mention_words, words):
442 538 sentence = get_sentence(idx, words)
443 539 mnt_end_position = get_mention_end(last_word, words)
444 540 break
445   - return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position
  541 + if word['lastinsent']:
  542 + sentence_id += 1
  543 + if word['lastinpar']:
  544 + paragraph_id += 1
  545 + return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, paragraph_id, sentence_id
446 546  
447 547  
448 548 def get_prec_context(mention_start, words):
... ... @@ -514,10 +614,10 @@ def get_sentence_end(words, word_idx):
514 614 return len(words) - 1
515 615  
516 616  
517   -def get_head_base(head_orth, words):
  617 +def get_head(head_orth, words):
518 618 for word in words:
519 619 if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth:
520   - return word['base']
  620 + return word
521 621 return None
522 622  
523 623  
... ... @@ -531,15 +631,61 @@ def get_words(filepath):
531 631 lastinsent = False
532 632 if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true':
533 633 lastinsent = True
  634 + lastinpar = False
  635 + if 'lastinpar' in word.attrib and word.attrib['lastinpar'] == 'true':
  636 + lastinpar = True
534 637 words.append({'id': word.attrib['id'],
535 638 'orth': word.text,
536 639 'base': word.attrib['base'],
537 640 'hasnps': hasnps,
538 641 'lastinsent': lastinsent,
539   - 'ctag': word.attrib['ctag']})
  642 + 'lastinpar': lastinpar,
  643 + 'ctag': word.attrib['ctag'],
  644 + 'msd': word.attrib['msd'],
  645 + 'gender': get_gender(word.attrib['msd']),
  646 + 'person': get_person(word.attrib['msd']),
  647 + 'number': get_number(word.attrib['msd'])})
540 648 return words
541 649  
542 650  
  651 +def get_gender(msd):
  652 + tags = msd.split(':')
  653 + if 'm1' in tags:
  654 + return 'm1'
  655 + elif 'm2' in tags:
  656 + return 'm2'
  657 + elif 'm3' in tags:
  658 + return 'm3'
  659 + elif 'f' in tags:
  660 + return 'f'
  661 + elif 'n' in tags:
  662 + return 'n'
  663 + else:
  664 + return 'unk'
  665 +
  666 +
  667 +def get_person(msd):
  668 + tags = msd.split(':')
  669 + if 'pri' in tags:
  670 + return 'pri'
  671 + elif 'sec' in tags:
  672 + return 'sec'
  673 + elif 'ter' in tags:
  674 + return 'ter'
  675 + else:
  676 + return 'unk'
  677 +
  678 +
  679 +def get_number(msd):
  680 + tags = msd.split(':')
  681 + if 'sg' in tags:
  682 + return 'sg'
  683 + elif 'pl' in tags:
  684 + return 'pl'
  685 + else:
  686 + return 'unk'
  687 +
  688 +
543 689 def get_mention_by_attr(mentions, attr_name, value):
544 690 for mention in mentions:
545 691 if mention[attr_name] == value:
... ... @@ -652,5 +798,19 @@ def word_to_ignore(word):
652 798 return False
653 799  
654 800  
  801 +def get_rarest_word(words, freq_list):
  802 + min_freq = 0
  803 + rarest_word = words[0]
  804 + for i, word in enumerate(words):
  805 + word_freq = 0
  806 + if word['base'] in freq_list:
  807 + word_freq = freq_list[word['base']]
  808 +
  809 + if i == 0 or word_freq < min_freq:
  810 + min_freq = word_freq
  811 + rarest_word = word
  812 + return rarest_word
  813 +
  814 +
655 815 if __name__ == '__main__':
656 816 main()
... ...