Commit 4a097c4f123cc34ee07d870361bad7a817a3faeb

Authored by Bartłomiej Nitoń
1 parent a5beabe2

Corneferencer alpha version

conf.py 0 → 100644
  1 +import os
  2 +
  3 +from gensim.models.word2vec import Word2Vec
  4 +
  5 +from corneferencer.utils import initialize_neural_model
  6 +
  7 +
  8 +CONTEXT = 5
  9 +THRESHOLD = 0.5
  10 +RANDOM_WORD_VECTORS = True
  11 +W2V_SIZE = 50
  12 +W2V_MODEL_NAME = 'w2v_allwiki_nkjpfull_50.model'
  13 +
  14 +NUMBER_OF_FEATURES = 1126
  15 +NEURAL_MODEL_NAME = 'weights_2017_05_10.h5'
  16 +
  17 +
  18 +# do not change that
  19 +W2V_MODEL_PATH = os.path.join(os.path.dirname(__file__), 'models', W2V_MODEL_NAME)
  20 +W2V_MODEL = Word2Vec.load(W2V_MODEL_PATH)
  21 +
  22 +NEURAL_MODEL_PATH = os.path.join(os.path.dirname(__file__), 'models', NEURAL_MODEL_NAME)
  23 +NEURAL_MODEL = initialize_neural_model(NUMBER_OF_FEATURES)
... ...
corneferencer/entities.py 0 → 100644
  1 +from corneferencer.resolvers.vectors import get_mention_features
  2 +
  3 +
  4 +class Text:
  5 +
  6 + def __init__(self, text_id):
  7 + self.__id = text_id
  8 + self.mentions = []
  9 +
  10 + def get_mention_set(self, mnt_id):
  11 + for mnt in self.mentions:
  12 + if mnt.id == mnt_id:
  13 + return mnt.set
  14 + return None
  15 +
  16 +
  17 +class Mention:
  18 +
  19 + def __init__(self, mnt_id, text, lemmatized_text, words, span,
  20 + head_orth, head_base, dominant, node, prec_context,
  21 + follow_context, sentence, position_in_mentions,
  22 + start_in_words, end_in_words):
  23 + self.id = mnt_id
  24 + self.set = ''
  25 + self.text = text
  26 + self.lemmatized_text = lemmatized_text
  27 + self.words = words
  28 + self.span = span
  29 + self.head_orth = head_orth
  30 + self.head_base = head_base
  31 + self.dominant = dominant
  32 + self.node = node
  33 + self.prec_context = prec_context
  34 + self.follow_context = follow_context
  35 + self.sentence = sentence
  36 + self.position_in_mentions = position_in_mentions
  37 + self.start_in_words = start_in_words
  38 + self.end_in_words = end_in_words
  39 + self.features = get_mention_features(self)
... ...
corneferencer/core.py renamed to corneferencer/inout/__init__.py
corneferencer/inout/constants.py 0 → 100644
  1 +INPUT_FORMATS = ['mmax']
... ...
corneferencer/inout/mmax.py 0 → 100644
  1 +import os
  2 +import shutil
  3 +
  4 +from lxml import etree
  5 +
  6 +from conf import CONTEXT
  7 +from corneferencer.entities import Mention, Text
  8 +
  9 +
  10 +def read(inpath):
  11 + textname = os.path.splitext(os.path.basename(inpath))[0]
  12 + textdir = os.path.dirname(inpath)
  13 +
  14 + mentions_path = os.path.join(textdir, '%s_mentions.xml' % textname)
  15 + words_path = os.path.join(textdir, '%s_words.xml' % textname)
  16 +
  17 + text = Text(textname)
  18 + mentions = read_mentions(mentions_path, words_path)
  19 + text.mentions = mentions
  20 + return text
  21 +
  22 +
  23 +def read_mentions(mentions_path, words_path):
  24 + mentions = []
  25 + mentions_tree = etree.parse(mentions_path)
  26 + markables = mentions_tree.xpath("//ns:markable",
  27 + namespaces={'ns': 'www.eml.org/NameSpaces/mention'})
  28 + words = get_words(words_path)
  29 +
  30 + for idx, markable in enumerate(markables):
  31 + span = markable.attrib['span']
  32 +
  33 + dominant = ''
  34 + if 'dominant' in markable.attrib:
  35 + dominant = markable.attrib['dominant']
  36 +
  37 + head_orth = markable.attrib['mention_head']
  38 + mention_words = span_to_words(span, words)
  39 +
  40 + (prec_context, follow_context, sentence,
  41 + mnt_start_position, mnt_end_position) = get_context(mention_words, words)
  42 +
  43 + head_base = get_head_base(head_orth, mention_words)
  44 + mention = Mention(mnt_id=markable.attrib['id'],
  45 + text=span_to_text(span, words, 'orth'),
  46 + lemmatized_text=span_to_text(span, words, 'base'),
  47 + words=mention_words,
  48 + span=span,
  49 + head_orth=head_orth,
  50 + head_base=head_base,
  51 + dominant=dominant,
  52 + node=markable,
  53 + prec_context=prec_context,
  54 + follow_context=follow_context,
  55 + sentence=sentence,
  56 + position_in_mentions=idx,
  57 + start_in_words=mnt_start_position,
  58 + end_in_words=mnt_end_position)
  59 + mentions.append(mention)
  60 +
  61 + return mentions
  62 +
  63 +
  64 +def get_words(filepath):
  65 + tree = etree.parse(filepath)
  66 + words = []
  67 + for word in tree.xpath("//word"):
  68 + hasnps = False
  69 + if 'hasnps' in word.attrib and word.attrib['hasnps'] == 'true':
  70 + hasnps = True
  71 + lastinsent = False
  72 + if 'lastinsent' in word.attrib and word.attrib['lastinsent'] == 'true':
  73 + lastinsent = True
  74 + words.append({'id': word.attrib['id'],
  75 + 'orth': word.text,
  76 + 'base': word.attrib['base'],
  77 + 'hasnps': hasnps,
  78 + 'lastinsent': lastinsent,
  79 + 'ctag': word.attrib['ctag']})
  80 + return words
  81 +
  82 +
  83 +def span_to_words(span, words):
  84 + fragments = span.split(',')
  85 + mention_parts = []
  86 + for fragment in fragments:
  87 + mention_parts.extend(fragment_to_words(fragment, words))
  88 + return mention_parts
  89 +
  90 +
  91 +def fragment_to_words(fragment, words):
  92 + mention_parts = []
  93 + if '..' in fragment:
  94 + mention_parts.extend(get_multiword(fragment, words))
  95 + else:
  96 + mention_parts.extend(get_word(fragment, words))
  97 + return mention_parts
  98 +
  99 +
  100 +def get_multiword(fragment, words):
  101 + mention_parts = []
  102 + boundaries = fragment.split('..')
  103 + start_id = boundaries[0]
  104 + end_id = boundaries[1]
  105 + in_string = False
  106 + for word in words:
  107 + if word['id'] == start_id:
  108 + in_string = True
  109 + if in_string and not word_to_ignore(word):
  110 + mention_parts.append(word)
  111 + if word['id'] == end_id:
  112 + break
  113 + return mention_parts
  114 +
  115 +
  116 +def get_word(word_id, words):
  117 + for word in words:
  118 + if word['id'] == word_id:
  119 + if not word_to_ignore(word):
  120 + return [word]
  121 + else:
  122 + return []
  123 + return []
  124 +
  125 +
  126 +def word_to_ignore(word):
  127 + if word['ctag'] == 'interp':
  128 + return True
  129 + return False
  130 +
  131 +
  132 +def get_context(mention_words, words):
  133 + prec_context = []
  134 + follow_context = []
  135 + sentence = []
  136 + mnt_start_position = -1
  137 + mnt_end_position = -1
  138 + first_word = mention_words[0]
  139 + last_word = mention_words[-1]
  140 + for idx, word in enumerate(words):
  141 + if word['id'] == first_word['id']:
  142 + prec_context = get_prec_context(idx, words)
  143 + mnt_start_position = get_mention_start(first_word, words)
  144 + if word['id'] == last_word['id']:
  145 + follow_context = get_follow_context(idx, words)
  146 + sentence = get_sentence(idx, words)
  147 + mnt_end_position = get_mention_end(last_word, words)
  148 + break
  149 + return prec_context, follow_context, sentence, mnt_start_position, mnt_end_position
  150 +
  151 +
  152 +def get_prec_context(mention_start, words):
  153 + context = []
  154 + context_start = mention_start - 1
  155 + while context_start >= 0:
  156 + if not word_to_ignore(words[context_start]):
  157 + context.append(words[context_start])
  158 + if len(context) == CONTEXT:
  159 + break
  160 + context_start -= 1
  161 + context.reverse()
  162 + return context
  163 +
  164 +
  165 +def get_mention_start(first_word, words):
  166 + start = 0
  167 + for word in words:
  168 + if not word_to_ignore(word):
  169 + start += 1
  170 + if word['id'] == first_word['id']:
  171 + break
  172 + return start
  173 +
  174 +
  175 +def get_mention_end(last_word, words):
  176 + end = 0
  177 + for word in words:
  178 + if not word_to_ignore(word):
  179 + end += 1
  180 + if word['id'] == last_word['id']:
  181 + break
  182 + return end
  183 +
  184 +
  185 +def get_follow_context(mention_end, words):
  186 + context = []
  187 + context_end = mention_end + 1
  188 + while context_end < len(words):
  189 + if not word_to_ignore(words[context_end]):
  190 + context.append(words[context_end])
  191 + if len(context) == CONTEXT:
  192 + break
  193 + context_end += 1
  194 + return context
  195 +
  196 +
  197 +def get_sentence(word_idx, words):
  198 + sentence_start = get_sentence_start(words, word_idx)
  199 + sentence_end = get_sentence_end(words, word_idx)
  200 + sentence = [word for word in words[sentence_start:sentence_end + 1] if not word_to_ignore(word)]
  201 + return sentence
  202 +
  203 +
  204 +def get_sentence_start(words, word_idx):
  205 + search_start = word_idx
  206 + while word_idx >= 0:
  207 + if words[word_idx]['lastinsent'] and search_start != word_idx:
  208 + return word_idx + 1
  209 + word_idx -= 1
  210 + return 0
  211 +
  212 +
  213 +def get_sentence_end(words, word_idx):
  214 + while word_idx < len(words):
  215 + if words[word_idx]['lastinsent']:
  216 + return word_idx
  217 + word_idx += 1
  218 + return len(words) - 1
  219 +
  220 +
  221 +def get_head_base(head_orth, words):
  222 + for word in words:
  223 + if word['orth'].lower() == head_orth.lower() or word['orth'] == head_orth:
  224 + return word['base']
  225 + return None
  226 +
  227 +
  228 +def span_to_text(span, words, form):
  229 + fragments = span.split(',')
  230 + mention_parts = []
  231 + for fragment in fragments:
  232 + mention_parts.append(fragment_to_text(fragment, words, form))
  233 + return u' [...] '.join(mention_parts)
  234 +
  235 +
  236 +def fragment_to_text(fragment, words, form):
  237 + if '..' in fragment:
  238 + text = get_multiword_text(fragment, words, form)
  239 + else:
  240 + text = get_one_word_text(fragment, words, form)
  241 + return text
  242 +
  243 +
  244 +def get_multiword_text(fragment, words, form):
  245 + mention_parts = []
  246 + boundaries = fragment.split('..')
  247 + start_id = boundaries[0]
  248 + end_id = boundaries[1]
  249 + in_string = False
  250 + for word in words:
  251 + if word['id'] == start_id:
  252 + in_string = True
  253 + if in_string and not word_to_ignore(word):
  254 + mention_parts.append(word)
  255 + if word['id'] == end_id:
  256 + break
  257 + return to_text(mention_parts, form)
  258 +
  259 +
  260 +def to_text(words, form):
  261 + text = ''
  262 + for idx, word in enumerate(words):
  263 + if word['hasnps'] or idx == 0:
  264 + text += word[form]
  265 + else:
  266 + text += u' %s' % word[form]
  267 + return text
  268 +
  269 +
  270 +def get_one_word_text(word_id, words, form):
  271 + this_word = next(word for word in words if word['id'] == word_id)
  272 + return this_word[form]
  273 +
  274 +
  275 +def write(inpath, outpath, text):
  276 + textname = os.path.splitext(os.path.basename(inpath))[0]
  277 + intextdir = os.path.dirname(inpath)
  278 + outtextdir = os.path.dirname(outpath)
  279 +
  280 + in_mmax_path = os.path.join(intextdir, '%s.mmax' % textname)
  281 + out_mmax_path = os.path.join(outtextdir, '%s.mmax' % textname)
  282 + copy_mmax(in_mmax_path, out_mmax_path)
  283 +
  284 + in_words_path = os.path.join(intextdir, '%s_words.xml' % textname)
  285 + out_words_path = os.path.join(outtextdir, '%s_words.xml' % textname)
  286 + copy_words(in_words_path, out_words_path)
  287 +
  288 + in_mentions_path = os.path.join(intextdir, '%s_mentions.xml' % textname)
  289 + out_mentions_path = os.path.join(outtextdir, '%s_mentions.xml' % textname)
  290 + write_mentions(in_mentions_path, out_mentions_path, text)
  291 +
  292 +
  293 +def copy_mmax(src, dest):
  294 + shutil.copyfile(src, dest)
  295 +
  296 +
  297 +def copy_words(src, dest):
  298 + shutil.copyfile(src, dest)
  299 +
  300 +
  301 +def write_mentions(inpath, outpath, text):
  302 + tree = etree.parse(inpath)
  303 + mentions = tree.xpath("//ns:markable", namespaces={'ns': 'www.eml.org/NameSpaces/mention'})
  304 +
  305 + for mnt in mentions:
  306 + mnt_set = text.get_mention_set(mnt.attrib['id'])
  307 + if mnt_set:
  308 + mnt.attrib['mention_group'] = mnt_set
  309 + else:
  310 + mnt.attrib['mention_group'] = 'empty'
  311 +
  312 + with open(outpath, 'wb') as output_file:
  313 + output_file.write(etree.tostring(tree, pretty_print=True,
  314 + xml_declaration=True, encoding='UTF-8',
  315 + doctype=u'<!DOCTYPE markables SYSTEM "markables.dtd">'))
... ...
corneferencer/main.py 0 → 100644
  1 +import os
  2 +import sys
  3 +
  4 +from argparse import ArgumentParser
  5 +from natsort import natsorted
  6 +
  7 +sys.path.append(os.path.abspath(os.path.join('..')))
  8 +
  9 +from inout import mmax
  10 +from inout.constants import INPUT_FORMATS
  11 +from resolvers import resolve
  12 +from resolvers.constants import RESOLVERS
  13 +from utils import eprint
  14 +
  15 +
  16 +def main():
  17 + args = parse_arguments()
  18 + if not args.input:
  19 + eprint("Error: Input file(s) not specified!")
  20 + elif args.resolver not in RESOLVERS:
  21 + eprint("Error: Unknown resolve algorithm!")
  22 + elif args.format not in INPUT_FORMATS:
  23 + eprint("Error: Unknown input file format!")
  24 + else:
  25 + process_texts(args.input, args.output, args.format, args.resolver)
  26 +
  27 +
  28 +def parse_arguments():
  29 + parser = ArgumentParser(description='Corneferencer: coreference resolver using neural nets.')
  30 + parser.add_argument('-i', '--input', type=str, action='store',
  31 + dest='input', default='',
  32 + help='input file or dir path')
  33 + parser.add_argument('-o', '--output', type=str, action='store',
  34 + dest='output', default='',
  35 + help='output path; if not specified writes output to standard output')
  36 + parser.add_argument('-f', '--format', type=str, action='store',
  37 + dest='format', default='mmax',
  38 + help='input format; default: mmax')
  39 + parser.add_argument('-r', '--resolver', type=str, action='store',
  40 + dest='resolver', default='incremental',
  41 + help='resolve algorithm; default: incremental; possibilities: %s'
  42 + % ', '.join(RESOLVERS))
  43 +
  44 + args = parser.parse_args()
  45 + return args
  46 +
  47 +
  48 +def process_texts(inpath, outpath, informat, resolver):
  49 + if os.path.isdir(inpath):
  50 + process_directory(inpath, outpath, informat, resolver)
  51 + elif os.path.isfile(inpath):
  52 + process_file(inpath, outpath, informat, resolver)
  53 + else:
  54 + eprint("Error: Specified input does not exist!")
  55 +
  56 +
  57 +def process_directory(inpath, outpath, informat, resolver):
  58 + inpath = os.path.abspath(inpath)
  59 + outpath = os.path.abspath(outpath)
  60 +
  61 + files = os.listdir(inpath)
  62 + files = natsorted(files)
  63 +
  64 + for filename in files:
  65 + textname = os.path.splitext(os.path.basename(filename))[0]
  66 + textoutput = os.path.join(outpath, textname)
  67 + textinput = os.path.join(inpath, filename)
  68 + process_file(textinput, textoutput, informat, resolver)
  69 +
  70 +
  71 +def process_file(inpath, outpath, informat, resolver):
  72 + basename = os.path.basename(inpath)
  73 + if informat == 'mmax' and basename.endswith('.mmax'):
  74 + print (basename)
  75 + text = mmax.read(inpath)
  76 + if resolver == 'incremental':
  77 + resolve.incremental(text)
  78 + elif resolver == 'entity_based':
  79 + resolve.entity_based(text)
  80 + mmax.write(inpath, outpath, text)
  81 +
  82 +
  83 +if __name__ == '__main__':
  84 + main()
... ...
corneferencer/readers/__init__.py deleted
corneferencer/entities/__init__.py renamed to corneferencer/resolvers/__init__.py
corneferencer/resolvers/constants.py 0 → 100644
  1 +RESOLVERS = ['entity_based', 'incremental']
... ...
corneferencer/resolvers/features.py 0 → 100644
  1 +import numpy
  2 +import random
  3 +
  4 +from conf import RANDOM_WORD_VECTORS, W2V_MODEL, W2V_SIZE
  5 +
  6 +
  7 +# mention features
  8 +def head_vec(mention):
  9 + return list(get_wv(W2V_MODEL, mention.head_base))
  10 +
  11 +
  12 +def first_word_vec(mention):
  13 + return list(get_wv(W2V_MODEL, mention.words[0]['base']))
  14 +
  15 +
  16 +def last_word_vec(mention):
  17 + return list(get_wv(W2V_MODEL, mention.words[-1]['base']))
  18 +
  19 +
  20 +def first_after_vec(mention):
  21 + if len(mention.follow_context) > 0:
  22 + vec = list(get_wv(W2V_MODEL, mention.follow_context[0]['base']))
  23 + else:
  24 + vec = [0.0] * W2V_SIZE
  25 + return vec
  26 +
  27 +
  28 +def second_after_vec(mention):
  29 + if len(mention.follow_context) > 1:
  30 + vec = list(get_wv(W2V_MODEL, mention.follow_context[1]['base']))
  31 + else:
  32 + vec = [0.0] * W2V_SIZE
  33 + return vec
  34 +
  35 +
  36 +def first_before_vec(mention):
  37 + if len(mention.prec_context) > 0:
  38 + vec = list(get_wv(W2V_MODEL, mention.prec_context[-1]['base']))
  39 + else:
  40 + vec = [0.0] * W2V_SIZE
  41 + return vec
  42 +
  43 +
  44 +def second_before_vec(mention):
  45 + if len(mention.prec_context) > 1:
  46 + vec = list(get_wv(W2V_MODEL, mention.prec_context[-2]['base']))
  47 + else:
  48 + vec = [0.0] * W2V_SIZE
  49 + return vec
  50 +
  51 +
  52 +def preceding_context_vec(mention):
  53 + return list(get_context_vec(mention.prec_context, W2V_MODEL))
  54 +
  55 +
  56 +def following_context_vec(mention):
  57 + return list(get_context_vec(mention.follow_context, W2V_MODEL))
  58 +
  59 +
  60 +def mention_vec(mention):
  61 + return list(get_context_vec(mention.words, W2V_MODEL))
  62 +
  63 +
  64 +def sentence_vec(mention):
  65 + return list(get_context_vec(mention.sentence, W2V_MODEL))
  66 +
  67 +
  68 +# pair features
  69 +def distances_vec(ante, ana):
  70 + vec = []
  71 +
  72 + mnts_intersect = pair_intersect(ante, ana)
  73 +
  74 + words_dist = [0] * 11
  75 + words_bucket = 0
  76 + if mnts_intersect != 1:
  77 + words_bucket = get_distance_bucket(ana.start_in_words - ante.end_in_words - 1)
  78 + words_dist[words_bucket] = 1
  79 + vec.extend(words_dist)
  80 +
  81 + mentions_dist = [0] * 11
  82 + mentions_bucket = 0
  83 + if mnts_intersect != 1:
  84 + mentions_bucket = get_distance_bucket(ana.position_in_mentions - ante.position_in_mentions - 1)
  85 + if words_bucket == 10:
  86 + mentions_bucket = 10
  87 + mentions_dist[mentions_bucket] = 1
  88 + vec.extend(mentions_dist)
  89 +
  90 + vec.append(mnts_intersect)
  91 +
  92 + return vec
  93 +
  94 +
  95 +def pair_intersect(ante, ana):
  96 + for ante_word in ante.words:
  97 + for ana_word in ana.words:
  98 + if ana_word['id'] == ante_word['id']:
  99 + return 1
  100 + return 0
  101 +
  102 +
  103 +def head_match(ante, ana):
  104 + if ante.head_orth.lower() == ana.head_orth.lower():
  105 + return 1
  106 + return 0
  107 +
  108 +
  109 +def exact_match(ante, ana):
  110 + if ante.text.lower() == ana.text.lower():
  111 + return 1
  112 + return 0
  113 +
  114 +
  115 +def base_match(ante, ana):
  116 + if ante.lemmatized_text.lower() == ana.lemmatized_text.lower():
  117 + return 1
  118 + return 0
  119 +
  120 +
  121 +# supporting functions
  122 +def get_wv(model, lemma, use_random_vec=True):
  123 + vec = None
  124 + if use_random_vec:
  125 + vec = random_vec()
  126 + try:
  127 + vec = model.wv[lemma]
  128 + except KeyError:
  129 + pass
  130 + except TypeError:
  131 + pass
  132 + return vec
  133 +
  134 +
  135 +def random_vec():
  136 + return numpy.asarray([random.uniform(-0.25, 0.25) for i in range(0, W2V_SIZE)], dtype=numpy.float32)
  137 +
  138 +
  139 +def get_context_vec(words, model):
  140 + vec = numpy.zeros(W2V_SIZE, dtype=numpy.float32)
  141 + unknown_count = 0
  142 + if len(words) != 0:
  143 + for word in words:
  144 + word_vec = get_wv(model, word['base'], RANDOM_WORD_VECTORS)
  145 + if word_vec is None:
  146 + unknown_count += 1
  147 + else:
  148 + vec += word_vec
  149 + significant_words = len(words) - unknown_count
  150 + if significant_words != 0:
  151 + vec = vec / float(significant_words)
  152 + else:
  153 + vec = random_vec()
  154 + return vec
  155 +
  156 +
  157 +def get_distance_bucket(distance):
  158 + if 0 <= distance <= 4:
  159 + return distance
  160 + elif 5 <= distance <= 7:
  161 + return 5
  162 + elif 8 <= distance <= 15:
  163 + return 6
  164 + elif 16 <= distance <= 31:
  165 + return 7
  166 + elif 32 <= distance <= 63:
  167 + return 8
  168 + elif distance >= 64:
  169 + return 9
  170 + return 10
... ...
corneferencer/resolvers/resolve.py 0 → 100644
  1 +from conf import NEURAL_MODEL, THRESHOLD
  2 +from corneferencer.resolvers.vectors import create_pair_vector
  3 +
  4 +
  5 +# incremental resolve algorithm
  6 +def incremental(text):
  7 + last_set_id = 1
  8 + for i, ana in enumerate(text.mentions):
  9 + if i > 0:
  10 + best_prediction = 0.0
  11 + best_ante = None
  12 + for ante in text.mentions[:i:-1]:
  13 + pair_vec = create_pair_vector(ante, ana)
  14 + prediction = NEURAL_MODEL.predict(pair_vec)
  15 + accuracy = prediction[0]
  16 + if accuracy > THRESHOLD and accuracy > best_prediction:
  17 + best_prediction = accuracy
  18 + best_ante = ante
  19 + if best_ante is not None:
  20 + if best_ante.set:
  21 + ana.set = best_ante.set
  22 + else:
  23 + str_set_id = 'set_%d' % last_set_id
  24 + best_ante.set = str_set_id
  25 + ana.set = str_set_id
  26 + last_set_id += 1
  27 +
  28 +
  29 +# entity based resolve algorithm
  30 +def entity_based(text):
  31 + sets = []
  32 + last_set_id = 1
  33 + for i, ana in enumerate(text.mentions):
  34 + if i > 0:
  35 + best_fit = get_best_set(sets, ana)
  36 + if best_fit is not None:
  37 + ana.set = best_fit['set_id']
  38 + best_fit['mentions'].append(ana)
  39 + else:
  40 + str_set_id = 'set_%d' % last_set_id
  41 + sets.append({'set_id': str_set_id,
  42 + 'mentions': [ana]})
  43 + ana.set = str_set_id
  44 + last_set_id += 1
  45 + else:
  46 + str_set_id = 'set_%d' % last_set_id
  47 + sets.append({'set_id': str_set_id,
  48 + 'mentions': [ana]})
  49 + ana.set = str_set_id
  50 + last_set_id += 1
  51 +
  52 + remove_singletons(sets)
  53 +
  54 +
  55 +def get_best_set(sets, ana):
  56 + best_prediction = 0.0
  57 + best_set = None
  58 + for s in sets:
  59 + accuracy = predict_set(s['mentions'], ana)
  60 + if accuracy > THRESHOLD and accuracy >= best_prediction:
  61 + best_prediction = accuracy
  62 + best_set = s
  63 + return best_set
  64 +
  65 +
  66 +def predict_set(mentions, ana):
  67 + accuracy_sum = 0.0
  68 + for mnt in mentions:
  69 + pair_vec = create_pair_vector(mnt, ana)
  70 + prediction = NEURAL_MODEL.predict(pair_vec)
  71 + accuracy = prediction[0]
  72 + accuracy_sum += accuracy
  73 + return accuracy_sum / float(len(mentions))
  74 +
  75 +
  76 +def remove_singletons(sets):
  77 + for s in sets:
  78 + if len(s['mentions']) == 1:
  79 + s['mentions'][0].set = ''
... ...
corneferencer/resolvers/vectors.py 0 → 100644
  1 +import numpy
  2 +
  3 +from corneferencer.resolvers import features
  4 +
  5 +# input_1 to have shape (None, 1126) but got array with shape (1126, 1)
  6 +def create_pair_vector(ante, ana):
  7 + vec = []
  8 + # ante_features = get_mention_features(ante)
  9 + # vec.extend(ante_features)
  10 + # ana_features = get_mention_features(ana)
  11 + # vec.extend(ana_features)
  12 + vec.extend(ante.features)
  13 + vec.extend(ana.features)
  14 + pair_features = get_pair_features(ante, ana)
  15 + vec.extend(pair_features)
  16 + return numpy.asarray([vec], dtype=numpy.float32)
  17 +
  18 +
  19 +def get_mention_features(mention):
  20 + vec = []
  21 + vec.extend(features.head_vec(mention))
  22 + vec.extend(features.first_word_vec(mention))
  23 + vec.extend(features.last_word_vec(mention))
  24 + vec.extend(features.first_after_vec(mention))
  25 + vec.extend(features.second_after_vec(mention))
  26 + vec.extend(features.first_before_vec(mention))
  27 + vec.extend(features.second_before_vec(mention))
  28 + vec.extend(features.preceding_context_vec(mention))
  29 + vec.extend(features.following_context_vec(mention))
  30 + vec.extend(features.mention_vec(mention))
  31 + vec.extend(features.sentence_vec(mention))
  32 + return vec
  33 +
  34 +
  35 +def get_pair_features(ante, ana):
  36 + vec = []
  37 + vec.extend(features.distances_vec(ante, ana))
  38 + vec.append(features.head_match(ante, ana))
  39 + vec.append(features.exact_match(ante, ana))
  40 + vec.append(features.base_match(ante, ana))
  41 + return vec
... ...
corneferencer/utils.py 0 → 100644
  1 +from __future__ import print_function
  2 +
  3 +import sys
  4 +
  5 +from keras.models import Model
  6 +from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization
  7 +
  8 +
  9 +def eprint(*args, **kwargs):
  10 + print(*args, file=sys.stderr, **kwargs)
  11 +
  12 +
  13 +def initialize_neural_model(number_of_features):
  14 + inputs = Input(shape=(number_of_features,))
  15 + output_from_1st_layer = Dense(1000, activation='relu')(inputs)
  16 + output_from_1st_layer = Dropout(0.5)(output_from_1st_layer)
  17 + output_from_1st_layer = BatchNormalization()(output_from_1st_layer)
  18 + output_from_2nd_layer = Dense(500, activation='relu')(output_from_1st_layer)
  19 + output_from_2nd_layer = Dropout(0.5)(output_from_2nd_layer)
  20 + output_from_2nd_layer = BatchNormalization()(output_from_2nd_layer)
  21 + output = Dense(1, activation='sigmoid')(output_from_2nd_layer)
  22 +
  23 + model = Model(inputs, output)
  24 + model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
  25 + return model
... ...
requirements.txt
  1 +lxml
  2 +natsort
  3 +gensim
  4 +numpy
... ...
setup.py deleted