Commit 5be49adfc0657b73f843aca5e8bc769550baf8f5

Authored by Bartłomiej Nitoń
1 parent 3d6332cc

Preparations for experiments with using sieve system output as Corneferencer input.

corneferencer/entities.py
... ... @@ -13,6 +13,16 @@ class Text:
13 13 return mnt.set
14 14 return None
15 15  
  16 + def get_sets(self):
  17 + sets = {}
  18 + for mnt in self.mentions:
  19 + if mnt.set and mnt.set in sets:
  20 + sets[mnt.set].append(mnt)
  21 + elif mnt.set:
  22 + sets[mnt.set] = [mnt]
  23 + return sets
  24 +
  25 +
16 26 def merge_sets(self, set1, set2):
17 27 for mnt in self.mentions:
18 28 if mnt.set == set1:
... ... @@ -25,9 +35,9 @@ class Mention:
25 35 head_orth, head, dominant, node, prec_context,
26 36 follow_context, sentence, position_in_mentions,
27 37 start_in_words, end_in_words, rarest, paragraph_id, sentence_id,
28   - first_in_sentence, first_in_paragraph):
  38 + first_in_sentence, first_in_paragraph, set_id=''):
29 39 self.id = mnt_id
30   - self.set = ''
  40 + self.set = set_id
31 41 self.old_set = ''
32 42 self.text = text
33 43 self.lemmatized_text = lemmatized_text
... ...
corneferencer/inout/mmax.py
... ... @@ -15,8 +15,7 @@ def read(inpath):
15 15 words_path = os.path.join(textdir, '%s_words.xml' % textname)
16 16  
17 17 text = Text(textname)
18   - mentions = read_mentions(mentions_path, words_path)
19   - text.mentions = mentions
  18 + text.mentions = read_mentions(mentions_path, words_path)
20 19 return text
21 20  
22 21  
... ... @@ -43,6 +42,9 @@ def read_mentions(mentions_path, words_path):
43 42 first_in_sentence, first_in_paragraph) = get_context(mention_words, words)
44 43  
45 44 head = get_head(head_orth, mention_words)
  45 + mention_group = ''
  46 + if markable.attrib['mention_group'] != 'empty':
  47 + mention_group = markable.attrib['mention_group']
46 48 mention = Mention(mnt_id=markable.attrib['id'],
47 49 text=span_to_text(span, words, 'orth'),
48 50 lemmatized_text=span_to_text(span, words, 'base'),
... ... @@ -62,7 +64,8 @@ def read_mentions(mentions_path, words_path):
62 64 paragraph_id=paragraph_id,
63 65 sentence_id=sentence_id,
64 66 first_in_sentence=first_in_sentence,
65   - first_in_paragraph=first_in_paragraph)
  67 + first_in_paragraph=first_in_paragraph,
  68 + set_id=mention_group)
66 69 mentions.append(mention)
67 70  
68 71 return mentions
... ... @@ -73,13 +76,16 @@ def get_words(filepath):
73 76 words = []
74 77 for word in tree.xpath("//word"):
75 78 hasnps = False
76   - if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true':
  79 + if (('hasnps' in word.attrib and word.attrib['hasnps'] == 'true') or
  80 + ('hasNps' in word.attrib and word.attrib['hasNps'] == 'true')):
77 81 hasnps = True
78 82 lastinsent = False
79   - if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true':
  83 + if (('lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true') or
  84 + ('lastInSent' in word.attrib and word.attrib['lastInSent'] == 'true')):
80 85 lastinsent = True
81 86 lastinpar = False
82   - if 'lastinpar' in word.attrib and word.attrib['lastinpar'] == 'true':
  87 + if (('lastinpar' in word.attrib and word.attrib['lastinpar'] == 'true') or
  88 + ('lastInPar' in word.attrib and word.attrib['lastInPar'] == 'true')):
83 89 lastinpar = True
84 90 words.append({'id': word.attrib['id'],
85 91 'orth': word.text,
... ...
corneferencer/resolvers/resolve.py
... ... @@ -62,16 +62,14 @@ def incremental(text, threshold):
62 62  
63 63  
64 64 # all2all resolve algorithm
65   -def all2all_v1(text, threshold):
  65 +def all2all_debug(text, threshold):
66 66 last_set_id = 0
67 67 for pos1, mnt1 in enumerate(text.mentions):
68 68 print ('!!!!!!!!!!%s!!!!!!!!!!!' % mnt1.text)
69 69 best_prediction = 0.0
70 70 best_link = None
71   - if mnt1.set:
72   - continue
73 71 for pos2, mnt2 in enumerate(text.mentions):
74   - if (pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
  72 + if ((mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
75 73 ante = mnt1
76 74 ana = mnt2
77 75 if pos2 < pos1:
... ... @@ -80,29 +78,32 @@ def all2all_v1(text, threshold):
80 78 pair_vec = get_pair_vector(ante, ana)
81 79 sample = numpy.asarray([pair_vec], dtype=numpy.float32)
82 80 prediction = NEURAL_MODEL.predict(sample)[0]
83   - print (u'%s >> %f' % (mnt2.text, prediction))
  81 + print (u'mnt2: %s | %s == %s >> %f' % (mnt2.text, ante.text, ana.text, prediction))
84 82 if prediction > threshold and prediction > best_prediction:
85 83 best_prediction = prediction
86 84 best_link = mnt2
87 85 if best_link is not None:
88   - print (u'best: %s' % best_link.text)
89   - if best_link.set:
  86 + print (u'best: %s >> %f, best set: %s, mnt1_set: %s' % (best_link.text, best_prediction, best_link.set, mnt1.set))
  87 + if best_link.set and not mnt1.set:
90 88 mnt1.set = best_link.set
91   - else:
  89 + elif best_link.set and mnt1.set:
  90 + text.merge_sets(best_link.set, mnt1.set)
  91 + elif not best_link.set and not mnt1.set:
92 92 str_set_id = 'set_%d' % last_set_id
93 93 best_link.set = str_set_id
94 94 mnt1.set = str_set_id
95 95 last_set_id += 1
  96 + print (u'best set: %s, mnt1_set: %s' % (best_link.set, mnt1.set))
96 97  
97 98  
98   -def all2all_debug(text, threshold):
  99 +def all2all_v1(text, threshold):
99 100 last_set_id = 0
100 101 for pos1, mnt1 in enumerate(text.mentions):
101   - print ('!!!!!!!!!!%s!!!!!!!!!!!' % mnt1.text)
102 102 best_prediction = 0.0
103 103 best_link = None
104 104 for pos2, mnt2 in enumerate(text.mentions):
105   - if ((mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
  105 + if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set)
  106 + and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
106 107 ante = mnt1
107 108 ana = mnt2
108 109 if pos2 < pos1:
... ... @@ -111,14 +112,14 @@ def all2all_debug(text, threshold):
111 112 pair_vec = get_pair_vector(ante, ana)
112 113 sample = numpy.asarray([pair_vec], dtype=numpy.float32)
113 114 prediction = NEURAL_MODEL.predict(sample)[0]
114   - print (u'mnt2: %s | %s == %s >> %f' % (mnt2.text, ante.text, ana.text, prediction))
115 115 if prediction > threshold and prediction > best_prediction:
116 116 best_prediction = prediction
117 117 best_link = mnt2
118 118 if best_link is not None:
119   - print (u'best: %s >> %f, best set: %s, mnt1_set: %s' % (best_link.text, best_prediction, best_link.set, mnt1.set))
120 119 if best_link.set and not mnt1.set:
121 120 mnt1.set = best_link.set
  121 + elif not best_link.set and mnt1.set:
  122 + best_link.set = mnt1.set
122 123 elif best_link.set and mnt1.set:
123 124 text.merge_sets(best_link.set, mnt1.set)
124 125 elif not best_link.set and not mnt1.set:
... ... @@ -126,11 +127,11 @@ def all2all_debug(text, threshold):
126 127 best_link.set = str_set_id
127 128 mnt1.set = str_set_id
128 129 last_set_id += 1
129   - print (u'best set: %s, mnt1_set: %s' % (best_link.set, mnt1.set))
130 130  
131 131  
132 132 def all2all(text, threshold):
133 133 last_set_id = 0
  134 + sets = text.get_sets()
134 135 for pos1, mnt1 in enumerate(text.mentions):
135 136 best_prediction = 0.0
136 137 best_link = None
... ... @@ -157,9 +158,13 @@ def all2all(text, threshold):
157 158 text.merge_sets(best_link.set, mnt1.set)
158 159 elif not best_link.set and not mnt1.set:
159 160 str_set_id = 'set_%d' % last_set_id
  161 + while str_set_id in sets:
  162 + last_set_id += 1
  163 + str_set_id = 'set_%d' % last_set_id
160 164 best_link.set = str_set_id
161 165 mnt1.set = str_set_id
162   - last_set_id += 1
  166 + sets[str_set_id] = [best_link, mnt1]
  167 +
163 168  
164 169  
165 170 # entity based resolve algorithm
... ...