Commit 5be49adfc0657b73f843aca5e8bc769550baf8f5
1 parent
3d6332cc
Preparations for experiments with using sieve system output as Corneferencer input.
Showing
3 changed files
with
44 additions
and
23 deletions
corneferencer/entities.py
... | ... | @@ -13,6 +13,16 @@ class Text: |
13 | 13 | return mnt.set |
14 | 14 | return None |
15 | 15 | |
16 | + def get_sets(self): | |
17 | + sets = {} | |
18 | + for mnt in self.mentions: | |
19 | + if mnt.set and mnt.set in sets: | |
20 | + sets[mnt.set].append(mnt) | |
21 | + elif mnt.set: | |
22 | + sets[mnt.set] = [mnt] | |
23 | + return sets | |
24 | + | |
25 | + | |
16 | 26 | def merge_sets(self, set1, set2): |
17 | 27 | for mnt in self.mentions: |
18 | 28 | if mnt.set == set1: |
... | ... | @@ -25,9 +35,9 @@ class Mention: |
25 | 35 | head_orth, head, dominant, node, prec_context, |
26 | 36 | follow_context, sentence, position_in_mentions, |
27 | 37 | start_in_words, end_in_words, rarest, paragraph_id, sentence_id, |
28 | - first_in_sentence, first_in_paragraph): | |
38 | + first_in_sentence, first_in_paragraph, set_id=''): | |
29 | 39 | self.id = mnt_id |
30 | - self.set = '' | |
40 | + self.set = set_id | |
31 | 41 | self.old_set = '' |
32 | 42 | self.text = text |
33 | 43 | self.lemmatized_text = lemmatized_text |
... | ... |
corneferencer/inout/mmax.py
... | ... | @@ -15,8 +15,7 @@ def read(inpath): |
15 | 15 | words_path = os.path.join(textdir, '%s_words.xml' % textname) |
16 | 16 | |
17 | 17 | text = Text(textname) |
18 | - mentions = read_mentions(mentions_path, words_path) | |
19 | - text.mentions = mentions | |
18 | + text.mentions = read_mentions(mentions_path, words_path) | |
20 | 19 | return text |
21 | 20 | |
22 | 21 | |
... | ... | @@ -43,6 +42,9 @@ def read_mentions(mentions_path, words_path): |
43 | 42 | first_in_sentence, first_in_paragraph) = get_context(mention_words, words) |
44 | 43 | |
45 | 44 | head = get_head(head_orth, mention_words) |
45 | + mention_group = '' | |
46 | + if markable.attrib['mention_group'] != 'empty': | |
47 | + mention_group = markable.attrib['mention_group'] | |
46 | 48 | mention = Mention(mnt_id=markable.attrib['id'], |
47 | 49 | text=span_to_text(span, words, 'orth'), |
48 | 50 | lemmatized_text=span_to_text(span, words, 'base'), |
... | ... | @@ -62,7 +64,8 @@ def read_mentions(mentions_path, words_path): |
62 | 64 | paragraph_id=paragraph_id, |
63 | 65 | sentence_id=sentence_id, |
64 | 66 | first_in_sentence=first_in_sentence, |
65 | - first_in_paragraph=first_in_paragraph) | |
67 | + first_in_paragraph=first_in_paragraph, | |
68 | + set_id=mention_group) | |
66 | 69 | mentions.append(mention) |
67 | 70 | |
68 | 71 | return mentions |
... | ... | @@ -73,13 +76,16 @@ def get_words(filepath): |
73 | 76 | words = [] |
74 | 77 | for word in tree.xpath("//word"): |
75 | 78 | hasnps = False |
76 | - if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true': | |
79 | + if (('hasnps' in word.attrib and word.attrib['hasnps'] == 'true') or | |
80 | + ('hasNps' in word.attrib and word.attrib['hasNps'] == 'true')): | |
77 | 81 | hasnps = True |
78 | 82 | lastinsent = False |
79 | - if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true': | |
83 | + if (('lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true') or | |
84 | + ('lastInSent' in word.attrib and word.attrib['lastInSent'] == 'true')): | |
80 | 85 | lastinsent = True |
81 | 86 | lastinpar = False |
82 | - if 'lastinpar' in word.attrib and word.attrib['lastinpar'] == 'true': | |
87 | + if (('lastinpar' in word.attrib and word.attrib['lastinpar'] == 'true') or | |
88 | + ('lastInPar' in word.attrib and word.attrib['lastInPar'] == 'true')): | |
83 | 89 | lastinpar = True |
84 | 90 | words.append({'id': word.attrib['id'], |
85 | 91 | 'orth': word.text, |
... | ... |
corneferencer/resolvers/resolve.py
... | ... | @@ -62,16 +62,14 @@ def incremental(text, threshold): |
62 | 62 | |
63 | 63 | |
64 | 64 | # all2all resolve algorithm |
65 | -def all2all_v1(text, threshold): | |
65 | +def all2all_debug(text, threshold): | |
66 | 66 | last_set_id = 0 |
67 | 67 | for pos1, mnt1 in enumerate(text.mentions): |
68 | 68 | print ('!!!!!!!!!!%s!!!!!!!!!!!' % mnt1.text) |
69 | 69 | best_prediction = 0.0 |
70 | 70 | best_link = None |
71 | - if mnt1.set: | |
72 | - continue | |
73 | 71 | for pos2, mnt2 in enumerate(text.mentions): |
74 | - if (pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): | |
72 | + if ((mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): | |
75 | 73 | ante = mnt1 |
76 | 74 | ana = mnt2 |
77 | 75 | if pos2 < pos1: |
... | ... | @@ -80,29 +78,32 @@ def all2all_v1(text, threshold): |
80 | 78 | pair_vec = get_pair_vector(ante, ana) |
81 | 79 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) |
82 | 80 | prediction = NEURAL_MODEL.predict(sample)[0] |
83 | - print (u'%s >> %f' % (mnt2.text, prediction)) | |
81 | + print (u'mnt2: %s | %s == %s >> %f' % (mnt2.text, ante.text, ana.text, prediction)) | |
84 | 82 | if prediction > threshold and prediction > best_prediction: |
85 | 83 | best_prediction = prediction |
86 | 84 | best_link = mnt2 |
87 | 85 | if best_link is not None: |
88 | - print (u'best: %s' % best_link.text) | |
89 | - if best_link.set: | |
86 | + print (u'best: %s >> %f, best set: %s, mnt1_set: %s' % (best_link.text, best_prediction, best_link.set, mnt1.set)) | |
87 | + if best_link.set and not mnt1.set: | |
90 | 88 | mnt1.set = best_link.set |
91 | - else: | |
89 | + elif best_link.set and mnt1.set: | |
90 | + text.merge_sets(best_link.set, mnt1.set) | |
91 | + elif not best_link.set and not mnt1.set: | |
92 | 92 | str_set_id = 'set_%d' % last_set_id |
93 | 93 | best_link.set = str_set_id |
94 | 94 | mnt1.set = str_set_id |
95 | 95 | last_set_id += 1 |
96 | + print (u'best set: %s, mnt1_set: %s' % (best_link.set, mnt1.set)) | |
96 | 97 | |
97 | 98 | |
98 | -def all2all_debug(text, threshold): | |
99 | +def all2all_v1(text, threshold): | |
99 | 100 | last_set_id = 0 |
100 | 101 | for pos1, mnt1 in enumerate(text.mentions): |
101 | - print ('!!!!!!!!!!%s!!!!!!!!!!!' % mnt1.text) | |
102 | 102 | best_prediction = 0.0 |
103 | 103 | best_link = None |
104 | 104 | for pos2, mnt2 in enumerate(text.mentions): |
105 | - if ((mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): | |
105 | + if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set) | |
106 | + and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): | |
106 | 107 | ante = mnt1 |
107 | 108 | ana = mnt2 |
108 | 109 | if pos2 < pos1: |
... | ... | @@ -111,14 +112,14 @@ def all2all_debug(text, threshold): |
111 | 112 | pair_vec = get_pair_vector(ante, ana) |
112 | 113 | sample = numpy.asarray([pair_vec], dtype=numpy.float32) |
113 | 114 | prediction = NEURAL_MODEL.predict(sample)[0] |
114 | - print (u'mnt2: %s | %s == %s >> %f' % (mnt2.text, ante.text, ana.text, prediction)) | |
115 | 115 | if prediction > threshold and prediction > best_prediction: |
116 | 116 | best_prediction = prediction |
117 | 117 | best_link = mnt2 |
118 | 118 | if best_link is not None: |
119 | - print (u'best: %s >> %f, best set: %s, mnt1_set: %s' % (best_link.text, best_prediction, best_link.set, mnt1.set)) | |
120 | 119 | if best_link.set and not mnt1.set: |
121 | 120 | mnt1.set = best_link.set |
121 | + elif not best_link.set and mnt1.set: | |
122 | + best_link.set = mnt1.set | |
122 | 123 | elif best_link.set and mnt1.set: |
123 | 124 | text.merge_sets(best_link.set, mnt1.set) |
124 | 125 | elif not best_link.set and not mnt1.set: |
... | ... | @@ -126,11 +127,11 @@ def all2all_debug(text, threshold): |
126 | 127 | best_link.set = str_set_id |
127 | 128 | mnt1.set = str_set_id |
128 | 129 | last_set_id += 1 |
129 | - print (u'best set: %s, mnt1_set: %s' % (best_link.set, mnt1.set)) | |
130 | 130 | |
131 | 131 | |
132 | 132 | def all2all(text, threshold): |
133 | 133 | last_set_id = 0 |
134 | + sets = text.get_sets() | |
134 | 135 | for pos1, mnt1 in enumerate(text.mentions): |
135 | 136 | best_prediction = 0.0 |
136 | 137 | best_link = None |
... | ... | @@ -157,9 +158,13 @@ def all2all(text, threshold): |
157 | 158 | text.merge_sets(best_link.set, mnt1.set) |
158 | 159 | elif not best_link.set and not mnt1.set: |
159 | 160 | str_set_id = 'set_%d' % last_set_id |
161 | + while str_set_id in sets: | |
162 | + last_set_id += 1 | |
163 | + str_set_id = 'set_%d' % last_set_id | |
160 | 164 | best_link.set = str_set_id |
161 | 165 | mnt1.set = str_set_id |
162 | - last_set_id += 1 | |
166 | + sets[str_set_id] = [best_link, mnt1] | |
167 | + | |
163 | 168 | |
164 | 169 | |
165 | 170 | # entity based resolve algorithm |
... | ... |