diff --git a/count_dist.py b/count_dist.py new file mode 100644 index 0000000..091c006 --- /dev/null +++ b/count_dist.py @@ -0,0 +1,465 @@ +# -*- coding: utf-8 -*- + +import os + + +from lxml import etree +from natsort import natsorted + + +MAIN_PATH = os.path.dirname(__file__) +TEST_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'test-prepared')) +TRAIN_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'train-prepared')) + +ANNO_PATH = TRAIN_PATH + +CONTEXT = 5 +POSSIBLE_HEADS = [u'§', u'%', u'*', u'"', u'„', u'&', u'-'] + + +def main(): + max_mnt_dist = count_max_mnt_dist() + print ('Max mention distance (positive pairs): %d' % max_mnt_dist) + + +def count_max_mnt_dist(): + global_max_mnt_dist = 0 + anno_files = os.listdir(ANNO_PATH) + anno_files = natsorted(anno_files) + for filename in anno_files: + if filename.endswith('.mmax'): + print ('=======> ', filename) + textname = filename.replace('.mmax', '') + + mentions_path = os.path.join(ANNO_PATH, '%s_mentions.xml' % textname) + tree = etree.parse(mentions_path) + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) + + words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname) + mentions_dict = markables_level_2_dict(mentions_path, words_path) + + file_max_mnt_dist = get_max_file_dist(mentions, mentions_dict) + if file_max_mnt_dist > global_max_mnt_dist: + global_max_mnt_dist = file_max_mnt_dist + + return global_max_mnt_dist + + +def get_max_file_dist(mentions, mentions_dict): + max_file_dist = 0 + sets, all_mentions, clustered_mensions = get_sets(mentions) + for set_id in sets: + set_dist = get_max_set_dist(sets[set_id], mentions_dict) + if set_dist > max_file_dist: + max_file_dist = set_dist + print ('Max mention distance: %d' % max_file_dist) + return max_file_dist + + +def get_sets(mentions): + sets = {} + all_mentions = [] + clustered_mensions = [] + for mention in mentions: + all_mentions.append(mention.attrib['span']) + set_id = mention.attrib['mention_group'] + if set_id == 'empty' or set_id == '': + pass + elif set_id not in sets: + sets[set_id] = [mention.attrib['span']] + clustered_mensions.append(mention.attrib['span']) + elif set_id in sets: + sets[set_id].append(mention.attrib['span']) + clustered_mensions.append(mention.attrib['span']) + else: + print (u'Coś poszło nie tak przy wyszukiwaniu klastrów!') + + sets_to_remove = [] + for set_id in sets: + if len(sets[set_id]) < 2: + sets_to_remove.append(set_id) + if len(sets[set_id]) == 1: + print (u'Removing clustered mention: ', sets[set_id][0]) + clustered_mensions.remove(sets[set_id][0]) + + for set_id in sets_to_remove: + print (u'Removing set: ', set_id) + sets.pop(set_id) + + return sets, all_mentions, clustered_mensions + + +def get_max_set_dist(mnt_set, mentions_dict): + max_set_dist = 0 + for id, mnt2_span in enumerate(mnt_set): + mnt2 = get_mention_by_attr(mentions_dict, 'span', mnt2_span) + dist = None + dist1 = None + if id - 1 >= 0: + mnt1_span = mnt_set[id - 1] + mnt1 = get_mention_by_attr(mentions_dict, 'span', mnt1_span) + dist1 = get_pair_dist(mnt1, mnt2) + dist = dist1 + if id + 1 < len(mnt_set): + mnt3_span = mnt_set[id + 1] + mnt3 = get_mention_by_attr(mentions_dict, 'span', mnt3_span) + dist2 = get_pair_dist(mnt2, mnt3) + if dist1 is not None and dist2 < dist1: + dist = dist2 + + if dist > max_set_dist: + max_set_dist = dist + + return max_set_dist + + +def get_pair_dist(ante, ana): + dist = 0 + mnts_intersect = pair_intersect(ante, ana) + if mnts_intersect != 1: + dist = ana['position_in_mentions'] - ante['position_in_mentions'] + return dist + + +def pair_intersect(ante, ana): + for ante_word in ante['words']: + for ana_word in ana['words']: + if ana_word['id'] == ante_word['id']: + return 1 + return 0 + + +def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/NameSpaces/mention'): + markables_dicts = [] + markables_tree = etree.parse(markables_path) + markables = markables_tree.xpath("//ns:markable", namespaces={'ns': namespace}) + + words = get_words(words_path) + + for idx, markable in enumerate(markables): + span = markable.attrib['span'] + if not get_mention_by_attr(markables_dicts, 'span', span): + + dominant = '' + if 'dominant' in markable.attrib: + dominant = markable.attrib['dominant'] + + head_orth = markable.attrib['mention_head'] + if True: + mention_words = span_to_words(span, words) + + (prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, + paragraph_id, sentence_id, first_in_sentence, first_in_paragraph) = get_context(mention_words, words) + + head = get_head(head_orth, mention_words) + markables_dicts.append({'id': markable.attrib['id'], + 'set': markable.attrib['mention_group'], + 'text': span_to_text(span, words, 'orth'), + 'lemmatized_text': span_to_text(span, words, 'base'), + 'words': mention_words, + 'span': span, + 'head_orth': head_orth, + 'head': head, + 'dominant': dominant, + 'node': markable, + 'prec_context': prec_context, + 'follow_context': follow_context, + 'sentence': sentence, + 'position_in_mentions': idx, + 'start_in_words': mnt_start_position, + 'end_in_words': mnt_end_position, + 'paragraph_id': paragraph_id, + 'sentence_id': sentence_id, + 'first_in_sentence': first_in_sentence, + 'first_in_paragraph': first_in_paragraph}) + else: + print ('Zduplikowana wzmianka: %s' % span) + + return markables_dicts + + +def get_context(mention_words, words): + paragraph_id = 0 + sentence_id = 0 + prec_context = [] + follow_context = [] + sentence = [] + mnt_start_position = -1 + first_word = mention_words[0] + last_word = mention_words[-1] + first_in_sentence = False + first_in_paragraph = False + for idx, word in enumerate(words): + if word['id'] == first_word['id']: + prec_context = get_prec_context(idx, words) + mnt_start_position = get_mention_start(first_word, words) + if idx == 0 or words[idx-1]['lastinsent']: + first_in_sentence = True + if idx == 0 or words[idx-1]['lastinpar']: + first_in_paragraph = True + if word['id'] == last_word['id']: + follow_context = get_follow_context(idx, words) + sentence = get_sentence(idx, words) + mnt_end_position = get_mention_end(last_word, words) + break + if word['lastinsent']: + sentence_id += 1 + if word['lastinpar']: + paragraph_id += 1 + return (prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, + paragraph_id, sentence_id, first_in_sentence, first_in_paragraph) + + +def get_prec_context(mention_start, words): + context = [] + context_start = mention_start - 1 + while context_start >= 0: + if not word_to_ignore(words[context_start]): + context.append(words[context_start]) + if len(context) == CONTEXT: + break + context_start -= 1 + context.reverse() + return context + + +def get_mention_start(first_word, words): + start = 0 + for word in words: + if not word_to_ignore(word): + start += 1 + if word['id'] == first_word['id']: + break + return start + + +def get_mention_end(last_word, words): + end = 0 + for word in words: + if not word_to_ignore(word): + end += 1 + if word['id'] == last_word['id']: + break + return end + + +def get_follow_context(mention_end, words): + context = [] + context_end = mention_end + 1 + while context_end < len(words): + if not word_to_ignore(words[context_end]): + context.append(words[context_end]) + if len(context) == CONTEXT: + break + context_end += 1 + return context + + +def get_sentence(word_idx, words): + sentence_start = get_sentence_start(words, word_idx) + sentence_end = get_sentence_end(words, word_idx) + sentence = [word for word in words[sentence_start:sentence_end+1] if not word_to_ignore(word)] + return sentence + + +def get_sentence_start(words, word_idx): + search_start = word_idx + while word_idx >= 0: + if words[word_idx]['lastinsent'] and search_start != word_idx: + return word_idx+1 + word_idx -= 1 + return 0 + + +def get_sentence_end(words, word_idx): + while word_idx < len(words): + if words[word_idx]['lastinsent']: + return word_idx + word_idx += 1 + return len(words) - 1 + + +def get_head(head_orth, words): + for word in words: + if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth: + return word + return None + + +def get_words(filepath): + tree = etree.parse(filepath) + words = [] + for word in tree.xpath("//word"): + hasnps = False + if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true': + hasnps = True + lastinsent = False + if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': + lastinsent = True + lastinpar = False + if 'lastinpar' in word.attrib and word.attrib['lastinpar'] == 'true': + lastinpar = True + words.append({'id': word.attrib['id'], + 'orth': word.text, + 'base': word.attrib['base'], + 'hasnps': hasnps, + 'lastinsent': lastinsent, + 'lastinpar': lastinpar, + 'ctag': word.attrib['ctag'], + 'msd': word.attrib['msd'], + 'gender': get_gender(word.attrib['msd']), + 'person': get_person(word.attrib['msd']), + 'number': get_number(word.attrib['msd'])}) + return words + + +def get_gender(msd): + tags = msd.split(':') + if 'm1' in tags: + return 'm1' + elif 'm2' in tags: + return 'm2' + elif 'm3' in tags: + return 'm3' + elif 'f' in tags: + return 'f' + elif 'n' in tags: + return 'n' + else: + return 'unk' + + +def get_person(msd): + tags = msd.split(':') + if 'pri' in tags: + return 'pri' + elif 'sec' in tags: + return 'sec' + elif 'ter' in tags: + return 'ter' + else: + return 'unk' + + +def get_number(msd): + tags = msd.split(':') + if 'sg' in tags: + return 'sg' + elif 'pl' in tags: + return 'pl' + else: + return 'unk' + + +def get_mention_by_attr(mentions, attr_name, value): + for mention in mentions: + if mention[attr_name] == value: + return mention + return None + + +def get_mention_index_by_attr(mentions, attr_name, value): + for idx, mention in enumerate(mentions): + if mention[attr_name] == value: + return idx + return None + + +def span_to_text(span, words, form): + fragments = span.split(',') + mention_parts = [] + for fragment in fragments: + mention_parts.append(fragment_to_text(fragment, words, form)) + return u' [...] '.join(mention_parts) + + +def fragment_to_text(fragment, words, form): + if '..' in fragment: + text = get_multiword_text(fragment, words, form) + else: + text = get_one_word_text(fragment, words, form) + return text + + +def get_multiword_text(fragment, words, form): + mention_parts = [] + boundaries = fragment.split('..') + start_id = boundaries[0] + end_id = boundaries[1] + in_string = False + for word in words: + if word['id'] == start_id: + in_string = True + if in_string and not word_to_ignore(word): + mention_parts.append(word) + if word['id'] == end_id: + break + return to_text(mention_parts, form) + + +def to_text(words, form): + text = '' + for idx, word in enumerate(words): + if word['hasnps'] or idx == 0: + text += word[form] + else: + text += u' %s' % word[form] + return text + + +def get_one_word_text(word_id, words, form): + this_word = next(word for word in words if word['id'] == word_id) + if word_to_ignore(this_word): + print (this_word) + return this_word[form] + + +def span_to_words(span, words): + fragments = span.split(',') + mention_parts = [] + for fragment in fragments: + mention_parts.extend(fragment_to_words(fragment, words)) + return mention_parts + + +def fragment_to_words(fragment, words): + mention_parts = [] + if '..' in fragment: + mention_parts.extend(get_multiword(fragment, words)) + else: + mention_parts.extend(get_word(fragment, words)) + return mention_parts + + +def get_multiword(fragment, words): + mention_parts = [] + boundaries = fragment.split('..') + start_id = boundaries[0] + end_id = boundaries[1] + in_string = False + for word in words: + if word['id'] == start_id: + in_string = True + if in_string and not word_to_ignore(word): + mention_parts.append(word) + if word['id'] == end_id: + break + return mention_parts + + +def get_word(word_id, words): + for word in words: + if word['id'] == word_id: + if not word_to_ignore(word): + return [word] + else: + return [] + return [] + + +def word_to_ignore(word): + return False + + +if __name__ == '__main__': + main() diff --git a/preparator.py b/preparator.py index 02365ec..e407aee 100644 --- a/preparator.py +++ b/preparator.py @@ -29,7 +29,7 @@ TITLE2REDIRECT_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'wikipedia ANNO_PATH = TEST_PATH OUT_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', - 'test-1to5-20170720.csv')) + 'test-1to5-singletons-20170720.csv')) EACH_TEXT_SEPARATELLY = False CONTEXT = 5 @@ -53,6 +53,7 @@ HYPHEN_SIGNS = ['-', '#'] NEG_PROPORTION = 5 RANDOM_VECTORS = True +USE_SINGLETONS = True DEBUG = False POS_COUNT = 0 @@ -154,9 +155,9 @@ def create_data_vectors(model, freq_list, lemma2synonyms, def diff_mentions(mentions): - sets, clustered_mensions = get_sets(mentions) + sets, all_mentions, clustered_mensions = get_sets(mentions) positives = get_positives(sets) - positives, negatives = get_negatives_and_update_positives(clustered_mensions, positives) + positives, negatives = get_negatives_and_update_positives(all_mentions, clustered_mensions, positives) if len(negatives) != len(positives) and NEG_PROPORTION == 1: print (u'Niezgodna liczba przypadków pozytywnych i negatywnych!') return positives, negatives @@ -164,8 +165,10 @@ def diff_mentions(mentions): def get_sets(mentions): sets = {} + all_mentions = [] clustered_mensions = [] for mention in mentions: + all_mentions.append(mention.attrib['span']) set_id = mention.attrib['mention_group'] if set_id == 'empty' or set_id == '' or mention.attrib['mention_head'] in POSSIBLE_HEADS: pass @@ -190,7 +193,7 @@ def get_sets(mentions): print (u'Removing set: ', set_id) sets.pop(set_id) - return sets, clustered_mensions + return sets, all_mentions, clustered_mensions def get_positives(sets): @@ -201,8 +204,12 @@ def get_positives(sets): return positives -def get_negatives_and_update_positives(clustered_mensions, positives): - all_pairs = list(combinations(clustered_mensions, 2)) +def get_negatives_and_update_positives(all_mentions, clustered_mentions, positives): + all_pairs = [] + if USE_SINGLETONS: + all_pairs = list(combinations(all_mentions, 2)) + else: + all_pairs = list(combinations(clustered_mentions, 2)) all_pairs = set(all_pairs) negatives = [pair for pair in all_pairs if pair not in positives] samples_count = NEG_PROPORTION * len(positives) @@ -474,7 +481,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms, words_dist = [0] * 11 words_bucket = 0 if mnts_intersect != 1: - words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words'] - 1) + words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words']) if DEBUG: features.append('Bucket %d' % words_bucket) words_dist[words_bucket] = 1 @@ -483,7 +490,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms, mentions_dist = [0] * 11 mentions_bucket = 0 if mnts_intersect != 1: - mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions'] - 1) + mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions']) if words_bucket == 10: mentions_bucket = 10 if DEBUG: