Commit c2871e0ded5bfb23380ab7c041d4237e1a0c8481

Authored by Bartłomiej Nitoń
1 parent 01a04337

Basic evaluation and data preparation scripts.

counter.py 0 → 100644
  1 +# -*- coding: utf-8 -*-
  2 +
  3 +import os
  4 +
  5 +from lxml import etree
  6 +from natsort import natsorted
  7 +
  8 +from preparator import ANNO_PATH
  9 +
  10 +
  11 +def count_words():
  12 + anno_files = os.listdir(ANNO_PATH)
  13 + anno_files = natsorted(anno_files)
  14 + for filename in anno_files:
  15 + if filename.endswith('.mmax'):
  16 + words_count = 0
  17 + textname = filename.replace('.mmax', '')
  18 + words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname)
  19 + tree = etree.parse(words_path)
  20 + for word in tree.xpath("//word"):
  21 + if word.attrib['ctag'] != 'interp':
  22 + words_count += 1
  23 + print textname, words_count
  24 +
  25 +
  26 +def count_mentions():
  27 + anno_files = os.listdir(ANNO_PATH)
  28 + anno_files = natsorted(anno_files)
  29 + for filename in anno_files:
  30 + if filename.endswith('.mmax'):
  31 + textname = filename.replace('.mmax', '')
  32 +
  33 + mentions_path = os.path.join(ANNO_PATH, '%s_mentions.xml' % textname)
  34 + tree = etree.parse(mentions_path)
  35 + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'})
  36 + print textname, len(mentions)
... ...
preparator.py 0 → 100644
  1 +# -*- coding: utf-8 -*-
  2 +
  3 +import codecs
  4 +import numpy
  5 +import os
  6 +import random
  7 +
  8 +from lxml import etree
  9 +from itertools import combinations
  10 +from natsort import natsorted
  11 +
  12 +from gensim.models.word2vec import Word2Vec
  13 +
  14 +
  15 +TEST_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'test-prepared'))
  16 +TRAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data', 'train-prepared'))
  17 +
  18 +ANNO_PATH = TEST_PATH
  19 +OUT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data',
  20 + 'test.csv'))
  21 +EACH_TEXT_SEPARATELLY = False
  22 +
  23 +CONTEXT = 5
  24 +W2V_SIZE = 50
  25 +MODEL = os.path.abspath(os.path.join(os.path.dirname(__file__), 'models',
  26 + '%d' % W2V_SIZE,
  27 + 'w2v_allwiki_nkjpfull_%d.model' % W2V_SIZE))
  28 +POSSIBLE_HEADS = [u'§', u'%', u'*', u'"', u'„', u'&', u'-']
  29 +NEG_PROPORTION = 1
  30 +RANDOM_VECTORS = True
  31 +
  32 +DEBUG = False
  33 +POS_COUNT = 0
  34 +NEG_COUNT = 0
  35 +ALL_WORDS = 0
  36 +UNKNONW_WORDS = 0
  37 +
  38 +
  39 +def main():
  40 + model = Word2Vec.load(MODEL)
  41 + try:
  42 + create_data_vectors(model)
  43 + finally:
  44 + print 'Unknown words: ', UNKNONW_WORDS
  45 + print 'All words: ', ALL_WORDS
  46 + print 'Positives: ', POS_COUNT
  47 + print 'Negatives: ', NEG_COUNT
  48 +
  49 +
  50 +def create_data_vectors(model):
  51 + features_file = None
  52 + if not EACH_TEXT_SEPARATELLY:
  53 + features_file = codecs.open(OUT_PATH, 'wt', 'utf-8')
  54 +
  55 + anno_files = os.listdir(ANNO_PATH)
  56 + anno_files = natsorted(anno_files)
  57 + for filename in anno_files:
  58 + if filename.endswith('.mmax'):
  59 + print '=======> ', filename
  60 + textname = filename.replace('.mmax', '')
  61 +
  62 + mentions_path = os.path.join(ANNO_PATH, '%s_mentions.xml' % textname)
  63 + tree = etree.parse(mentions_path)
  64 + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'})
  65 + positives, negatives = diff_mentions(mentions)
  66 +
  67 + if DEBUG:
  68 + print 'Positives:'
  69 + print len(positives)
  70 +
  71 + print 'Negatives:'
  72 + print len(negatives)
  73 +
  74 + words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname)
  75 + mentions_dict = markables_level_2_dict(mentions_path, words_path)
  76 +
  77 + if EACH_TEXT_SEPARATELLY:
  78 + text_features_path = os.path.join(OUT_PATH, '%s.csv' % textname)
  79 + features_file = codecs.open(text_features_path, 'wt', 'utf-8')
  80 + write_features(features_file, positives, negatives, mentions_dict, model, textname)
  81 +
  82 + if not EACH_TEXT_SEPARATELLY:
  83 + features_file.close()
  84 +
  85 +
  86 +def diff_mentions(mentions):
  87 + sets, clustered_mensions = get_sets(mentions)
  88 + positives = get_positives(sets)
  89 + positives, negatives = get_negatives_and_update_positives(clustered_mensions, positives)
  90 + if len(negatives) != len(positives) and NEG_PROPORTION == 1:
  91 + print u'Niezgodna liczba przypadków pozytywnych i negatywnych!'
  92 + return positives, negatives
  93 +
  94 +
  95 +def get_sets(mentions):
  96 + sets = {}
  97 + clustered_mensions = []
  98 + for mention in mentions:
  99 + set_id = mention.attrib['mention_group']
  100 + if set_id == 'empty' or set_id == '' or mention.attrib['mention_head'] in POSSIBLE_HEADS:
  101 + pass
  102 + elif set_id not in sets:
  103 + sets[set_id] = [mention.attrib['span']]
  104 + clustered_mensions.append(mention.attrib['span'])
  105 + elif set_id in sets:
  106 + sets[set_id].append(mention.attrib['span'])
  107 + clustered_mensions.append(mention.attrib['span'])
  108 + else:
  109 + print u'Coś poszło nie tak przy wyszukiwaniu klastrów!'
  110 +
  111 + sets_to_remove = []
  112 + for set_id in sets:
  113 + if len(sets[set_id]) < 2:
  114 + sets_to_remove.append(set_id)
  115 + if len(sets[set_id]) == 1:
  116 + print u'Removing clustered mention: ', sets[set_id][0]
  117 + clustered_mensions.remove(sets[set_id][0])
  118 +
  119 + for set_id in sets_to_remove:
  120 + print u'Removing set: ', set_id
  121 + sets.pop(set_id)
  122 +
  123 + return sets, clustered_mensions
  124 +
  125 +
  126 +def get_positives(sets):
  127 + positives = []
  128 + for set_id in sets:
  129 + coref_set = sets[set_id]
  130 + positives.extend(list(combinations(coref_set, 2)))
  131 + return positives
  132 +
  133 +
  134 +def get_negatives_and_update_positives(clustered_mensions, positives):
  135 + all_pairs = list(combinations(clustered_mensions, 2))
  136 + all_pairs = set(all_pairs)
  137 + negatives = [pair for pair in all_pairs if pair not in positives]
  138 + samples_count = NEG_PROPORTION * len(positives)
  139 + if samples_count > len(negatives):
  140 + samples_count = len(negatives)
  141 + if NEG_PROPORTION == 1:
  142 + positives = random.sample(set(positives), samples_count)
  143 + print u'Więcej przypadków pozytywnych niż negatywnych!'
  144 + negatives = random.sample(set(negatives), samples_count)
  145 + return positives, negatives
  146 +
  147 +
  148 +def write_features(features_file, positives, negatives, mentions_dict, model, textname):
  149 + global POS_COUNT
  150 + POS_COUNT += len(positives)
  151 + for pair in positives:
  152 + pair_features = []
  153 + if DEBUG:
  154 + pair_features = ['%s>%s:%s' % (textname, pair[0], pair[1])]
  155 + pair_features.extend(get_features(pair, mentions_dict, model))
  156 + pair_features.append(1)
  157 + features_file.write(u'%s\n' % u'\t'.join([unicode(feature) for feature in pair_features]))
  158 +
  159 + global NEG_COUNT
  160 + NEG_COUNT += len(negatives)
  161 + for pair in negatives:
  162 + pair_features = []
  163 + if DEBUG:
  164 + pair_features = ['%s>%s:%s' % (textname, pair[0], pair[1])]
  165 + pair_features.extend(get_features(pair, mentions_dict, model))
  166 + pair_features.append(0)
  167 + features_file.write(u'%s\n' % u'\t'.join([unicode(feature) for feature in pair_features]))
  168 +
  169 +
  170 +def get_features(pair, mentions_dict, model):
  171 + features = []
  172 + ante = pair[0]
  173 + ana = pair[1]
  174 + ante_features = get_mention_features(ante, mentions_dict, model)
  175 + features.extend(ante_features)
  176 + ana_features = get_mention_features(ana, mentions_dict, model)
  177 + features.extend(ana_features)
  178 + pair_features = get_pair_features(pair, mentions_dict)
  179 + features.extend(pair_features)
  180 + return features
  181 +
  182 +
  183 +def get_mention_features(mention_span, mentions_dict, model):
  184 + features = []
  185 + mention = get_mention_by_attr(mentions_dict, 'span', mention_span)
  186 +
  187 + if DEBUG:
  188 + features.append(mention['head_base'])
  189 + head_vec = get_wv(model, mention['head_base'])
  190 + features.extend(list(head_vec))
  191 +
  192 + if DEBUG:
  193 + features.append(mention['words'][0]['base'])
  194 + first_vec = get_wv(model, mention['words'][0]['base'])
  195 + features.extend(list(first_vec))
  196 +
  197 + if DEBUG:
  198 + features.append(mention['words'][-1]['base'])
  199 + last_vec = get_wv(model, mention['words'][-1]['base'])
  200 + features.extend(list(last_vec))
  201 +
  202 + if len(mention['follow_context']) > 0:
  203 + if DEBUG:
  204 + features.append(mention['follow_context'][0]['base'])
  205 + after_1_vec = get_wv(model, mention['follow_context'][0]['base'])
  206 + features.extend(list(after_1_vec))
  207 + else:
  208 + if DEBUG:
  209 + features.append('None')
  210 + features.extend([0.0] * W2V_SIZE)
  211 + if len(mention['follow_context']) > 1:
  212 + if DEBUG:
  213 + features.append(mention['follow_context'][1]['base'])
  214 + after_2_vec = get_wv(model, mention['follow_context'][1]['base'])
  215 + features.extend(list(after_2_vec))
  216 + else:
  217 + if DEBUG:
  218 + features.append('None')
  219 + features.extend([0.0] * W2V_SIZE)
  220 +
  221 + if len(mention['prec_context']) > 0:
  222 + if DEBUG:
  223 + features.append(mention['prec_context'][-1]['base'])
  224 + prec_1_vec = get_wv(model, mention['prec_context'][-1]['base'])
  225 + features.extend(list(prec_1_vec))
  226 + else:
  227 + if DEBUG:
  228 + features.append('None')
  229 + features.extend([0.0] * W2V_SIZE)
  230 + if len(mention['prec_context']) > 1:
  231 + if DEBUG:
  232 + features.append(mention['prec_context'][-2]['base'])
  233 + prec_2_vec = get_wv(model, mention['prec_context'][-2]['base'])
  234 + features.extend(list(prec_2_vec))
  235 + else:
  236 + if DEBUG:
  237 + features.append('None')
  238 + features.extend([0.0] * W2V_SIZE)
  239 +
  240 + if DEBUG:
  241 + features.append(u' '.join([word['orth'] for word in mention['prec_context']]))
  242 + prec_vec = get_context_vec(mention['prec_context'], model)
  243 + features.extend(list(prec_vec))
  244 +
  245 + if DEBUG:
  246 + features.append(u' '.join([word['orth'] for word in mention['follow_context']]))
  247 + follow_vec = get_context_vec(mention['follow_context'], model)
  248 + features.extend(list(follow_vec))
  249 +
  250 + if DEBUG:
  251 + features.append(u' '.join([word['orth'] for word in mention['words']]))
  252 + mention_vec = get_context_vec(mention['words'], model)
  253 + features.extend(list(mention_vec))
  254 +
  255 + if DEBUG:
  256 + features.append(u' '.join([word['orth'] for word in mention['sentence']]))
  257 + sentence_vec = get_context_vec(mention['sentence'], model)
  258 + features.extend(list(sentence_vec))
  259 +
  260 + return features
  261 +
  262 +
  263 +def get_wv(model, lemma, random=True):
  264 + global ALL_WORDS
  265 + global UNKNONW_WORDS
  266 + vec = None
  267 + if random:
  268 + vec = random_vec()
  269 + ALL_WORDS += 1
  270 + try:
  271 + vec = model.wv[lemma]
  272 + except KeyError:
  273 + UNKNONW_WORDS += 1
  274 + return vec
  275 +
  276 +
  277 +def random_vec():
  278 + return numpy.asarray([random.uniform(-0.25, 0.25) for i in range(0, W2V_SIZE)], dtype=numpy.float32)
  279 +
  280 +
  281 +def get_context_vec(words, model):
  282 + vec = numpy.zeros(W2V_SIZE, dtype=numpy.float32)
  283 + unknown_count = 0
  284 + if len(words) != 0:
  285 + for word in words:
  286 + word_vec = get_wv(model, word['base'], RANDOM_VECTORS)
  287 + if word_vec is None:
  288 + unknown_count += 1
  289 + else:
  290 + vec += word_vec
  291 + significant_words = len(words) - unknown_count
  292 + if significant_words != 0:
  293 + vec = vec/float(significant_words)
  294 + else:
  295 + vec = random_vec()
  296 + return vec
  297 +
  298 +
  299 +def get_pair_features(pair, mentions_dict):
  300 + ante = get_mention_by_attr(mentions_dict, 'span', pair[0])
  301 + ana = get_mention_by_attr(mentions_dict, 'span', pair[1])
  302 +
  303 + features = []
  304 + mnts_intersect = pair_intersect(ante, ana)
  305 +
  306 + words_dist = [0] * 11
  307 + words_bucket = 0
  308 + if mnts_intersect != 1:
  309 + words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words'] - 1)
  310 + if DEBUG:
  311 + features.append('Bucket %d' % words_bucket)
  312 + words_dist[words_bucket] = 1
  313 + features.extend(words_dist)
  314 +
  315 + mentions_dist = [0] * 11
  316 + mentions_bucket = 0
  317 + if mnts_intersect != 1:
  318 + mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions'] - 1)
  319 + if words_bucket == 10:
  320 + mentions_bucket = 10
  321 + if DEBUG:
  322 + features.append('Bucket %d' % mentions_bucket)
  323 + mentions_dist[mentions_bucket] = 1
  324 + features.extend(mentions_dist)
  325 +
  326 + if DEBUG:
  327 + features.append('Other features')
  328 + features.append(mnts_intersect)
  329 + features.append(head_match(ante, ana))
  330 + features.append(exact_match(ante, ana))
  331 + features.append(base_match(ante, ana))
  332 +
  333 + if len(mentions_dict) > 100:
  334 + features.append(1)
  335 + else:
  336 + features.append(0)
  337 +
  338 + return features
  339 +
  340 +
  341 +def get_distance_bucket(distance):
  342 + if distance >= 0 and distance <= 4:
  343 + return distance
  344 + elif distance >= 5 and distance <= 7:
  345 + return 5
  346 + elif distance >= 8 and distance <= 15:
  347 + return 6
  348 + elif distance >= 16 and distance <= 31:
  349 + return 7
  350 + elif distance >= 32 and distance <= 63:
  351 + return 8
  352 + elif distance >= 64:
  353 + return 9
  354 + else:
  355 + print u'Coś poszło nie tak przy kubełkowaniu!!'
  356 + return 10
  357 +
  358 +
  359 +def pair_intersect(ante, ana):
  360 + for ante_word in ante['words']:
  361 + for ana_word in ana['words']:
  362 + if ana_word['id'] == ante_word['id']:
  363 + return 1
  364 + return 0
  365 +
  366 +
  367 +def head_match(ante, ana):
  368 + if ante['head_orth'].lower() == ana['head_orth'].lower():
  369 + return 1
  370 + return 0
  371 +
  372 +
  373 +def exact_match(ante, ana):
  374 + if ante['text'].lower() == ana['text'].lower():
  375 + return 1
  376 + return 0
  377 +
  378 +
  379 +def base_match(ante, ana):
  380 + if ante['lemmatized_text'].lower() == ana['lemmatized_text'].lower():
  381 + return 1
  382 + return 0
  383 +
  384 +
  385 +def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/NameSpaces/mention'):
  386 + markables_dicts = []
  387 + markables_tree = etree.parse(markables_path)
  388 + markables = markables_tree.xpath("//ns:markable", namespaces={'ns': namespace})
  389 +
  390 + words = get_words(words_path)
  391 +
  392 + for idx, markable in enumerate(markables):
  393 + span = markable.attrib['span']
  394 + if not get_mention_by_attr(markables_dicts, 'span', span):
  395 +
  396 + dominant = ''
  397 + if 'dominant' in markable.attrib:
  398 + dominant = markable.attrib['dominant']
  399 +
  400 + head_orth = markable.attrib['mention_head']
  401 + if head_orth not in POSSIBLE_HEADS:
  402 + mention_words = span_to_words(span, words)
  403 +
  404 + prec_context, follow_context, sentence, mnt_start_position, mnt_end_position = get_context(mention_words, words)
  405 +
  406 + head_base = get_head_base(head_orth, mention_words)
  407 + markables_dicts.append({'id': markable.attrib['id'],
  408 + 'set': markable.attrib['mention_group'],
  409 + 'text': span_to_text(span, words, 'orth'),
  410 + 'lemmatized_text': span_to_text(span, words, 'base'),
  411 + 'words': mention_words,
  412 + 'span': span,
  413 + 'head_orth': head_orth,
  414 + 'head_base': head_base,
  415 + 'dominant': dominant,
  416 + 'node': markable,
  417 + 'prec_context': prec_context,
  418 + 'follow_context': follow_context,
  419 + 'sentence': sentence,
  420 + 'position_in_mentions': idx,
  421 + 'start_in_words': mnt_start_position,
  422 + 'end_in_words': mnt_end_position})
  423 + else:
  424 + print 'Zduplikowana wzmianka: %s' % span
  425 +
  426 + return markables_dicts
  427 +
  428 +
  429 +def get_context(mention_words, words):
  430 + prec_context = []
  431 + follow_context = []
  432 + sentence = []
  433 + mnt_start_position = -1
  434 + first_word = mention_words[0]
  435 + last_word = mention_words[-1]
  436 + for idx, word in enumerate(words):
  437 + if word['id'] == first_word['id']:
  438 + prec_context = get_prec_context(idx, words)
  439 + mnt_start_position = get_mention_start(first_word, words)
  440 + if word['id'] == last_word['id']:
  441 + follow_context = get_follow_context(idx, words)
  442 + sentence = get_sentence(idx, words)
  443 + mnt_end_position = get_mention_end(last_word, words)
  444 + break
  445 + return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position
  446 +
  447 +
  448 +def get_prec_context(mention_start, words):
  449 + context = []
  450 + context_start = mention_start - 1
  451 + while context_start >= 0:
  452 + if not word_to_ignore(words[context_start]):
  453 + context.append(words[context_start])
  454 + if len(context) == CONTEXT:
  455 + break
  456 + context_start -= 1
  457 + context.reverse()
  458 + return context
  459 +
  460 +
  461 +def get_mention_start(first_word, words):
  462 + start = 0
  463 + for word in words:
  464 + if not word_to_ignore(word):
  465 + start += 1
  466 + if word['id'] == first_word['id']:
  467 + break
  468 + return start
  469 +
  470 +
  471 +def get_mention_end(last_word, words):
  472 + end = 0
  473 + for word in words:
  474 + if not word_to_ignore(word):
  475 + end += 1
  476 + if word['id'] == last_word['id']:
  477 + break
  478 + return end
  479 +
  480 +
  481 +def get_follow_context(mention_end, words):
  482 + context = []
  483 + context_end = mention_end + 1
  484 + while context_end < len(words):
  485 + if not word_to_ignore(words[context_end]):
  486 + context.append(words[context_end])
  487 + if len(context) == CONTEXT:
  488 + break
  489 + context_end += 1
  490 + return context
  491 +
  492 +
  493 +def get_sentence(word_idx, words):
  494 + sentence_start = get_sentence_start(words, word_idx)
  495 + sentence_end = get_sentence_end(words, word_idx)
  496 + sentence = [word for word in words[sentence_start:sentence_end+1] if not word_to_ignore(word)]
  497 + return sentence
  498 +
  499 +
  500 +def get_sentence_start(words, word_idx):
  501 + search_start = word_idx
  502 + while word_idx >= 0:
  503 + if words[word_idx]['lastinsent'] and search_start != word_idx:
  504 + return word_idx+1
  505 + word_idx -= 1
  506 + return 0
  507 +
  508 +
  509 +def get_sentence_end(words, word_idx):
  510 + while word_idx < len(words):
  511 + if words[word_idx]['lastinsent']:
  512 + return word_idx
  513 + word_idx += 1
  514 + return len(words) - 1
  515 +
  516 +
  517 +def get_head_base(head_orth, words):
  518 + for word in words:
  519 + if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth:
  520 + return word['base']
  521 + return None
  522 +
  523 +
  524 +def get_words(filepath):
  525 + tree = etree.parse(filepath)
  526 + words = []
  527 + for word in tree.xpath("//word"):
  528 + hasnps = False
  529 + if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true':
  530 + hasnps = True
  531 + lastinsent = False
  532 + if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true':
  533 + lastinsent = True
  534 + words.append({'id': word.attrib['id'],
  535 + 'orth': word.text,
  536 + 'base': word.attrib['base'],
  537 + 'hasnps': hasnps,
  538 + 'lastinsent': lastinsent,
  539 + 'ctag': word.attrib['ctag']})
  540 + return words
  541 +
  542 +
  543 +def get_mention_by_attr(mentions, attr_name, value):
  544 + for mention in mentions:
  545 + if mention[attr_name] == value:
  546 + return mention
  547 + return None
  548 +
  549 +
  550 +def get_mention_index_by_attr(mentions, attr_name, value):
  551 + for idx, mention in enumerate(mentions):
  552 + if mention[attr_name] == value:
  553 + return idx
  554 + return None
  555 +
  556 +
  557 +def span_to_text(span, words, form):
  558 + fragments = span.split(',')
  559 + mention_parts = []
  560 + for fragment in fragments:
  561 + mention_parts.append(fragment_to_text(fragment, words, form))
  562 + return u' [...] '.join(mention_parts)
  563 +
  564 +
  565 +def fragment_to_text(fragment, words, form):
  566 + if '..' in fragment:
  567 + text = get_multiword_text(fragment, words, form)
  568 + else:
  569 + text = get_one_word_text(fragment, words, form)
  570 + return text
  571 +
  572 +
  573 +def get_multiword_text(fragment, words, form):
  574 + mention_parts = []
  575 + boundaries = fragment.split('..')
  576 + start_id = boundaries[0]
  577 + end_id = boundaries[1]
  578 + in_string = False
  579 + for word in words:
  580 + if word['id'] == start_id:
  581 + in_string = True
  582 + if in_string and not word_to_ignore(word):
  583 + mention_parts.append(word)
  584 + if word['id'] == end_id:
  585 + break
  586 + return to_text(mention_parts, form)
  587 +
  588 +
  589 +def to_text(words, form):
  590 + text = ''
  591 + for idx, word in enumerate(words):
  592 + if word['hasnps'] or idx == 0:
  593 + text += word[form]
  594 + else:
  595 + text += u' %s' % word[form]
  596 + return text
  597 +
  598 +
  599 +def get_one_word_text(word_id, words, form):
  600 + this_word = (word for word in words if word['id'] == word_id).next()
  601 + if word_to_ignore(this_word):
  602 + print this_word
  603 + return this_word[form]
  604 +
  605 +
  606 +def span_to_words(span, words):
  607 + fragments = span.split(',')
  608 + mention_parts = []
  609 + for fragment in fragments:
  610 + mention_parts.extend(fragment_to_words(fragment, words))
  611 + return mention_parts
  612 +
  613 +
  614 +def fragment_to_words(fragment, words):
  615 + mention_parts = []
  616 + if '..' in fragment:
  617 + mention_parts.extend(get_multiword(fragment, words))
  618 + else:
  619 + mention_parts.extend(get_word(fragment, words))
  620 + return mention_parts
  621 +
  622 +
  623 +def get_multiword(fragment, words):
  624 + mention_parts = []
  625 + boundaries = fragment.split('..')
  626 + start_id = boundaries[0]
  627 + end_id = boundaries[1]
  628 + in_string = False
  629 + for word in words:
  630 + if word['id'] == start_id:
  631 + in_string = True
  632 + if in_string and not word_to_ignore(word):
  633 + mention_parts.append(word)
  634 + if word['id'] == end_id:
  635 + break
  636 + return mention_parts
  637 +
  638 +
  639 +def get_word(word_id, words):
  640 + for word in words:
  641 + if word['id'] == word_id:
  642 + if not word_to_ignore(word):
  643 + return [word]
  644 + else:
  645 + return []
  646 + return []
  647 +
  648 +
  649 +def word_to_ignore(word):
  650 + if word['ctag'] == 'interp':
  651 + return True
  652 + return False
  653 +
  654 +
  655 +if __name__ == '__main__':
  656 + main()
... ...
resolver.py 0 → 100644
  1 +# -*- coding: utf-8 -*-
  2 +
  3 +import codecs
  4 +import os
  5 +
  6 +import numpy as np
  7 +
  8 +from natsort import natsorted
  9 +
  10 +from keras.models import Model
  11 +from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization
  12 +from keras.optimizers import SGD, Adam
  13 +
  14 +IN_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data',
  15 + 'prepared_text_files'))
  16 +OUT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'data',
  17 + 'metrics.csv'))
  18 +
  19 +MODEL = os.path.abspath(os.path.join(os.path.dirname(__file__), 'weights_2017_05_10.h5'))
  20 +
  21 +
  22 +NUMBER_OF_FEATURES = 1126
  23 +
  24 +
  25 +def main():
  26 + resolve_files()
  27 +
  28 +
  29 +def resolve_files():
  30 + metrics_file = codecs.open(OUT_PATH, 'w', 'utf-8')
  31 + write_labels(metrics_file)
  32 +
  33 + anno_files = os.listdir(IN_PATH)
  34 + anno_files = natsorted(anno_files)
  35 + for filename in anno_files:
  36 + print (filename)
  37 + textname = filename.replace('.csv', '')
  38 + text_data_path = os.path.join(IN_PATH, filename)
  39 + resolve(textname, text_data_path, metrics_file)
  40 +
  41 + metrics_file.close()
  42 +
  43 +
  44 +def write_labels(metrics_file):
  45 + metrics_file.write('Text\tAccuracy\tPrecision\tRecall\tF1\tPairs\n')
  46 +
  47 +
  48 +def resolve(textname, text_data_path, metrics_file):
  49 + raw_data = open(text_data_path, 'rt')
  50 + test_data = np.loadtxt(raw_data, delimiter='\t')
  51 + test_set = test_data[:, 0:NUMBER_OF_FEATURES]
  52 + test_labels = test_data[:, NUMBER_OF_FEATURES] # last column consists of labels
  53 +
  54 + inputs = Input(shape=(NUMBER_OF_FEATURES,))
  55 + output_from_1st_layer = Dense(1000, activation='relu')(inputs)
  56 + output_from_1st_layer = Dropout(0.5)(output_from_1st_layer)
  57 + output_from_1st_layer = BatchNormalization()(output_from_1st_layer)
  58 + output_from_2nd_layer = Dense(500, activation='relu')(output_from_1st_layer)
  59 + output_from_2nd_layer = Dropout(0.5)(output_from_2nd_layer)
  60 + output_from_2nd_layer = BatchNormalization()(output_from_2nd_layer)
  61 + output = Dense(1, activation='sigmoid')(output_from_2nd_layer)
  62 +
  63 + model = Model(inputs, output)
  64 + model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
  65 + model.load_weights(MODEL)
  66 +
  67 + predictions = model.predict(test_set)
  68 +
  69 + calc_metrics(textname, test_set, test_labels, predictions, metrics_file)
  70 +
  71 +
  72 +def calc_metrics(textname, test_set, test_labels, predictions, metrics_file):
  73 + true_positives = 0.0
  74 + false_positives = 0.0
  75 + true_negatives = 0.0
  76 + false_negatives = 0.0
  77 +
  78 + for i in range(len(test_set)):
  79 + if (predictions[i] < 0.5 and test_labels[i] == 0): true_negatives += 1
  80 + if (predictions[i] < 0.5 and test_labels[i] == 1): false_negatives += 1
  81 + if (predictions[i] >= 0.5 and test_labels[i] == 1): true_positives += 1
  82 + if (predictions[i] >= 0.5 and test_labels[i] == 0): false_positives += 1
  83 +
  84 + accuracy = (true_positives + true_negatives) / len(test_set)
  85 + precision = true_positives / (true_positives + false_positives)
  86 + recall = true_positives / (true_positives + false_negatives)
  87 + f1 = 2 * (precision * recall) / (precision + recall)
  88 +
  89 + metrics_file.write('%s\t%s\t%s\t%s\t%s\t%s\n' % (textname,
  90 + repr(accuracy),
  91 + repr(precision),
  92 + repr(recall),
  93 + repr(f1),
  94 + repr(len(test_set))))
  95 +
  96 +
  97 +if __name__ == '__main__':
  98 + main()
... ...