resolve.py
2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from conf import NEURAL_MODEL, THRESHOLD
from corneferencer.resolvers.vectors import create_pair_vector
# incremental resolve algorithm
def incremental(text):
last_set_id = 1
for i, ana in enumerate(text.mentions):
if i > 0:
best_prediction = 0.0
best_ante = None
for ante in text.mentions[:i:-1]:
pair_vec = create_pair_vector(ante, ana)
prediction = NEURAL_MODEL.predict(pair_vec)
accuracy = prediction[0]
if accuracy > THRESHOLD and accuracy > best_prediction:
best_prediction = accuracy
best_ante = ante
if best_ante is not None:
if best_ante.set:
ana.set = best_ante.set
else:
str_set_id = 'set_%d' % last_set_id
best_ante.set = str_set_id
ana.set = str_set_id
last_set_id += 1
# entity based resolve algorithm
def entity_based(text):
sets = []
last_set_id = 1
for i, ana in enumerate(text.mentions):
if i > 0:
best_fit = get_best_set(sets, ana)
if best_fit is not None:
ana.set = best_fit['set_id']
best_fit['mentions'].append(ana)
else:
str_set_id = 'set_%d' % last_set_id
sets.append({'set_id': str_set_id,
'mentions': [ana]})
ana.set = str_set_id
last_set_id += 1
else:
str_set_id = 'set_%d' % last_set_id
sets.append({'set_id': str_set_id,
'mentions': [ana]})
ana.set = str_set_id
last_set_id += 1
remove_singletons(sets)
def get_best_set(sets, ana):
best_prediction = 0.0
best_set = None
for s in sets:
accuracy = predict_set(s['mentions'], ana)
if accuracy > THRESHOLD and accuracy >= best_prediction:
best_prediction = accuracy
best_set = s
return best_set
def predict_set(mentions, ana):
accuracy_sum = 0.0
for mnt in mentions:
pair_vec = create_pair_vector(mnt, ana)
prediction = NEURAL_MODEL.predict(pair_vec)
accuracy = prediction[0]
accuracy_sum += accuracy
return accuracy_sum / float(len(mentions))
def remove_singletons(sets):
for s in sets:
if len(s['mentions']) == 1:
s['mentions'][0].set = ''