Commit 445955b9f02279ffd510804a8c6d92231529b004
1 parent
08bfdfc4
Added singletons to negative pool in preparator script.
Showing
2 changed files
with
480 additions
and
8 deletions
count_dist.py
0 → 100644
| 1 | +# -*- coding: utf-8 -*- | ||
| 2 | + | ||
| 3 | +import os | ||
| 4 | + | ||
| 5 | + | ||
| 6 | +from lxml import etree | ||
| 7 | +from natsort import natsorted | ||
| 8 | + | ||
| 9 | + | ||
| 10 | +MAIN_PATH = os.path.dirname(__file__) | ||
| 11 | +TEST_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'test-prepared')) | ||
| 12 | +TRAIN_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'train-prepared')) | ||
| 13 | + | ||
| 14 | +ANNO_PATH = TRAIN_PATH | ||
| 15 | + | ||
| 16 | +CONTEXT = 5 | ||
| 17 | +POSSIBLE_HEADS = [u'§', u'%', u'*', u'"', u'„', u'&', u'-'] | ||
| 18 | + | ||
| 19 | + | ||
| 20 | +def main(): | ||
| 21 | + max_mnt_dist = count_max_mnt_dist() | ||
| 22 | + print ('Max mention distance (positive pairs): %d' % max_mnt_dist) | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +def count_max_mnt_dist(): | ||
| 26 | + global_max_mnt_dist = 0 | ||
| 27 | + anno_files = os.listdir(ANNO_PATH) | ||
| 28 | + anno_files = natsorted(anno_files) | ||
| 29 | + for filename in anno_files: | ||
| 30 | + if filename.endswith('.mmax'): | ||
| 31 | + print ('=======> ', filename) | ||
| 32 | + textname = filename.replace('.mmax', '') | ||
| 33 | + | ||
| 34 | + mentions_path = os.path.join(ANNO_PATH, '%s_mentions.xml' % textname) | ||
| 35 | + tree = etree.parse(mentions_path) | ||
| 36 | + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) | ||
| 37 | + | ||
| 38 | + words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname) | ||
| 39 | + mentions_dict = markables_level_2_dict(mentions_path, words_path) | ||
| 40 | + | ||
| 41 | + file_max_mnt_dist = get_max_file_dist(mentions, mentions_dict) | ||
| 42 | + if file_max_mnt_dist > global_max_mnt_dist: | ||
| 43 | + global_max_mnt_dist = file_max_mnt_dist | ||
| 44 | + | ||
| 45 | + return global_max_mnt_dist | ||
| 46 | + | ||
| 47 | + | ||
| 48 | +def get_max_file_dist(mentions, mentions_dict): | ||
| 49 | + max_file_dist = 0 | ||
| 50 | + sets, all_mentions, clustered_mensions = get_sets(mentions) | ||
| 51 | + for set_id in sets: | ||
| 52 | + set_dist = get_max_set_dist(sets[set_id], mentions_dict) | ||
| 53 | + if set_dist > max_file_dist: | ||
| 54 | + max_file_dist = set_dist | ||
| 55 | + print ('Max mention distance: %d' % max_file_dist) | ||
| 56 | + return max_file_dist | ||
| 57 | + | ||
| 58 | + | ||
| 59 | +def get_sets(mentions): | ||
| 60 | + sets = {} | ||
| 61 | + all_mentions = [] | ||
| 62 | + clustered_mensions = [] | ||
| 63 | + for mention in mentions: | ||
| 64 | + all_mentions.append(mention.attrib['span']) | ||
| 65 | + set_id = mention.attrib['mention_group'] | ||
| 66 | + if set_id == 'empty' or set_id == '': | ||
| 67 | + pass | ||
| 68 | + elif set_id not in sets: | ||
| 69 | + sets[set_id] = [mention.attrib['span']] | ||
| 70 | + clustered_mensions.append(mention.attrib['span']) | ||
| 71 | + elif set_id in sets: | ||
| 72 | + sets[set_id].append(mention.attrib['span']) | ||
| 73 | + clustered_mensions.append(mention.attrib['span']) | ||
| 74 | + else: | ||
| 75 | + print (u'Coś poszło nie tak przy wyszukiwaniu klastrów!') | ||
| 76 | + | ||
| 77 | + sets_to_remove = [] | ||
| 78 | + for set_id in sets: | ||
| 79 | + if len(sets[set_id]) < 2: | ||
| 80 | + sets_to_remove.append(set_id) | ||
| 81 | + if len(sets[set_id]) == 1: | ||
| 82 | + print (u'Removing clustered mention: ', sets[set_id][0]) | ||
| 83 | + clustered_mensions.remove(sets[set_id][0]) | ||
| 84 | + | ||
| 85 | + for set_id in sets_to_remove: | ||
| 86 | + print (u'Removing set: ', set_id) | ||
| 87 | + sets.pop(set_id) | ||
| 88 | + | ||
| 89 | + return sets, all_mentions, clustered_mensions | ||
| 90 | + | ||
| 91 | + | ||
| 92 | +def get_max_set_dist(mnt_set, mentions_dict): | ||
| 93 | + max_set_dist = 0 | ||
| 94 | + for id, mnt2_span in enumerate(mnt_set): | ||
| 95 | + mnt2 = get_mention_by_attr(mentions_dict, 'span', mnt2_span) | ||
| 96 | + dist = None | ||
| 97 | + dist1 = None | ||
| 98 | + if id - 1 >= 0: | ||
| 99 | + mnt1_span = mnt_set[id - 1] | ||
| 100 | + mnt1 = get_mention_by_attr(mentions_dict, 'span', mnt1_span) | ||
| 101 | + dist1 = get_pair_dist(mnt1, mnt2) | ||
| 102 | + dist = dist1 | ||
| 103 | + if id + 1 < len(mnt_set): | ||
| 104 | + mnt3_span = mnt_set[id + 1] | ||
| 105 | + mnt3 = get_mention_by_attr(mentions_dict, 'span', mnt3_span) | ||
| 106 | + dist2 = get_pair_dist(mnt2, mnt3) | ||
| 107 | + if dist1 is not None and dist2 < dist1: | ||
| 108 | + dist = dist2 | ||
| 109 | + | ||
| 110 | + if dist > max_set_dist: | ||
| 111 | + max_set_dist = dist | ||
| 112 | + | ||
| 113 | + return max_set_dist | ||
| 114 | + | ||
| 115 | + | ||
| 116 | +def get_pair_dist(ante, ana): | ||
| 117 | + dist = 0 | ||
| 118 | + mnts_intersect = pair_intersect(ante, ana) | ||
| 119 | + if mnts_intersect != 1: | ||
| 120 | + dist = ana['position_in_mentions'] - ante['position_in_mentions'] | ||
| 121 | + return dist | ||
| 122 | + | ||
| 123 | + | ||
| 124 | +def pair_intersect(ante, ana): | ||
| 125 | + for ante_word in ante['words']: | ||
| 126 | + for ana_word in ana['words']: | ||
| 127 | + if ana_word['id'] == ante_word['id']: | ||
| 128 | + return 1 | ||
| 129 | + return 0 | ||
| 130 | + | ||
| 131 | + | ||
| 132 | +def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/NameSpaces/mention'): | ||
| 133 | + markables_dicts = [] | ||
| 134 | + markables_tree = etree.parse(markables_path) | ||
| 135 | + markables = markables_tree.xpath("//ns:markable", namespaces={'ns': namespace}) | ||
| 136 | + | ||
| 137 | + words = get_words(words_path) | ||
| 138 | + | ||
| 139 | + for idx, markable in enumerate(markables): | ||
| 140 | + span = markable.attrib['span'] | ||
| 141 | + if not get_mention_by_attr(markables_dicts, 'span', span): | ||
| 142 | + | ||
| 143 | + dominant = '' | ||
| 144 | + if 'dominant' in markable.attrib: | ||
| 145 | + dominant = markable.attrib['dominant'] | ||
| 146 | + | ||
| 147 | + head_orth = markable.attrib['mention_head'] | ||
| 148 | + if True: | ||
| 149 | + mention_words = span_to_words(span, words) | ||
| 150 | + | ||
| 151 | + (prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, | ||
| 152 | + paragraph_id, sentence_id, first_in_sentence, first_in_paragraph) = get_context(mention_words, words) | ||
| 153 | + | ||
| 154 | + head = get_head(head_orth, mention_words) | ||
| 155 | + markables_dicts.append({'id': markable.attrib['id'], | ||
| 156 | + 'set': markable.attrib['mention_group'], | ||
| 157 | + 'text': span_to_text(span, words, 'orth'), | ||
| 158 | + 'lemmatized_text': span_to_text(span, words, 'base'), | ||
| 159 | + 'words': mention_words, | ||
| 160 | + 'span': span, | ||
| 161 | + 'head_orth': head_orth, | ||
| 162 | + 'head': head, | ||
| 163 | + 'dominant': dominant, | ||
| 164 | + 'node': markable, | ||
| 165 | + 'prec_context': prec_context, | ||
| 166 | + 'follow_context': follow_context, | ||
| 167 | + 'sentence': sentence, | ||
| 168 | + 'position_in_mentions': idx, | ||
| 169 | + 'start_in_words': mnt_start_position, | ||
| 170 | + 'end_in_words': mnt_end_position, | ||
| 171 | + 'paragraph_id': paragraph_id, | ||
| 172 | + 'sentence_id': sentence_id, | ||
| 173 | + 'first_in_sentence': first_in_sentence, | ||
| 174 | + 'first_in_paragraph': first_in_paragraph}) | ||
| 175 | + else: | ||
| 176 | + print ('Zduplikowana wzmianka: %s' % span) | ||
| 177 | + | ||
| 178 | + return markables_dicts | ||
| 179 | + | ||
| 180 | + | ||
| 181 | +def get_context(mention_words, words): | ||
| 182 | + paragraph_id = 0 | ||
| 183 | + sentence_id = 0 | ||
| 184 | + prec_context = [] | ||
| 185 | + follow_context = [] | ||
| 186 | + sentence = [] | ||
| 187 | + mnt_start_position = -1 | ||
| 188 | + first_word = mention_words[0] | ||
| 189 | + last_word = mention_words[-1] | ||
| 190 | + first_in_sentence = False | ||
| 191 | + first_in_paragraph = False | ||
| 192 | + for idx, word in enumerate(words): | ||
| 193 | + if word['id'] == first_word['id']: | ||
| 194 | + prec_context = get_prec_context(idx, words) | ||
| 195 | + mnt_start_position = get_mention_start(first_word, words) | ||
| 196 | + if idx == 0 or words[idx-1]['lastinsent']: | ||
| 197 | + first_in_sentence = True | ||
| 198 | + if idx == 0 or words[idx-1]['lastinpar']: | ||
| 199 | + first_in_paragraph = True | ||
| 200 | + if word['id'] == last_word['id']: | ||
| 201 | + follow_context = get_follow_context(idx, words) | ||
| 202 | + sentence = get_sentence(idx, words) | ||
| 203 | + mnt_end_position = get_mention_end(last_word, words) | ||
| 204 | + break | ||
| 205 | + if word['lastinsent']: | ||
| 206 | + sentence_id += 1 | ||
| 207 | + if word['lastinpar']: | ||
| 208 | + paragraph_id += 1 | ||
| 209 | + return (prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, | ||
| 210 | + paragraph_id, sentence_id, first_in_sentence, first_in_paragraph) | ||
| 211 | + | ||
| 212 | + | ||
| 213 | +def get_prec_context(mention_start, words): | ||
| 214 | + context = [] | ||
| 215 | + context_start = mention_start - 1 | ||
| 216 | + while context_start >= 0: | ||
| 217 | + if not word_to_ignore(words[context_start]): | ||
| 218 | + context.append(words[context_start]) | ||
| 219 | + if len(context) == CONTEXT: | ||
| 220 | + break | ||
| 221 | + context_start -= 1 | ||
| 222 | + context.reverse() | ||
| 223 | + return context | ||
| 224 | + | ||
| 225 | + | ||
| 226 | +def get_mention_start(first_word, words): | ||
| 227 | + start = 0 | ||
| 228 | + for word in words: | ||
| 229 | + if not word_to_ignore(word): | ||
| 230 | + start += 1 | ||
| 231 | + if word['id'] == first_word['id']: | ||
| 232 | + break | ||
| 233 | + return start | ||
| 234 | + | ||
| 235 | + | ||
| 236 | +def get_mention_end(last_word, words): | ||
| 237 | + end = 0 | ||
| 238 | + for word in words: | ||
| 239 | + if not word_to_ignore(word): | ||
| 240 | + end += 1 | ||
| 241 | + if word['id'] == last_word['id']: | ||
| 242 | + break | ||
| 243 | + return end | ||
| 244 | + | ||
| 245 | + | ||
| 246 | +def get_follow_context(mention_end, words): | ||
| 247 | + context = [] | ||
| 248 | + context_end = mention_end + 1 | ||
| 249 | + while context_end < len(words): | ||
| 250 | + if not word_to_ignore(words[context_end]): | ||
| 251 | + context.append(words[context_end]) | ||
| 252 | + if len(context) == CONTEXT: | ||
| 253 | + break | ||
| 254 | + context_end += 1 | ||
| 255 | + return context | ||
| 256 | + | ||
| 257 | + | ||
| 258 | +def get_sentence(word_idx, words): | ||
| 259 | + sentence_start = get_sentence_start(words, word_idx) | ||
| 260 | + sentence_end = get_sentence_end(words, word_idx) | ||
| 261 | + sentence = [word for word in words[sentence_start:sentence_end+1] if not word_to_ignore(word)] | ||
| 262 | + return sentence | ||
| 263 | + | ||
| 264 | + | ||
| 265 | +def get_sentence_start(words, word_idx): | ||
| 266 | + search_start = word_idx | ||
| 267 | + while word_idx >= 0: | ||
| 268 | + if words[word_idx]['lastinsent'] and search_start != word_idx: | ||
| 269 | + return word_idx+1 | ||
| 270 | + word_idx -= 1 | ||
| 271 | + return 0 | ||
| 272 | + | ||
| 273 | + | ||
| 274 | +def get_sentence_end(words, word_idx): | ||
| 275 | + while word_idx < len(words): | ||
| 276 | + if words[word_idx]['lastinsent']: | ||
| 277 | + return word_idx | ||
| 278 | + word_idx += 1 | ||
| 279 | + return len(words) - 1 | ||
| 280 | + | ||
| 281 | + | ||
| 282 | +def get_head(head_orth, words): | ||
| 283 | + for word in words: | ||
| 284 | + if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth: | ||
| 285 | + return word | ||
| 286 | + return None | ||
| 287 | + | ||
| 288 | + | ||
| 289 | +def get_words(filepath): | ||
| 290 | + tree = etree.parse(filepath) | ||
| 291 | + words = [] | ||
| 292 | + for word in tree.xpath("//word"): | ||
| 293 | + hasnps = False | ||
| 294 | + if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true': | ||
| 295 | + hasnps = True | ||
| 296 | + lastinsent = False | ||
| 297 | + if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': | ||
| 298 | + lastinsent = True | ||
| 299 | + lastinpar = False | ||
| 300 | + if 'lastinpar' in word.attrib and word.attrib['lastinpar'] == 'true': | ||
| 301 | + lastinpar = True | ||
| 302 | + words.append({'id': word.attrib['id'], | ||
| 303 | + 'orth': word.text, | ||
| 304 | + 'base': word.attrib['base'], | ||
| 305 | + 'hasnps': hasnps, | ||
| 306 | + 'lastinsent': lastinsent, | ||
| 307 | + 'lastinpar': lastinpar, | ||
| 308 | + 'ctag': word.attrib['ctag'], | ||
| 309 | + 'msd': word.attrib['msd'], | ||
| 310 | + 'gender': get_gender(word.attrib['msd']), | ||
| 311 | + 'person': get_person(word.attrib['msd']), | ||
| 312 | + 'number': get_number(word.attrib['msd'])}) | ||
| 313 | + return words | ||
| 314 | + | ||
| 315 | + | ||
| 316 | +def get_gender(msd): | ||
| 317 | + tags = msd.split(':') | ||
| 318 | + if 'm1' in tags: | ||
| 319 | + return 'm1' | ||
| 320 | + elif 'm2' in tags: | ||
| 321 | + return 'm2' | ||
| 322 | + elif 'm3' in tags: | ||
| 323 | + return 'm3' | ||
| 324 | + elif 'f' in tags: | ||
| 325 | + return 'f' | ||
| 326 | + elif 'n' in tags: | ||
| 327 | + return 'n' | ||
| 328 | + else: | ||
| 329 | + return 'unk' | ||
| 330 | + | ||
| 331 | + | ||
| 332 | +def get_person(msd): | ||
| 333 | + tags = msd.split(':') | ||
| 334 | + if 'pri' in tags: | ||
| 335 | + return 'pri' | ||
| 336 | + elif 'sec' in tags: | ||
| 337 | + return 'sec' | ||
| 338 | + elif 'ter' in tags: | ||
| 339 | + return 'ter' | ||
| 340 | + else: | ||
| 341 | + return 'unk' | ||
| 342 | + | ||
| 343 | + | ||
| 344 | +def get_number(msd): | ||
| 345 | + tags = msd.split(':') | ||
| 346 | + if 'sg' in tags: | ||
| 347 | + return 'sg' | ||
| 348 | + elif 'pl' in tags: | ||
| 349 | + return 'pl' | ||
| 350 | + else: | ||
| 351 | + return 'unk' | ||
| 352 | + | ||
| 353 | + | ||
| 354 | +def get_mention_by_attr(mentions, attr_name, value): | ||
| 355 | + for mention in mentions: | ||
| 356 | + if mention[attr_name] == value: | ||
| 357 | + return mention | ||
| 358 | + return None | ||
| 359 | + | ||
| 360 | + | ||
| 361 | +def get_mention_index_by_attr(mentions, attr_name, value): | ||
| 362 | + for idx, mention in enumerate(mentions): | ||
| 363 | + if mention[attr_name] == value: | ||
| 364 | + return idx | ||
| 365 | + return None | ||
| 366 | + | ||
| 367 | + | ||
| 368 | +def span_to_text(span, words, form): | ||
| 369 | + fragments = span.split(',') | ||
| 370 | + mention_parts = [] | ||
| 371 | + for fragment in fragments: | ||
| 372 | + mention_parts.append(fragment_to_text(fragment, words, form)) | ||
| 373 | + return u' [...] '.join(mention_parts) | ||
| 374 | + | ||
| 375 | + | ||
| 376 | +def fragment_to_text(fragment, words, form): | ||
| 377 | + if '..' in fragment: | ||
| 378 | + text = get_multiword_text(fragment, words, form) | ||
| 379 | + else: | ||
| 380 | + text = get_one_word_text(fragment, words, form) | ||
| 381 | + return text | ||
| 382 | + | ||
| 383 | + | ||
| 384 | +def get_multiword_text(fragment, words, form): | ||
| 385 | + mention_parts = [] | ||
| 386 | + boundaries = fragment.split('..') | ||
| 387 | + start_id = boundaries[0] | ||
| 388 | + end_id = boundaries[1] | ||
| 389 | + in_string = False | ||
| 390 | + for word in words: | ||
| 391 | + if word['id'] == start_id: | ||
| 392 | + in_string = True | ||
| 393 | + if in_string and not word_to_ignore(word): | ||
| 394 | + mention_parts.append(word) | ||
| 395 | + if word['id'] == end_id: | ||
| 396 | + break | ||
| 397 | + return to_text(mention_parts, form) | ||
| 398 | + | ||
| 399 | + | ||
| 400 | +def to_text(words, form): | ||
| 401 | + text = '' | ||
| 402 | + for idx, word in enumerate(words): | ||
| 403 | + if word['hasnps'] or idx == 0: | ||
| 404 | + text += word[form] | ||
| 405 | + else: | ||
| 406 | + text += u' %s' % word[form] | ||
| 407 | + return text | ||
| 408 | + | ||
| 409 | + | ||
| 410 | +def get_one_word_text(word_id, words, form): | ||
| 411 | + this_word = next(word for word in words if word['id'] == word_id) | ||
| 412 | + if word_to_ignore(this_word): | ||
| 413 | + print (this_word) | ||
| 414 | + return this_word[form] | ||
| 415 | + | ||
| 416 | + | ||
| 417 | +def span_to_words(span, words): | ||
| 418 | + fragments = span.split(',') | ||
| 419 | + mention_parts = [] | ||
| 420 | + for fragment in fragments: | ||
| 421 | + mention_parts.extend(fragment_to_words(fragment, words)) | ||
| 422 | + return mention_parts | ||
| 423 | + | ||
| 424 | + | ||
| 425 | +def fragment_to_words(fragment, words): | ||
| 426 | + mention_parts = [] | ||
| 427 | + if '..' in fragment: | ||
| 428 | + mention_parts.extend(get_multiword(fragment, words)) | ||
| 429 | + else: | ||
| 430 | + mention_parts.extend(get_word(fragment, words)) | ||
| 431 | + return mention_parts | ||
| 432 | + | ||
| 433 | + | ||
| 434 | +def get_multiword(fragment, words): | ||
| 435 | + mention_parts = [] | ||
| 436 | + boundaries = fragment.split('..') | ||
| 437 | + start_id = boundaries[0] | ||
| 438 | + end_id = boundaries[1] | ||
| 439 | + in_string = False | ||
| 440 | + for word in words: | ||
| 441 | + if word['id'] == start_id: | ||
| 442 | + in_string = True | ||
| 443 | + if in_string and not word_to_ignore(word): | ||
| 444 | + mention_parts.append(word) | ||
| 445 | + if word['id'] == end_id: | ||
| 446 | + break | ||
| 447 | + return mention_parts | ||
| 448 | + | ||
| 449 | + | ||
| 450 | +def get_word(word_id, words): | ||
| 451 | + for word in words: | ||
| 452 | + if word['id'] == word_id: | ||
| 453 | + if not word_to_ignore(word): | ||
| 454 | + return [word] | ||
| 455 | + else: | ||
| 456 | + return [] | ||
| 457 | + return [] | ||
| 458 | + | ||
| 459 | + | ||
| 460 | +def word_to_ignore(word): | ||
| 461 | + return False | ||
| 462 | + | ||
| 463 | + | ||
| 464 | +if __name__ == '__main__': | ||
| 465 | + main() |
preparator.py
| @@ -29,7 +29,7 @@ TITLE2REDIRECT_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'wikipedia | @@ -29,7 +29,7 @@ TITLE2REDIRECT_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'wikipedia | ||
| 29 | 29 | ||
| 30 | ANNO_PATH = TEST_PATH | 30 | ANNO_PATH = TEST_PATH |
| 31 | OUT_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', | 31 | OUT_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', |
| 32 | - 'test-1to5-20170720.csv')) | 32 | + 'test-1to5-singletons-20170720.csv')) |
| 33 | EACH_TEXT_SEPARATELLY = False | 33 | EACH_TEXT_SEPARATELLY = False |
| 34 | 34 | ||
| 35 | CONTEXT = 5 | 35 | CONTEXT = 5 |
| @@ -53,6 +53,7 @@ HYPHEN_SIGNS = ['-', '#'] | @@ -53,6 +53,7 @@ HYPHEN_SIGNS = ['-', '#'] | ||
| 53 | 53 | ||
| 54 | NEG_PROPORTION = 5 | 54 | NEG_PROPORTION = 5 |
| 55 | RANDOM_VECTORS = True | 55 | RANDOM_VECTORS = True |
| 56 | +USE_SINGLETONS = True | ||
| 56 | 57 | ||
| 57 | DEBUG = False | 58 | DEBUG = False |
| 58 | POS_COUNT = 0 | 59 | POS_COUNT = 0 |
| @@ -154,9 +155,9 @@ def create_data_vectors(model, freq_list, lemma2synonyms, | @@ -154,9 +155,9 @@ def create_data_vectors(model, freq_list, lemma2synonyms, | ||
| 154 | 155 | ||
| 155 | 156 | ||
| 156 | def diff_mentions(mentions): | 157 | def diff_mentions(mentions): |
| 157 | - sets, clustered_mensions = get_sets(mentions) | 158 | + sets, all_mentions, clustered_mensions = get_sets(mentions) |
| 158 | positives = get_positives(sets) | 159 | positives = get_positives(sets) |
| 159 | - positives, negatives = get_negatives_and_update_positives(clustered_mensions, positives) | 160 | + positives, negatives = get_negatives_and_update_positives(all_mentions, clustered_mensions, positives) |
| 160 | if len(negatives) != len(positives) and NEG_PROPORTION == 1: | 161 | if len(negatives) != len(positives) and NEG_PROPORTION == 1: |
| 161 | print (u'Niezgodna liczba przypadków pozytywnych i negatywnych!') | 162 | print (u'Niezgodna liczba przypadków pozytywnych i negatywnych!') |
| 162 | return positives, negatives | 163 | return positives, negatives |
| @@ -164,8 +165,10 @@ def diff_mentions(mentions): | @@ -164,8 +165,10 @@ def diff_mentions(mentions): | ||
| 164 | 165 | ||
| 165 | def get_sets(mentions): | 166 | def get_sets(mentions): |
| 166 | sets = {} | 167 | sets = {} |
| 168 | + all_mentions = [] | ||
| 167 | clustered_mensions = [] | 169 | clustered_mensions = [] |
| 168 | for mention in mentions: | 170 | for mention in mentions: |
| 171 | + all_mentions.append(mention.attrib['span']) | ||
| 169 | set_id = mention.attrib['mention_group'] | 172 | set_id = mention.attrib['mention_group'] |
| 170 | if set_id == 'empty' or set_id == '' or mention.attrib['mention_head'] in POSSIBLE_HEADS: | 173 | if set_id == 'empty' or set_id == '' or mention.attrib['mention_head'] in POSSIBLE_HEADS: |
| 171 | pass | 174 | pass |
| @@ -190,7 +193,7 @@ def get_sets(mentions): | @@ -190,7 +193,7 @@ def get_sets(mentions): | ||
| 190 | print (u'Removing set: ', set_id) | 193 | print (u'Removing set: ', set_id) |
| 191 | sets.pop(set_id) | 194 | sets.pop(set_id) |
| 192 | 195 | ||
| 193 | - return sets, clustered_mensions | 196 | + return sets, all_mentions, clustered_mensions |
| 194 | 197 | ||
| 195 | 198 | ||
| 196 | def get_positives(sets): | 199 | def get_positives(sets): |
| @@ -201,8 +204,12 @@ def get_positives(sets): | @@ -201,8 +204,12 @@ def get_positives(sets): | ||
| 201 | return positives | 204 | return positives |
| 202 | 205 | ||
| 203 | 206 | ||
| 204 | -def get_negatives_and_update_positives(clustered_mensions, positives): | ||
| 205 | - all_pairs = list(combinations(clustered_mensions, 2)) | 207 | +def get_negatives_and_update_positives(all_mentions, clustered_mentions, positives): |
| 208 | + all_pairs = [] | ||
| 209 | + if USE_SINGLETONS: | ||
| 210 | + all_pairs = list(combinations(all_mentions, 2)) | ||
| 211 | + else: | ||
| 212 | + all_pairs = list(combinations(clustered_mentions, 2)) | ||
| 206 | all_pairs = set(all_pairs) | 213 | all_pairs = set(all_pairs) |
| 207 | negatives = [pair for pair in all_pairs if pair not in positives] | 214 | negatives = [pair for pair in all_pairs if pair not in positives] |
| 208 | samples_count = NEG_PROPORTION * len(positives) | 215 | samples_count = NEG_PROPORTION * len(positives) |
| @@ -474,7 +481,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms, | @@ -474,7 +481,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms, | ||
| 474 | words_dist = [0] * 11 | 481 | words_dist = [0] * 11 |
| 475 | words_bucket = 0 | 482 | words_bucket = 0 |
| 476 | if mnts_intersect != 1: | 483 | if mnts_intersect != 1: |
| 477 | - words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words'] - 1) | 484 | + words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words']) |
| 478 | if DEBUG: | 485 | if DEBUG: |
| 479 | features.append('Bucket %d' % words_bucket) | 486 | features.append('Bucket %d' % words_bucket) |
| 480 | words_dist[words_bucket] = 1 | 487 | words_dist[words_bucket] = 1 |
| @@ -483,7 +490,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms, | @@ -483,7 +490,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms, | ||
| 483 | mentions_dist = [0] * 11 | 490 | mentions_dist = [0] * 11 |
| 484 | mentions_bucket = 0 | 491 | mentions_bucket = 0 |
| 485 | if mnts_intersect != 1: | 492 | if mnts_intersect != 1: |
| 486 | - mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions'] - 1) | 493 | + mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions']) |
| 487 | if words_bucket == 10: | 494 | if words_bucket == 10: |
| 488 | mentions_bucket = 10 | 495 | mentions_bucket = 10 |
| 489 | if DEBUG: | 496 | if DEBUG: |