Commit 2637e9621ff1b6d4dc1782bda90ff3a3a1499ddd

Authored by Bartłomiej Nitoń
1 parent db32de75

Added all2all resolve algorithm.

corneferencer/entities.py
... ... @@ -13,6 +13,11 @@ class Text:
13 13 return mnt.set
14 14 return None
15 15  
  16 + def merge_sets(self, set1, set2):
  17 + for mnt in self.mentions:
  18 + if mnt.set == set1:
  19 + mnt.set = set2
  20 +
16 21  
17 22 class Mention:
18 23  
... ...
corneferencer/main.py
... ... @@ -89,6 +89,8 @@ def process_file(inpath, outpath, informat, resolver, threshold):
89 89 resolve.closest(text, threshold)
90 90 elif resolver == 'siamese':
91 91 resolve.siamese(text, threshold)
  92 + elif resolver == 'all2all':
  93 + resolve.all2all(text, threshold)
92 94 mmax.write(inpath, outpath, text)
93 95  
94 96  
... ...
corneferencer/resolvers/constants.py
1 1 # -*- coding: utf-8 -*-
2 2  
3   -RESOLVERS = ['entity_based', 'incremental', 'closest', 'siamese']
  3 +RESOLVERS = ['entity_based', 'incremental', 'closest', 'siamese', 'all2all']
4 4  
5 5 NOUN_TAGS = ['subst', 'ger', 'depr']
6 6 PPRON_TAGS = ['ppron12', 'ppron3']
... ...
corneferencer/resolvers/resolve.py
... ... @@ -61,6 +61,104 @@ def incremental(text, threshold):
61 61 last_set_id += 1
62 62  
63 63  
  64 +# all2all resolve algorithm
  65 +def all2all_v1(text, threshold):
  66 + last_set_id = 0
  67 + for pos1, mnt1 in enumerate(text.mentions):
  68 + print ('!!!!!!!!!!%s!!!!!!!!!!!' % mnt1.text)
  69 + best_prediction = 0.0
  70 + best_link = None
  71 + if mnt1.set:
  72 + continue
  73 + for pos2, mnt2 in enumerate(text.mentions):
  74 + if (pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
  75 + ante = mnt1
  76 + ana = mnt2
  77 + if pos2 < pos1:
  78 + ante = mnt2
  79 + ana = mnt1
  80 + pair_vec = get_pair_vector(ante, ana)
  81 + sample = numpy.asarray([pair_vec], dtype=numpy.float32)
  82 + prediction = NEURAL_MODEL.predict(sample)[0]
  83 + print (u'%s >> %f' % (mnt2.text, prediction))
  84 + if prediction > threshold and prediction > best_prediction:
  85 + best_prediction = prediction
  86 + best_link = mnt2
  87 + if best_link is not None:
  88 + print (u'best: %s' % best_link.text)
  89 + if best_link.set:
  90 + mnt1.set = best_link.set
  91 + else:
  92 + str_set_id = 'set_%d' % last_set_id
  93 + best_link.set = str_set_id
  94 + mnt1.set = str_set_id
  95 + last_set_id += 1
  96 +
  97 +
  98 +def all2all_debug(text, threshold):
  99 + last_set_id = 0
  100 + for pos1, mnt1 in enumerate(text.mentions):
  101 + print ('!!!!!!!!!!%s!!!!!!!!!!!' % mnt1.text)
  102 + best_prediction = 0.0
  103 + best_link = None
  104 + for pos2, mnt2 in enumerate(text.mentions):
  105 + if ((mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
  106 + ante = mnt1
  107 + ana = mnt2
  108 + if pos2 < pos1:
  109 + ante = mnt2
  110 + ana = mnt1
  111 + pair_vec = get_pair_vector(ante, ana)
  112 + sample = numpy.asarray([pair_vec], dtype=numpy.float32)
  113 + prediction = NEURAL_MODEL.predict(sample)[0]
  114 + print (u'mnt2: %s | %s == %s >> %f' % (mnt2.text, ante.text, ana.text, prediction))
  115 + if prediction > threshold and prediction > best_prediction:
  116 + best_prediction = prediction
  117 + best_link = mnt2
  118 + if best_link is not None:
  119 + print (u'best: %s >> %f, best set: %s, mnt1_set: %s' % (best_link.text, best_prediction, best_link.set, mnt1.set))
  120 + if best_link.set and not mnt1.set:
  121 + mnt1.set = best_link.set
  122 + elif best_link.set and mnt1.set:
  123 + text.merge_sets(best_link.set, mnt1.set)
  124 + elif not best_link.set and not mnt1.set:
  125 + str_set_id = 'set_%d' % last_set_id
  126 + best_link.set = str_set_id
  127 + mnt1.set = str_set_id
  128 + last_set_id += 1
  129 + print (u'best set: %s, mnt1_set: %s' % (best_link.set, mnt1.set))
  130 +
  131 +
  132 +def all2all(text, threshold):
  133 + last_set_id = 0
  134 + for pos1, mnt1 in enumerate(text.mentions):
  135 + best_prediction = 0.0
  136 + best_link = None
  137 + for pos2, mnt2 in enumerate(text.mentions):
  138 + if ((mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
  139 + ante = mnt1
  140 + ana = mnt2
  141 + if pos2 < pos1:
  142 + ante = mnt2
  143 + ana = mnt1
  144 + pair_vec = get_pair_vector(ante, ana)
  145 + sample = numpy.asarray([pair_vec], dtype=numpy.float32)
  146 + prediction = NEURAL_MODEL.predict(sample)[0]
  147 + if prediction > threshold and prediction > best_prediction:
  148 + best_prediction = prediction
  149 + best_link = mnt2
  150 + if best_link is not None:
  151 + if best_link.set and not mnt1.set:
  152 + mnt1.set = best_link.set
  153 + elif best_link.set and mnt1.set:
  154 + text.merge_sets(best_link.set, mnt1.set)
  155 + elif not best_link.set and not mnt1.set:
  156 + str_set_id = 'set_%d' % last_set_id
  157 + best_link.set = str_set_id
  158 + mnt1.set = str_set_id
  159 + last_set_id += 1
  160 +
  161 +
64 162 # entity based resolve algorithm
65 163 def entity_based(text, threshold):
66 164 sets = []
... ...