Commit 445955b9f02279ffd510804a8c6d92231529b004
1 parent
08bfdfc4
Added singletons to negative pool in preparator script.
Showing
2 changed files
with
480 additions
and
8 deletions
count_dist.py
0 → 100644
1 | +# -*- coding: utf-8 -*- | |
2 | + | |
3 | +import os | |
4 | + | |
5 | + | |
6 | +from lxml import etree | |
7 | +from natsort import natsorted | |
8 | + | |
9 | + | |
10 | +MAIN_PATH = os.path.dirname(__file__) | |
11 | +TEST_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'test-prepared')) | |
12 | +TRAIN_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'train-prepared')) | |
13 | + | |
14 | +ANNO_PATH = TRAIN_PATH | |
15 | + | |
16 | +CONTEXT = 5 | |
17 | +POSSIBLE_HEADS = [u'§', u'%', u'*', u'"', u'„', u'&', u'-'] | |
18 | + | |
19 | + | |
20 | +def main(): | |
21 | + max_mnt_dist = count_max_mnt_dist() | |
22 | + print ('Max mention distance (positive pairs): %d' % max_mnt_dist) | |
23 | + | |
24 | + | |
25 | +def count_max_mnt_dist(): | |
26 | + global_max_mnt_dist = 0 | |
27 | + anno_files = os.listdir(ANNO_PATH) | |
28 | + anno_files = natsorted(anno_files) | |
29 | + for filename in anno_files: | |
30 | + if filename.endswith('.mmax'): | |
31 | + print ('=======> ', filename) | |
32 | + textname = filename.replace('.mmax', '') | |
33 | + | |
34 | + mentions_path = os.path.join(ANNO_PATH, '%s_mentions.xml' % textname) | |
35 | + tree = etree.parse(mentions_path) | |
36 | + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'}) | |
37 | + | |
38 | + words_path = os.path.join(ANNO_PATH, '%s_words.xml' % textname) | |
39 | + mentions_dict = markables_level_2_dict(mentions_path, words_path) | |
40 | + | |
41 | + file_max_mnt_dist = get_max_file_dist(mentions, mentions_dict) | |
42 | + if file_max_mnt_dist > global_max_mnt_dist: | |
43 | + global_max_mnt_dist = file_max_mnt_dist | |
44 | + | |
45 | + return global_max_mnt_dist | |
46 | + | |
47 | + | |
48 | +def get_max_file_dist(mentions, mentions_dict): | |
49 | + max_file_dist = 0 | |
50 | + sets, all_mentions, clustered_mensions = get_sets(mentions) | |
51 | + for set_id in sets: | |
52 | + set_dist = get_max_set_dist(sets[set_id], mentions_dict) | |
53 | + if set_dist > max_file_dist: | |
54 | + max_file_dist = set_dist | |
55 | + print ('Max mention distance: %d' % max_file_dist) | |
56 | + return max_file_dist | |
57 | + | |
58 | + | |
59 | +def get_sets(mentions): | |
60 | + sets = {} | |
61 | + all_mentions = [] | |
62 | + clustered_mensions = [] | |
63 | + for mention in mentions: | |
64 | + all_mentions.append(mention.attrib['span']) | |
65 | + set_id = mention.attrib['mention_group'] | |
66 | + if set_id == 'empty' or set_id == '': | |
67 | + pass | |
68 | + elif set_id not in sets: | |
69 | + sets[set_id] = [mention.attrib['span']] | |
70 | + clustered_mensions.append(mention.attrib['span']) | |
71 | + elif set_id in sets: | |
72 | + sets[set_id].append(mention.attrib['span']) | |
73 | + clustered_mensions.append(mention.attrib['span']) | |
74 | + else: | |
75 | + print (u'Coś poszło nie tak przy wyszukiwaniu klastrów!') | |
76 | + | |
77 | + sets_to_remove = [] | |
78 | + for set_id in sets: | |
79 | + if len(sets[set_id]) < 2: | |
80 | + sets_to_remove.append(set_id) | |
81 | + if len(sets[set_id]) == 1: | |
82 | + print (u'Removing clustered mention: ', sets[set_id][0]) | |
83 | + clustered_mensions.remove(sets[set_id][0]) | |
84 | + | |
85 | + for set_id in sets_to_remove: | |
86 | + print (u'Removing set: ', set_id) | |
87 | + sets.pop(set_id) | |
88 | + | |
89 | + return sets, all_mentions, clustered_mensions | |
90 | + | |
91 | + | |
92 | +def get_max_set_dist(mnt_set, mentions_dict): | |
93 | + max_set_dist = 0 | |
94 | + for id, mnt2_span in enumerate(mnt_set): | |
95 | + mnt2 = get_mention_by_attr(mentions_dict, 'span', mnt2_span) | |
96 | + dist = None | |
97 | + dist1 = None | |
98 | + if id - 1 >= 0: | |
99 | + mnt1_span = mnt_set[id - 1] | |
100 | + mnt1 = get_mention_by_attr(mentions_dict, 'span', mnt1_span) | |
101 | + dist1 = get_pair_dist(mnt1, mnt2) | |
102 | + dist = dist1 | |
103 | + if id + 1 < len(mnt_set): | |
104 | + mnt3_span = mnt_set[id + 1] | |
105 | + mnt3 = get_mention_by_attr(mentions_dict, 'span', mnt3_span) | |
106 | + dist2 = get_pair_dist(mnt2, mnt3) | |
107 | + if dist1 is not None and dist2 < dist1: | |
108 | + dist = dist2 | |
109 | + | |
110 | + if dist > max_set_dist: | |
111 | + max_set_dist = dist | |
112 | + | |
113 | + return max_set_dist | |
114 | + | |
115 | + | |
116 | +def get_pair_dist(ante, ana): | |
117 | + dist = 0 | |
118 | + mnts_intersect = pair_intersect(ante, ana) | |
119 | + if mnts_intersect != 1: | |
120 | + dist = ana['position_in_mentions'] - ante['position_in_mentions'] | |
121 | + return dist | |
122 | + | |
123 | + | |
124 | +def pair_intersect(ante, ana): | |
125 | + for ante_word in ante['words']: | |
126 | + for ana_word in ana['words']: | |
127 | + if ana_word['id'] == ante_word['id']: | |
128 | + return 1 | |
129 | + return 0 | |
130 | + | |
131 | + | |
132 | +def markables_level_2_dict(markables_path, words_path, namespace='www.eml.org/NameSpaces/mention'): | |
133 | + markables_dicts = [] | |
134 | + markables_tree = etree.parse(markables_path) | |
135 | + markables = markables_tree.xpath("//ns:markable", namespaces={'ns': namespace}) | |
136 | + | |
137 | + words = get_words(words_path) | |
138 | + | |
139 | + for idx, markable in enumerate(markables): | |
140 | + span = markable.attrib['span'] | |
141 | + if not get_mention_by_attr(markables_dicts, 'span', span): | |
142 | + | |
143 | + dominant = '' | |
144 | + if 'dominant' in markable.attrib: | |
145 | + dominant = markable.attrib['dominant'] | |
146 | + | |
147 | + head_orth = markable.attrib['mention_head'] | |
148 | + if True: | |
149 | + mention_words = span_to_words(span, words) | |
150 | + | |
151 | + (prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, | |
152 | + paragraph_id, sentence_id, first_in_sentence, first_in_paragraph) = get_context(mention_words, words) | |
153 | + | |
154 | + head = get_head(head_orth, mention_words) | |
155 | + markables_dicts.append({'id': markable.attrib['id'], | |
156 | + 'set': markable.attrib['mention_group'], | |
157 | + 'text': span_to_text(span, words, 'orth'), | |
158 | + 'lemmatized_text': span_to_text(span, words, 'base'), | |
159 | + 'words': mention_words, | |
160 | + 'span': span, | |
161 | + 'head_orth': head_orth, | |
162 | + 'head': head, | |
163 | + 'dominant': dominant, | |
164 | + 'node': markable, | |
165 | + 'prec_context': prec_context, | |
166 | + 'follow_context': follow_context, | |
167 | + 'sentence': sentence, | |
168 | + 'position_in_mentions': idx, | |
169 | + 'start_in_words': mnt_start_position, | |
170 | + 'end_in_words': mnt_end_position, | |
171 | + 'paragraph_id': paragraph_id, | |
172 | + 'sentence_id': sentence_id, | |
173 | + 'first_in_sentence': first_in_sentence, | |
174 | + 'first_in_paragraph': first_in_paragraph}) | |
175 | + else: | |
176 | + print ('Zduplikowana wzmianka: %s' % span) | |
177 | + | |
178 | + return markables_dicts | |
179 | + | |
180 | + | |
181 | +def get_context(mention_words, words): | |
182 | + paragraph_id = 0 | |
183 | + sentence_id = 0 | |
184 | + prec_context = [] | |
185 | + follow_context = [] | |
186 | + sentence = [] | |
187 | + mnt_start_position = -1 | |
188 | + first_word = mention_words[0] | |
189 | + last_word = mention_words[-1] | |
190 | + first_in_sentence = False | |
191 | + first_in_paragraph = False | |
192 | + for idx, word in enumerate(words): | |
193 | + if word['id'] == first_word['id']: | |
194 | + prec_context = get_prec_context(idx, words) | |
195 | + mnt_start_position = get_mention_start(first_word, words) | |
196 | + if idx == 0 or words[idx-1]['lastinsent']: | |
197 | + first_in_sentence = True | |
198 | + if idx == 0 or words[idx-1]['lastinpar']: | |
199 | + first_in_paragraph = True | |
200 | + if word['id'] == last_word['id']: | |
201 | + follow_context = get_follow_context(idx, words) | |
202 | + sentence = get_sentence(idx, words) | |
203 | + mnt_end_position = get_mention_end(last_word, words) | |
204 | + break | |
205 | + if word['lastinsent']: | |
206 | + sentence_id += 1 | |
207 | + if word['lastinpar']: | |
208 | + paragraph_id += 1 | |
209 | + return (prec_context, follow_context, sentence, mnt_start_position, mnt_end_position, | |
210 | + paragraph_id, sentence_id, first_in_sentence, first_in_paragraph) | |
211 | + | |
212 | + | |
213 | +def get_prec_context(mention_start, words): | |
214 | + context = [] | |
215 | + context_start = mention_start - 1 | |
216 | + while context_start >= 0: | |
217 | + if not word_to_ignore(words[context_start]): | |
218 | + context.append(words[context_start]) | |
219 | + if len(context) == CONTEXT: | |
220 | + break | |
221 | + context_start -= 1 | |
222 | + context.reverse() | |
223 | + return context | |
224 | + | |
225 | + | |
226 | +def get_mention_start(first_word, words): | |
227 | + start = 0 | |
228 | + for word in words: | |
229 | + if not word_to_ignore(word): | |
230 | + start += 1 | |
231 | + if word['id'] == first_word['id']: | |
232 | + break | |
233 | + return start | |
234 | + | |
235 | + | |
236 | +def get_mention_end(last_word, words): | |
237 | + end = 0 | |
238 | + for word in words: | |
239 | + if not word_to_ignore(word): | |
240 | + end += 1 | |
241 | + if word['id'] == last_word['id']: | |
242 | + break | |
243 | + return end | |
244 | + | |
245 | + | |
246 | +def get_follow_context(mention_end, words): | |
247 | + context = [] | |
248 | + context_end = mention_end + 1 | |
249 | + while context_end < len(words): | |
250 | + if not word_to_ignore(words[context_end]): | |
251 | + context.append(words[context_end]) | |
252 | + if len(context) == CONTEXT: | |
253 | + break | |
254 | + context_end += 1 | |
255 | + return context | |
256 | + | |
257 | + | |
258 | +def get_sentence(word_idx, words): | |
259 | + sentence_start = get_sentence_start(words, word_idx) | |
260 | + sentence_end = get_sentence_end(words, word_idx) | |
261 | + sentence = [word for word in words[sentence_start:sentence_end+1] if not word_to_ignore(word)] | |
262 | + return sentence | |
263 | + | |
264 | + | |
265 | +def get_sentence_start(words, word_idx): | |
266 | + search_start = word_idx | |
267 | + while word_idx >= 0: | |
268 | + if words[word_idx]['lastinsent'] and search_start != word_idx: | |
269 | + return word_idx+1 | |
270 | + word_idx -= 1 | |
271 | + return 0 | |
272 | + | |
273 | + | |
274 | +def get_sentence_end(words, word_idx): | |
275 | + while word_idx < len(words): | |
276 | + if words[word_idx]['lastinsent']: | |
277 | + return word_idx | |
278 | + word_idx += 1 | |
279 | + return len(words) - 1 | |
280 | + | |
281 | + | |
282 | +def get_head(head_orth, words): | |
283 | + for word in words: | |
284 | + if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth: | |
285 | + return word | |
286 | + return None | |
287 | + | |
288 | + | |
289 | +def get_words(filepath): | |
290 | + tree = etree.parse(filepath) | |
291 | + words = [] | |
292 | + for word in tree.xpath("//word"): | |
293 | + hasnps = False | |
294 | + if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true': | |
295 | + hasnps = True | |
296 | + lastinsent = False | |
297 | + if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': | |
298 | + lastinsent = True | |
299 | + lastinpar = False | |
300 | + if 'lastinpar' in word.attrib and word.attrib['lastinpar'] == 'true': | |
301 | + lastinpar = True | |
302 | + words.append({'id': word.attrib['id'], | |
303 | + 'orth': word.text, | |
304 | + 'base': word.attrib['base'], | |
305 | + 'hasnps': hasnps, | |
306 | + 'lastinsent': lastinsent, | |
307 | + 'lastinpar': lastinpar, | |
308 | + 'ctag': word.attrib['ctag'], | |
309 | + 'msd': word.attrib['msd'], | |
310 | + 'gender': get_gender(word.attrib['msd']), | |
311 | + 'person': get_person(word.attrib['msd']), | |
312 | + 'number': get_number(word.attrib['msd'])}) | |
313 | + return words | |
314 | + | |
315 | + | |
316 | +def get_gender(msd): | |
317 | + tags = msd.split(':') | |
318 | + if 'm1' in tags: | |
319 | + return 'm1' | |
320 | + elif 'm2' in tags: | |
321 | + return 'm2' | |
322 | + elif 'm3' in tags: | |
323 | + return 'm3' | |
324 | + elif 'f' in tags: | |
325 | + return 'f' | |
326 | + elif 'n' in tags: | |
327 | + return 'n' | |
328 | + else: | |
329 | + return 'unk' | |
330 | + | |
331 | + | |
332 | +def get_person(msd): | |
333 | + tags = msd.split(':') | |
334 | + if 'pri' in tags: | |
335 | + return 'pri' | |
336 | + elif 'sec' in tags: | |
337 | + return 'sec' | |
338 | + elif 'ter' in tags: | |
339 | + return 'ter' | |
340 | + else: | |
341 | + return 'unk' | |
342 | + | |
343 | + | |
344 | +def get_number(msd): | |
345 | + tags = msd.split(':') | |
346 | + if 'sg' in tags: | |
347 | + return 'sg' | |
348 | + elif 'pl' in tags: | |
349 | + return 'pl' | |
350 | + else: | |
351 | + return 'unk' | |
352 | + | |
353 | + | |
354 | +def get_mention_by_attr(mentions, attr_name, value): | |
355 | + for mention in mentions: | |
356 | + if mention[attr_name] == value: | |
357 | + return mention | |
358 | + return None | |
359 | + | |
360 | + | |
361 | +def get_mention_index_by_attr(mentions, attr_name, value): | |
362 | + for idx, mention in enumerate(mentions): | |
363 | + if mention[attr_name] == value: | |
364 | + return idx | |
365 | + return None | |
366 | + | |
367 | + | |
368 | +def span_to_text(span, words, form): | |
369 | + fragments = span.split(',') | |
370 | + mention_parts = [] | |
371 | + for fragment in fragments: | |
372 | + mention_parts.append(fragment_to_text(fragment, words, form)) | |
373 | + return u' [...] '.join(mention_parts) | |
374 | + | |
375 | + | |
376 | +def fragment_to_text(fragment, words, form): | |
377 | + if '..' in fragment: | |
378 | + text = get_multiword_text(fragment, words, form) | |
379 | + else: | |
380 | + text = get_one_word_text(fragment, words, form) | |
381 | + return text | |
382 | + | |
383 | + | |
384 | +def get_multiword_text(fragment, words, form): | |
385 | + mention_parts = [] | |
386 | + boundaries = fragment.split('..') | |
387 | + start_id = boundaries[0] | |
388 | + end_id = boundaries[1] | |
389 | + in_string = False | |
390 | + for word in words: | |
391 | + if word['id'] == start_id: | |
392 | + in_string = True | |
393 | + if in_string and not word_to_ignore(word): | |
394 | + mention_parts.append(word) | |
395 | + if word['id'] == end_id: | |
396 | + break | |
397 | + return to_text(mention_parts, form) | |
398 | + | |
399 | + | |
400 | +def to_text(words, form): | |
401 | + text = '' | |
402 | + for idx, word in enumerate(words): | |
403 | + if word['hasnps'] or idx == 0: | |
404 | + text += word[form] | |
405 | + else: | |
406 | + text += u' %s' % word[form] | |
407 | + return text | |
408 | + | |
409 | + | |
410 | +def get_one_word_text(word_id, words, form): | |
411 | + this_word = next(word for word in words if word['id'] == word_id) | |
412 | + if word_to_ignore(this_word): | |
413 | + print (this_word) | |
414 | + return this_word[form] | |
415 | + | |
416 | + | |
417 | +def span_to_words(span, words): | |
418 | + fragments = span.split(',') | |
419 | + mention_parts = [] | |
420 | + for fragment in fragments: | |
421 | + mention_parts.extend(fragment_to_words(fragment, words)) | |
422 | + return mention_parts | |
423 | + | |
424 | + | |
425 | +def fragment_to_words(fragment, words): | |
426 | + mention_parts = [] | |
427 | + if '..' in fragment: | |
428 | + mention_parts.extend(get_multiword(fragment, words)) | |
429 | + else: | |
430 | + mention_parts.extend(get_word(fragment, words)) | |
431 | + return mention_parts | |
432 | + | |
433 | + | |
434 | +def get_multiword(fragment, words): | |
435 | + mention_parts = [] | |
436 | + boundaries = fragment.split('..') | |
437 | + start_id = boundaries[0] | |
438 | + end_id = boundaries[1] | |
439 | + in_string = False | |
440 | + for word in words: | |
441 | + if word['id'] == start_id: | |
442 | + in_string = True | |
443 | + if in_string and not word_to_ignore(word): | |
444 | + mention_parts.append(word) | |
445 | + if word['id'] == end_id: | |
446 | + break | |
447 | + return mention_parts | |
448 | + | |
449 | + | |
450 | +def get_word(word_id, words): | |
451 | + for word in words: | |
452 | + if word['id'] == word_id: | |
453 | + if not word_to_ignore(word): | |
454 | + return [word] | |
455 | + else: | |
456 | + return [] | |
457 | + return [] | |
458 | + | |
459 | + | |
460 | +def word_to_ignore(word): | |
461 | + return False | |
462 | + | |
463 | + | |
464 | +if __name__ == '__main__': | |
465 | + main() | |
... | ... |
preparator.py
... | ... | @@ -29,7 +29,7 @@ TITLE2REDIRECT_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', 'wikipedia |
29 | 29 | |
30 | 30 | ANNO_PATH = TEST_PATH |
31 | 31 | OUT_PATH = os.path.abspath(os.path.join(MAIN_PATH, 'data', |
32 | - 'test-1to5-20170720.csv')) | |
32 | + 'test-1to5-singletons-20170720.csv')) | |
33 | 33 | EACH_TEXT_SEPARATELLY = False |
34 | 34 | |
35 | 35 | CONTEXT = 5 |
... | ... | @@ -53,6 +53,7 @@ HYPHEN_SIGNS = ['-', '#'] |
53 | 53 | |
54 | 54 | NEG_PROPORTION = 5 |
55 | 55 | RANDOM_VECTORS = True |
56 | +USE_SINGLETONS = True | |
56 | 57 | |
57 | 58 | DEBUG = False |
58 | 59 | POS_COUNT = 0 |
... | ... | @@ -154,9 +155,9 @@ def create_data_vectors(model, freq_list, lemma2synonyms, |
154 | 155 | |
155 | 156 | |
156 | 157 | def diff_mentions(mentions): |
157 | - sets, clustered_mensions = get_sets(mentions) | |
158 | + sets, all_mentions, clustered_mensions = get_sets(mentions) | |
158 | 159 | positives = get_positives(sets) |
159 | - positives, negatives = get_negatives_and_update_positives(clustered_mensions, positives) | |
160 | + positives, negatives = get_negatives_and_update_positives(all_mentions, clustered_mensions, positives) | |
160 | 161 | if len(negatives) != len(positives) and NEG_PROPORTION == 1: |
161 | 162 | print (u'Niezgodna liczba przypadków pozytywnych i negatywnych!') |
162 | 163 | return positives, negatives |
... | ... | @@ -164,8 +165,10 @@ def diff_mentions(mentions): |
164 | 165 | |
165 | 166 | def get_sets(mentions): |
166 | 167 | sets = {} |
168 | + all_mentions = [] | |
167 | 169 | clustered_mensions = [] |
168 | 170 | for mention in mentions: |
171 | + all_mentions.append(mention.attrib['span']) | |
169 | 172 | set_id = mention.attrib['mention_group'] |
170 | 173 | if set_id == 'empty' or set_id == '' or mention.attrib['mention_head'] in POSSIBLE_HEADS: |
171 | 174 | pass |
... | ... | @@ -190,7 +193,7 @@ def get_sets(mentions): |
190 | 193 | print (u'Removing set: ', set_id) |
191 | 194 | sets.pop(set_id) |
192 | 195 | |
193 | - return sets, clustered_mensions | |
196 | + return sets, all_mentions, clustered_mensions | |
194 | 197 | |
195 | 198 | |
196 | 199 | def get_positives(sets): |
... | ... | @@ -201,8 +204,12 @@ def get_positives(sets): |
201 | 204 | return positives |
202 | 205 | |
203 | 206 | |
204 | -def get_negatives_and_update_positives(clustered_mensions, positives): | |
205 | - all_pairs = list(combinations(clustered_mensions, 2)) | |
207 | +def get_negatives_and_update_positives(all_mentions, clustered_mentions, positives): | |
208 | + all_pairs = [] | |
209 | + if USE_SINGLETONS: | |
210 | + all_pairs = list(combinations(all_mentions, 2)) | |
211 | + else: | |
212 | + all_pairs = list(combinations(clustered_mentions, 2)) | |
206 | 213 | all_pairs = set(all_pairs) |
207 | 214 | negatives = [pair for pair in all_pairs if pair not in positives] |
208 | 215 | samples_count = NEG_PROPORTION * len(positives) |
... | ... | @@ -474,7 +481,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms, |
474 | 481 | words_dist = [0] * 11 |
475 | 482 | words_bucket = 0 |
476 | 483 | if mnts_intersect != 1: |
477 | - words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words'] - 1) | |
484 | + words_bucket = get_distance_bucket(ana['start_in_words'] - ante['end_in_words']) | |
478 | 485 | if DEBUG: |
479 | 486 | features.append('Bucket %d' % words_bucket) |
480 | 487 | words_dist[words_bucket] = 1 |
... | ... | @@ -483,7 +490,7 @@ def get_pair_features(pair, mentions_dict, lemma2synonyms, |
483 | 490 | mentions_dist = [0] * 11 |
484 | 491 | mentions_bucket = 0 |
485 | 492 | if mnts_intersect != 1: |
486 | - mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions'] - 1) | |
493 | + mentions_bucket = get_distance_bucket(ana['position_in_mentions'] - ante['position_in_mentions']) | |
487 | 494 | if words_bucket == 10: |
488 | 495 | mentions_bucket = 10 |
489 | 496 | if DEBUG: |
... | ... |