Commit 3fa58087b5af247ae25bfb29e3c8e813c6fd7dca
1 parent
cf3852f0
Added closest resolve algorithm.
Showing
4 changed files
with
39 additions
and
43 deletions
conf.py
... | ... | @@ -12,9 +12,9 @@ W2V_SIZE = 50 |
12 | 12 | W2V_MODEL_NAME = 'w2v_allwiki_nkjpfull_50.model' |
13 | 13 | |
14 | 14 | # simple or siamese |
15 | -NEURAL_MODEL_ARCHITECTURE = 'siamese' | |
16 | -NUMBER_OF_FEATURES = 625 | |
17 | -NEURAL_MODEL_NAME = 'weights_siamese_model.h5' | |
15 | +NEURAL_MODEL_ARCHITECTURE = 'simple' | |
16 | +NUMBER_OF_FEATURES = 1190 | |
17 | +NEURAL_MODEL_NAME = 'model_1190_features.h5' | |
18 | 18 | |
19 | 19 | FREQ_LIST_NAME = 'base.lst' |
20 | 20 | LEMMA2SYNONYMS_NAME = 'lemma2synonyms.map' |
... | ... |
corneferencer/main.py
... | ... | @@ -85,6 +85,8 @@ def process_file(inpath, outpath, informat, resolver, threshold): |
85 | 85 | resolve.incremental(text, threshold) |
86 | 86 | elif resolver == 'entity_based': |
87 | 87 | resolve.entity_based(text, threshold) |
88 | + elif resolver == 'closest': | |
89 | + resolve.closest(text, threshold) | |
88 | 90 | elif resolver == 'siamese': |
89 | 91 | resolve.siamese(text, threshold) |
90 | 92 | mmax.write(inpath, outpath, text) |
... | ... |
corneferencer/resolvers/constants.py
corneferencer/resolvers/resolve.py
1 | 1 | import numpy |
2 | 2 | |
3 | -from conf import NEURAL_MODEL#, THRESHOLD | |
3 | +from conf import NEURAL_MODEL | |
4 | 4 | from corneferencer.resolvers import features |
5 | 5 | from corneferencer.resolvers.vectors import get_pair_features, get_pair_vector |
6 | 6 | |
7 | 7 | |
8 | -# siamese resolve algorithm | |
9 | -# def siamese(text): | |
10 | -# last_set_id = 0 | |
11 | -# for i, ana in enumerate(text.mentions): | |
12 | -# if i > 0: | |
13 | -# best_prediction = 20.0 | |
14 | -# best_ante = None | |
15 | -# for ante in text.mentions[:i]: | |
16 | -# if not features.pair_intersect(ante, ana): | |
17 | -# pair_features = get_pair_features(ante, ana) | |
18 | -# | |
19 | -# ante_vec = [] | |
20 | -# ante_vec.extend(ante.features) | |
21 | -# ante_vec.extend(pair_features) | |
22 | -# ante_sample = numpy.asarray([ante_vec], dtype=numpy.float32) | |
23 | -# | |
24 | -# ana_vec = [] | |
25 | -# ana_vec.extend(ana.features) | |
26 | -# ana_vec.extend(pair_features) | |
27 | -# ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32) | |
28 | -# | |
29 | -# prediction = NEURAL_MODEL.predict([ante_sample, ana_sample])[0] | |
30 | -# | |
31 | -# print (ante.text, '--->', ana.text, '>>', prediction) | |
32 | -# | |
33 | -# if prediction < THRESHOLD and prediction < best_prediction: | |
34 | -# best_prediction = prediction | |
35 | -# best_ante = ante | |
36 | -# if best_ante is not None: | |
37 | -# if best_ante.set: | |
38 | -# ana.set = best_ante.set | |
39 | -# else: | |
40 | -# str_set_id = 'set_%d' % last_set_id | |
41 | -# best_ante.set = str_set_id | |
42 | -# ana.set = str_set_id | |
43 | -# last_set_id += 1 | |
44 | - | |
45 | - | |
46 | 8 | def siamese(text, threshold): |
47 | 9 | last_set_id = 0 |
48 | 10 | for i, ana in enumerate(text.mentions): |
... | ... | @@ -152,3 +114,35 @@ def remove_singletons(sets): |
152 | 114 | for s in sets: |
153 | 115 | if len(s['mentions']) == 1: |
154 | 116 | s['mentions'][0].set = '' |
117 | + | |
118 | + | |
119 | +# closest resolve algorithm | |
120 | +def closest(text, threshold): | |
121 | + last_set_id = 0 | |
122 | + for i, ana in enumerate(text.mentions): | |
123 | + if i > 0: | |
124 | + for ante in reversed(text.mentions[:i]): | |
125 | + if not features.pair_intersect(ante, ana): | |
126 | + pair_features = get_pair_features(ante, ana) | |
127 | + | |
128 | + ante_vec = [] | |
129 | + ante_vec.extend(ante.features) | |
130 | + ante_vec.extend(pair_features) | |
131 | + ante_sample = numpy.asarray([ante_vec], dtype=numpy.float32) | |
132 | + | |
133 | + ana_vec = [] | |
134 | + ana_vec.extend(ana.features) | |
135 | + ana_vec.extend(pair_features) | |
136 | + ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32) | |
137 | + | |
138 | + prediction = NEURAL_MODEL.predict([ante_sample, ana_sample])[0] | |
139 | + | |
140 | + if prediction > threshold: | |
141 | + if ante.set: | |
142 | + ana.set = ante.set | |
143 | + else: | |
144 | + str_set_id = 'set_%d' % last_set_id | |
145 | + ante.set = str_set_id | |
146 | + ana.set = str_set_id | |
147 | + last_set_id += 1 | |
148 | + break | |
... | ... |