From c29c36cde4fa9ea17a625fa2b0e8ea5c1192ec25 Mon Sep 17 00:00:00 2001
From: bniton <bartek.niton@gmail.com>
Date: Fri, 30 Nov 2018 00:23:46 +0100
Subject: [PATCH] Add prepare data script and other minor improvements.

---
 conf.py                             |   1 -
 corneferencer/entities.py           |   7 +++++--
 corneferencer/inout/mmax.py         |  19 +++++++++----------
 corneferencer/inout/tei.py          |  17 ++++++++++-------
 corneferencer/main.py               |  52 +++++++++++++++++++++++++++++-----------------------
 corneferencer/prepare_data.py       | 140 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 corneferencer/resolvers/features.py | 210 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------------------------------------------------------------------------------------------
 corneferencer/resolvers/resolve.py  | 113 +++++++++++++++++++++++------------------------------------------------------------------------------------------
 corneferencer/utils.py              |   1 -
 9 files changed, 321 insertions(+), 239 deletions(-)
 create mode 100644 corneferencer/prepare_data.py

diff --git a/conf.py b/conf.py
index 648e472..b4c42e8 100644
--- a/conf.py
+++ b/conf.py
@@ -30,7 +30,6 @@ W2V_MODEL_PATH = os.path.join(MAIN_PATH, 'models', W2V_MODEL_NAME)
 W2V_MODEL = Word2Vec.load(W2V_MODEL_PATH)
 
 NEURAL_MODEL_PATH = os.path.join(MAIN_PATH, 'models', NEURAL_MODEL_NAME)
-NEURAL_MODEL = utils.initialize_neural_model(NEURAL_MODEL_ARCHITECTURE, NUMBER_OF_FEATURES, NEURAL_MODEL_PATH)
 
 FREQ_LIST_PATH = os.path.join(MAIN_PATH, 'freq', FREQ_LIST_NAME)
 FREQ_LIST = utils.load_freq_list(FREQ_LIST_PATH)
diff --git a/corneferencer/entities.py b/corneferencer/entities.py
index c1e1509..cc52b52 100644
--- a/corneferencer/entities.py
+++ b/corneferencer/entities.py
@@ -1,4 +1,4 @@
-from corneferencer.resolvers.vectors import get_mention_features
+from corneferencer.resolvers import vectors
 
 
 class Text:
@@ -19,6 +19,9 @@ class Text:
                 return mnt
         return None
 
+    def get_mentions(self):
+        return self.mentions
+
     def get_sets(self):
         sets = {}
         for mnt in self.mentions:
@@ -62,4 +65,4 @@ class Mention:
         self.sentence_id = sentence_id
         self.first_in_sentence = first_in_sentence
         self.first_in_paragraph = first_in_paragraph
-        self.features = get_mention_features(self)
+        self.features = vectors.get_mention_features(self)
diff --git a/corneferencer/inout/mmax.py b/corneferencer/inout/mmax.py
index e30f551..62feb49 100644
--- a/corneferencer/inout/mmax.py
+++ b/corneferencer/inout/mmax.py
@@ -3,11 +3,11 @@ import shutil
 
 from lxml import etree
 
-from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST
+import conf
 from corneferencer.entities import Mention, Text
 
 
-def read(inpath):
+def read(inpath, clear_mentions=conf.CLEAR_INPUT):
     textname = os.path.splitext(os.path.basename(inpath))[0]
     textdir = os.path.dirname(inpath)
 
@@ -15,11 +15,11 @@ def read(inpath):
     words_path = os.path.join(textdir, '%s_words.xml' % textname)
 
     text = Text(textname)
-    text.mentions = read_mentions(mentions_path, words_path)
+    text.mentions = read_mentions(mentions_path, words_path, clear_mentions)
     return text
 
 
-def read_mentions(mentions_path, words_path):
+def read_mentions(mentions_path, words_path, clear_mentions=conf.CLEAR_INPUT):
     mentions = []
     mentions_tree = etree.parse(mentions_path)
     markables = mentions_tree.xpath("//ns:markable",
@@ -43,7 +43,7 @@ def read_mentions(mentions_path, words_path):
 
         head = get_head(head_orth, mention_words)
         mention_group = ''
-        if markable.attrib['mention_group'] != 'empty' and not CLEAR_INPUT:
+        if markable.attrib['mention_group'] != 'empty' and not clear_mentions:
             mention_group = markable.attrib['mention_group']
         mention = Mention(mnt_id=markable.attrib['id'],
                           text=span_to_text(span, words, 'orth'),
@@ -189,7 +189,7 @@ def get_prec_context(mention_start, words):
     while context_start >= 0:
         if not word_to_ignore(words[context_start]):
             context.append(words[context_start])
-        if len(context) == CONTEXT:
+        if len(context) == conf.CONTEXT:
             break
         context_start -= 1
     context.reverse()
@@ -222,7 +222,7 @@ def get_follow_context(mention_end, words):
     while context_end < len(words):
         if not word_to_ignore(words[context_end]):
             context.append(words[context_end])
-        if len(context) == CONTEXT:
+        if len(context) == conf.CONTEXT:
             break
         context_end += 1
     return context
@@ -349,9 +349,8 @@ def get_rarest_word(words):
     rarest_word = words[0]
     for i, word in enumerate(words):
         word_freq = 0
-        if word['base'] in FREQ_LIST:
-            word_freq = FREQ_LIST[word['base']]
-
+        if word['base'] in conf.FREQ_LIST:
+            word_freq = conf.FREQ_LIST[word['base']]
         if i == 0 or word_freq < min_freq:
             min_freq = word_freq
             rarest_word = word
diff --git a/corneferencer/inout/tei.py b/corneferencer/inout/tei.py
index 81d4afc..b166719 100644
--- a/corneferencer/inout/tei.py
+++ b/corneferencer/inout/tei.py
@@ -4,7 +4,7 @@ import shutil
 
 from lxml import etree
 
-from conf import CLEAR_INPUT, CONTEXT, FREQ_LIST
+import conf
 from corneferencer.entities import Mention, Text
 from corneferencer.utils import eprint
 
@@ -18,7 +18,7 @@ NSMAP = {None: TEI_NS,
          'xi': XI_NS}
 
 
-def read(inpath):
+def read(inpath, clear_mentions=conf.CLEAR_INPUT):
     textname = os.path.basename(inpath)
 
     text = Text(textname)
@@ -49,7 +49,7 @@ def read(inpath):
         eprint("Error: missing mentions layer for text %s!" % textname)
         return None
 
-    if os.path.exists(ann_coreference) and not CLEAR_INPUT:
+    if os.path.exists(ann_coreference) and not clear_mentions:
         add_coreference_layer(ann_coreference, text)
 
     return text
@@ -215,6 +215,9 @@ def get_mention(mention, mnt_id, segments, segments_ids, paragraph_id, sentence_
             semh_id = get_fval(f).split('#')[-1]
             semh = segments[semh_id]
 
+    if len(mnt_segments) == 0:
+        mnt_segments.append(semh)
+
     (sent_segments, prec_context, follow_context,
      first_in_sentence, first_in_paragraph) = get_context(mnt_segments, segments, segments_ids)
 
@@ -272,7 +275,7 @@ def get_prec_context(mention_start, segments, segments_ids):
     while context_start >= 0:
         if not word_to_ignore(segments[segments_ids[context_start]]):
             context.append(segments[segments_ids[context_start]])
-        if len(context) == CONTEXT:
+        if len(context) == conf.CONTEXT:
             break
         context_start -= 1
     context.reverse()
@@ -285,7 +288,7 @@ def get_follow_context(mention_end, segments, segments_ids):
     while context_end < len(segments):
         if not word_to_ignore(segments[segments_ids[context_end]]):
             context.append(segments[segments_ids[context_end]])
-        if len(context) == CONTEXT:
+        if len(context) == conf.CONTEXT:
             break
         context_end += 1
     return context
@@ -341,8 +344,8 @@ def get_rarest_word(words):
     rarest_word = words[0]
     for i, word in enumerate(words):
         word_freq = 0
-        if word['base'] in FREQ_LIST:
-            word_freq = FREQ_LIST[word['base']]
+        if word['base'] in conf.FREQ_LIST:
+            word_freq = conf.FREQ_LIST[word['base']]
 
         if i == 0 or word_freq < min_freq:
             min_freq = word_freq
diff --git a/corneferencer/main.py b/corneferencer/main.py
index f9aebc7..9f85c5d 100644
--- a/corneferencer/main.py
+++ b/corneferencer/main.py
@@ -4,9 +4,11 @@ import sys
 from argparse import ArgumentParser
 from natsort import natsorted
 
-sys.path.append(os.path.abspath(os.path.join('..')))
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
 
 import conf
+import utils
 from inout import mmax, tei
 from inout.constants import INPUT_FORMATS
 from resolvers import resolve
@@ -27,22 +29,25 @@ def main():
         if conf.NEURAL_MODEL_ARCHITECTURE == 'siamese':
             resolver = conf.NEURAL_MODEL_ARCHITECTURE
             eprint("Warning: Using %s resolver because of selected neural model architecture!" %
-                conf.NEURAL_MODEL_ARCHITECTURE)
-        process_texts(args.input, args.output, args.format, resolver, args.threshold)
+                   conf.NEURAL_MODEL_ARCHITECTURE)
+        process_texts(args.input, args.output, args.format, resolver, args.threshold, args.model)
 
 
 def parse_arguments():
     parser = ArgumentParser(description='Corneferencer: coreference resolver using neural nets.')
+    parser.add_argument('-f', '--format', type=str, action='store',
+                        dest='format', default=INPUT_FORMATS[0],
+                        help='input format; default: %s; possibilities: %s'
+                             % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS)))
     parser.add_argument('-i', '--input', type=str, action='store',
                         dest='input', default='',
                         help='input file or dir path')
+    parser.add_argument('-m', '--model', type=str, action='store',
+                        dest='model', default='',
+                        help='neural model path; default: %s' % conf.NEURAL_MODEL_PATH)
     parser.add_argument('-o', '--output', type=str, action='store',
                         dest='output', default='',
                         help='output path; if not specified writes output to standard output')
-    parser.add_argument('-f', '--format', type=str, action='store',
-                        dest='format', default=INPUT_FORMATS[0],
-                        help='input format; default: %s; possibilities: %s'
-                             % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS)))
     parser.add_argument('-r', '--resolver', type=str, action='store',
                         dest='resolver', default=RESOLVERS[0],
                         help='resolve algorithm; default: %s; possibilities: %s'
@@ -55,16 +60,17 @@ def parse_arguments():
     return args
 
 
-def process_texts(inpath, outpath, informat, resolver, threshold):
+def process_texts(inpath, outpath, informat, resolver, threshold, model_path):
+    model = utils.initialize_neural_model(conf.NEURAL_MODEL_ARCHITECTURE, conf.NUMBER_OF_FEATURES, model_path)
     if os.path.isdir(inpath):
-        process_directory(inpath, outpath, informat, resolver, threshold)
+        process_directory(inpath, outpath, informat, resolver, threshold, model)
     elif os.path.isfile(inpath):
-        process_text(inpath, outpath, informat, resolver, threshold)
+        process_text(inpath, outpath, informat, resolver, threshold, model)
     else:
         eprint("Error: Specified input does not exist!")
 
 
-def process_directory(inpath, outpath, informat, resolver, threshold):
+def process_directory(inpath, outpath, informat, resolver, threshold, model):
     inpath = os.path.abspath(inpath)
     outpath = os.path.abspath(outpath)
 
@@ -75,38 +81,38 @@ def process_directory(inpath, outpath, informat, resolver, threshold):
         textname = os.path.splitext(os.path.basename(filename))[0]
         textoutput = os.path.join(outpath, textname)
         textinput = os.path.join(inpath, filename)
-        process_text(textinput, textoutput, informat, resolver, threshold)
+        process_text(textinput, textoutput, informat, resolver, threshold, model)
 
 
-def process_text(inpath, outpath, informat, resolver, threshold):
+def process_text(inpath, outpath, informat, resolver, threshold, model):
     basename = os.path.basename(inpath)
     if informat == 'mmax' and basename.endswith('.mmax'):
         print (basename)
         text = mmax.read(inpath)
         if resolver == 'incremental':
-            resolve.incremental(text, threshold)
+            resolve.incremental(text, threshold, model)
         elif resolver == 'entity_based':
-            resolve.entity_based(text, threshold)
+            resolve.entity_based(text, threshold, model)
         elif resolver == 'closest':
-            resolve.closest(text, threshold)
+            resolve.closest(text, threshold, model)
         elif resolver == 'siamese':
-            resolve.siamese(text, threshold)
+            resolve.siamese(text, threshold, model)
         elif resolver == 'all2all':
-            resolve.all2all(text, threshold)
+            resolve.all2all(text, threshold, model)
         mmax.write(inpath, outpath, text)
     elif informat == 'tei':
         print (basename)
         text = tei.read(inpath)
         if resolver == 'incremental':
-            resolve.incremental(text, threshold)
+            resolve.incremental(text, threshold, model)
         elif resolver == 'entity_based':
-            resolve.entity_based(text, threshold)
+            resolve.entity_based(text, threshold, model)
         elif resolver == 'closest':
-            resolve.closest(text, threshold)
+            resolve.closest(text, threshold, model)
         elif resolver == 'siamese':
-            resolve.siamese(text, threshold)
+            resolve.siamese(text, threshold, model)
         elif resolver == 'all2all':
-            resolve.all2all(text, threshold)
+            resolve.all2all(text, threshold, model)
         tei.write(inpath, outpath, text)
 
 
diff --git a/corneferencer/prepare_data.py b/corneferencer/prepare_data.py
new file mode 100644
index 0000000..6591929
--- /dev/null
+++ b/corneferencer/prepare_data.py
@@ -0,0 +1,140 @@
+# -*- coding: utf-8 -*-
+
+import codecs
+import os
+import random
+import sys
+
+from itertools import combinations
+from argparse import ArgumentParser
+from natsort import natsorted
+
+sys.path.append(os.path.abspath(os.path.join('..')))
+
+from inout import mmax, tei
+from inout.constants import INPUT_FORMATS
+from utils import eprint
+from corneferencer.resolvers import vectors
+
+
+POS_COUNT = 0
+NEG_COUNT = 0
+
+
+def main():
+    args = parse_arguments()
+    if not args.input:
+        eprint("Error: Input file(s) not specified!")
+    elif args.format not in INPUT_FORMATS:
+        eprint("Error: Unknown input file format!")
+    else:
+        process_texts(args.input, args.output, args.format, args.proportion)
+
+
+def parse_arguments():
+    parser = ArgumentParser(description='Corneferencer: data preparator for neural nets training.')
+    parser.add_argument('-i', '--input', type=str, action='store',
+                        dest='input', default='',
+                        help='input dir path')
+    parser.add_argument('-o', '--output', type=str, action='store',
+                        dest='output', default='',
+                        help='output path; if not specified writes output to standard output')
+    parser.add_argument('-f', '--format', type=str, action='store',
+                        dest='format', default=INPUT_FORMATS[0],
+                        help='input format; default: %s; possibilities: %s'
+                             % (INPUT_FORMATS[0], ', '.join(INPUT_FORMATS)))
+    parser.add_argument('-p', '--proportion', type=int, action='store',
+                        dest='proportion', default=5,
+                        help='negative examples proportion; default: 5')
+    args = parser.parse_args()
+    return args
+
+
+def process_texts(inpath, outpath, informat, proportion):
+    if os.path.isdir(inpath):
+        process_directory(inpath, outpath, informat, proportion)
+    else:
+        eprint("Error: Specified input does not exist or is not a directory!")
+
+
+def process_directory(inpath, outpath, informat, proportion):
+    inpath = os.path.abspath(inpath)
+    outpath = os.path.abspath(outpath)
+
+    try:
+        create_data_vectors(inpath, outpath, informat, proportion)
+    finally:
+        print ('Positives: ', POS_COUNT)
+        print ('Negatives: ', NEG_COUNT)
+
+
+def create_data_vectors(inpath, outpath, informat, proportion):
+    features_file = codecs.open(outpath, 'w', 'utf-8')
+
+    files = os.listdir(inpath)
+    files = natsorted(files)
+
+    for filename in files:
+        textname = os.path.splitext(os.path.basename(filename))[0]
+        textinput = os.path.join(inpath, filename)
+
+        print ('Processing text: %s' % textname)
+        text = None
+        if informat == 'mmax' and filename.endswith('.mmax'):
+            text = mmax.read(textinput, False)
+        elif informat == 'tei':
+            text = tei.read(textinput, False)
+
+        positives, negatives = diff_mentions(text, proportion)
+        write_features(features_file, positives, negatives)
+
+
+def diff_mentions(text, proportion):
+    sets = text.get_sets()
+    all_mentions = text.get_mentions()
+    positives = get_positives(sets)
+    positives, negatives = get_negatives_and_update_positives(all_mentions, positives, proportion)
+    return positives, negatives
+
+
+def get_positives(sets):
+    positives = []
+    for set_id in sets:
+        coref_set = sets[set_id]
+        positives.extend(list(combinations(coref_set, 2)))
+    return positives
+
+
+def get_negatives_and_update_positives(all_mentions, positives, proportion):
+    all_pairs = list(combinations(all_mentions, 2))
+
+    all_pairs = set(all_pairs)
+    negatives = [pair for pair in all_pairs if pair not in positives]
+    samples_count = proportion * len(positives)
+    if samples_count > len(negatives):
+        samples_count = len(negatives)
+        if proportion == 1:
+            positives = random.sample(set(positives), samples_count)
+        print (u'Więcej przypadków pozytywnych niż negatywnych!')
+    negatives = random.sample(set(negatives), samples_count)
+    return positives, negatives
+
+
+def write_features(features_file, positives, negatives):
+    global POS_COUNT
+    POS_COUNT += len(positives)
+    for pair in positives:
+        vector = vectors.get_pair_vector(pair[0], pair[1])
+        vector.append(1.0)
+        features_file.write(u'%s\n' % u'\t'.join([str(feature) for feature in vector]))
+
+    global NEG_COUNT
+    NEG_COUNT += len(negatives)
+    for pair in negatives:
+        vector = vectors.get_pair_vector(pair[0], pair[1])
+        vector.append(0.0)
+        features_file.write(u'%s\n' % u'\t'.join([str(feature) for feature in vector]))
+
+
+if __name__ == '__main__':
+    main()
diff --git a/corneferencer/resolvers/features.py b/corneferencer/resolvers/features.py
index 987eae8..b4a83f0 100644
--- a/corneferencer/resolvers/features.py
+++ b/corneferencer/resolvers/features.py
@@ -72,97 +72,97 @@ def sentence_vec(mention):
 
 
 def mention_type(mention):
-    type_vec = [0] * 4
+    type_vec = [0.0] * 4
     if mention.head is None:
-        type_vec[3] = 1
+        type_vec[3] = 1.0
     elif mention.head['ctag'] in constants.NOUN_TAGS:
-        type_vec[0] = 1
+        type_vec[0] = 1.0
     elif mention.head['ctag'] in constants.PPRON_TAGS:
-        type_vec[1] = 1
+        type_vec[1] = 1.0
     elif mention.head['ctag'] in constants.ZERO_TAGS:
-        type_vec[2] = 1
+        type_vec[2] = 1.0
     else:
-        type_vec[3] = 1
+        type_vec[3] = 1.0
     return type_vec
 
 
 def is_first_second_person(mention):
     if mention.head is None:
-        return 0
+        return 0.0
     if mention.head['person'] in constants.FIRST_SECOND_PERSON:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def is_demonstrative(mention):
     if mention.words[0]['base'].lower() in constants.INDICATIVE_PRONS_BASES:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def is_demonstrative_nominal(mention):
     if mention.head is None:
-        return 0
+        return 0.0
     if is_demonstrative(mention) and mention.head['ctag'] in constants.NOUN_TAGS:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def is_demonstrative_pronoun(mention):
     if mention.head is None:
-        return 0
+        return 0.0
     if (is_demonstrative(mention) and
             (mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS)):
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def is_refl_pronoun(mention):
     if mention.head is None:
-        return 0
+        return 0.0
     if mention.head['ctag'] in constants.SIEBIE_TAGS:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def is_first_in_sentence(mention):
     if mention.first_in_sentence:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def is_zero_or_pronoun(mention):
     if mention.head is None:
-        return 0
+        return 0.0
     if mention.head['ctag'] in constants.PPRON_TAGS or mention.head['ctag'] in constants.ZERO_TAGS:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def head_contains_digit(mention):
     _digits = re.compile('\d')
     if _digits.search(mention.head_orth):
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def mention_contains_digit(mention):
     _digits = re.compile('\d')
     if _digits.search(mention.text):
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def contains_letter(mention):
     if any(c.isalpha() for c in mention.text):
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def post_modified(mention):
     if mention.head_orth != mention.words[-1]['orth']:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 # pair features
@@ -171,20 +171,20 @@ def distances_vec(ante, ana):
 
     mnts_intersect = pair_intersect(ante, ana)
 
-    words_dist = [0] * 11
+    words_dist = [0.0] * 11
     words_bucket = 0
-    if mnts_intersect != 1:
+    if mnts_intersect != 1.0:
         words_bucket = get_distance_bucket(ana.start_in_words - ante.end_in_words)
-    words_dist[words_bucket] = 1
+    words_dist[words_bucket] = 1.0
     vec.extend(words_dist)
 
-    mentions_dist = [0] * 11
+    mentions_dist = [0.0] * 11
     mentions_bucket = 0
-    if mnts_intersect != 1:
+    if mnts_intersect != 1.0:
         mentions_bucket = get_distance_bucket(ana.position_in_mentions - ante.position_in_mentions)
     if words_bucket == 10:
         mentions_bucket = 10
-    mentions_dist[mentions_bucket] = 1
+    mentions_dist[mentions_bucket] = 1.0
     vec.extend(mentions_dist)
 
     vec.append(mnts_intersect)
@@ -196,45 +196,45 @@ def pair_intersect(ante, ana):
     for ante_word in ante.words:
         for ana_word in ana.words:
             if ana_word['id'] == ante_word['id']:
-                return 1
-    return 0
+                return 1.0
+    return 0.0
 
 
 def head_match(ante, ana):
     if ante.head_orth.lower() == ana.head_orth.lower():
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def exact_match(ante, ana):
     if ante.text.lower() == ana.text.lower():
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def base_match(ante, ana):
     if ante.lemmatized_text.lower() == ana.lemmatized_text.lower():
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def ante_contains_rarest_from_ana(ante, ana):
     ana_rarest = ana.rarest
     for word in ante.words:
         if word['base'] == ana_rarest['base']:
-            return 1
-    return 0
+            return 1.0
+    return 0.0
 
 
 def agreement(ante, ana, tag_name):
-    agr_vec = [0] * 3
+    agr_vec = [0.0] * 3
     if (ante.head is None or ana.head is None or
             ante.head[tag_name] == 'unk' or ana.head[tag_name] == 'unk'):
-        agr_vec[2] = 1
+        agr_vec[2] = 1.0
     elif ante.head[tag_name] == ana.head[tag_name]:
-        agr_vec[0] = 1
+        agr_vec[0] = 1.0
     else:
-        agr_vec[1] = 1
+        agr_vec[1] = 1.0
     return agr_vec
 
 
@@ -243,72 +243,72 @@ def is_acronym(ante, ana):
         return check_one_way_acronym(ana.text, ante.text)
     if ante.text.upper() == ante.text:
         return check_one_way_acronym(ante.text, ana.text)
-    return 0
+    return 0.0
 
 
 def same_sentence(ante, ana):
     if ante.sentence_id == ana.sentence_id:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def neighbouring_sentence(ante, ana):
     if ana.sentence_id - ante.sentence_id == 1:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def cousin_sentence(ante, ana):
     if ana.sentence_id - ante.sentence_id == 2:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def distant_sentence(ante, ana):
     if ana.sentence_id - ante.sentence_id > 2:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def same_paragraph(ante, ana):
     if ante.paragraph_id == ana.paragraph_id:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def flat_gender_agreement(ante, ana):
-    agr_vec = [0] * 3
+    agr_vec = [0.0] * 3
     if (ante.head is None or ana.head is None or
             ante.head['gender'] == 'unk' or ana.head['gender'] == 'unk'):
-        agr_vec[2] = 1
+        agr_vec[2] = 1.0
     elif (ante.head['gender'] == ana.head['gender'] or
             (ante.head['gender'] in constants.MASCULINE_TAGS and ana.head['gender'] in constants.MASCULINE_TAGS)):
-        agr_vec[0] = 1
+        agr_vec[0] = 1.0
     else:
-        agr_vec[1] = 1
+        agr_vec[1] = 1.0
     return agr_vec
 
 
 def left_match(ante, ana):
     if (ante.text.lower().startswith(ana.text.lower()) or
             ana.text.lower().startswith(ante.text.lower())):
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def right_match(ante, ana):
     if (ante.text.lower().endswith(ana.text.lower()) or
             ana.text.lower().endswith(ante.text.lower())):
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def abbrev2(ante, ana):
     ante_abbrev = get_abbrev(ante)
     ana_abbrev = get_abbrev(ana)
     if ante.head_orth == ana_abbrev or ana.head_orth == ante_abbrev:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def string_kernel(ante, ana):
@@ -326,7 +326,7 @@ def head_string_kernel(ante, ana):
 def wordnet_synonyms(ante, ana):
     ante_synonyms = set()
     if ante.head is None or ana.head is None:
-        return 0
+        return 0.0
 
     if ante.head['base'] in conf.LEMMA2SYNONYMS:
         ante_synonyms = conf.LEMMA2SYNONYMS[ante.head['base']]
@@ -336,13 +336,13 @@ def wordnet_synonyms(ante, ana):
         ana_synonyms = conf.LEMMA2SYNONYMS[ana.head['base']]
 
     if ana.head['base'] in ante_synonyms or ante.head['base'] in ana_synonyms:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def wordnet_ana_is_hypernym(ante, ana):
     if ante.head is None or ana.head is None:
-        return 0
+        return 0.0
 
     ante_hypernyms = set()
     if ante.head['base'] in conf.LEMMA2HYPERNYMS:
@@ -353,16 +353,16 @@ def wordnet_ana_is_hypernym(ante, ana):
         ana_hypernyms = conf.LEMMA2HYPERNYMS[ana.head['base']]
 
     if not ante_hypernyms or not ana_hypernyms:
-        return 0
+        return 0.0
 
     if ana.head['base'] in ante_hypernyms:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def wordnet_ante_is_hypernym(ante, ana):
     if ante.head is None or ana.head is None:
-        return 0
+        return 0.0
 
     ana_hypernyms = set()
     if ana.head['base'] in conf.LEMMA2HYPERNYMS:
@@ -373,18 +373,18 @@ def wordnet_ante_is_hypernym(ante, ana):
         ante_hypernyms = conf.LEMMA2HYPERNYMS[ante.head['base']]
 
     if not ante_hypernyms or not ana_hypernyms:
-        return 0
+        return 0.0
 
     if ante.head['base'] in ana_hypernyms:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def wikipedia_link(ante, ana):
     ante_base = ante.lemmatized_text.lower()
     ana_base = ana.lemmatized_text.lower()
     if ante_base == ana_base:
-        return 1
+        return 1.0
 
     ante_links = set()
     if ante_base in conf.TITLE2LINKS:
@@ -395,16 +395,16 @@ def wikipedia_link(ante, ana):
         ana_links = conf.TITLE2LINKS[ana_base]
 
     if ana_base in ante_links or ante_base in ana_links:
-        return 1
+        return 1.0
 
-    return 0
+    return 0.0
 
 
 def wikipedia_mutual_link(ante, ana):
     ante_base = ante.lemmatized_text.lower()
     ana_base = ana.lemmatized_text.lower()
     if ante_base == ana_base:
-        return 1
+        return 1.0
 
     ante_links = set()
     if ante_base in conf.TITLE2LINKS:
@@ -415,52 +415,52 @@ def wikipedia_mutual_link(ante, ana):
         ana_links = conf.TITLE2LINKS[ana_base]
 
     if ana_base in ante_links and ante_base in ana_links:
-        return 1
+        return 1.0
 
-    return 0
+    return 0.0
 
 
 def wikipedia_redirect(ante, ana):
     ante_base = ante.lemmatized_text.lower()
     ana_base = ana.lemmatized_text.lower()
     if ante_base == ana_base:
-        return 1
+        return 1.0
 
     if ante_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ante_base] == ana_base:
-        return 1
+        return 1.0
 
     if ana_base in conf.TITLE2REDIRECT and conf.TITLE2REDIRECT[ana_base] == ante_base:
-        return 1
+        return 1.0
 
-    return 0
+    return 0.0
 
 
 def samesent_anapron_antefirstinpar(ante, ana):
     if same_sentence(ante, ana) and is_zero_or_pronoun(ana) and ante.first_in_paragraph:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def samesent_antefirstinpar_personnumbermatch(ante, ana):
     if (same_sentence(ante, ana) and ante.first_in_paragraph
             and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]):
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def adjsent_anapron_adjmen_personnumbermatch(ante, ana):
     if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana)
             and ana.position_in_mentions - ante.position_in_mentions == 1
             and agreement(ante, ana, 'number')[0] and agreement(ante, ana, 'person')[0]):
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def adjsent_anapron_adjmen(ante, ana):
     if (neighbouring_sentence(ante, ana) and is_zero_or_pronoun(ana)
             and ana.position_in_mentions - ante.position_in_mentions == 1):
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 # supporting functions
@@ -523,8 +523,8 @@ def check_one_way_acronym(acronym, expression):
             if expr2:
                 initials += expr2[0].upper()
     if acronym == initials:
-        return 1
-    return 0
+        return 1.0
+    return 0.0
 
 
 def get_abbrev(mention):
diff --git a/corneferencer/resolvers/resolve.py b/corneferencer/resolvers/resolve.py
index c3a1038..bcca721 100644
--- a/corneferencer/resolvers/resolve.py
+++ b/corneferencer/resolvers/resolve.py
@@ -1,17 +1,15 @@
 import numpy
 
-from conf import NEURAL_MODEL
-from corneferencer.resolvers import features
-from corneferencer.resolvers.vectors import get_pair_features, get_pair_vector
+from corneferencer.resolvers import features, vectors
 
 
-def siamese(text, threshold):
+def siamese(text, threshold, neural_model):
     last_set_id = 0
     for i, ana in enumerate(text.mentions):
         if i > 0:
             for ante in reversed(text.mentions[:i]):
                 if not features.pair_intersect(ante, ana):
-                    pair_features = get_pair_features(ante, ana)
+                    pair_features = vectors.get_pair_features(ante, ana)
 
                     ante_vec = []
                     ante_vec.extend(ante.features)
@@ -23,7 +21,7 @@ def siamese(text, threshold):
                     ana_vec.extend(pair_features)
                     ana_sample = numpy.asarray([ana_vec], dtype=numpy.float32)
 
-                    prediction = NEURAL_MODEL.predict([ante_sample, ana_sample])[0]
+                    prediction = neural_model.predict([ante_sample, ana_sample])[0]
 
                     if prediction < threshold:
                         if ante.set:
@@ -37,7 +35,7 @@ def siamese(text, threshold):
 
 
 # incremental resolve algorithm
-def incremental(text, threshold):
+def incremental(text, threshold, neural_model):
     last_set_id = 0
     for i, ana in enumerate(text.mentions):
         if i > 0:
@@ -45,9 +43,9 @@ def incremental(text, threshold):
             best_ante = None
             for ante in text.mentions[:i]:
                 if not features.pair_intersect(ante, ana):
-                    pair_vec = get_pair_vector(ante, ana)
+                    pair_vec = vectors.get_pair_vector(ante, ana)
                     sample = numpy.asarray([pair_vec], dtype=numpy.float32)
-                    prediction = NEURAL_MODEL.predict(sample)[0]
+                    prediction = neural_model.predict(sample)[0]
                     if prediction > threshold and prediction >= best_prediction:
                         best_prediction = prediction
                         best_ante = ante
@@ -62,86 +60,21 @@ def incremental(text, threshold):
 
 
 # all2all resolve algorithm
-def all2all_debug(text, threshold):
-    last_set_id = 0
-    for pos1, mnt1 in enumerate(text.mentions):
-        best_prediction = 0.0
-        best_link = None
-        for pos2, mnt2 in enumerate(text.mentions):
-            if (mnt1.set != mnt2.set or not mnt1.set) and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2):
-                ante = mnt1
-                ana = mnt2
-                if pos2 < pos1:
-                    ante = mnt2
-                    ana = mnt1
-                pair_vec = get_pair_vector(ante, ana)
-                sample = numpy.asarray([pair_vec], dtype=numpy.float32)
-                prediction = NEURAL_MODEL.predict(sample)[0]
-                if prediction > threshold and prediction > best_prediction:
-                    best_prediction = prediction
-                    best_link = mnt2
-        if best_link is not None:
-            if best_link.set and not mnt1.set:
-                mnt1.set = best_link.set
-            elif best_link.set and mnt1.set:
-                text.merge_sets(best_link.set, mnt1.set)
-            elif not best_link.set and not mnt1.set:
-                str_set_id = 'set_%d' % last_set_id
-                best_link.set = str_set_id
-                mnt1.set = str_set_id
-                last_set_id += 1
-
-
-def all2all_v1(text, threshold):
-    last_set_id = 0
-    for pos1, mnt1 in enumerate(text.mentions):
-        best_prediction = 0.0
-        best_link = None
-        for pos2, mnt2 in enumerate(text.mentions):
-            if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set)
-                    and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
-                ante = mnt1
-                ana = mnt2
-                if pos2 < pos1:
-                    ante = mnt2
-                    ana = mnt1
-                pair_vec = get_pair_vector(ante, ana)
-                sample = numpy.asarray([pair_vec], dtype=numpy.float32)
-                prediction = NEURAL_MODEL.predict(sample)[0]
-                if prediction > threshold and prediction > best_prediction:
-                    best_prediction = prediction
-                    best_link = mnt2
-        if best_link is not None:
-            if best_link.set and not mnt1.set:
-                mnt1.set = best_link.set
-            elif not best_link.set and mnt1.set:
-                best_link.set = mnt1.set
-            elif best_link.set and mnt1.set:
-                text.merge_sets(best_link.set, mnt1.set)
-            elif not best_link.set and not mnt1.set:
-                str_set_id = 'set_%d' % last_set_id
-                best_link.set = str_set_id
-                mnt1.set = str_set_id
-                last_set_id += 1
-
-
-def all2all(text, threshold):
+def all2all(text, threshold, neural_model):
     last_set_id = 0
     sets = text.get_sets()
     for pos1, mnt1 in enumerate(text.mentions):
         best_prediction = 0.0
         best_link = None
         for pos2, mnt2 in enumerate(text.mentions):
-            if ((mnt1.set != mnt2.set or not mnt1.set or not mnt2.set)
-                    and pos1 != pos2 and not features.pair_intersect(mnt1, mnt2)):
+            if (pos2 > pos1 and
+                    (mnt1.set != mnt2.set or not mnt1.set or not mnt2.set)
+                    and not features.pair_intersect(mnt1, mnt2)):
                 ante = mnt1
                 ana = mnt2
-                if pos2 < pos1:
-                    ante = mnt2
-                    ana = mnt1
-                pair_vec = get_pair_vector(ante, ana)
+                pair_vec = vectors.get_pair_vector(ante, ana)
                 sample = numpy.asarray([pair_vec], dtype=numpy.float32)
-                prediction = NEURAL_MODEL.predict(sample)[0]
+                prediction = neural_model.predict(sample)[0]
                 if prediction > threshold and prediction > best_prediction:
                     best_prediction = prediction
                     best_link = mnt2
@@ -163,12 +96,12 @@ def all2all(text, threshold):
 
 
 # entity based resolve algorithm
-def entity_based(text, threshold):
+def entity_based(text, threshold, neural_model):
     sets = []
     last_set_id = 0
     for i, ana in enumerate(text.mentions):
         if i > 0:
-            best_fit = get_best_set(sets, ana, threshold)
+            best_fit = get_best_set(sets, ana, threshold, neural_model)
             if best_fit is not None:
                 ana.set = best_fit['set_id']
                 best_fit['mentions'].append(ana)
@@ -188,25 +121,25 @@ def entity_based(text, threshold):
     remove_singletons(sets)
 
 
-def get_best_set(sets, ana, threshold):
+def get_best_set(sets, ana, threshold, neural_model):
     best_prediction = 0.0
     best_set = None
     for s in sets:
-        accuracy = predict_set(s['mentions'], ana)
+        accuracy = predict_set(s['mentions'], ana, neural_model)
         if accuracy > threshold and accuracy >= best_prediction:
             best_prediction = accuracy
             best_set = s
     return best_set
 
 
-def predict_set(mentions, ana):
+def predict_set(mentions, ana, neural_model):
     prediction_sum = 0.0
     for mnt in mentions:
         prediction = 0.0
         if not features.pair_intersect(mnt, ana):
-            pair_vec = get_pair_vector(mnt, ana)
+            pair_vec = vectors.get_pair_vector(mnt, ana)
             sample = numpy.asarray([pair_vec], dtype=numpy.float32)
-            prediction = NEURAL_MODEL.predict(sample)[0]
+            prediction = neural_model.predict(sample)[0]
         prediction_sum += prediction
     return prediction_sum / float(len(mentions))
 
@@ -218,15 +151,15 @@ def remove_singletons(sets):
 
 
 # closest resolve algorithm
-def closest(text, threshold):
+def closest(text, threshold, neural_model):
     last_set_id = 0
     for i, ana in enumerate(text.mentions):
         if i > 0:
             for ante in reversed(text.mentions[:i]):
                 if not features.pair_intersect(ante, ana):
-                    pair_vec = get_pair_vector(ante, ana)
+                    pair_vec = vectors.get_pair_vector(ante, ana)
                     sample = numpy.asarray([pair_vec], dtype=numpy.float32)
-                    prediction = NEURAL_MODEL.predict(sample)[0]
+                    prediction = neural_model.predict(sample)[0]
                     if prediction > threshold:
                         if ante.set:
                             ana.set = ante.set
diff --git a/corneferencer/utils.py b/corneferencer/utils.py
index ea6b125..15c4ca1 100644
--- a/corneferencer/utils.py
+++ b/corneferencer/utils.py
@@ -7,7 +7,6 @@ import javaobj
 
 from keras.models import Sequential, Model
 from keras.layers import Input, Dense, Dropout, Activation, BatchNormalization, Lambda
-from keras.optimizers import RMSprop, Adam
 from keras import backend as K
 
 
--
libgit2 0.22.2