Commit 0f6eeffb07f8f4d854097a0a7c952234c94a26a6
1 parent
a9a48a44
Added mention intersection rule to resolving algorithms.
Showing
1 changed file
with
13 additions
and
9 deletions
corneferencer/resolvers/resolve.py
1 | 1 | import numpy |
2 | 2 | |
3 | 3 | from conf import NEURAL_MODEL, THRESHOLD |
4 | +from corneferencer.resolvers import features | |
4 | 5 | from corneferencer.resolvers.vectors import get_pair_vector |
5 | 6 | |
6 | 7 | |
... | ... | @@ -12,12 +13,13 @@ def incremental(text): |
12 | 13 | best_prediction = 0.0 |
13 | 14 | best_ante = None |
14 | 15 | for ante in text.mentions[:i]: |
15 | - pair_vec = get_pair_vector(ante, ana) | |
16 | - sample = numpy.asarray([pair_vec], dtype=numpy.float32) | |
17 | - prediction = NEURAL_MODEL.predict(sample)[0] | |
18 | - if prediction > THRESHOLD and prediction >= best_prediction: | |
19 | - best_prediction = prediction | |
20 | - best_ante = ante | |
16 | + if not features.pair_intersect(ante, ana): | |
17 | + pair_vec = get_pair_vector(ante, ana) | |
18 | + sample = numpy.asarray([pair_vec], dtype=numpy.float32) | |
19 | + prediction = NEURAL_MODEL.predict(sample)[0] | |
20 | + if prediction > THRESHOLD and prediction >= best_prediction: | |
21 | + best_prediction = prediction | |
22 | + best_ante = ante | |
21 | 23 | if best_ante is not None: |
22 | 24 | # print ('wynik') |
23 | 25 | # print(best_ante.text, best_prediction, ana.text) |
... | ... | @@ -78,9 +80,11 @@ def get_best_set(sets, ana): |
78 | 80 | def predict_set(mentions, ana): |
79 | 81 | prediction_sum = 0.0 |
80 | 82 | for mnt in mentions: |
81 | - pair_vec = get_pair_vector(mnt, ana) | |
82 | - sample = numpy.asarray([pair_vec], dtype=numpy.float32) | |
83 | - prediction = NEURAL_MODEL.predict(sample)[0] | |
83 | + prediction = 0.0 | |
84 | + if not features.pair_intersect(mnt, ana): | |
85 | + pair_vec = get_pair_vector(mnt, ana) | |
86 | + sample = numpy.asarray([pair_vec], dtype=numpy.float32) | |
87 | + prediction = NEURAL_MODEL.predict(sample)[0] | |
84 | 88 | prediction_sum += prediction |
85 | 89 | # print(mnt.text, prediction, ana.text) |
86 | 90 | return prediction_sum / float(len(mentions)) |
... | ... |