Commit 2637e9621ff1b6d4dc1782bda90ff3a3a1499ddd
1 parent
db32de75
Added all2all resolve algorithm.
Showing
4 changed files
with
106 additions
and
1 deletions
corneferencer/entities.py
corneferencer/main.py
... | ... | @@ -89,6 +89,8 @@ def process_file(inpath, outpath, informat, resolver, threshold): |
89 | 89 | resolve.closest(text, threshold) |
90 | 90 | elif resolver == 'siamese': |
91 | 91 | resolve.siamese(text, threshold) |
92 | + elif resolver == 'all2all': | |
93 | + resolve.all2all(text, threshold) | |
92 | 94 | mmax.write(inpath, outpath, text) |
93 | 95 | |
94 | 96 | |
... | ... |
corneferencer/resolvers/constants.py
corneferencer/resolvers/resolve.py
... | ... | @@ -61,6 +61,104 @@ def incremental(text, threshold): |
61 | 61 | last_set_id += 1 |
62 | 62 | |
63 | 63 | |
64 | +# all2all resolve algorithm | |
65 | +def all2all_v1(text, threshold): | |
66 | + last_set_id = 0 | |
67 | + for pos1, mnt1 in enumerate(text.mentions): | |
68 | + print ('!!!!!!!!!!%s!!!!!!!!!!!' % mnt1.text) | |
69 | + best_prediction = 0.0 | |
70 | + best_link = None | |
71 | + if mnt1.set: | |
72 | + continue | |
73 | + for pos2, mnt2 in enumerate(text.mentions): | |
74 | + if (pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): | |
75 | + ante = mnt1 | |
76 | + ana = mnt2 | |
77 | + if pos2 < pos1: | |
78 | + ante = mnt2 | |
79 | + ana = mnt1 | |
80 | + pair_vec = get_pair_vector(ante, ana) | |
81 | + sample = numpy.asarray([pair_vec], dtype=numpy.float32) | |
82 | + prediction = NEURAL_MODEL.predict(sample)[0] | |
83 | + print (u'%s >> %f' % (mnt2.text, prediction)) | |
84 | + if prediction > threshold and prediction > best_prediction: | |
85 | + best_prediction = prediction | |
86 | + best_link = mnt2 | |
87 | + if best_link is not None: | |
88 | + print (u'best: %s' % best_link.text) | |
89 | + if best_link.set: | |
90 | + mnt1.set = best_link.set | |
91 | + else: | |
92 | + str_set_id = 'set_%d' % last_set_id | |
93 | + best_link.set = str_set_id | |
94 | + mnt1.set = str_set_id | |
95 | + last_set_id += 1 | |
96 | + | |
97 | + | |
98 | +def all2all_debug(text, threshold): | |
99 | + last_set_id = 0 | |
100 | + for pos1, mnt1 in enumerate(text.mentions): | |
101 | + print ('!!!!!!!!!!%s!!!!!!!!!!!' % mnt1.text) | |
102 | + best_prediction = 0.0 | |
103 | + best_link = None | |
104 | + for pos2, mnt2 in enumerate(text.mentions): | |
105 | + if ((mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): | |
106 | + ante = mnt1 | |
107 | + ana = mnt2 | |
108 | + if pos2 < pos1: | |
109 | + ante = mnt2 | |
110 | + ana = mnt1 | |
111 | + pair_vec = get_pair_vector(ante, ana) | |
112 | + sample = numpy.asarray([pair_vec], dtype=numpy.float32) | |
113 | + prediction = NEURAL_MODEL.predict(sample)[0] | |
114 | + print (u'mnt2: %s | %s == %s >> %f' % (mnt2.text, ante.text, ana.text, prediction)) | |
115 | + if prediction > threshold and prediction > best_prediction: | |
116 | + best_prediction = prediction | |
117 | + best_link = mnt2 | |
118 | + if best_link is not None: | |
119 | + print (u'best: %s >> %f, best set: %s, mnt1_set: %s' % (best_link.text, best_prediction, best_link.set, mnt1.set)) | |
120 | + if best_link.set and not mnt1.set: | |
121 | + mnt1.set = best_link.set | |
122 | + elif best_link.set and mnt1.set: | |
123 | + text.merge_sets(best_link.set, mnt1.set) | |
124 | + elif not best_link.set and not mnt1.set: | |
125 | + str_set_id = 'set_%d' % last_set_id | |
126 | + best_link.set = str_set_id | |
127 | + mnt1.set = str_set_id | |
128 | + last_set_id += 1 | |
129 | + print (u'best set: %s, mnt1_set: %s' % (best_link.set, mnt1.set)) | |
130 | + | |
131 | + | |
132 | +def all2all(text, threshold): | |
133 | + last_set_id = 0 | |
134 | + for pos1, mnt1 in enumerate(text.mentions): | |
135 | + best_prediction = 0.0 | |
136 | + best_link = None | |
137 | + for pos2, mnt2 in enumerate(text.mentions): | |
138 | + if ((mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)): | |
139 | + ante = mnt1 | |
140 | + ana = mnt2 | |
141 | + if pos2 < pos1: | |
142 | + ante = mnt2 | |
143 | + ana = mnt1 | |
144 | + pair_vec = get_pair_vector(ante, ana) | |
145 | + sample = numpy.asarray([pair_vec], dtype=numpy.float32) | |
146 | + prediction = NEURAL_MODEL.predict(sample)[0] | |
147 | + if prediction > threshold and prediction > best_prediction: | |
148 | + best_prediction = prediction | |
149 | + best_link = mnt2 | |
150 | + if best_link is not None: | |
151 | + if best_link.set and not mnt1.set: | |
152 | + mnt1.set = best_link.set | |
153 | + elif best_link.set and mnt1.set: | |
154 | + text.merge_sets(best_link.set, mnt1.set) | |
155 | + elif not best_link.set and not mnt1.set: | |
156 | + str_set_id = 'set_%d' % last_set_id | |
157 | + best_link.set = str_set_id | |
158 | + mnt1.set = str_set_id | |
159 | + last_set_id += 1 | |
160 | + | |
161 | + | |
64 | 162 | # entity based resolve algorithm |
65 | 163 | def entity_based(text, threshold): |
66 | 164 | sets = [] |
... | ... |