diff --git a/counter.py b/counter.py new file mode 100644 index 0000000..21f6e7a --- /dev/null +++ b/counter.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +import os + +from lxml import etree +from natsort import natsorted + +from preparator import ANNO_PATH + + +def count_words(): + anno_files = os.listdir(ANNO_PATH) + anno_files = natsorted(anno_files) + for filename in anno_files: + if filename.endswith('.mmax'): + words_count = 0 + textname = filename.replace('.mmax', '') + words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname) + tree = etree.parse(words_path) + for word in tree.xpath("//word"): + if word.attrib['ctag'] != 'interp': + words_count += 1 + print textname, words_count + + +def count_mentions(): + anno_files = os.listdir(ANNO_PATH) + anno_files = natsorted(anno_files) + for filename in anno_files: + if filename.endswith('.mmax'): + textname = filename.replace('.mmax', '') + + mentions_path = os.path.join(ANNO_PATH, '%s_mentions.xml' % textname) + tree = etree.parse(mentions_path) + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) + print textname, len(mentions) diff --git a/preparator.py b/preparator.py new file mode 100644 index 0000000..1c170b4 --- /dev/null +++ b/preparator.py @@ -0,0 +1,656 @@ +# -*- coding: utf-8 -*- + +import codecs +import numpy +import os +import random + +from lxml import etree +from itertools import combinations +from natsort import natsorted + +from gensim.models.word2vec import Word2Vec + + +TEST_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'test-prepared')) +TRAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'train-prepared')) + +ANNO_PATH = TEST_PATH +OUT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', + 'test.csv')) +EACH_TEXT_SEPARATELLY = False + +CONTEXT = 5 +W2V_SIZE = 50 +MODEL = os.path.abspath(os.path.join(os.path.dirname(__file__), 'models', + '%d' % W2V_SIZE, + 'w2v_allwiki_nkjpfull_%d.model' % W2V_SIZE)) +POSSIBLE_HEADS = [u'§', u'%', u'*', u'"', u'„', u'&', u'-'] +NEG_PROPORTION = 1 +RANDOM_VECTORS = True + +DEBUG = False +POS_COUNT = 0 +NEG_COUNT = 0 +ALL_WORDS = 0 +UNKNONW_WORDS = 0 + + +def main(): + model = Word2Vec.load(MODEL) + try: + create_data_vectors(model) + finally: + print 'Unknown words: ', UNKNONW_WORDS + print 'All words: ', ALL_WORDS + print 'Positives: ', POS_COUNT + print 'Negatives: ', NEG_COUNT + + +def create_data_vectors(model): + features_file = None + if not EACH_TEXT_SEPARATELLY: + features_file = codecs.open(OUT_PATH, 'wt', 'utf-8') + + anno_files = os.listdir(ANNO_PATH) + anno_files = natsorted(anno_files) + for filename in anno_files: + if filename.endswith('.mmax'): + print '=======> ', filename + textname = filename.replace('.mmax', '') + + mentions_path = os.path.join(ANNO_PATH, '%s_mentions.xml' % textname) + tree = etree.parse(mentions_path) + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) + positives, negatives = diff_mentions(mentions) + + if DEBUG: + print 'Positives:' + print len(positives) + + print 'Negatives:' + print len(negatives) + + words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname) + mentions_dict = markables_level_2_dict(mentions_path, words_path) + + if EACH_TEXT_SEPARATELLY: + text_features_path = os.path.join(OUT_PATH, '%s.csv' % textname) + features_file = codecs.open(text_features_path, 'wt', 'utf-8') + write_features(features_file, positives, negatives, mentions_dict, model, textname) + + if not EACH_TEXT_SEPARATELLY: + features_file.close() + + +def diff_mentions(mentions): + sets, clustered_mensions = get_sets(mentions) + positives = get_positives(sets) + positives, negatives = get_negatives_and_update_positives(clustered_mensions, positives) + if len(negatives) != len(positives) and NEG_PROPORTION == 1: + print u'Niezgodna liczba przypadków pozytywnych i negatywnych!' + return positives, negatives + + +def get_sets(mentions): + sets = {} + clustered_mensions = [] + for mention in mentions: + set_id = mention.attrib['mention_group'] + if set_id == 'empty' or set_id == '' or mention.attrib['mention_head'] in POSSIBLE_HEADS: + pass + elif set_id not in sets: + sets[set_id] = [mention.attrib['span']] + clustered_mensions.append(mention.attrib['span']) + elif set_id in sets: + sets[set_id].append(mention.attrib['span']) + clustered_mensions.append(mention.attrib['span']) + else: + print u'Coś poszło nie tak przy wyszukiwaniu klastrów!' + + sets_to_remove = [] + for set_id in sets: + if len(sets[set_id]) < 2: + sets_to_remove.append(set_id) + if len(sets[set_id]) == 1: + print u'Removing clustered mention: ', sets[set_id][0] + clustered_mensions.remove(sets[set_id][0]) + + for set_id in sets_to_remove: + print u'Removing set: ', set_id + sets.pop(set_id) + + return sets, clustered_mensions + + +def get_positives(sets): + positives = [] + for set_id in sets: + coref_set = sets[set_id] + positives.extend(list(combinations(coref_set, 2))) + return positives + + +def get_negatives_and_update_positives(clustered_mensions, positives): + all_pairs = list(combinations(clustered_mensions, 2)) + all_pairs = set(all_pairs) + negatives = [pair for pair in all_pairs if pair not in positives] + samples_count = NEG_PROPORTION * len(positives) + if samples_count > len(negatives): + samples_count = len(negatives) + if NEG_PROPORTION == 1: + positives = random.sample(set(positives), samples_count) + print u'Więcej przypadków pozytywnych niż negatywnych!' + negatives = random.sample(set(negatives), samples_count) + return positives, negatives + + +def write_features(features_file, positives, negatives, mentions_dict, model, textname): + global POS_COUNT + POS_COUNT += len(positives) + for pair in positives: + pair_features = [] + if DEBUG: + pair_features = ['%s>%s:%s' % (textname, pair[0], pair[1])] + pair_features.extend(get_features(pair, mentions_dict, model)) + pair_features.append(1) + features_file.write(u'%s\n' % u'\t'.join([unicode(feature) for feature in pair_features])) + + global NEG_COUNT + NEG_COUNT += len(negatives) + for pair in negatives: + pair_features = [] + if DEBUG: + pair_features = ['%s>%s:%s' % (textname, pair[0], pair[1])] + pair_features.extend(get_features(pair, mentions_dict, model)) + pair_features.append(0) + features_file.write(u'%s\n' % u'\t'.join([unicode(feature) for feature in pair_features])) + + +def get_features(pair, mentions_dict, model): + features = [] + ante = pair[0] + ana = pair[1] + ante_features = get_mention_features(ante, mentions_dict, model) + features.extend(ante_features) + ana_features = get_mention_features(ana, mentions_dict, model) + features.extend(ana_features) + pair_features = get_pair_features(pair, mentions_dict) + features.extend(pair_features) + return features + + +def get_mention_features(mention_span, mentions_dict, model): + features = [] + mention = get_mention_by_attr(mentions_dict, 'span', mention_span) + + if DEBUG: + features.append(mention['head_base']) + head_vec = get_wv(model, mention['head_base']) + features.extend(list(head_vec)) + + if DEBUG: + features.append(mention['words'][0]['base']) + first_vec = get_wv(model, mention['words'][0]['base']) + features.extend(list(first_vec)) + + if DEBUG: + features.append(mention['words'][-1]['base']) + last_vec = get_wv(model, mention['words'][-1]['base']) + features.extend(list(last_vec)) + + if len(mention['follow_context']) > 0: + if DEBUG: + features.append(mention['follow_context'][0]['base']) + after_1_vec = get_wv(model, mention['follow_context'][0]['base']) + features.extend(list(after_1_vec)) + else: + if DEBUG: + features.append('None') + features.extend([0.0] * W2V_SIZE) + if len(mention['follow_context']) > 1: + if DEBUG: + features.append(mention['follow_context'][1]['base']) + after_2_vec = get_wv(model, mention['follow_context'][1]['base']) + features.extend(list(after_2_vec)) + else: + if DEBUG: + features.append('None') + features.extend([0.0] * W2V_SIZE) + + if len(mention['prec_context']) > 0: + if DEBUG: + features.append(mention['prec_context'][-1]['base']) + prec_1_vec = get_wv(model, mention['prec_context'][-1]['base']) + features.extend(list(prec_1_vec)) + else: + if DEBUG: + features.append('None') + features.extend([0.0] * W2V_SIZE) + if len(mention['prec_context']) > 1: + if DEBUG: + features.append(mention['prec_context'][-2]['base']) + prec_2_vec = get_wv(model, mention['prec_context'][-2]['base']) + features.extend(list(prec_2_vec)) + else: + if DEBUG: + features.append('None') + features.extend([0.0] * W2V_SIZE) + + if DEBUG: + features.append(u' '.join([word['orth'] for word in mention['prec_context']])) + prec_vec = get_context_vec(mention['prec_context'], model) + features.extend(list(prec_vec)) + + if DEBUG: + features.append(u' '.join([word['orth'] for word in mention['follow_context']])) + follow_vec = get_context_vec(mention['follow_context'], model) + features.extend(list(follow_vec)) + + if DEBUG: + features.append(u' '.join([word['orth'] for word in mention['words']])) + mention_vec = get_context_vec(mention['words'], model) + features.extend(list(mention_vec)) + + if DEBUG: + features.append(u' '.join([word['orth'] for word in mention['sentence']])) + sentence_vec = get_context_vec(mention['sentence'], model) + features.extend(list(sentence_vec)) + + return features + + +def get_wv(model, lemma, random=True): + global ALL_WORDS + global UNKNONW_WORDS + vec = None + if random: + vec = random_vec() + ALL_WORDS += 1 + try: + vec = model.wv[lemma] + except KeyError: + UNKNONW_WORDS += 1 + return vec + + +def random_vec(): + return numpy.asarray([random.uniform(-0.25, 0.25) for i in range(0, W2V_SIZE)], dtype=numpy.float32) + + +def get_context_vec(words, model): + vec = numpy.zeros(W2V_SIZE, dtype=numpy.float32) + unknown_count = 0 + if len(words) != 0: + for word in words: + word_vec = get_wv(model, word['base'], RANDOM_VECTORS) + if word_vec is None: + unknown_count += 1 + else: + vec += word_vec + significant_words = len(words) - unknown_count + if significant_words != 0: + vec = vec/float(significant_words) + else: + vec = random_vec() + return vec + + +def get_pair_features(pair, mentions_dict): + ante = get_mention_by_attr(mentions_dict, 'span', pair[0]) + ana = get_mention_by_attr(mentions_dict, 'span', pair[1]) + + features = [] + mnts_intersect = pair_intersect(ante, ana) + + words_dist = [0] * 11 + words_bucket = 0 + if mnts_intersect != 1: + words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words'] - 1) + if DEBUG: + features.append('Bucket %d' % words_bucket) + words_dist[words_bucket] = 1 + features.extend(words_dist) + + mentions_dist = [0] * 11 + mentions_bucket = 0 + if mnts_intersect != 1: + mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions'] - 1) + if words_bucket == 10: + mentions_bucket = 10 + if DEBUG: + features.append('Bucket %d' % mentions_bucket) + mentions_dist[mentions_bucket] = 1 + features.extend(mentions_dist) + + if DEBUG: + features.append('Other features') + features.append(mnts_intersect) + features.append(head_match(ante, ana)) + features.append(exact_match(ante, ana)) + features.append(base_match(ante, ana)) + + if len(mentions_dict) > 100: + features.append(1) + else: + features.append(0) + + return features + + +def get_distance_bucket(distance): + if distance >= 0 and distance <= 4: + return distance + elif distance >= 5 and distance <= 7: + return 5 + elif distance >= 8 and distance <= 15: + return 6 + elif distance >= 16 and distance <= 31: + return 7 + elif distance >= 32 and distance <= 63: + return 8 + elif distance >= 64: + return 9 + else: + print u'Coś poszło nie tak przy kubełkowaniu!!' + return 10 + + +def pair_intersect(ante, ana): + for ante_word in ante['words']: + for ana_word in ana['words']: + if ana_word['id'] == ante_word['id']: + return 1 + return 0 + + +def head_match(ante, ana): + if ante['head_orth'].lower() == ana['head_orth'].lower(): + return 1 + return 0 + + +def exact_match(ante, ana): + if ante['text'].lower() == ana['text'].lower(): + return 1 + return 0 + + +def base_match(ante, ana): + if ante['lemmatized_text'].lower() == ana['lemmatized_text'].lower(): + return 1 + return 0 + + +def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/NameSpaces/mention'): + markables_dicts = [] + markables_tree = etree.parse(markables_path) + markables = markables_tree.xpath("//ns:markable", namespaces={'ns': namespace}) + + words = get_words(words_path) + + for idx, markable in enumerate(markables): + span = markable.attrib['span'] + if not get_mention_by_attr(markables_dicts, 'span', span): + + dominant = '' + if 'dominant' in markable.attrib: + dominant = markable.attrib['dominant'] + + head_orth = markable.attrib['mention_head'] + if head_orth not in POSSIBLE_HEADS: + mention_words = span_to_words(span, words) + + prec_context, follow_context, sentence, mnt_start_position, mnt_end_position = get_context(mention_words, words) + + head_base = get_head_base(head_orth, mention_words) + markables_dicts.append({'id': markable.attrib['id'], + 'set': markable.attrib['mention_group'], + 'text': span_to_text(span, words, 'orth'), + 'lemmatized_text': span_to_text(span, words, 'base'), + 'words': mention_words, + 'span': span, + 'head_orth': head_orth, + 'head_base': head_base, + 'dominant': dominant, + 'node': markable, + 'prec_context': prec_context, + 'follow_context': follow_context, + 'sentence': sentence, + 'position_in_mentions': idx, + 'start_in_words': mnt_start_position, + 'end_in_words': mnt_end_position}) + else: + print 'Zduplikowana wzmianka: %s' % span + + return markables_dicts + + +def get_context(mention_words, words): + prec_context = [] + follow_context = [] + sentence = [] + mnt_start_position = -1 + first_word = mention_words[0] + last_word = mention_words[-1] + for idx, word in enumerate(words): + if word['id'] == first_word['id']: + prec_context = get_prec_context(idx, words) + mnt_start_position = get_mention_start(first_word, words) + if word['id'] == last_word['id']: + follow_context = get_follow_context(idx, words) + sentence = get_sentence(idx, words) + mnt_end_position = get_mention_end(last_word, words) + break + return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position + + +def get_prec_context(mention_start, words): + context = [] + context_start = mention_start - 1 + while context_start >= 0: + if not word_to_ignore(words[context_start]): + context.append(words[context_start]) + if len(context) == CONTEXT: + break + context_start -= 1 + context.reverse() + return context + + +def get_mention_start(first_word, words): + start = 0 + for word in words: + if not word_to_ignore(word): + start += 1 + if word['id'] == first_word['id']: + break + return start + + +def get_mention_end(last_word, words): + end = 0 + for word in words: + if not word_to_ignore(word): + end += 1 + if word['id'] == last_word['id']: + break + return end + + +def get_follow_context(mention_end, words): + context = [] + context_end = mention_end + 1 + while context_end < len(words): + if not word_to_ignore(words[context_end]): + context.append(words[context_end]) + if len(context) == CONTEXT: + break + context_end += 1 + return context + + +def get_sentence(word_idx, words): + sentence_start = get_sentence_start(words, word_idx) + sentence_end = get_sentence_end(words, word_idx) + sentence = [word for word in words[sentence_start:sentence_end+1] if not word_to_ignore(word)] + return sentence + + +def get_sentence_start(words, word_idx): + search_start = word_idx + while word_idx >= 0: + if words[word_idx]['lastinsent'] and search_start != word_idx: + return word_idx+1 + word_idx -= 1 + return 0 + + +def get_sentence_end(words, word_idx): + while word_idx < len(words): + if words[word_idx]['lastinsent']: + return word_idx + word_idx += 1 + return len(words) - 1 + + +def get_head_base(head_orth, words): + for word in words: + if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth: + return word['base'] + return None + + +def get_words(filepath): + tree = etree.parse(filepath) + words = [] + for word in tree.xpath("//word"): + hasnps = False + if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true': + hasnps = True + lastinsent = False + if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': + lastinsent = True + words.append({'id': word.attrib['id'], + 'orth': word.text, + 'base': word.attrib['base'], + 'hasnps': hasnps, + 'lastinsent': lastinsent, + 'ctag': word.attrib['ctag']}) + return words + + +def get_mention_by_attr(mentions, attr_name, value): + for mention in mentions: + if mention[attr_name] == value: + return mention + return None + + +def get_mention_index_by_attr(mentions, attr_name, value): + for idx, mention in enumerate(mentions): + if mention[attr_name] == value: + return idx + return None + + +def span_to_text(span, words, form): + fragments = span.split(',') + mention_parts = [] + for fragment in fragments: + mention_parts.append(fragment_to_text(fragment, words, form)) + return u' [...] '.join(mention_parts) + + +def fragment_to_text(fragment, words, form): + if '..' in fragment: + text = get_multiword_text(fragment, words, form) + else: + text = get_one_word_text(fragment, words, form) + return text + + +def get_multiword_text(fragment, words, form): + mention_parts = [] + boundaries = fragment.split('..') + start_id = boundaries[0] + end_id = boundaries[1] + in_string = False + for word in words: + if word['id'] == start_id: + in_string = True + if in_string and not word_to_ignore(word): + mention_parts.append(word) + if word['id'] == end_id: + break + return to_text(mention_parts, form) + + +def to_text(words, form): + text = '' + for idx, word in enumerate(words): + if word['hasnps'] or idx == 0: + text += word[form] + else: + text += u' %s' % word[form] + return text + + +def get_one_word_text(word_id, words, form): + this_word = (word for word in words if word['id'] == word_id).next() + if word_to_ignore(this_word): + print this_word + return this_word[form] + + +def span_to_words(span, words): + fragments = span.split(',') + mention_parts = [] + for fragment in fragments: + mention_parts.extend(fragment_to_words(fragment, words)) + return mention_parts + + +def fragment_to_words(fragment, words): + mention_parts = [] + if '..' in fragment: + mention_parts.extend(get_multiword(fragment, words)) + else: + mention_parts.extend(get_word(fragment, words)) + return mention_parts + + +def get_multiword(fragment, words): + mention_parts = [] + boundaries = fragment.split('..') + start_id = boundaries[0] + end_id = boundaries[1] + in_string = False + for word in words: + if word['id'] == start_id: + in_string = True + if in_string and not word_to_ignore(word): + mention_parts.append(word) + if word['id'] == end_id: + break + return mention_parts + + +def get_word(word_id, words): + for word in words: + if word['id'] == word_id: + if not word_to_ignore(word): + return [word] + else: + return [] + return [] + + +def word_to_ignore(word): + if word['ctag'] == 'interp': + return True + return False + + +if __name__ == '__main__': + main() diff --git a/resolver.py b/resolver.py new file mode 100644 index 0000000..0536074 --- /dev/null +++ b/resolver.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- + +import codecs +import os + +import numpy as np + +from natsort import natsorted + +from keras.models import Model +from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization +from keras.optimizers import SGD, Adam + +IN_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', + 'prepared_text_files')) +OUT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', + 'metrics.csv')) + +MODEL = os.path.abspath(os.path.join(os.path.dirname(__file__), 'weights_2017_05_10.h5')) + + +NUMBER_OF_FEATURES = 1126 + + +def main(): + resolve_files() + + +def resolve_files(): + metrics_file = codecs.open(OUT_PATH, 'w', 'utf-8') + write_labels(metrics_file) + + anno_files = os.listdir(IN_PATH) + anno_files = natsorted(anno_files) + for filename in anno_files: + print (filename) + textname = filename.replace('.csv', '') + text_data_path = os.path.join(IN_PATH, filename) + resolve(textname, text_data_path, metrics_file) + + metrics_file.close() + + +def write_labels(metrics_file): + metrics_file.write('Text\tAccuracy\tPrecision\tRecall\tF1\tPairs\n') + + +def resolve(textname, text_data_path, metrics_file): + raw_data = open(text_data_path, 'rt') + test_data = np.loadtxt(raw_data, delimiter='\t') + test_set = test_data[:, 0:NUMBER_OF_FEATURES] + test_labels = test_data[:, NUMBER_OF_FEATURES] # last column consists of labels + + inputs = Input(shape=(NUMBER_OF_FEATURES,)) + output_from_1st_layer = Dense(1000, activation='relu')(inputs) + output_from_1st_layer = Dropout(0.5)(output_from_1st_layer) + output_from_1st_layer = BatchNormalization()(output_from_1st_layer) + output_from_2nd_layer = Dense(500, activation='relu')(output_from_1st_layer) + output_from_2nd_layer = Dropout(0.5)(output_from_2nd_layer) + output_from_2nd_layer = BatchNormalization()(output_from_2nd_layer) + output = Dense(1, activation='sigmoid')(output_from_2nd_layer) + + model = Model(inputs, output) + model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy']) + model.load_weights(MODEL) + + predictions = model.predict(test_set) + + calc_metrics(textname, test_set, test_labels, predictions, metrics_file) + + +def calc_metrics(textname, test_set, test_labels, predictions, metrics_file): + true_positives = 0.0 + false_positives = 0.0 + true_negatives = 0.0 + false_negatives = 0.0 + + for i in range(len(test_set)): + if (predictions[i] < 0.5 and test_labels[i] == 0): true_negatives += 1 + if (predictions[i] < 0.5 and test_labels[i] == 1): false_negatives += 1 + if (predictions[i] >= 0.5 and test_labels[i] == 1): true_positives += 1 + if (predictions[i] >= 0.5 and test_labels[i] == 0): false_positives += 1 + + accuracy = (true_positives + true_negatives) / len(test_set) + precision = true_positives / (true_positives + false_positives) + recall = true_positives / (true_positives + false_negatives) + f1 = 2 * (precision * recall) / (precision + recall) + + metrics_file.write('%s\t%s\t%s\t%s\t%s\t%s\n' % (textname, + repr(accuracy), + repr(precision), + repr(recall), + repr(f1), + repr(len(test_set)))) + + +if __name__ == '__main__': + main()