Commit 3fa58087b5af247ae25bfb29e3c8e813c6fd7dca

Authored by Bartłomiej Nitoń
1 parent cf3852f0

Added closest resolve algorithm.

... ... @@ -12,9 +12,9 @@ W2V_SIZE = 50
12 12 W2V_MODEL_NAME = 'w2v_allwiki_nkjpfull_50.model'
13 13  
14 14 # simple or siamese
15   -NEURAL_MODEL_ARCHITECTURE = 'siamese'
16   -NUMBER_OF_FEATURES = 625
17   -NEURAL_MODEL_NAME = 'weights_siamese_model.h5'
  15 +NEURAL_MODEL_ARCHITECTURE = 'simple'
  16 +NUMBER_OF_FEATURES = 1190
  17 +NEURAL_MODEL_NAME = 'model_1190_features.h5'
18 18  
19 19 FREQ_LIST_NAME = 'base.lst'
20 20 LEMMA2SYNONYMS_NAME = 'lemma2synonyms.map'
... ...
corneferencer/main.py
... ... @@ -85,6 +85,8 @@ def process_file(inpath, outpath, informat, resolver, threshold):
85 85 resolve.incremental(text, threshold)
86 86 elif resolver == 'entity_based':
87 87 resolve.entity_based(text, threshold)
  88 + elif resolver == 'closest':
  89 + resolve.closest(text, threshold)
88 90 elif resolver == 'siamese':
89 91 resolve.siamese(text, threshold)
90 92 mmax.write(inpath, outpath, text)
... ...
corneferencer/resolvers/constants.py
1 1 # -*- coding: utf-8 -*-
2 2  
3   -RESOLVERS = ['entity_based', 'incremental', 'siamese']
  3 +RESOLVERS = ['entity_based', 'incremental', 'closest', 'siamese']
4 4  
5 5 NOUN_TAGS = ['subst', 'ger', 'depr']
6 6 PPRON_TAGS = ['ppron12', 'ppron3']
... ...
corneferencer/resolvers/resolve.py
1 1 import numpy
2 2  
3   -from conf import NEURAL_MODEL#, THRESHOLD
  3 +from conf import NEURAL_MODEL
4 4 from corneferencer.resolvers import features
5 5 from corneferencer.resolvers.vectors import get_pair_features, get_pair_vector
6 6  
7 7  
8   -# siamese resolve algorithm
9   -# def siamese(text):
10   -# last_set_id = 0
11   -# for i, ana in enumerate(text.mentions):
12   -# if i > 0:
13   -# best_prediction = 20.0
14   -# best_ante = None
15   -# for ante in text.mentions[:i]:
16   -# if not features.pair_intersect(ante, ana):
17   -# pair_features = get_pair_features(ante, ana)
18   -#
19   -# ante_vec = []
20   -# ante_vec.extend(ante.features)
21   -# ante_vec.extend(pair_features)
22   -# ante_sample = numpy.asarray([ante_vec], dtype=numpy.float32)
23   -#
24   -# ana_vec = []
25   -# ana_vec.extend(ana.features)
26   -# ana_vec.extend(pair_features)
27   -# ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32)
28   -#
29   -# prediction = NEURAL_MODEL.predict([ante_sample, ana_sample])[0]
30   -#
31   -# print (ante.text, '--->', ana.text, '>>', prediction)
32   -#
33   -# if prediction < THRESHOLD and prediction < best_prediction:
34   -# best_prediction = prediction
35   -# best_ante = ante
36   -# if best_ante is not None:
37   -# if best_ante.set:
38   -# ana.set = best_ante.set
39   -# else:
40   -# str_set_id = 'set_%d' % last_set_id
41   -# best_ante.set = str_set_id
42   -# ana.set = str_set_id
43   -# last_set_id += 1
44   -
45   -
46 8 def siamese(text, threshold):
47 9 last_set_id = 0
48 10 for i, ana in enumerate(text.mentions):
... ... @@ -152,3 +114,35 @@ def remove_singletons(sets):
152 114 for s in sets:
153 115 if len(s['mentions']) == 1:
154 116 s['mentions'][0].set = ''
  117 +
  118 +
  119 +# closest resolve algorithm
  120 +def closest(text, threshold):
  121 + last_set_id = 0
  122 + for i, ana in enumerate(text.mentions):
  123 + if i > 0:
  124 + for ante in reversed(text.mentions[:i]):
  125 + if not features.pair_intersect(ante, ana):
  126 + pair_features = get_pair_features(ante, ana)
  127 +
  128 + ante_vec = []
  129 + ante_vec.extend(ante.features)
  130 + ante_vec.extend(pair_features)
  131 + ante_sample = numpy.asarray([ante_vec], dtype=numpy.float32)
  132 +
  133 + ana_vec = []
  134 + ana_vec.extend(ana.features)
  135 + ana_vec.extend(pair_features)
  136 + ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32)
  137 +
  138 + prediction = NEURAL_MODEL.predict([ante_sample, ana_sample])[0]
  139 +
  140 + if prediction > threshold:
  141 + if ante.set:
  142 + ana.set = ante.set
  143 + else:
  144 + str_set_id = 'set_%d' % last_set_id
  145 + ante.set = str_set_id
  146 + ana.set = str_set_id
  147 + last_set_id += 1
  148 + break
... ...