Commit db88d6e4b4617b7bf5e3ef3596b91c45b0d8521a

Authored by Bartłomiej Nitoń
1 parent 5be49adf

Add TEI format support.

... ... @@ -7,6 +7,7 @@ from gensim.models.word2vec import Word2Vec
7 7  
8 8 CONTEXT = 5
9 9 RANDOM_WORD_VECTORS = True
  10 +CLEAR_INPUT = False
10 11 W2V_SIZE = 50
11 12 W2V_MODEL_NAME = 'w2v_allwiki_nkjpfull_50.model'
12 13  
... ...
corneferencer/entities.py
... ... @@ -13,6 +13,12 @@ class Text:
13 13 return mnt.set
14 14 return None
15 15  
  16 + def get_mention(self, mnt_id):
  17 + for mnt in self.mentions:
  18 + if mnt.id == mnt_id:
  19 + return mnt
  20 + return None
  21 +
16 22 def get_sets(self):
17 23 sets = {}
18 24 for mnt in self.mentions:
... ... @@ -22,7 +28,6 @@ class Text:
22 28 sets[mnt.set] = [mnt]
23 29 return sets
24 30  
25   -
26 31 def merge_sets(self, set1, set2):
27 32 for mnt in self.mentions:
28 33 if mnt.set == set1:
... ... @@ -38,7 +43,6 @@ class Mention:
38 43 first_in_sentence, first_in_paragraph, set_id=''):
39 44 self.id = mnt_id
40 45 self.set = set_id
41   - self.old_set = ''
42 46 self.text = text
43 47 self.lemmatized_text = lemmatized_text
44 48 self.words = words
... ...
corneferencer/inout/constants.py
1   -INPUT_FORMATS = ['mmax']
  1 +INPUT_FORMATS = ['mmax', 'tei']
... ...
corneferencer/inout/mmax.py
... ... @@ -3,7 +3,7 @@ import shutil
3 3  
4 4 from lxml import etree
5 5  
6   -from conf import CONTEXT, FREQ_LIST
  6 +from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST
7 7 from corneferencer.entities import Mention, Text
8 8  
9 9  
... ... @@ -43,7 +43,7 @@ def read_mentions(mentions_path, words_path):
43 43  
44 44 head = get_head(head_orth, mention_words)
45 45 mention_group = ''
46   - if markable.attrib['mention_group'] != 'empty':
  46 + if markable.attrib['mention_group'] != 'empty' and not CLEAR_INPUT:
47 47 mention_group = markable.attrib['mention_group']
48 48 mention = Mention(mnt_id=markable.attrib['id'],
49 49 text=span_to_text(span, words, 'orth'),
... ... @@ -77,15 +77,15 @@ def get_words(filepath):
77 77 for word in tree.xpath("//word"):
78 78 hasnps = False
79 79 if (('hasnps' in word.attrib and word.attrib['hasnps'] == 'true') or
80   - ('hasNps' in word.attrib and word.attrib['hasNps'] == 'true')):
  80 + ('hasNps' in word.attrib and word.attrib['hasNps'] == 'true')):
81 81 hasnps = True
82 82 lastinsent = False
83 83 if (('lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true') or
84   - ('lastInSent' in word.attrib and word.attrib['lastInSent'] == 'true')):
  84 + ('lastInSent' in word.attrib and word.attrib['lastInSent'] == 'true')):
85 85 lastinsent = True
86 86 lastinpar = False
87 87 if (('lastinpar' in word.attrib and word.attrib['lastinpar'] == 'true') or
88   - ('lastInPar' in word.attrib and word.attrib['lastInPar'] == 'true')):
  88 + ('lastInPar' in word.attrib and word.attrib['lastInPar'] == 'true')):
89 89 lastinpar = True
90 90 words.append({'id': word.attrib['id'],
91 91 'orth': word.text,
... ... @@ -388,10 +388,13 @@ def write_mentions(inpath, outpath, text):
388 388 tree = etree.parse(inpath)
389 389 mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'})
390 390  
  391 + sets = text.get_sets()
  392 +
391 393 for mnt in mentions:
392 394 mnt_set = text.get_mention_set(mnt.attrib['id'])
393 395 if mnt_set:
394 396 mnt.attrib['mention_group'] = mnt_set
  397 + mnt.attrib['dominant'] = get_dominant(sets[mnt_set])
395 398 else:
396 399 mnt.attrib['mention_group'] = 'empty'
397 400  
... ... @@ -399,3 +402,11 @@ def write_mentions(inpath, outpath, text):
399 402 output_file.write(etree.tostring(tree, pretty_print=True,
400 403 xml_declaration=True, encoding='UTF-8',
401 404 doctype=u'<!DOCTYPE markables SYSTEM "markables.dtd">'))
  405 +
  406 +
  407 +def get_dominant(mentions):
  408 + longest_mention = mentions[0]
  409 + for mnt in mentions:
  410 + if len(mnt.words) > len(longest_mention.words):
  411 + longest_mention = mnt
  412 + return longest_mention.text
... ...
corneferencer/inout/tei.py 0 → 100644
  1 +import gzip
  2 +import os
  3 +import shutil
  4 +
  5 +from lxml import etree
  6 +
  7 +from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST
  8 +from corneferencer.entities import Mention, Text
  9 +from corneferencer.utils import eprint
  10 +
  11 +
  12 +NKJP_NS = 'http://www.nkjp.pl/ns/1.0'
  13 +TEI_NS = 'http://www.tei-c.org/ns/1.0'
  14 +XI_NS = 'http://www.w3.org/2001/XInclude'
  15 +XML_NS = 'http://www.w3.org/XML/1998/namespace'
  16 +NSMAP = {None: TEI_NS,
  17 + 'nkjp': NKJP_NS,
  18 + 'xi': XI_NS}
  19 +
  20 +
  21 +def read(inpath):
  22 + textname = os.path.basename(inpath)
  23 +
  24 + text = Text(textname)
  25 +
  26 + # essential layers
  27 + ann_segmentation = os.path.join(inpath, 'ann_segmentation.xml.gz')
  28 + ann_morphosyntax = os.path.join(inpath, 'ann_morphosyntax.xml.gz')
  29 + ann_mentions = os.path.join(inpath, 'ann_mentions.xml.gz')
  30 +
  31 + # additional layers
  32 + ann_coreference = os.path.join(inpath, 'ann_coreference.xml.gz')
  33 +
  34 + if os.path.exists(ann_segmentation):
  35 + pass
  36 + else:
  37 + eprint("Error: missing segmentation layer for text %s!" % textname)
  38 + return None
  39 +
  40 + if os.path.exists(ann_morphosyntax):
  41 + (segments, segments_ids) = read_morphosyntax(ann_morphosyntax)
  42 + else:
  43 + eprint("Error: missing morphosyntax layer for text %s!" % textname)
  44 + return None
  45 +
  46 + if os.path.exists(ann_mentions):
  47 + text.mentions = read_mentions(ann_mentions, segments, segments_ids)
  48 + else:
  49 + eprint("Error: missing mentions layer for text %s!" % textname)
  50 + return None
  51 +
  52 + if os.path.exists(ann_coreference) and not CLEAR_INPUT:
  53 + add_coreference_layer(ann_coreference, text)
  54 +
  55 + return text
  56 +
  57 +
  58 +# morphosyntax
  59 +def read_morphosyntax(ann_archive):
  60 + segments_dict = {}
  61 + segments_ids = []
  62 + ann_file = gzip.open(ann_archive, 'rb')
  63 + parser = etree.XMLParser(encoding="utf-8")
  64 + tree = etree.parse(ann_file, parser)
  65 + body = tree.xpath('//xmlns:body', namespaces={'xmlns': TEI_NS})[0]
  66 +
  67 + paragraphs = body.xpath(".//xmlns:p", namespaces={'xmlns': TEI_NS})
  68 + for par in paragraphs:
  69 + sentences = par.xpath(".//xmlns:s", namespaces={'xmlns': TEI_NS})
  70 + for sent_id, sent in enumerate(sentences):
  71 + segments = sent.xpath(".//xmlns:seg", namespaces={'xmlns': TEI_NS})
  72 + for seg_id, seg in enumerate(segments):
  73 + lastinsent = False
  74 + lastinpar = False
  75 + if seg_id == len(segments) - 1:
  76 + lastinsent = True
  77 + if sent_id == len(sentences) - 1:
  78 + lastinpar = True
  79 + segment = read_segment(seg, lastinsent, lastinpar)
  80 + segments_dict[segment['id']] = segment
  81 + segments_ids.append(segment['id'])
  82 +
  83 + return segments_dict, segments_ids
  84 +
  85 +
  86 +def read_segment(seg, lastinsent, lastinpar):
  87 + hasnps = False
  88 + base = ''
  89 + ctag = ''
  90 + msd = ''
  91 + orth = ''
  92 + idx = seg.attrib['{%s}id' % XML_NS]
  93 + for f in seg.xpath(".//xmlns:f", namespaces={'xmlns': TEI_NS}):
  94 + if f.attrib['name'] == 'orth':
  95 + orth = get_f_string(f)
  96 + elif f.attrib['name'] == 'nps':
  97 + hasnps = get_f_bin_value(f)
  98 + elif f.attrib['name'] == 'interpretation':
  99 + interpretation = get_f_string(f)
  100 + (base, ctag, msd) = parse_interpretation(interpretation)
  101 + return {'id': idx,
  102 + 'orth': orth,
  103 + 'base': base,
  104 + 'hasnps': hasnps,
  105 + 'lastinsent': lastinsent,
  106 + 'lastinpar': lastinpar,
  107 + 'ctag': ctag,
  108 + 'msd': msd,
  109 + 'number': get_number(msd),
  110 + 'person': get_person(msd),
  111 + 'gender': get_gender(msd)}
  112 +
  113 +
  114 +def get_f_string(f):
  115 + return f.getchildren()[0].text
  116 +
  117 +
  118 +def get_f_bin_value(f):
  119 + value = False
  120 + if f.getchildren()[0].attrib['value'] == 'true':
  121 + value = True
  122 + return value
  123 +
  124 +
  125 +def parse_interpretation(interpretation):
  126 + split = interpretation.split(':')
  127 + if interpretation.startswith(':'):
  128 + base = ':'
  129 + ctag = 'interp'
  130 + msd = ''
  131 + elif len(split) > 2:
  132 + base = split[0]
  133 + ctag = split[1]
  134 + msd = ':'.join(split[2:])
  135 + else:
  136 + base = split[0]
  137 + ctag = split[1]
  138 + msd = ''
  139 + return base, ctag, msd
  140 +
  141 +
  142 +def get_gender(msd):
  143 + tags = msd.split(':')
  144 + if 'm1' in tags:
  145 + return 'm1'
  146 + elif 'm2' in tags:
  147 + return 'm2'
  148 + elif 'm3' in tags:
  149 + return 'm3'
  150 + elif 'f' in tags:
  151 + return 'f'
  152 + elif 'n' in tags:
  153 + return 'n'
  154 + else:
  155 + return 'unk'
  156 +
  157 +
  158 +def get_person(msd):
  159 + tags = msd.split(':')
  160 + if 'pri' in tags:
  161 + return 'pri'
  162 + elif 'sec' in tags:
  163 + return 'sec'
  164 + elif 'ter' in tags:
  165 + return 'ter'
  166 + else:
  167 + return 'unk'
  168 +
  169 +
  170 +def get_number(msd):
  171 + tags = msd.split(':')
  172 + if 'sg' in tags:
  173 + return 'sg'
  174 + elif 'pl' in tags:
  175 + return 'pl'
  176 + else:
  177 + return 'unk'
  178 +
  179 +
  180 +# mentions
  181 +def read_mentions(ann_archive, segments, segments_ids):
  182 + mentions = []
  183 +
  184 + ann_file = gzip.open(ann_archive, 'rb')
  185 + parser = etree.XMLParser(encoding="utf-8")
  186 + tree = etree.parse(ann_file, parser)
  187 + body = tree.xpath('//xmlns:body', namespaces={'xmlns': TEI_NS})[0]
  188 +
  189 + paragraphs = body.xpath(".//xmlns:p", namespaces={'xmlns': TEI_NS})
  190 + mnt_id = 0
  191 + for par_id, par in enumerate(paragraphs):
  192 + sentences = par.xpath(".//xmlns:s", namespaces={'xmlns': TEI_NS})
  193 + for sent_id, sent in enumerate(sentences):
  194 + mention_nodes = sent.xpath(".//xmlns:seg", namespaces={'xmlns': TEI_NS})
  195 + for mnt in mention_nodes:
  196 + mnt_id += 1
  197 + mention = get_mention(mnt, mnt_id, segments, segments_ids, par_id, sent_id)
  198 + mentions.append(mention)
  199 +
  200 + return mentions
  201 +
  202 +
  203 +def get_mention(mention, mnt_id, segments, segments_ids, paragraph_id, sentence_id):
  204 + idx = mention.attrib['{%s}id' % XML_NS]
  205 +
  206 + mnt_segments = []
  207 + for ptr in mention.xpath(".//xmlns:ptr", namespaces={'xmlns': TEI_NS}):
  208 + seg_id = ptr.attrib['target'].split('#')[-1]
  209 + if not word_to_ignore(segments[seg_id]):
  210 + mnt_segments.append(segments[seg_id])
  211 +
  212 + semh = None
  213 + for f in mention.xpath(".//xmlns:f", namespaces={'xmlns': TEI_NS}):
  214 + if f.attrib['name'] == 'semh':
  215 + semh_id = get_fval(f).split('#')[-1]
  216 + semh = segments[semh_id]
  217 +
  218 + (sent_segments, prec_context, follow_context,
  219 + first_in_sentence, first_in_paragraph) = get_context(mnt_segments, segments, segments_ids)
  220 +
  221 + mention = Mention(mnt_id=idx,
  222 + text=to_text(mnt_segments, 'orth'),
  223 + lemmatized_text=to_text(mnt_segments, 'base'),
  224 + words=mnt_segments,
  225 + span=None,
  226 + head_orth=semh['orth'],
  227 + head=semh,
  228 + node=mention,
  229 + prec_context=prec_context,
  230 + follow_context=follow_context,
  231 + sentence=sent_segments,
  232 + sentence_id=sentence_id,
  233 + paragraph_id=paragraph_id,
  234 + position_in_mentions=mnt_id,
  235 + start_in_words=segments_ids.index(mnt_segments[0]['id']),
  236 + end_in_words=segments_ids.index(mnt_segments[-1]['id']),
  237 + rarest=get_rarest_word(mnt_segments),
  238 + first_in_sentence=first_in_sentence,
  239 + first_in_paragraph=first_in_paragraph,
  240 + set_id=None,
  241 + dominant=None,)
  242 +
  243 + return mention
  244 +
  245 +
  246 +def get_context(mention_words, segments, segments_ids):
  247 + prec_context = []
  248 + follow_context = []
  249 + sentence = []
  250 + first_word = mention_words[0]
  251 + last_word = mention_words[-1]
  252 + first_in_sentence = False
  253 + first_in_paragraph = False
  254 + for idx, morph_id in enumerate(segments_ids):
  255 + word = segments[morph_id]
  256 + if word['id'] == first_word['id']:
  257 + prec_context = get_prec_context(idx, segments, segments_ids)
  258 + if idx == 0 or segments[segments_ids[idx-1]]['lastinsent']:
  259 + first_in_sentence = True
  260 + if idx == 0 or segments[segments_ids[idx-1]]['lastinpar']:
  261 + first_in_paragraph = True
  262 + if word['id'] == last_word['id']:
  263 + follow_context = get_follow_context(idx, segments, segments_ids)
  264 + sentence = get_sentence(idx, segments, segments_ids)
  265 + break
  266 + return (sentence, prec_context, follow_context, first_in_sentence, first_in_paragraph)
  267 +
  268 +
  269 +def get_prec_context(mention_start, segments, segments_ids):
  270 + context = []
  271 + context_start = mention_start - 1
  272 + while context_start >= 0:
  273 + if not word_to_ignore(segments[segments_ids[context_start]]):
  274 + context.append(segments[segments_ids[context_start]])
  275 + if len(context) == CONTEXT:
  276 + break
  277 + context_start -= 1
  278 + context.reverse()
  279 + return context
  280 +
  281 +
  282 +def get_follow_context(mention_end, segments, segments_ids):
  283 + context = []
  284 + context_end = mention_end + 1
  285 + while context_end < len(segments):
  286 + if not word_to_ignore(segments[segments_ids[context_end]]):
  287 + context.append(segments[segments_ids[context_end]])
  288 + if len(context) == CONTEXT:
  289 + break
  290 + context_end += 1
  291 + return context
  292 +
  293 +
  294 +def get_sentence(word_idx, segments, segments_ids):
  295 + sentence_start = get_sentence_start(segments, segments_ids, word_idx)
  296 + sentence_end = get_sentence_end(segments, segments_ids, word_idx)
  297 + sentence = [segments[morph_id] for morph_id in segments_ids[sentence_start:sentence_end + 1]
  298 + if not word_to_ignore(segments[morph_id])]
  299 + return sentence
  300 +
  301 +
  302 +def get_sentence_start(segments, segments_ids, word_idx):
  303 + search_start = word_idx
  304 + while word_idx >= 0:
  305 + if segments[segments_ids[word_idx]]['lastinsent'] and search_start != word_idx:
  306 + return word_idx + 1
  307 + word_idx -= 1
  308 + return 0
  309 +
  310 +
  311 +def get_sentence_end(segments, segments_ids, word_idx):
  312 + while word_idx < len(segments):
  313 + if segments[segments_ids[word_idx]]['lastinsent']:
  314 + return word_idx
  315 + word_idx += 1
  316 + return len(segments) - 1
  317 +
  318 +
  319 +def word_to_ignore(word):
  320 + if word['ctag'] == 'interp':
  321 + return True
  322 + return False
  323 +
  324 +
  325 +def to_text(words, form):
  326 + text = ''
  327 + for idx, word in enumerate(words):
  328 + if word['hasnps'] or idx == 0:
  329 + text += word[form]
  330 + else:
  331 + text += u' %s' % word[form]
  332 + return text
  333 +
  334 +
  335 +def get_fval(f):
  336 + return f.attrib['fVal']
  337 +
  338 +
  339 +def get_rarest_word(words):
  340 + min_freq = 0
  341 + rarest_word = words[0]
  342 + for i, word in enumerate(words):
  343 + word_freq = 0
  344 + if word['base'] in FREQ_LIST:
  345 + word_freq = FREQ_LIST[word['base']]
  346 +
  347 + if i == 0 or word_freq < min_freq:
  348 + min_freq = word_freq
  349 + rarest_word = word
  350 + return rarest_word
  351 +
  352 +
  353 +# coreference
  354 +def add_coreference_layer(ann_archive, text):
  355 + ann_file = gzip.open(ann_archive, 'rb')
  356 + parser = etree.XMLParser(encoding="utf-8")
  357 + tree = etree.parse(ann_file, parser)
  358 + body = tree.xpath('//xmlns:body', namespaces={'xmlns': TEI_NS})[0]
  359 +
  360 + parts = body.xpath(".//xmlns:p", namespaces={'xmlns': TEI_NS})
  361 + for par in parts:
  362 + coreferences = par.xpath(".//xmlns:seg", namespaces={'xmlns': TEI_NS})
  363 + for cor in coreferences:
  364 + add_coreference(cor, text)
  365 +
  366 +
  367 +def add_coreference(coref, text):
  368 + idx = coref.attrib['{%s}id' % XML_NS]
  369 +
  370 + coref_type = None
  371 + dominant = None
  372 + for f in coref.xpath(".//xmlns:f", namespaces={'xmlns': TEI_NS}):
  373 + if f.attrib['name'] == 'type':
  374 + coref_type = get_fval(f)
  375 + elif f.attrib['name'] == 'dominant':
  376 + dominant = get_fval(f)
  377 +
  378 + if coref_type == 'ident':
  379 + for ptr in coref.xpath(".//xmlns:ptr", namespaces={'xmlns': TEI_NS}):
  380 + mnt_id = ptr.attrib['target'].split('#')[-1]
  381 + mention = text.get_mention(mnt_id)
  382 + mention.set = idx
  383 + mention.dominant = dominant
  384 +
  385 +
  386 +# write
  387 +def write(inpath, outpath, text):
  388 +
  389 + if not os.path.exists(outpath):
  390 + os.mkdir(outpath)
  391 +
  392 + for filename in os.listdir(inpath):
  393 + if not filename.startswith('ann_coreference'):
  394 + layer_inpath = os.path.join(inpath, filename)
  395 + layer_outpath = os.path.join(outpath, filename)
  396 + copy_layer(layer_inpath, layer_outpath)
  397 +
  398 + coref_outpath = os.path.join(outpath, 'ann_coreference.xml.gz')
  399 + write_coreference(coref_outpath, text)
  400 +
  401 +
  402 +def copy_layer(src, dest):
  403 + shutil.copyfile(src, dest)
  404 +
  405 +
  406 +def write_coreference(outpath, text):
  407 + root, tei = write_header()
  408 + write_body(tei, text)
  409 +
  410 + with gzip.open(outpath, 'wb') as output_file:
  411 + output_file.write(etree.tostring(root, pretty_print=True,
  412 + xml_declaration=True, encoding='UTF-8'))
  413 +
  414 +
  415 +def write_header():
  416 + root = etree.Element('teiCorpus', nsmap=NSMAP)
  417 +
  418 + corpus_xinclude = etree.SubElement(root, etree.QName(XI_NS, 'include'))
  419 + corpus_xinclude.attrib['href'] = 'PCC_header.xml'
  420 +
  421 + tei = etree.SubElement(root, 'TEI')
  422 + tei_xinclude = etree.SubElement(tei, etree.QName(XI_NS, 'include'))
  423 + tei_xinclude.attrib['href'] = 'header.xml'
  424 +
  425 + return root, tei
  426 +
  427 +
  428 +def write_body(tei, text):
  429 + text_node = etree.SubElement(tei, 'text')
  430 + body = etree.SubElement(text_node, 'body')
  431 + p = etree.SubElement(body, 'p')
  432 +
  433 + sets = text.get_sets()
  434 + for set_id in sets:
  435 + comment_text = create_set_comment(sets[set_id])
  436 + p.append(etree.Comment(comment_text))
  437 +
  438 + seg = etree.SubElement(p, 'seg')
  439 + seg.attrib[etree.QName(XML_NS, 'id')] = set_id.replace('set', 'coreference')
  440 +
  441 + fs = etree.SubElement(seg, 'fs')
  442 + fs.attrib['type'] = 'coreference'
  443 +
  444 + f_type = etree.SubElement(fs, 'f')
  445 + f_type.attrib['name'] = 'type'
  446 + f_type.attrib['fVal'] = 'ident'
  447 +
  448 + dominant = get_dominant(sets[set_id])
  449 + f_dominant = etree.SubElement(fs, 'f')
  450 + f_dominant.attrib['name'] = 'dominant'
  451 + f_dominant.attrib['fVal'] = dominant
  452 +
  453 + for mnt in sets[set_id]:
  454 + ptr = etree.SubElement(seg, 'ptr')
  455 + ptr.attrib['target'] = 'ann_mentions.xml#%s' % mnt.id
  456 +
  457 +
  458 +def create_set_comment(mentions):
  459 + mentions_orths = [mnt.text for mnt in mentions]
  460 + return ' %s ' % '; '.join(mentions_orths)
  461 +
  462 +
  463 +def get_dominant(mentions):
  464 + longest_mention = mentions[0]
  465 + for mnt in mentions:
  466 + if len(mnt.words) > len(longest_mention.words):
  467 + longest_mention = mnt
  468 + return longest_mention.text
... ...
corneferencer/main.py
... ... @@ -7,7 +7,7 @@ from natsort import natsorted
7 7 sys.path.append(os.path.abspath(os.path.join('..')))
8 8  
9 9 import conf
10   -from inout import mmax
  10 +from inout import mmax, tei
11 11 from inout.constants import INPUT_FORMATS
12 12 from resolvers import resolve
13 13 from resolvers.constants import RESOLVERS
... ... @@ -26,7 +26,8 @@ def main():
26 26 resolver = args.resolver
27 27 if conf.NEURAL_MODEL_ARCHITECTURE == 'siamese':
28 28 resolver = conf.NEURAL_MODEL_ARCHITECTURE
29   - eprint ("Warning: Using %s resolver because of selected neural model architecture!" % conf.NEURAL_MODEL_ARCHITECTURE)
  29 + eprint("Warning: Using %s resolver because of selected neural model architecture!" %
  30 + conf.NEURAL_MODEL_ARCHITECTURE)
30 31 process_texts(args.input, args.output, args.format, resolver, args.threshold)
31 32  
32 33  
... ... @@ -39,15 +40,16 @@ def parse_arguments():
39 40 dest='output', default='',
40 41 help='output path; if not specified writes output to standard output')
41 42 parser.add_argument('-f', '--format', type=str, action='store',
42   - dest='format', default='mmax',
43   - help='input format; default: mmax')
  43 + dest='format', default=INPUT_FORMATS[0],
  44 + help='input format; default: %s; possibilities: %s'
  45 + % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS)))
44 46 parser.add_argument('-r', '--resolver', type=str, action='store',
45   - dest='resolver', default='incremental',
46   - help='resolve algorithm; default: incremental; possibilities: %s'
47   - % ', '.join(RESOLVERS))
  47 + dest='resolver', default=RESOLVERS[0],
  48 + help='resolve algorithm; default: %s; possibilities: %s'
  49 + % (RESOLVERS[0], ', '.join(RESOLVERS)))
48 50 parser.add_argument('-t', '--threshold', type=float, action='store',
49   - dest='threshold', default=0.001,
50   - help='threshold; default: 0.001')
  51 + dest='threshold', default=0.85,
  52 + help='threshold; default: 0.85')
51 53  
52 54 args = parser.parse_args()
53 55 return args
... ... @@ -57,7 +59,7 @@ def process_texts(inpath, outpath, informat, resolver, threshold):
57 59 if os.path.isdir(inpath):
58 60 process_directory(inpath, outpath, informat, resolver, threshold)
59 61 elif os.path.isfile(inpath):
60   - process_file(inpath, outpath, informat, resolver, threshold)
  62 + process_text(inpath, outpath, informat, resolver, threshold)
61 63 else:
62 64 eprint("Error: Specified input does not exist!")
63 65  
... ... @@ -73,10 +75,10 @@ def process_directory(inpath, outpath, informat, resolver, threshold):
73 75 textname = os.path.splitext(os.path.basename(filename))[0]
74 76 textoutput = os.path.join(outpath, textname)
75 77 textinput = os.path.join(inpath, filename)
76   - process_file(textinput, textoutput, informat, resolver, threshold)
  78 + process_text(textinput, textoutput, informat, resolver, threshold)
77 79  
78 80  
79   -def process_file(inpath, outpath, informat, resolver, threshold):
  81 +def process_text(inpath, outpath, informat, resolver, threshold):
80 82 basename = os.path.basename(inpath)
81 83 if informat == 'mmax' and basename.endswith('.mmax'):
82 84 print (basename)
... ... @@ -92,6 +94,20 @@ def process_file(inpath, outpath, informat, resolver, threshold):
92 94 elif resolver == 'all2all':
93 95 resolve.all2all(text, threshold)
94 96 mmax.write(inpath, outpath, text)
  97 + elif informat == 'tei':
  98 + print (basename)
  99 + text = tei.read(inpath)
  100 + if resolver == 'incremental':
  101 + resolve.incremental(text, threshold)
  102 + elif resolver == 'entity_based':
  103 + resolve.entity_based(text, threshold)
  104 + elif resolver == 'closest':
  105 + resolve.closest(text, threshold)
  106 + elif resolver == 'siamese':
  107 + resolve.siamese(text, threshold)
  108 + elif resolver == 'all2all':
  109 + resolve.all2all(text, threshold)
  110 + tei.write(inpath, outpath, text)
95 111  
96 112  
97 113 if __name__ == '__main__':
... ...
corneferencer/resolvers/constants.py
1 1 # -*- coding: utf-8 -*-
2 2  
3   -RESOLVERS = ['entity_based', 'incremental', 'closest', 'siamese', 'all2all']
  3 +RESOLVERS = ['all2all', 'entity_based', 'incremental', 'closest', 'siamese']
4 4  
5 5 NOUN_TAGS = ['subst', 'ger', 'depr']
6 6 PPRON_TAGS = ['ppron12', 'ppron3']
... ...
corneferencer/resolvers/features.py
... ... @@ -229,7 +229,7 @@ def ante_contains_rarest_from_ana(ante, ana):
229 229 def agreement(ante, ana, tag_name):
230 230 agr_vec = [0] * 3
231 231 if (ante.head is None or ana.head is None or
232   - ante.head[tag_name] == 'unk' or ana.head[tag_name] == 'unk'):
  232 + ante.head[tag_name] == 'unk' or ana.head[tag_name] == 'unk'):
233 233 agr_vec[2] = 1
234 234 elif ante.head[tag_name] == ana.head[tag_name]:
235 235 agr_vec[0] = 1
... ... @@ -279,10 +279,10 @@ def same_paragraph(ante, ana):
279 279 def flat_gender_agreement(ante, ana):
280 280 agr_vec = [0] * 3
281 281 if (ante.head is None or ana.head is None or
282   - ante.head['gender'] == 'unk' or ana.head['gender'] == 'unk'):
  282 + ante.head['gender'] == 'unk' or ana.head['gender'] == 'unk'):
283 283 agr_vec[2] = 1
284 284 elif (ante.head['gender'] == ana.head['gender'] or
285   - (ante.head['gender'] in constants.MASCULINE_TAGS and ana.head['gender'] in constants.MASCULINE_TAGS)):
  285 + (ante.head['gender'] in constants.MASCULINE_TAGS and ana.head['gender'] in constants.MASCULINE_TAGS)):
286 286 agr_vec[0] = 1
287 287 else:
288 288 agr_vec[1] = 1
... ... @@ -314,13 +314,13 @@ def abbrev2(ante, ana):
314 314 def string_kernel(ante, ana):
315 315 s1 = ante.text
316 316 s2 = ana.text
317   - return SK(s1, s2) / (math.sqrt(SK(s1, s1) * SK(s2, s2)))
  317 + return sk(s1, s2) / (math.sqrt(sk(s1, s1) * sk(s2, s2)))
318 318  
319 319  
320 320 def head_string_kernel(ante, ana):
321 321 s1 = ante.head_orth
322 322 s2 = ana.head_orth
323   - return SK(s1, s2) / (math.sqrt(SK(s1, s1) * SK(s2, s2)))
  323 + return sk(s1, s2) / (math.sqrt(sk(s1, s1) * sk(s2, s2)))
324 324  
325 325  
326 326 def wordnet_synonyms(ante, ana):
... ... @@ -443,22 +443,22 @@ def samesent_anapron_antefirstinpar(ante, ana):
443 443  
444 444 def samesent_antefirstinpar_personnumbermatch(ante, ana):
445 445 if (same_sentence(ante, ana) and ante.first_in_paragraph
446   - and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]):
  446 + and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]):
447 447 return 1
448 448 return 0
449 449  
450 450  
451 451 def adjsent_anapron_adjmen_personnumbermatch(ante, ana):
452 452 if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana)
453   - and ana.position_in_mentions - ante.position_in_mentions == 1
454   - and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]):
  453 + and ana.position_in_mentions - ante.position_in_mentions == 1
  454 + and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]):
455 455 return 1
456 456 return 0
457 457  
458 458  
459 459 def adjsent_anapron_adjmen(ante, ana):
460 460 if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana)
461   - and ana.position_in_mentions - ante.position_in_mentions == 1):
  461 + and ana.position_in_mentions - ante.position_in_mentions == 1):
462 462 return 1
463 463 return 0
464 464  
... ... @@ -535,16 +535,16 @@ def get_abbrev(mention):
535 535 return abbrev
536 536  
537 537  
538   -def SK(s1, s2):
539   - LAMBDA = 0.4
  538 +def sk(s1, s2):
  539 + lam = 0.4
540 540  
541 541 p = len(s1)
542 542 if len(s2) < len(s1):
543 543 p = len(s2)
544 544  
545 545 h, w = len(s1)+1, len(s2)+1
546   - DPS = [[0.0] * w for i in range(h)]
547   - DP = [[0.0] * w for i in range(h)]
  546 + dps = [[0.0] * w for i in range(h)]
  547 + dp = [[0.0] * w for i in range(h)]
548 548  
549 549 kernel_mat = [0.0] * (len(s1) + 1)
550 550  
... ... @@ -555,35 +555,35 @@ def SK(s1, s2):
555 555 if j == 0:
556 556 continue
557 557 if s1[i-1] == s2[j-1]:
558   - DPS[i][j] = LAMBDA * LAMBDA
559   - kernel_mat[0] += DPS[i][j]
  558 + dps[i][j] = lam * lam
  559 + kernel_mat[0] += dps[i][j]
560 560 else:
561   - DPS[i][j] = 0.0
  561 + dps[i][j] = 0.0
562 562  
563   - for l in range(p):
564   - if l == 0:
  563 + for m in range(p):
  564 + if m == 0:
565 565 continue
566 566  
567   - kernel_mat[l] = 0.0
  567 + kernel_mat[m] = 0.0
568 568 for j in range(len(s2)+1):
569   - DP[l-1][j] = 0.0
  569 + dp[m-1][j] = 0.0
570 570  
571 571 for i in range(len(s1)+1):
572   - DP[i][l-1] = 0.0
  572 + dp[i][m-1] = 0.0
573 573  
574 574 for i in range(len(s1)+1):
575   - if i < l:
  575 + if i < m:
576 576 continue
577 577 for j in range(len(s2)+1):
578   - if j < l:
  578 + if j < m:
579 579 continue
580   - DP[i][j] = DPS[i][j] + LAMBDA * DP[i - 1][j] + LAMBDA * DP[i][j - 1] - LAMBDA * LAMBDA * DP[i - 1][j - 1]
  580 + dp[i][j] = dps[i][j] + lam * dp[i - 1][j] + lam * dp[i][j - 1] - lam * lam * dp[i - 1][j - 1]
581 581  
582 582 if s1[i-1] == s2[j-1]:
583   - DPS[i][j] = LAMBDA * LAMBDA * DP[i - 1][j - 1]
584   - kernel_mat[l] += DPS[i][j]
  583 + dps[i][j] = lam * lam * dp[i - 1][j - 1]
  584 + kernel_mat[m] += dps[i][j]
585 585  
586   - K = 0.0
587   - for l in range(p):
588   - K += kernel_mat[l]
589   - return K
  586 + k = 0.0
  587 + for i in range(p):
  588 + k += kernel_mat[i]
  589 + return k
... ...
corneferencer/resolvers/resolve.py
... ... @@ -65,11 +65,10 @@ def incremental(text, threshold):
65 65 def all2all_debug(text, threshold):
66 66 last_set_id = 0
67 67 for pos1, mnt1 in enumerate(text.mentions):
68   - print ('!!!!!!!!!!%s!!!!!!!!!!!' % mnt1.text)
69 68 best_prediction = 0.0
70 69 best_link = None
71 70 for pos2, mnt2 in enumerate(text.mentions):
72   - if ((mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
  71 + if (mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2):
73 72 ante = mnt1
74 73 ana = mnt2
75 74 if pos2 < pos1:
... ... @@ -78,12 +77,10 @@ def all2all_debug(text, threshold):
78 77 pair_vec = get_pair_vector(ante, ana)
79 78 sample = numpy.asarray([pair_vec], dtype=numpy.float32)
80 79 prediction = NEURAL_MODEL.predict(sample)[0]
81   - print (u'mnt2: %s | %s == %s >> %f' % (mnt2.text, ante.text, ana.text, prediction))
82 80 if prediction > threshold and prediction > best_prediction:
83 81 best_prediction = prediction
84 82 best_link = mnt2
85 83 if best_link is not None:
86   - print (u'best: %s >> %f, best set: %s, mnt1_set: %s' % (best_link.text, best_prediction, best_link.set, mnt1.set))
87 84 if best_link.set and not mnt1.set:
88 85 mnt1.set = best_link.set
89 86 elif best_link.set and mnt1.set:
... ... @@ -93,7 +90,6 @@ def all2all_debug(text, threshold):
93 90 best_link.set = str_set_id
94 91 mnt1.set = str_set_id
95 92 last_set_id += 1
96   - print (u'best set: %s, mnt1_set: %s' % (best_link.set, mnt1.set))
97 93  
98 94  
99 95 def all2all_v1(text, threshold):
... ... @@ -103,7 +99,7 @@ def all2all_v1(text, threshold):
103 99 best_link = None
104 100 for pos2, mnt2 in enumerate(text.mentions):
105 101 if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set)
106   - and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
  102 + and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
107 103 ante = mnt1
108 104 ana = mnt2
109 105 if pos2 < pos1:
... ... @@ -137,7 +133,7 @@ def all2all(text, threshold):
137 133 best_link = None
138 134 for pos2, mnt2 in enumerate(text.mentions):
139 135 if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set)
140   - and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
  136 + and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
141 137 ante = mnt1
142 138 ana = mnt2
143 139 if pos2 < pos1:
... ... @@ -166,7 +162,6 @@ def all2all(text, threshold):
166 162 sets[str_set_id] = [best_link, mnt1]
167 163  
168 164  
169   -
170 165 # entity based resolve algorithm
171 166 def entity_based(text, threshold):
172 167 sets = []
... ...
corneferencer/resolvers/vectors.py
... ... @@ -24,10 +24,10 @@ def get_mention_features(mention):
24 24 vec.extend(features.mention_vec(mention))
25 25 vec.extend(features.sentence_vec(mention))
26 26  
27   - # cechy uzupelniajace
  27 + # complementary features
28 28 vec.extend(features.mention_type(mention))
29 29  
30   - # cechy uzupelniajace 2
  30 + # complementary features 2
31 31 vec.append(features.is_first_second_person(mention))
32 32 vec.append(features.is_demonstrative(mention))
33 33 vec.append(features.is_demonstrative_nominal(mention))
... ... @@ -50,7 +50,7 @@ def get_pair_features(ante, ana):
50 50 vec.append(features.exact_match(ante, ana))
51 51 vec.append(features.base_match(ante, ana))
52 52  
53   - # cechy uzupelniajace
  53 + # complementary features
54 54 vec.append(features.ante_contains_rarest_from_ana(ante, ana))
55 55 vec.extend(features.agreement(ante, ana, 'gender'))
56 56 vec.extend(features.agreement(ante, ana, 'number'))
... ... @@ -59,7 +59,7 @@ def get_pair_features(ante, ana):
59 59 vec.append(features.same_sentence(ante, ana))
60 60 vec.append(features.same_paragraph(ante, ana))
61 61  
62   - # cechy uzupelniajace 2
  62 + # complementary features 2
63 63 vec.append(features.neighbouring_sentence(ante, ana))
64 64 vec.append(features.cousin_sentence(ante, ana))
65 65 vec.append(features.distant_sentence(ante, ana))
... ... @@ -79,7 +79,7 @@ def get_pair_features(ante, ana):
79 79 vec.append(features.wikipedia_mutual_link(ante, ana))
80 80 vec.append(features.wikipedia_redirect(ante, ana))
81 81  
82   - # combined
  82 + # combined features
83 83 vec.append(features.samesent_anapron_antefirstinpar(ante, ana))
84 84 vec.append(features.samesent_antefirstinpar_personnumbermatch(ante, ana))
85 85 vec.append(features.adjsent_anapron_adjmen_personnumbermatch(ante, ana))
... ...
corneferencer/utils.py
... ... @@ -72,7 +72,6 @@ def initialize_siamese_model(number_of_features, path_to_model):
72 72  
73 73  
74 74 def create_base_network(input_dim):
75   - '''Base network to be shared'''
76 75 seq = Sequential()
77 76  
78 77 seq.add(Dense(1000, input_shape=(input_dim,), activation='relu'))
... ... @@ -94,13 +93,10 @@ def euclidean_distance(vects):
94 93  
95 94 def eucl_dist_output_shape(shapes):
96 95 shape1, shape2 = shapes
97   - return (shape1[0], 1)
  96 + return shape1[0], 1
98 97  
99 98  
100 99 def contrastive_loss(y_true, y_pred):
101   - '''Contrastive loss from Hadsell-et-al.'06
102   - http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
103   - '''
104 100 margin = 1
105 101 return K.mean(y_true * K.square(y_pred) + (1 - y_true) * K.square(K.maximum(margin - y_pred, 0)))
106 102  
... ... @@ -125,9 +121,9 @@ def load_one2many_map(map_path):
125 121 jmap_annotations = pobj.__dict__['annotations']
126 122 jmap_annotations_count = len(jmap_annotations)
127 123 for i in range(jmap_annotations_count):
128   - if i%2 == 1:
129   - mapped_elements = set(jmap_annotations[i+1].__dict__['annotations'])
130   - this_map[jmap_annotations[i]] = mapped_elements
  124 + if i % 2 == 1:
  125 + mapped_elements = set(jmap_annotations[i+1].__dict__['annotations'])
  126 + this_map[jmap_annotations[i]] = mapped_elements
131 127 return this_map
132 128  
133 129  
... ... @@ -138,7 +134,7 @@ def load_one2one_map(map_path):
138 134 jmap_annotations = pobj.__dict__['annotations']
139 135 jmap_annotations_count = len(jmap_annotations)
140 136 for i in range(jmap_annotations_count):
141   - if i%2 == 1:
142   - element = jmap_annotations[i+1]
143   - this_map[jmap_annotations[i]] = element
  137 + if i % 2 == 1:
  138 + element = jmap_annotations[i+1]
  139 + this_map[jmap_annotations[i]] = element
144 140 return this_map
... ...
requirements.txt
1 1 lxml
2 2 natsort
3 3 gensim
  4 +keras
  5 +tensorflow
4 6 numpy
  7 +javaobj-py3
... ...