Commit 445955b9f02279ffd510804a8c6d92231529b004

Authored by Bartłomiej Nitoń
1 parent 08bfdfc4

Added singletons to negative pool in preparator script.

Showing 2 changed files with 480 additions and 8 deletions
count_dist.py 0 → 100644
  1 +# -*- coding: utf-8 -*-
  2 +
  3 +import os
  4 +
  5 +
  6 +from lxml import etree
  7 +from natsort import natsorted
  8 +
  9 +
  10 +MAIN_PATH = os.path.dirname(__file__)
  11 +TEST_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'test-prepared'))
  12 +TRAIN_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'train-prepared'))
  13 +
  14 +ANNO_PATH = TRAIN_PATH
  15 +
  16 +CONTEXT = 5
  17 +POSSIBLE_HEADS = [u'§', u'%', u'*', u'"', u'„', u'&', u'-']
  18 +
  19 +
  20 +def main():
  21 + max_mnt_dist = count_max_mnt_dist()
  22 + print ('Max mention distance (positive pairs): %d' % max_mnt_dist)
  23 +
  24 +
  25 +def count_max_mnt_dist():
  26 + global_max_mnt_dist = 0
  27 + anno_files = os.listdir(ANNO_PATH)
  28 + anno_files = natsorted(anno_files)
  29 + for filename in anno_files:
  30 + if filename.endswith('.mmax'):
  31 + print ('=======> ', filename)
  32 + textname = filename.replace('.mmax', '')
  33 +
  34 + mentions_path = os.path.join(ANNO_PATH, '%s_mentions.xml' % textname)
  35 + tree = etree.parse(mentions_path)
  36 + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'})
  37 +
  38 + words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname)
  39 + mentions_dict = markables_level_2_dict(mentions_path, words_path)
  40 +
  41 + file_max_mnt_dist = get_max_file_dist(mentions, mentions_dict)
  42 + if file_max_mnt_dist > global_max_mnt_dist:
  43 + global_max_mnt_dist = file_max_mnt_dist
  44 +
  45 + return global_max_mnt_dist
  46 +
  47 +
  48 +def get_max_file_dist(mentions, mentions_dict):
  49 + max_file_dist = 0
  50 + sets, all_mentions, clustered_mensions = get_sets(mentions)
  51 + for set_id in sets:
  52 + set_dist = get_max_set_dist(sets[set_id], mentions_dict)
  53 + if set_dist > max_file_dist:
  54 + max_file_dist = set_dist
  55 + print ('Max mention distance: %d' % max_file_dist)
  56 + return max_file_dist
  57 +
  58 +
  59 +def get_sets(mentions):
  60 + sets = {}
  61 + all_mentions = []
  62 + clustered_mensions = []
  63 + for mention in mentions:
  64 + all_mentions.append(mention.attrib['span'])
  65 + set_id = mention.attrib['mention_group']
  66 + if set_id == 'empty' or set_id == '':
  67 + pass
  68 + elif set_id not in sets:
  69 + sets[set_id] = [mention.attrib['span']]
  70 + clustered_mensions.append(mention.attrib['span'])
  71 + elif set_id in sets:
  72 + sets[set_id].append(mention.attrib['span'])
  73 + clustered_mensions.append(mention.attrib['span'])
  74 + else:
  75 + print (u'Coś poszło nie tak przy wyszukiwaniu klastrów!')
  76 +
  77 + sets_to_remove = []
  78 + for set_id in sets:
  79 + if len(sets[set_id]) < 2:
  80 + sets_to_remove.append(set_id)
  81 + if len(sets[set_id]) == 1:
  82 + print (u'Removing clustered mention: ', sets[set_id][0])
  83 + clustered_mensions.remove(sets[set_id][0])
  84 +
  85 + for set_id in sets_to_remove:
  86 + print (u'Removing set: ', set_id)
  87 + sets.pop(set_id)
  88 +
  89 + return sets, all_mentions, clustered_mensions
  90 +
  91 +
  92 +def get_max_set_dist(mnt_set, mentions_dict):
  93 + max_set_dist = 0
  94 + for id, mnt2_span in enumerate(mnt_set):
  95 + mnt2 = get_mention_by_attr(mentions_dict, 'span', mnt2_span)
  96 + dist = None
  97 + dist1 = None
  98 + if id - 1 >= 0:
  99 + mnt1_span = mnt_set[id - 1]
  100 + mnt1 = get_mention_by_attr(mentions_dict, 'span', mnt1_span)
  101 + dist1 = get_pair_dist(mnt1, mnt2)
  102 + dist = dist1
  103 + if id + 1 < len(mnt_set):
  104 + mnt3_span = mnt_set[id + 1]
  105 + mnt3 = get_mention_by_attr(mentions_dict, 'span', mnt3_span)
  106 + dist2 = get_pair_dist(mnt2, mnt3)
  107 + if dist1 is not None and dist2 < dist1:
  108 + dist = dist2
  109 +
  110 + if dist > max_set_dist:
  111 + max_set_dist = dist
  112 +
  113 + return max_set_dist
  114 +
  115 +
  116 +def get_pair_dist(ante, ana):
  117 + dist = 0
  118 + mnts_intersect = pair_intersect(ante, ana)
  119 + if mnts_intersect != 1:
  120 + dist = ana['position_in_mentions'] - ante['position_in_mentions']
  121 + return dist
  122 +
  123 +
  124 +def pair_intersect(ante, ana):
  125 + for ante_word in ante['words']:
  126 + for ana_word in ana['words']:
  127 + if ana_word['id'] == ante_word['id']:
  128 + return 1
  129 + return 0
  130 +
  131 +
  132 +def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/NameSpaces/mention'):
  133 + markables_dicts = []
  134 + markables_tree = etree.parse(markables_path)
  135 + markables = markables_tree.xpath("//ns:markable", namespaces={'ns': namespace})
  136 +
  137 + words = get_words(words_path)
  138 +
  139 + for idx, markable in enumerate(markables):
  140 + span = markable.attrib['span']
  141 + if not get_mention_by_attr(markables_dicts, 'span', span):
  142 +
  143 + dominant = ''
  144 + if 'dominant' in markable.attrib:
  145 + dominant = markable.attrib['dominant']
  146 +
  147 + head_orth = markable.attrib['mention_head']
  148 + if True:
  149 + mention_words = span_to_words(span, words)
  150 +
  151 + (prec_context, follow_context, sentence, mnt_start_position, mnt_end_position,
  152 + paragraph_id, sentence_id, first_in_sentence, first_in_paragraph) = get_context(mention_words, words)
  153 +
  154 + head = get_head(head_orth, mention_words)
  155 + markables_dicts.append({'id': markable.attrib['id'],
  156 + 'set': markable.attrib['mention_group'],
  157 + 'text': span_to_text(span, words, 'orth'),
  158 + 'lemmatized_text': span_to_text(span, words, 'base'),
  159 + 'words': mention_words,
  160 + 'span': span,
  161 + 'head_orth': head_orth,
  162 + 'head': head,
  163 + 'dominant': dominant,
  164 + 'node': markable,
  165 + 'prec_context': prec_context,
  166 + 'follow_context': follow_context,
  167 + 'sentence': sentence,
  168 + 'position_in_mentions': idx,
  169 + 'start_in_words': mnt_start_position,
  170 + 'end_in_words': mnt_end_position,
  171 + 'paragraph_id': paragraph_id,
  172 + 'sentence_id': sentence_id,
  173 + 'first_in_sentence': first_in_sentence,
  174 + 'first_in_paragraph': first_in_paragraph})
  175 + else:
  176 + print ('Zduplikowana wzmianka: %s' % span)
  177 +
  178 + return markables_dicts
  179 +
  180 +
  181 +def get_context(mention_words, words):
  182 + paragraph_id = 0
  183 + sentence_id = 0
  184 + prec_context = []
  185 + follow_context = []
  186 + sentence = []
  187 + mnt_start_position = -1
  188 + first_word = mention_words[0]
  189 + last_word = mention_words[-1]
  190 + first_in_sentence = False
  191 + first_in_paragraph = False
  192 + for idx, word in enumerate(words):
  193 + if word['id'] == first_word['id']:
  194 + prec_context = get_prec_context(idx, words)
  195 + mnt_start_position = get_mention_start(first_word, words)
  196 + if idx == 0 or words[idx-1]['lastinsent']:
  197 + first_in_sentence = True
  198 + if idx == 0 or words[idx-1]['lastinpar']:
  199 + first_in_paragraph = True
  200 + if word['id'] == last_word['id']:
  201 + follow_context = get_follow_context(idx, words)
  202 + sentence = get_sentence(idx, words)
  203 + mnt_end_position = get_mention_end(last_word, words)
  204 + break
  205 + if word['lastinsent']:
  206 + sentence_id += 1
  207 + if word['lastinpar']:
  208 + paragraph_id += 1
  209 + return (prec_context, follow_context, sentence, mnt_start_position, mnt_end_position,
  210 + paragraph_id, sentence_id, first_in_sentence, first_in_paragraph)
  211 +
  212 +
  213 +def get_prec_context(mention_start, words):
  214 + context = []
  215 + context_start = mention_start - 1
  216 + while context_start >= 0:
  217 + if not word_to_ignore(words[context_start]):
  218 + context.append(words[context_start])
  219 + if len(context) == CONTEXT:
  220 + break
  221 + context_start -= 1
  222 + context.reverse()
  223 + return context
  224 +
  225 +
  226 +def get_mention_start(first_word, words):
  227 + start = 0
  228 + for word in words:
  229 + if not word_to_ignore(word):
  230 + start += 1
  231 + if word['id'] == first_word['id']:
  232 + break
  233 + return start
  234 +
  235 +
  236 +def get_mention_end(last_word, words):
  237 + end = 0
  238 + for word in words:
  239 + if not word_to_ignore(word):
  240 + end += 1
  241 + if word['id'] == last_word['id']:
  242 + break
  243 + return end
  244 +
  245 +
  246 +def get_follow_context(mention_end, words):
  247 + context = []
  248 + context_end = mention_end + 1
  249 + while context_end < len(words):
  250 + if not word_to_ignore(words[context_end]):
  251 + context.append(words[context_end])
  252 + if len(context) == CONTEXT:
  253 + break
  254 + context_end += 1
  255 + return context
  256 +
  257 +
  258 +def get_sentence(word_idx, words):
  259 + sentence_start = get_sentence_start(words, word_idx)
  260 + sentence_end = get_sentence_end(words, word_idx)
  261 + sentence = [word for word in words[sentence_start:sentence_end+1] if not word_to_ignore(word)]
  262 + return sentence
  263 +
  264 +
  265 +def get_sentence_start(words, word_idx):
  266 + search_start = word_idx
  267 + while word_idx >= 0:
  268 + if words[word_idx]['lastinsent'] and search_start != word_idx:
  269 + return word_idx+1
  270 + word_idx -= 1
  271 + return 0
  272 +
  273 +
  274 +def get_sentence_end(words, word_idx):
  275 + while word_idx < len(words):
  276 + if words[word_idx]['lastinsent']:
  277 + return word_idx
  278 + word_idx += 1
  279 + return len(words) - 1
  280 +
  281 +
  282 +def get_head(head_orth, words):
  283 + for word in words:
  284 + if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth:
  285 + return word
  286 + return None
  287 +
  288 +
  289 +def get_words(filepath):
  290 + tree = etree.parse(filepath)
  291 + words = []
  292 + for word in tree.xpath("//word"):
  293 + hasnps = False
  294 + if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true':
  295 + hasnps = True
  296 + lastinsent = False
  297 + if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true':
  298 + lastinsent = True
  299 + lastinpar = False
  300 + if 'lastinpar' in word.attrib and word.attrib['lastinpar'] == 'true':
  301 + lastinpar = True
  302 + words.append({'id': word.attrib['id'],
  303 + 'orth': word.text,
  304 + 'base': word.attrib['base'],
  305 + 'hasnps': hasnps,
  306 + 'lastinsent': lastinsent,
  307 + 'lastinpar': lastinpar,
  308 + 'ctag': word.attrib['ctag'],
  309 + 'msd': word.attrib['msd'],
  310 + 'gender': get_gender(word.attrib['msd']),
  311 + 'person': get_person(word.attrib['msd']),
  312 + 'number': get_number(word.attrib['msd'])})
  313 + return words
  314 +
  315 +
  316 +def get_gender(msd):
  317 + tags = msd.split(':')
  318 + if 'm1' in tags:
  319 + return 'm1'
  320 + elif 'm2' in tags:
  321 + return 'm2'
  322 + elif 'm3' in tags:
  323 + return 'm3'
  324 + elif 'f' in tags:
  325 + return 'f'
  326 + elif 'n' in tags:
  327 + return 'n'
  328 + else:
  329 + return 'unk'
  330 +
  331 +
  332 +def get_person(msd):
  333 + tags = msd.split(':')
  334 + if 'pri' in tags:
  335 + return 'pri'
  336 + elif 'sec' in tags:
  337 + return 'sec'
  338 + elif 'ter' in tags:
  339 + return 'ter'
  340 + else:
  341 + return 'unk'
  342 +
  343 +
  344 +def get_number(msd):
  345 + tags = msd.split(':')
  346 + if 'sg' in tags:
  347 + return 'sg'
  348 + elif 'pl' in tags:
  349 + return 'pl'
  350 + else:
  351 + return 'unk'
  352 +
  353 +
  354 +def get_mention_by_attr(mentions, attr_name, value):
  355 + for mention in mentions:
  356 + if mention[attr_name] == value:
  357 + return mention
  358 + return None
  359 +
  360 +
  361 +def get_mention_index_by_attr(mentions, attr_name, value):
  362 + for idx, mention in enumerate(mentions):
  363 + if mention[attr_name] == value:
  364 + return idx
  365 + return None
  366 +
  367 +
  368 +def span_to_text(span, words, form):
  369 + fragments = span.split(',')
  370 + mention_parts = []
  371 + for fragment in fragments:
  372 + mention_parts.append(fragment_to_text(fragment, words, form))
  373 + return u' [...] '.join(mention_parts)
  374 +
  375 +
  376 +def fragment_to_text(fragment, words, form):
  377 + if '..' in fragment:
  378 + text = get_multiword_text(fragment, words, form)
  379 + else:
  380 + text = get_one_word_text(fragment, words, form)
  381 + return text
  382 +
  383 +
  384 +def get_multiword_text(fragment, words, form):
  385 + mention_parts = []
  386 + boundaries = fragment.split('..')
  387 + start_id = boundaries[0]
  388 + end_id = boundaries[1]
  389 + in_string = False
  390 + for word in words:
  391 + if word['id'] == start_id:
  392 + in_string = True
  393 + if in_string and not word_to_ignore(word):
  394 + mention_parts.append(word)
  395 + if word['id'] == end_id:
  396 + break
  397 + return to_text(mention_parts, form)
  398 +
  399 +
  400 +def to_text(words, form):
  401 + text = ''
  402 + for idx, word in enumerate(words):
  403 + if word['hasnps'] or idx == 0:
  404 + text += word[form]
  405 + else:
  406 + text += u' %s' % word[form]
  407 + return text
  408 +
  409 +
  410 +def get_one_word_text(word_id, words, form):
  411 + this_word = next(word for word in words if word['id'] == word_id)
  412 + if word_to_ignore(this_word):
  413 + print (this_word)
  414 + return this_word[form]
  415 +
  416 +
  417 +def span_to_words(span, words):
  418 + fragments = span.split(',')
  419 + mention_parts = []
  420 + for fragment in fragments:
  421 + mention_parts.extend(fragment_to_words(fragment, words))
  422 + return mention_parts
  423 +
  424 +
  425 +def fragment_to_words(fragment, words):
  426 + mention_parts = []
  427 + if '..' in fragment:
  428 + mention_parts.extend(get_multiword(fragment, words))
  429 + else:
  430 + mention_parts.extend(get_word(fragment, words))
  431 + return mention_parts
  432 +
  433 +
  434 +def get_multiword(fragment, words):
  435 + mention_parts = []
  436 + boundaries = fragment.split('..')
  437 + start_id = boundaries[0]
  438 + end_id = boundaries[1]
  439 + in_string = False
  440 + for word in words:
  441 + if word['id'] == start_id:
  442 + in_string = True
  443 + if in_string and not word_to_ignore(word):
  444 + mention_parts.append(word)
  445 + if word['id'] == end_id:
  446 + break
  447 + return mention_parts
  448 +
  449 +
  450 +def get_word(word_id, words):
  451 + for word in words:
  452 + if word['id'] == word_id:
  453 + if not word_to_ignore(word):
  454 + return [word]
  455 + else:
  456 + return []
  457 + return []
  458 +
  459 +
  460 +def word_to_ignore(word):
  461 + return False
  462 +
  463 +
  464 +if __name__ == '__main__':
  465 + main()
... ...
preparator.py
... ... @@ -29,7 +29,7 @@ TITLE2REDIRECT_PATH = os.path.abspath(os.path.join(MAIN_PATH, &#39;data&#39;, &#39;wikipedia
29 29  
30 30 ANNO_PATH = TEST_PATH
31 31 OUT_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data',
32   - 'test-1to5-20170720.csv'))
  32 + 'test-1to5-singletons-20170720.csv'))
33 33 EACH_TEXT_SEPARATELLY = False
34 34  
35 35 CONTEXT = 5
... ... @@ -53,6 +53,7 @@ HYPHEN_SIGNS = [&#39;-&#39;, &#39;#&#39;]
53 53  
54 54 NEG_PROPORTION = 5
55 55 RANDOM_VECTORS = True
  56 +USE_SINGLETONS = True
56 57  
57 58 DEBUG = False
58 59 POS_COUNT = 0
... ... @@ -154,9 +155,9 @@ def create_data_vectors(model, freq_list, lemma2synonyms,
154 155  
155 156  
156 157 def diff_mentions(mentions):
157   - sets, clustered_mensions = get_sets(mentions)
  158 + sets, all_mentions, clustered_mensions = get_sets(mentions)
158 159 positives = get_positives(sets)
159   - positives, negatives = get_negatives_and_update_positives(clustered_mensions, positives)
  160 + positives, negatives = get_negatives_and_update_positives(all_mentions, clustered_mensions, positives)
160 161 if len(negatives) != len(positives) and NEG_PROPORTION == 1:
161 162 print (u'Niezgodna liczba przypadków pozytywnych i negatywnych!')
162 163 return positives, negatives
... ... @@ -164,8 +165,10 @@ def diff_mentions(mentions):
164 165  
165 166 def get_sets(mentions):
166 167 sets = {}
  168 + all_mentions = []
167 169 clustered_mensions = []
168 170 for mention in mentions:
  171 + all_mentions.append(mention.attrib['span'])
169 172 set_id = mention.attrib['mention_group']
170 173 if set_id == 'empty' or set_id == '' or mention.attrib['mention_head'] in POSSIBLE_HEADS:
171 174 pass
... ... @@ -190,7 +193,7 @@ def get_sets(mentions):
190 193 print (u'Removing set: ', set_id)
191 194 sets.pop(set_id)
192 195  
193   - return sets, clustered_mensions
  196 + return sets, all_mentions, clustered_mensions
194 197  
195 198  
196 199 def get_positives(sets):
... ... @@ -201,8 +204,12 @@ def get_positives(sets):
201 204 return positives
202 205  
203 206  
204   -def get_negatives_and_update_positives(clustered_mensions, positives):
205   - all_pairs = list(combinations(clustered_mensions, 2))
  207 +def get_negatives_and_update_positives(all_mentions, clustered_mentions, positives):
  208 + all_pairs = []
  209 + if USE_SINGLETONS:
  210 + all_pairs = list(combinations(all_mentions, 2))
  211 + else:
  212 + all_pairs = list(combinations(clustered_mentions, 2))
206 213 all_pairs = set(all_pairs)
207 214 negatives = [pair for pair in all_pairs if pair not in positives]
208 215 samples_count = NEG_PROPORTION * len(positives)
... ... @@ -474,7 +481,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms,
474 481 words_dist = [0] * 11
475 482 words_bucket = 0
476 483 if mnts_intersect != 1:
477   - words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words'] - 1)
  484 + words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words'])
478 485 if DEBUG:
479 486 features.append('Bucket %d' % words_bucket)
480 487 words_dist[words_bucket] = 1
... ... @@ -483,7 +490,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms,
483 490 mentions_dist = [0] * 11
484 491 mentions_bucket = 0
485 492 if mnts_intersect != 1:
486   - mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions'] - 1)
  493 + mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions'])
487 494 if words_bucket == 10:
488 495 mentions_bucket = 10
489 496 if DEBUG:
... ...