Commit 445955b9f02279ffd510804a8c6d92231529b004
1 parent
08bfdfc4
Added singletons to negative pool in preparator script.
Showing
2 changed files
with
480 additions
and
8 deletions
count_dist.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | ||
2 | + | ||
3 | +import os | ||
4 | + | ||
5 | + | ||
6 | +from lxml import etree | ||
7 | +from natsort import natsorted | ||
8 | + | ||
9 | + | ||
10 | +MAIN_PATH = os.path.dirname(__file__) | ||
11 | +TEST_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'test-prepared')) | ||
12 | +TRAIN_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'train-prepared')) | ||
13 | + | ||
14 | +ANNO_PATH = TRAIN_PATH | ||
15 | + | ||
16 | +CONTEXT = 5 | ||
17 | +POSSIBLE_HEADS = [u'§', u'%', u'*', u'"', u'„', u'&', u'-'] | ||
18 | + | ||
19 | + | ||
20 | +def main(): | ||
21 | + max_mnt_dist = count_max_mnt_dist() | ||
22 | + print ('Max mention distance (positive pairs): %d' % max_mnt_dist) | ||
23 | + | ||
24 | + | ||
25 | +def count_max_mnt_dist(): | ||
26 | + global_max_mnt_dist = 0 | ||
27 | + anno_files = os.listdir(ANNO_PATH) | ||
28 | + anno_files = natsorted(anno_files) | ||
29 | + for filename in anno_files: | ||
30 | + if filename.endswith('.mmax'): | ||
31 | + print ('=======> ', filename) | ||
32 | + textname = filename.replace('.mmax', '') | ||
33 | + | ||
34 | + mentions_path = os.path.join(ANNO_PATH, '%s_mentions.xml' % textname) | ||
35 | + tree = etree.parse(mentions_path) | ||
36 | + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) | ||
37 | + | ||
38 | + words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname) | ||
39 | + mentions_dict = markables_level_2_dict(mentions_path, words_path) | ||
40 | + | ||
41 | + file_max_mnt_dist = get_max_file_dist(mentions, mentions_dict) | ||
42 | + if file_max_mnt_dist > global_max_mnt_dist: | ||
43 | + global_max_mnt_dist = file_max_mnt_dist | ||
44 | + | ||
45 | + return global_max_mnt_dist | ||
46 | + | ||
47 | + | ||
48 | +def get_max_file_dist(mentions, mentions_dict): | ||
49 | + max_file_dist = 0 | ||
50 | + sets, all_mentions, clustered_mensions = get_sets(mentions) | ||
51 | + for set_id in sets: | ||
52 | + set_dist = get_max_set_dist(sets[set_id], mentions_dict) | ||
53 | + if set_dist > max_file_dist: | ||
54 | + max_file_dist = set_dist | ||
55 | + print ('Max mention distance: %d' % max_file_dist) | ||
56 | + return max_file_dist | ||
57 | + | ||
58 | + | ||
59 | +def get_sets(mentions): | ||
60 | + sets = {} | ||
61 | + all_mentions = [] | ||
62 | + clustered_mensions = [] | ||
63 | + for mention in mentions: | ||
64 | + all_mentions.append(mention.attrib['span']) | ||
65 | + set_id = mention.attrib['mention_group'] | ||
66 | + if set_id == 'empty' or set_id == '': | ||
67 | + pass | ||
68 | + elif set_id not in sets: | ||
69 | + sets[set_id] = [mention.attrib['span']] | ||
70 | + clustered_mensions.append(mention.attrib['span']) | ||
71 | + elif set_id in sets: | ||
72 | + sets[set_id].append(mention.attrib['span']) | ||
73 | + clustered_mensions.append(mention.attrib['span']) | ||
74 | + else: | ||
75 | + print (u'Coś poszło nie tak przy wyszukiwaniu klastrów!') | ||
76 | + | ||
77 | + sets_to_remove = [] | ||
78 | + for set_id in sets: | ||
79 | + if len(sets[set_id]) < 2: | ||
80 | + sets_to_remove.append(set_id) | ||
81 | + if len(sets[set_id]) == 1: | ||
82 | + print (u'Removing clustered mention: ', sets[set_id][0]) | ||
83 | + clustered_mensions.remove(sets[set_id][0]) | ||
84 | + | ||
85 | + for set_id in sets_to_remove: | ||
86 | + print (u'Removing set: ', set_id) | ||
87 | + sets.pop(set_id) | ||
88 | + | ||
89 | + return sets, all_mentions, clustered_mensions | ||
90 | + | ||
91 | + | ||
92 | +def get_max_set_dist(mnt_set, mentions_dict): | ||
93 | + max_set_dist = 0 | ||
94 | + for id, mnt2_span in enumerate(mnt_set): | ||
95 | + mnt2 = get_mention_by_attr(mentions_dict, 'span', mnt2_span) | ||
96 | + dist = None | ||
97 | + dist1 = None | ||
98 | + if id - 1 >= 0: | ||
99 | + mnt1_span = mnt_set[id - 1] | ||
100 | + mnt1 = get_mention_by_attr(mentions_dict, 'span', mnt1_span) | ||
101 | + dist1 = get_pair_dist(mnt1, mnt2) | ||
102 | + dist = dist1 | ||
103 | + if id + 1 < len(mnt_set): | ||
104 | + mnt3_span = mnt_set[id + 1] | ||
105 | + mnt3 = get_mention_by_attr(mentions_dict, 'span', mnt3_span) | ||
106 | + dist2 = get_pair_dist(mnt2, mnt3) | ||
107 | + if dist1 is not None and dist2 < dist1: | ||
108 | + dist = dist2 | ||
109 | + | ||
110 | + if dist > max_set_dist: | ||
111 | + max_set_dist = dist | ||
112 | + | ||
113 | + return max_set_dist | ||
114 | + | ||
115 | + | ||
116 | +def get_pair_dist(ante, ana): | ||
117 | + dist = 0 | ||
118 | + mnts_intersect = pair_intersect(ante, ana) | ||
119 | + if mnts_intersect != 1: | ||
120 | + dist = ana['position_in_mentions'] - ante['position_in_mentions'] | ||
121 | + return dist | ||
122 | + | ||
123 | + | ||
124 | +def pair_intersect(ante, ana): | ||
125 | + for ante_word in ante['words']: | ||
126 | + for ana_word in ana['words']: | ||
127 | + if ana_word['id'] == ante_word['id']: | ||
128 | + return 1 | ||
129 | + return 0 | ||
130 | + | ||
131 | + | ||
132 | +def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/NameSpaces/mention'): | ||
133 | + markables_dicts = [] | ||
134 | + markables_tree = etree.parse(markables_path) | ||
135 | + markables = markables_tree.xpath("//ns:markable", namespaces={'ns': namespace}) | ||
136 | + | ||
137 | + words = get_words(words_path) | ||
138 | + | ||
139 | + for idx, markable in enumerate(markables): | ||
140 | + span = markable.attrib['span'] | ||
141 | + if not get_mention_by_attr(markables_dicts, 'span', span): | ||
142 | + | ||
143 | + dominant = '' | ||
144 | + if 'dominant' in markable.attrib: | ||
145 | + dominant = markable.attrib['dominant'] | ||
146 | + | ||
147 | + head_orth = markable.attrib['mention_head'] | ||
148 | + if True: | ||
149 | + mention_words = span_to_words(span, words) | ||
150 | + | ||
151 | + (prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, | ||
152 | + paragraph_id, sentence_id, first_in_sentence, first_in_paragraph) = get_context(mention_words, words) | ||
153 | + | ||
154 | + head = get_head(head_orth, mention_words) | ||
155 | + markables_dicts.append({'id': markable.attrib['id'], | ||
156 | + 'set': markable.attrib['mention_group'], | ||
157 | + 'text': span_to_text(span, words, 'orth'), | ||
158 | + 'lemmatized_text': span_to_text(span, words, 'base'), | ||
159 | + 'words': mention_words, | ||
160 | + 'span': span, | ||
161 | + 'head_orth': head_orth, | ||
162 | + 'head': head, | ||
163 | + 'dominant': dominant, | ||
164 | + 'node': markable, | ||
165 | + 'prec_context': prec_context, | ||
166 | + 'follow_context': follow_context, | ||
167 | + 'sentence': sentence, | ||
168 | + 'position_in_mentions': idx, | ||
169 | + 'start_in_words': mnt_start_position, | ||
170 | + 'end_in_words': mnt_end_position, | ||
171 | + 'paragraph_id': paragraph_id, | ||
172 | + 'sentence_id': sentence_id, | ||
173 | + 'first_in_sentence': first_in_sentence, | ||
174 | + 'first_in_paragraph': first_in_paragraph}) | ||
175 | + else: | ||
176 | + print ('Zduplikowana wzmianka: %s' % span) | ||
177 | + | ||
178 | + return markables_dicts | ||
179 | + | ||
180 | + | ||
181 | +def get_context(mention_words, words): | ||
182 | + paragraph_id = 0 | ||
183 | + sentence_id = 0 | ||
184 | + prec_context = [] | ||
185 | + follow_context = [] | ||
186 | + sentence = [] | ||
187 | + mnt_start_position = -1 | ||
188 | + first_word = mention_words[0] | ||
189 | + last_word = mention_words[-1] | ||
190 | + first_in_sentence = False | ||
191 | + first_in_paragraph = False | ||
192 | + for idx, word in enumerate(words): | ||
193 | + if word['id'] == first_word['id']: | ||
194 | + prec_context = get_prec_context(idx, words) | ||
195 | + mnt_start_position = get_mention_start(first_word, words) | ||
196 | + if idx == 0 or words[idx-1]['lastinsent']: | ||
197 | + first_in_sentence = True | ||
198 | + if idx == 0 or words[idx-1]['lastinpar']: | ||
199 | + first_in_paragraph = True | ||
200 | + if word['id'] == last_word['id']: | ||
201 | + follow_context = get_follow_context(idx, words) | ||
202 | + sentence = get_sentence(idx, words) | ||
203 | + mnt_end_position = get_mention_end(last_word, words) | ||
204 | + break | ||
205 | + if word['lastinsent']: | ||
206 | + sentence_id += 1 | ||
207 | + if word['lastinpar']: | ||
208 | + paragraph_id += 1 | ||
209 | + return (prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, | ||
210 | + paragraph_id, sentence_id, first_in_sentence, first_in_paragraph) | ||
211 | + | ||
212 | + | ||
213 | +def get_prec_context(mention_start, words): | ||
214 | + context = [] | ||
215 | + context_start = mention_start - 1 | ||
216 | + while context_start >= 0: | ||
217 | + if not word_to_ignore(words[context_start]): | ||
218 | + context.append(words[context_start]) | ||
219 | + if len(context) == CONTEXT: | ||
220 | + break | ||
221 | + context_start -= 1 | ||
222 | + context.reverse() | ||
223 | + return context | ||
224 | + | ||
225 | + | ||
226 | +def get_mention_start(first_word, words): | ||
227 | + start = 0 | ||
228 | + for word in words: | ||
229 | + if not word_to_ignore(word): | ||
230 | + start += 1 | ||
231 | + if word['id'] == first_word['id']: | ||
232 | + break | ||
233 | + return start | ||
234 | + | ||
235 | + | ||
236 | +def get_mention_end(last_word, words): | ||
237 | + end = 0 | ||
238 | + for word in words: | ||
239 | + if not word_to_ignore(word): | ||
240 | + end += 1 | ||
241 | + if word['id'] == last_word['id']: | ||
242 | + break | ||
243 | + return end | ||
244 | + | ||
245 | + | ||
246 | +def get_follow_context(mention_end, words): | ||
247 | + context = [] | ||
248 | + context_end = mention_end + 1 | ||
249 | + while context_end < len(words): | ||
250 | + if not word_to_ignore(words[context_end]): | ||
251 | + context.append(words[context_end]) | ||
252 | + if len(context) == CONTEXT: | ||
253 | + break | ||
254 | + context_end += 1 | ||
255 | + return context | ||
256 | + | ||
257 | + | ||
258 | +def get_sentence(word_idx, words): | ||
259 | + sentence_start = get_sentence_start(words, word_idx) | ||
260 | + sentence_end = get_sentence_end(words, word_idx) | ||
261 | + sentence = [word for word in words[sentence_start:sentence_end+1] if not word_to_ignore(word)] | ||
262 | + return sentence | ||
263 | + | ||
264 | + | ||
265 | +def get_sentence_start(words, word_idx): | ||
266 | + search_start = word_idx | ||
267 | + while word_idx >= 0: | ||
268 | + if words[word_idx]['lastinsent'] and search_start != word_idx: | ||
269 | + return word_idx+1 | ||
270 | + word_idx -= 1 | ||
271 | + return 0 | ||
272 | + | ||
273 | + | ||
274 | +def get_sentence_end(words, word_idx): | ||
275 | + while word_idx < len(words): | ||
276 | + if words[word_idx]['lastinsent']: | ||
277 | + return word_idx | ||
278 | + word_idx += 1 | ||
279 | + return len(words) - 1 | ||
280 | + | ||
281 | + | ||
282 | +def get_head(head_orth, words): | ||
283 | + for word in words: | ||
284 | + if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth: | ||
285 | + return word | ||
286 | + return None | ||
287 | + | ||
288 | + | ||
289 | +def get_words(filepath): | ||
290 | + tree = etree.parse(filepath) | ||
291 | + words = [] | ||
292 | + for word in tree.xpath("//word"): | ||
293 | + hasnps = False | ||
294 | + if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true': | ||
295 | + hasnps = True | ||
296 | + lastinsent = False | ||
297 | + if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': | ||
298 | + lastinsent = True | ||
299 | + lastinpar = False | ||
300 | + if 'lastinpar' in word.attrib and word.attrib['lastinpar'] == 'true': | ||
301 | + lastinpar = True | ||
302 | + words.append({'id': word.attrib['id'], | ||
303 | + 'orth': word.text, | ||
304 | + 'base': word.attrib['base'], | ||
305 | + 'hasnps': hasnps, | ||
306 | + 'lastinsent': lastinsent, | ||
307 | + 'lastinpar': lastinpar, | ||
308 | + 'ctag': word.attrib['ctag'], | ||
309 | + 'msd': word.attrib['msd'], | ||
310 | + 'gender': get_gender(word.attrib['msd']), | ||
311 | + 'person': get_person(word.attrib['msd']), | ||
312 | + 'number': get_number(word.attrib['msd'])}) | ||
313 | + return words | ||
314 | + | ||
315 | + | ||
316 | +def get_gender(msd): | ||
317 | + tags = msd.split(':') | ||
318 | + if 'm1' in tags: | ||
319 | + return 'm1' | ||
320 | + elif 'm2' in tags: | ||
321 | + return 'm2' | ||
322 | + elif 'm3' in tags: | ||
323 | + return 'm3' | ||
324 | + elif 'f' in tags: | ||
325 | + return 'f' | ||
326 | + elif 'n' in tags: | ||
327 | + return 'n' | ||
328 | + else: | ||
329 | + return 'unk' | ||
330 | + | ||
331 | + | ||
332 | +def get_person(msd): | ||
333 | + tags = msd.split(':') | ||
334 | + if 'pri' in tags: | ||
335 | + return 'pri' | ||
336 | + elif 'sec' in tags: | ||
337 | + return 'sec' | ||
338 | + elif 'ter' in tags: | ||
339 | + return 'ter' | ||
340 | + else: | ||
341 | + return 'unk' | ||
342 | + | ||
343 | + | ||
344 | +def get_number(msd): | ||
345 | + tags = msd.split(':') | ||
346 | + if 'sg' in tags: | ||
347 | + return 'sg' | ||
348 | + elif 'pl' in tags: | ||
349 | + return 'pl' | ||
350 | + else: | ||
351 | + return 'unk' | ||
352 | + | ||
353 | + | ||
354 | +def get_mention_by_attr(mentions, attr_name, value): | ||
355 | + for mention in mentions: | ||
356 | + if mention[attr_name] == value: | ||
357 | + return mention | ||
358 | + return None | ||
359 | + | ||
360 | + | ||
361 | +def get_mention_index_by_attr(mentions, attr_name, value): | ||
362 | + for idx, mention in enumerate(mentions): | ||
363 | + if mention[attr_name] == value: | ||
364 | + return idx | ||
365 | + return None | ||
366 | + | ||
367 | + | ||
368 | +def span_to_text(span, words, form): | ||
369 | + fragments = span.split(',') | ||
370 | + mention_parts = [] | ||
371 | + for fragment in fragments: | ||
372 | + mention_parts.append(fragment_to_text(fragment, words, form)) | ||
373 | + return u' [...] '.join(mention_parts) | ||
374 | + | ||
375 | + | ||
376 | +def fragment_to_text(fragment, words, form): | ||
377 | + if '..' in fragment: | ||
378 | + text = get_multiword_text(fragment, words, form) | ||
379 | + else: | ||
380 | + text = get_one_word_text(fragment, words, form) | ||
381 | + return text | ||
382 | + | ||
383 | + | ||
384 | +def get_multiword_text(fragment, words, form): | ||
385 | + mention_parts = [] | ||
386 | + boundaries = fragment.split('..') | ||
387 | + start_id = boundaries[0] | ||
388 | + end_id = boundaries[1] | ||
389 | + in_string = False | ||
390 | + for word in words: | ||
391 | + if word['id'] == start_id: | ||
392 | + in_string = True | ||
393 | + if in_string and not word_to_ignore(word): | ||
394 | + mention_parts.append(word) | ||
395 | + if word['id'] == end_id: | ||
396 | + break | ||
397 | + return to_text(mention_parts, form) | ||
398 | + | ||
399 | + | ||
400 | +def to_text(words, form): | ||
401 | + text = '' | ||
402 | + for idx, word in enumerate(words): | ||
403 | + if word['hasnps'] or idx == 0: | ||
404 | + text += word[form] | ||
405 | + else: | ||
406 | + text += u' %s' % word[form] | ||
407 | + return text | ||
408 | + | ||
409 | + | ||
410 | +def get_one_word_text(word_id, words, form): | ||
411 | + this_word = next(word for word in words if word['id'] == word_id) | ||
412 | + if word_to_ignore(this_word): | ||
413 | + print (this_word) | ||
414 | + return this_word[form] | ||
415 | + | ||
416 | + | ||
417 | +def span_to_words(span, words): | ||
418 | + fragments = span.split(',') | ||
419 | + mention_parts = [] | ||
420 | + for fragment in fragments: | ||
421 | + mention_parts.extend(fragment_to_words(fragment, words)) | ||
422 | + return mention_parts | ||
423 | + | ||
424 | + | ||
425 | +def fragment_to_words(fragment, words): | ||
426 | + mention_parts = [] | ||
427 | + if '..' in fragment: | ||
428 | + mention_parts.extend(get_multiword(fragment, words)) | ||
429 | + else: | ||
430 | + mention_parts.extend(get_word(fragment, words)) | ||
431 | + return mention_parts | ||
432 | + | ||
433 | + | ||
434 | +def get_multiword(fragment, words): | ||
435 | + mention_parts = [] | ||
436 | + boundaries = fragment.split('..') | ||
437 | + start_id = boundaries[0] | ||
438 | + end_id = boundaries[1] | ||
439 | + in_string = False | ||
440 | + for word in words: | ||
441 | + if word['id'] == start_id: | ||
442 | + in_string = True | ||
443 | + if in_string and not word_to_ignore(word): | ||
444 | + mention_parts.append(word) | ||
445 | + if word['id'] == end_id: | ||
446 | + break | ||
447 | + return mention_parts | ||
448 | + | ||
449 | + | ||
450 | +def get_word(word_id, words): | ||
451 | + for word in words: | ||
452 | + if word['id'] == word_id: | ||
453 | + if not word_to_ignore(word): | ||
454 | + return [word] | ||
455 | + else: | ||
456 | + return [] | ||
457 | + return [] | ||
458 | + | ||
459 | + | ||
460 | +def word_to_ignore(word): | ||
461 | + return False | ||
462 | + | ||
463 | + | ||
464 | +if __name__ == '__main__': | ||
465 | + main() |
preparator.py
@@ -29,7 +29,7 @@ TITLE2REDIRECT_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'wikipedia | @@ -29,7 +29,7 @@ TITLE2REDIRECT_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'wikipedia | ||
29 | 29 | ||
30 | ANNO_PATH = TEST_PATH | 30 | ANNO_PATH = TEST_PATH |
31 | OUT_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', | 31 | OUT_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', |
32 | - 'test-1to5-20170720.csv')) | 32 | + 'test-1to5-singletons-20170720.csv')) |
33 | EACH_TEXT_SEPARATELLY = False | 33 | EACH_TEXT_SEPARATELLY = False |
34 | 34 | ||
35 | CONTEXT = 5 | 35 | CONTEXT = 5 |
@@ -53,6 +53,7 @@ HYPHEN_SIGNS = ['-', '#'] | @@ -53,6 +53,7 @@ HYPHEN_SIGNS = ['-', '#'] | ||
53 | 53 | ||
54 | NEG_PROPORTION = 5 | 54 | NEG_PROPORTION = 5 |
55 | RANDOM_VECTORS = True | 55 | RANDOM_VECTORS = True |
56 | +USE_SINGLETONS = True | ||
56 | 57 | ||
57 | DEBUG = False | 58 | DEBUG = False |
58 | POS_COUNT = 0 | 59 | POS_COUNT = 0 |
@@ -154,9 +155,9 @@ def create_data_vectors(model, freq_list, lemma2synonyms, | @@ -154,9 +155,9 @@ def create_data_vectors(model, freq_list, lemma2synonyms, | ||
154 | 155 | ||
155 | 156 | ||
156 | def diff_mentions(mentions): | 157 | def diff_mentions(mentions): |
157 | - sets, clustered_mensions = get_sets(mentions) | 158 | + sets, all_mentions, clustered_mensions = get_sets(mentions) |
158 | positives = get_positives(sets) | 159 | positives = get_positives(sets) |
159 | - positives, negatives = get_negatives_and_update_positives(clustered_mensions, positives) | 160 | + positives, negatives = get_negatives_and_update_positives(all_mentions, clustered_mensions, positives) |
160 | if len(negatives) != len(positives) and NEG_PROPORTION == 1: | 161 | if len(negatives) != len(positives) and NEG_PROPORTION == 1: |
161 | print (u'Niezgodna liczba przypadków pozytywnych i negatywnych!') | 162 | print (u'Niezgodna liczba przypadków pozytywnych i negatywnych!') |
162 | return positives, negatives | 163 | return positives, negatives |
@@ -164,8 +165,10 @@ def diff_mentions(mentions): | @@ -164,8 +165,10 @@ def diff_mentions(mentions): | ||
164 | 165 | ||
165 | def get_sets(mentions): | 166 | def get_sets(mentions): |
166 | sets = {} | 167 | sets = {} |
168 | + all_mentions = [] | ||
167 | clustered_mensions = [] | 169 | clustered_mensions = [] |
168 | for mention in mentions: | 170 | for mention in mentions: |
171 | + all_mentions.append(mention.attrib['span']) | ||
169 | set_id = mention.attrib['mention_group'] | 172 | set_id = mention.attrib['mention_group'] |
170 | if set_id == 'empty' or set_id == '' or mention.attrib['mention_head'] in POSSIBLE_HEADS: | 173 | if set_id == 'empty' or set_id == '' or mention.attrib['mention_head'] in POSSIBLE_HEADS: |
171 | pass | 174 | pass |
@@ -190,7 +193,7 @@ def get_sets(mentions): | @@ -190,7 +193,7 @@ def get_sets(mentions): | ||
190 | print (u'Removing set: ', set_id) | 193 | print (u'Removing set: ', set_id) |
191 | sets.pop(set_id) | 194 | sets.pop(set_id) |
192 | 195 | ||
193 | - return sets, clustered_mensions | 196 | + return sets, all_mentions, clustered_mensions |
194 | 197 | ||
195 | 198 | ||
196 | def get_positives(sets): | 199 | def get_positives(sets): |
@@ -201,8 +204,12 @@ def get_positives(sets): | @@ -201,8 +204,12 @@ def get_positives(sets): | ||
201 | return positives | 204 | return positives |
202 | 205 | ||
203 | 206 | ||
204 | -def get_negatives_and_update_positives(clustered_mensions, positives): | ||
205 | - all_pairs = list(combinations(clustered_mensions, 2)) | 207 | +def get_negatives_and_update_positives(all_mentions, clustered_mentions, positives): |
208 | + all_pairs = [] | ||
209 | + if USE_SINGLETONS: | ||
210 | + all_pairs = list(combinations(all_mentions, 2)) | ||
211 | + else: | ||
212 | + all_pairs = list(combinations(clustered_mentions, 2)) | ||
206 | all_pairs = set(all_pairs) | 213 | all_pairs = set(all_pairs) |
207 | negatives = [pair for pair in all_pairs if pair not in positives] | 214 | negatives = [pair for pair in all_pairs if pair not in positives] |
208 | samples_count = NEG_PROPORTION * len(positives) | 215 | samples_count = NEG_PROPORTION * len(positives) |
@@ -474,7 +481,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms, | @@ -474,7 +481,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms, | ||
474 | words_dist = [0] * 11 | 481 | words_dist = [0] * 11 |
475 | words_bucket = 0 | 482 | words_bucket = 0 |
476 | if mnts_intersect != 1: | 483 | if mnts_intersect != 1: |
477 | - words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words'] - 1) | 484 | + words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words']) |
478 | if DEBUG: | 485 | if DEBUG: |
479 | features.append('Bucket %d' % words_bucket) | 486 | features.append('Bucket %d' % words_bucket) |
480 | words_dist[words_bucket] = 1 | 487 | words_dist[words_bucket] = 1 |
@@ -483,7 +490,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms, | @@ -483,7 +490,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms, | ||
483 | mentions_dist = [0] * 11 | 490 | mentions_dist = [0] * 11 |
484 | mentions_bucket = 0 | 491 | mentions_bucket = 0 |
485 | if mnts_intersect != 1: | 492 | if mnts_intersect != 1: |
486 | - mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions'] - 1) | 493 | + mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions']) |
487 | if words_bucket == 10: | 494 | if words_bucket == 10: |
488 | mentions_bucket = 10 | 495 | mentions_bucket = 10 |
489 | if DEBUG: | 496 | if DEBUG: |