diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java index 38d2989..603608d 100644 --- a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java @@ -1,8 +1,13 @@ package pl.waw.ipipan.zil.summ.nicolas.common; import com.google.common.base.Charsets; +import com.google.common.collect.ImmutableList; import weka.classifiers.Classifier; -import weka.classifiers.functions.Logistic; +import weka.classifiers.functions.SMO; +import weka.classifiers.meta.AdaBoostM1; +import weka.classifiers.meta.AttributeSelectedClassifier; +import weka.classifiers.rules.JRip; +import weka.classifiers.trees.J48; import weka.classifiers.trees.RandomForest; import java.nio.charset.Charset; @@ -20,6 +25,8 @@ public class Constants { public static final Charset ENCODING = Charsets.UTF_8; + public static final ImmutableList<String> POS_TAGS = ImmutableList.of("end", "other", "null", "impt", "imps", "inf", "pred", "subst", "aglt", "ppron3", "ger", "praet", "fin", "num", "interp", "siebie", "brev", "interj", "ppron12", "adj", "burk", "pcon", "bedzie", "adv", "prep", "depr", "xxx", "winien", "conj", "qub", "adja", "ppas", "comp", "pact"); + private Constants() { } @@ -33,14 +40,14 @@ public class Constants { public static Classifier getSentencesClassifier() { RandomForest classifier = new RandomForest(); - classifier.setNumIterations(250); + classifier.setNumIterations(10); classifier.setSeed(0); classifier.setNumExecutionSlots(8); return classifier; } public static Classifier getZerosClassifier() { - Logistic classifier = new Logistic(); + Classifier classifier = new J48(); return classifier; } } diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java index 23d1958..6d7d07f 100644 --- a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java @@ -30,12 +30,13 @@ public class FeatureHelper { private final Map<TMention, TToken> mention2head = Maps.newHashMap(); private final Set<TMention> mentionsInNamedEntities = Sets.newHashSet(); - private final Map<TMention, Integer> mention2Index = Maps.newHashMap(); + private final Map<TMention, Integer> mention2index = Maps.newHashMap(); private final Map<TSentence, Integer> sent2Index = Maps.newHashMap(); private final Map<TParagraph, Integer> par2Index = Maps.newHashMap(); private final Map<TSentence, Integer> sent2IndexInPar = Maps.newHashMap(); private final Map<TMention, Integer> mention2indexInPar = Maps.newHashMap(); private final Map<TMention, Integer> mention2indexInSent = Maps.newHashMap(); + private final Map<TMention, Integer> mention2firstTokenIndex = Maps.newHashMap(); public FeatureHelper(TText preprocessedText) { @@ -82,7 +83,8 @@ public class FeatureHelper { for (TMention mention : sent.getMentions()) { mention2sent.put(mention, sent); mention2par.put(mention, par); - mention2Index.put(mention, mentionIdx++); + mention2index.put(mention, mentionIdx++); + mention2firstTokenIndex.put(mention, sent.getTokens().indexOf(tokenId2token.get(mention.getChildIds().iterator().next()))); mention2indexInSent.put(mention, mentionIdxInSent++); mention2indexInPar.put(mention, mentionIdxInPar++); @@ -124,7 +126,11 @@ public class FeatureHelper { } public int getMentionIndex(TMention mention) { - return mention2Index.get(mention); + return mention2index.get(mention); + } + + public int getMentionFirstTokenIndex(TMention mention) { + return mention2firstTokenIndex.get(mention); } public int getMentionIndexInSent(TMention mention) { @@ -200,4 +206,19 @@ public class FeatureHelper { public TText getText() { return text; } + + public TToken getTokenAfterMention(TMention mention) { + Integer idx = mention2firstTokenIndex.get(mention) + mention.getChildIds().size(); + List<TToken> sentenceTokens = mention2sent.get(mention).getTokens(); + if (idx >= sentenceTokens.size()) + return null; + return sentenceTokens.get(idx); + } + + public TToken getTokenBeforeMention(TMention mention) { + Integer idx = mention2firstTokenIndex.get(mention); + if (idx == 0) + return null; + return mention2sent.get(mention).getTokens().get(idx - 1); + } } diff --git a/nicolas-common/src/test/java/pl/waw/ipipan/zil/summ/nicolas/common/UtilsTest.java b/nicolas-common/src/test/java/pl/waw/ipipan/zil/summ/nicolas/common/UtilsTest.java new file mode 100644 index 0000000..715fbf5 --- /dev/null +++ b/nicolas-common/src/test/java/pl/waw/ipipan/zil/summ/nicolas/common/UtilsTest.java @@ -0,0 +1,22 @@ +package pl.waw.ipipan.zil.summ.nicolas.common; + +import org.junit.Test; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; + +import java.io.InputStream; + +import static org.junit.Assert.assertEquals; + +public class UtilsTest { + + private static final String SAMPLE_TEXT_PATH = "/199704210011.bin"; + + @Test + public void shouldDeserializeTextIgnoringClassVersionId() throws Exception { + try (InputStream stream = UtilsTest.class.getResourceAsStream(SAMPLE_TEXT_PATH)) { + TText text = Utils.loadThrifted(stream); + assertEquals(26, text.getParagraphs().size()); + assertEquals(2, text.getParagraphs().get(4).getSentences().size()); + } + } +} \ No newline at end of file diff --git a/nicolas-common/src/test/resources/199704210011.bin b/nicolas-common/src/test/resources/199704210011.bin new file mode 100644 index 0000000..cf072c2 --- /dev/null +++ b/nicolas-common/src/test/resources/199704210011.bin diff --git a/nicolas-core/pom.xml b/nicolas-core/pom.xml index b291cda..4557c3e 100644 --- a/nicolas-core/pom.xml +++ b/nicolas-core/pom.xml @@ -21,11 +21,8 @@ <groupId>pl.waw.ipipan.zil.summ</groupId> <artifactId>nicolas-model</artifactId> </dependency> - <dependency> - <groupId>pl.waw.ipipan.zil.summ</groupId> - <artifactId>nicolas-zero</artifactId> - </dependency> + <!-- internal --> <dependency> <groupId>pl.waw.ipipan.zil.summ</groupId> <artifactId>pscapi</artifactId> @@ -35,6 +32,7 @@ <artifactId>utils</artifactId> </dependency> + <!-- third party --> <dependency> <groupId>nz.ac.waikato.cms.weka</groupId> <artifactId>weka-dev</artifactId> @@ -51,5 +49,17 @@ <groupId>org.apache.commons</groupId> <artifactId>commons-lang3</artifactId> </dependency> + + <!-- logging --> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-api</artifactId> + </dependency> + + <!-- test --> + <dependency> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + </dependency> </dependencies> </project> \ No newline at end of file diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/EvalUtils.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/EvalUtils.java new file mode 100644 index 0000000..741ee9b --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/EvalUtils.java @@ -0,0 +1,96 @@ +package pl.waw.ipipan.zil.summ.nicolas.eval; + +import org.apache.commons.lang3.time.StopWatch; +import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import weka.classifiers.Classifier; +import weka.classifiers.bayes.BayesNet; +import weka.classifiers.bayes.NaiveBayes; +import weka.classifiers.evaluation.Evaluation; +import weka.classifiers.functions.LinearRegression; +import weka.classifiers.functions.Logistic; +import weka.classifiers.functions.SMOreg; +import weka.classifiers.functions.SimpleLogistic; +import weka.classifiers.lazy.IBk; +import weka.classifiers.lazy.KStar; +import weka.classifiers.lazy.LWL; +import weka.classifiers.rules.DecisionTable; +import weka.classifiers.rules.JRip; +import weka.classifiers.rules.PART; +import weka.classifiers.trees.HoeffdingTree; +import weka.classifiers.trees.J48; +import weka.classifiers.trees.LMT; +import weka.classifiers.trees.RandomForest; +import weka.core.Instances; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.Optional; +import java.util.Random; + +public class EvalUtils { + + private static final Logger LOG = LoggerFactory.getLogger(EvalUtils.class); + public static final int NUM_FOLDS = 10; + + private EvalUtils() { + } + + public static void crossvalidateClassification(Instances instances) throws Exception { + StopWatch watch = new StopWatch(); + watch.start(); + + Optional<Pair<Double, String>> max = Arrays.stream(new Classifier[]{new J48(), new RandomForest(), new HoeffdingTree(), new LMT(), + new Logistic(), + new SimpleLogistic(), new BayesNet(), new NaiveBayes(), + new KStar(), new IBk(), new LWL(), + new DecisionTable(), new JRip(), new PART()}).parallel().map(cls -> { + Evaluation eval = null; + try { + eval = new Evaluation(instances); + eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(1)); + } catch (Exception e) { + e.printStackTrace(); + } + double acc = eval.correct() / eval.numInstances(); + String name = cls.getClass().getSimpleName(); + LOG.info(name + " : " + acc); + + return Pair.of(acc, name); + }).max(Comparator.comparingDouble(Pair::getLeft)); + LOG.info("#########"); + LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); + + watch.stop(); + LOG.info("Elapsed time: " + watch); + } + + public static void crossvalidateRegression(Instances instances) { + StopWatch watch = new StopWatch(); + watch.start(); + + Optional<Pair<Double, String>> max = Arrays.stream(new Classifier[]{ + new RandomForest(), new LinearRegression(), new SMOreg()}).parallel().map(cls -> { + Evaluation eval = null; + double acc = 0; + try { + eval = new Evaluation(instances); + eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(1)); + acc = eval.correlationCoefficient(); + + } catch (Exception e) { + e.printStackTrace(); + } + String name = cls.getClass().getSimpleName(); + LOG.info(name + " : " + acc); + + return Pair.of(acc, name); + }).max(Comparator.comparingDouble(Pair::getLeft)); + LOG.info("#########"); + LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); + + watch.stop(); + LOG.info("Elapsed time: " + watch); + } +} \ No newline at end of file diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java index ad239f9..ec671aa 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java @@ -2,6 +2,7 @@ package pl.waw.ipipan.zil.summ.nicolas.mention; import com.google.common.collect.*; import pl.waw.ipipan.zil.multiservice.thrift.types.*; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureExtractor; import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; import pl.waw.ipipan.zil.summ.nicolas.common.features.Interpretation; @@ -45,7 +46,7 @@ public class MentionFeatureExtractor extends FeatureExtractor { addBinaryAttribute(prefix + "_is_zero"); addBinaryAttribute(prefix + "_is_named"); addBinaryAttribute(prefix + "_is_pronoun"); - addNominalAttribute(prefix + "_ctag", Lists.newArrayList("other", "null", "impt", "subst", "aglt", "ppron3", "ger", "praet", "fin", "num", "interp", "siebie", "brev", "interj", "ppron12", "adj", "burk", "pcon", "bedzie", "adv", "prep", "depr", "xxx", "winien", "conj", "qub", "adja", "ppas", "comp", "pact")); + addNominalAttribute(prefix + "_ctag", Constants.POS_TAGS); addNominalAttribute(prefix + "_person", Lists.newArrayList("other", "null", "pri", "sec", "ter")); addNominalAttribute(prefix + "_case", Lists.newArrayList("other", "null", "nom", "acc", "dat", "gen", "loc", "inst", "voc")); addNominalAttribute(prefix + "_number", Lists.newArrayList("other", "null", "sg", "pl")); diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java index a52ef2f..58b8c8f 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java @@ -17,6 +17,9 @@ public class Crossvalidate { private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); + private Crossvalidate() { + } + public static void main(String[] args) throws Exception { ArffLoader loader = new ArffLoader(); @@ -26,9 +29,6 @@ public class Crossvalidate { LOG.info(instances.size() + " instances loaded."); LOG.info(instances.numAttributes() + " attributes for each instance."); -// while (instances.size() > 10000) -// instances.remove(instances.size() - 1); - StopWatch watch = new StopWatch(); watch.start(); diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java index 48a8ccf..8b60024 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java @@ -14,9 +14,7 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.ObjectInputStream; -/** - * Created by me2 on 05.04.16. - */ + public class Validate { private static final Logger LOG = LoggerFactory.getLogger(Validate.class); diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java index 457a857..09cc621 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java @@ -1,22 +1,22 @@ package pl.waw.ipipan.zil.summ.nicolas.sentence.test; -import org.apache.commons.lang3.time.StopWatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import pl.waw.ipipan.zil.summ.nicolas.common.Constants; -import weka.classifiers.Classifier; -import weka.classifiers.evaluation.Evaluation; +import pl.waw.ipipan.zil.summ.nicolas.eval.EvalUtils; import weka.core.Instances; import weka.core.converters.ArffLoader; import java.io.File; -import java.util.Random; public class Crossvalidate { private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); + private Crossvalidate() { + } + public static void main(String[] args) throws Exception { ArffLoader loader = new ArffLoader(); @@ -26,16 +26,6 @@ public class Crossvalidate { LOG.info(instances.size() + " instances loaded."); LOG.info(instances.numAttributes() + " attributes for each instance."); - StopWatch watch = new StopWatch(); - watch.start(); - - Classifier tree = Constants.getSentencesClassifier(); - - Evaluation eval = new Evaluation(instances); - eval.crossValidateModel(tree, instances, 10, new Random(1)); - LOG.info(eval.toSummaryString()); - - watch.stop(); - LOG.info("Elapsed time: " + watch); + EvalUtils.crossvalidateRegression(instances); } } diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java index 9ce2a4b..f862b31 100644 --- a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java @@ -61,6 +61,8 @@ public class CandidateFinder { } private static boolean isInNominative(TInterpretation interp) { - return interp.getCtag().equals("subst") && Arrays.stream(interp.getMsd().split(":")).anyMatch(t -> t.equals("nom")); + boolean isNominative = Arrays.stream(interp.getMsd().split(":")).anyMatch(t -> t.equals("nom")); + boolean isSubst = interp.getCtag().equals("subst"); + return isSubst && isNominative; } } diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainingDataExtractor.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/PrepareTrainingData.java index fcdc68d..38fb018 100644 --- a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainingDataExtractor.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/PrepareTrainingData.java @@ -1,15 +1,14 @@ -package pl.waw.ipipan.zil.summ.nicolas.zero.train; +package pl.waw.ipipan.zil.summ.nicolas.zero; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import org.apache.commons.io.IOUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.common.Utils; import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; -import pl.waw.ipipan.zil.summ.nicolas.zero.CandidateFinder; -import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroFeatureExtractor; -import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectCandidate; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.Instance; @@ -23,13 +22,15 @@ import java.util.List; import java.util.Map; import java.util.Set; -public class TrainingDataExtractor { +public class PrepareTrainingData { + + private static final Logger LOG = LoggerFactory.getLogger(PrepareTrainingData.class); private static final String IDS_PATH = "corpora/summaries_dev"; private static final String THRIFTED_PATH = "corpora/preprocessed_full_texts/dev/"; private static final String GOLD_ZEROS_PATH = "/zeros.tsv"; - private TrainingDataExtractor() { + private PrepareTrainingData() { } public static void main(String[] args) throws IOException { @@ -42,7 +43,10 @@ public class TrainingDataExtractor { Instances instances = Utils.createNewInstances(featureExtractor.getAttributesList()); + int i = 1; for (Map.Entry<String, TText> entry : id2preprocessedText.entrySet()) { + LOG.info(i++ + "/" + id2preprocessedText.size()); + String textId = entry.getKey(); TText text = entry.getValue(); diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainModel.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/TrainModel.java index 34df6cf..77c5a30 100644 --- a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainModel.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/TrainModel.java @@ -1,4 +1,4 @@ -package pl.waw.ipipan.zil.summ.nicolas.zero.train; +package pl.waw.ipipan.zil.summ.nicolas.zero; import org.apache.commons.lang3.time.StopWatch; import org.slf4j.Logger; diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java index 43e1333..8111368 100644 --- a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java @@ -2,14 +2,18 @@ package pl.waw.ipipan.zil.summ.nicolas.zero; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.collect.Sets; import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.multiservice.thrift.types.TToken; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureExtractor; import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; import weka.core.Attribute; import java.util.List; import java.util.Map; +import java.util.Set; public class ZeroFeatureExtractor extends FeatureExtractor { @@ -18,13 +22,26 @@ public class ZeroFeatureExtractor extends FeatureExtractor { for (String prefix : new String[]{"antecedent", "candidate"}) { addNumericAttribute(prefix + "_index_in_sent"); + addNumericAttribute(prefix + "_first_token_index_in_sent"); addNumericAttribute(prefix + "_token_count"); - addBinaryAttribute(prefix + "_is_zero"); - addBinaryAttribute(prefix + "_is_pronoun"); addBinaryAttribute(prefix + "_is_named"); + addNumericAttribute(prefix + "_sentence_mention_count"); + addNominalAttribute(prefix + "_next_token_pos", Constants.POS_TAGS); + addNominalAttribute(prefix + "_prev_token_pos", Constants.POS_TAGS); + addBinaryAttribute(prefix + "_is_nested"); + addBinaryAttribute(prefix + "_is_nesting"); } + addNumericAttribute("chain_length"); + addBinaryAttribute("pair_equal_orth"); + addBinaryAttribute("pair_equal_ignore_case_orth"); + addBinaryAttribute("pair_equal_base"); + addBinaryAttribute("pair_equal_number"); + addBinaryAttribute("pair_equal_head_base"); + + addNumericAttribute("pair_sent_distance"); + addNumericAttribute("pair_par_distance"); addNominalAttribute("score", Lists.newArrayList("bad", "good")); fillSortedAttributes("score"); @@ -53,17 +70,57 @@ public class ZeroFeatureExtractor extends FeatureExtractor { addMentionFeatures(helper, candidateFeatures, mention, "candidate"); addMentionFeatures(helper, candidateFeatures, antecedent, "antecedent"); - candidateFeatures.put(getAttributeByName("pair_equal_orth"), toBinary(helper.getMentionOrth(mention).equalsIgnoreCase(helper.getMentionOrth(antecedent)))); + candidateFeatures.put(getAttributeByName("pair_equal_orth"), toBinary(helper.getMentionOrth(mention).equals(helper.getMentionOrth(antecedent)))); + candidateFeatures.put(getAttributeByName("pair_equal_base"), toBinary(helper.getMentionBase(mention).equalsIgnoreCase(helper.getMentionBase(antecedent)))); + candidateFeatures.put(getAttributeByName("pair_equal_ignore_case_orth"), toBinary(helper.getMentionOrth(mention).equalsIgnoreCase(helper.getMentionOrth(antecedent)))); + candidateFeatures.put(getAttributeByName("pair_equal_head_base"), toBinary(helper.getMentionHeadToken(mention).getChosenInterpretation().getBase().equalsIgnoreCase(helper.getMentionHeadToken(antecedent).getChosenInterpretation().getBase()))); + + candidateFeatures.put(getAttributeByName("pair_sent_distance"), (double) Math.abs(helper.getSentIndex(helper.getMentionSentence(mention)) - helper.getSentIndex(helper.getMentionSentence(antecedent)))); + candidateFeatures.put(getAttributeByName("pair_par_distance"), (double) Math.abs(helper.getParIndex(helper.getMentionParagraph(mention)) - helper.getParIndex(helper.getMentionParagraph(antecedent)))); + + String mentionNumber = getNumber(helper.getMentionHeadToken(mention)); + String antecedentNumber = getNumber(helper.getMentionHeadToken(antecedent)); + candidateFeatures.put(getAttributeByName("pair_equal_number"), toBinary(mentionNumber != null && mentionNumber.equals(antecedentNumber))); + + candidateFeatures.put(getAttributeByName("chain_length"), (double) helper.getChainLength(mention)); return candidateFeatures; } + private String getNumber(TToken token) { + Set<String> msd = Sets.newHashSet(token.getChosenInterpretation().getMsd().split(":")); + if (msd.contains("sg")) + return "sg"; + else if (msd.contains("pl")) + return "pl"; + else + return null; + } + private void addMentionFeatures(FeatureHelper helper, Map<Attribute, Double> candidateFeatures, TMention mention, String attributePrefix) { candidateFeatures.put(getAttributeByName(attributePrefix + "_index_in_sent"), (double) helper.getMentionIndexInSent(mention)); + candidateFeatures.put(getAttributeByName(attributePrefix + "_first_token_index_in_sent"), (double) helper.getMentionFirstTokenIndex(mention)); + candidateFeatures.put(getAttributeByName(attributePrefix + "_token_count"), (double) mention.getChildIdsSize()); - candidateFeatures.put(getAttributeByName(attributePrefix + "_is_zero"), toBinary(mention.isZeroSubject())); - candidateFeatures.put(getAttributeByName(attributePrefix + "_is_pronoun"), toBinary(helper.getMentionHeadToken(mention).getChosenInterpretation().getCtag().matches("ppron.*"))); candidateFeatures.put(getAttributeByName(attributePrefix + "_is_named"), toBinary(helper.isMentionNamedEntity(mention))); + candidateFeatures.put(getAttributeByName(attributePrefix + "_sentence_mention_count"), (double) helper.getMentionSentence(mention).getMentions().size()); + + TToken nextToken = helper.getTokenAfterMention(mention); + addNominalAttributeValue(nextToken == null ? "end" : nextToken.getChosenInterpretation().getCtag(), candidateFeatures, attributePrefix + "_next_token_pos"); + TToken prevToken = helper.getTokenBeforeMention(mention); + addNominalAttributeValue(prevToken == null ? "end" : prevToken.getChosenInterpretation().getCtag(), candidateFeatures, attributePrefix + "_prev_token_pos"); + + candidateFeatures.put(getAttributeByName(attributePrefix + "_is_nested"), toBinary(helper.isNested(mention))); + candidateFeatures.put(getAttributeByName(attributePrefix + "_is_nesting"), toBinary(helper.isNesting(mention))); + } + private void addNominalAttributeValue(String value, Map<Attribute, Double> attribute2value, String attributeName) { + Attribute att = getAttributeByName(attributeName); + int index = att.indexOfValue(value); + if (index == -1) + LOG.warn(value + " not found for attribute " + attributeName); + attribute2value.put(att, (double) (index == -1 ? att.indexOfValue("other") : index)); + } } + diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/ZeroScorer.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroScorer.java index 3eef010..65db887 100644 --- a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/ZeroScorer.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroScorer.java @@ -1,4 +1,4 @@ -package pl.waw.ipipan.zil.summ.nicolas.zero.train; +package pl.waw.ipipan.zil.summ.nicolas.zero; import com.google.common.collect.Maps; import org.apache.commons.csv.CSVFormat; @@ -7,7 +7,6 @@ import org.apache.commons.csv.CSVRecord; import org.apache.commons.csv.QuoteMode; import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; -import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectCandidate; import java.io.IOException; import java.io.InputStream; diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectCandidate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectCandidate.java index 6d0a76f..6d0a76f 100644 --- a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectCandidate.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectCandidate.java diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java index 329f31a..5da90a5 100644 --- a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java @@ -5,7 +5,6 @@ import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.common.Utils; -import pl.waw.ipipan.zil.summ.nicolas.zero.train.TrainingDataExtractor; import weka.classifiers.Classifier; import weka.core.Instance; import weka.core.Instances; @@ -32,7 +31,7 @@ public class ZeroSubjectInjector { Set<String> summarySentenceIds = selectedSentences.stream().map(TSentence::getId).collect(Collectors.toSet()); List<ZeroSubjectCandidate> zeroSubjectCandidates = CandidateFinder.findZeroSubjectCandidates(text, summarySentenceIds); Map<ZeroSubjectCandidate, Instance> candidate2instance = - TrainingDataExtractor.extractInstancesFromZeroCandidates(zeroSubjectCandidates, text, featureExtractor); + PrepareTrainingData.extractInstancesFromZeroCandidates(zeroSubjectCandidates, text, featureExtractor); Set<String> result = Sets.newHashSet(); for (Map.Entry<ZeroSubjectCandidate, Instance> entry : candidate2instance.entrySet()) { diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/test/Crossvalidate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/test/Crossvalidate.java new file mode 100644 index 0000000..252ae6e --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/test/Crossvalidate.java @@ -0,0 +1,31 @@ +package pl.waw.ipipan.zil.summ.nicolas.zero.test; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; +import pl.waw.ipipan.zil.summ.nicolas.eval.EvalUtils; +import weka.core.Instances; +import weka.core.converters.ArffLoader; + +import java.io.File; + + +public class Crossvalidate { + + private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); + + private Crossvalidate() { + } + + public static void main(String[] args) throws Exception { + + ArffLoader loader = new ArffLoader(); + loader.setFile(new File(Constants.ZERO_DATASET_PATH)); + Instances instances = loader.getDataSet(); + instances.setClassIndex(0); + LOG.info(instances.size() + " instances loaded."); + LOG.info(instances.numAttributes() + " attributes for each instance."); + + EvalUtils.crossvalidateClassification(instances); + } +} diff --git a/nicolas-zero/src/main/resources/zeros.tsv b/nicolas-core/src/main/resources/zeros.tsv index 5ac27a6..5ac27a6 100644 --- a/nicolas-zero/src/main/resources/zeros.tsv +++ b/nicolas-core/src/main/resources/zeros.tsv diff --git a/nicolas-zero/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java b/nicolas-core/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java index 4ab4ee2..4ab4ee2 100644 --- a/nicolas-zero/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java +++ b/nicolas-core/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java diff --git a/nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin b/nicolas-core/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin index e30b245..e30b245 100644 --- a/nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin +++ b/nicolas-core/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin diff --git a/nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt b/nicolas-core/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt index 10ac642..10ac642 100644 --- a/nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt +++ b/nicolas-core/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt diff --git a/nicolas-train/src/test/java/pl/waw/ipipan/zil/summ/nicolas/train/multiservice/NLPProcessTest.java b/nicolas-train/src/test/java/pl/waw/ipipan/zil/summ/nicolas/train/multiservice/NLPProcessTest.java new file mode 100644 index 0000000..018c352 --- /dev/null +++ b/nicolas-train/src/test/java/pl/waw/ipipan/zil/summ/nicolas/train/multiservice/NLPProcessTest.java @@ -0,0 +1,17 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.multiservice; + +import org.junit.Test; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; + +import java.io.File; + +public class NLPProcessTest { + @Test + public void shouldProcessSampleText() throws Exception { + String text = "Ala ma kota. Ala ma też psa."; + TText processed = NLPProcess.annotate(text); + processed.getParagraphs().stream().flatMap(p->p.getSentences().stream()).forEach(s->System.out.println(s.getId())); + File targetFile = new File("sample_serialized_text.bin"); + NLPProcess.serialize(processed, targetFile); + } +} \ No newline at end of file diff --git a/nicolas-zero/pom.xml b/nicolas-zero/pom.xml deleted file mode 100644 index 666e517..0000000 --- a/nicolas-zero/pom.xml +++ /dev/null @@ -1,48 +0,0 @@ -<?xml version="1.0" encoding="UTF-8"?> -<project xmlns="http://maven.apache.org/POM/4.0.0" - xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" - xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> - <modelVersion>4.0.0</modelVersion> - <parent> - <artifactId>nicolas-container</artifactId> - <groupId>pl.waw.ipipan.zil.summ</groupId> - <version>1.0-SNAPSHOT</version> - </parent> - - <artifactId>nicolas-zero</artifactId> - - <dependencies> - <!-- project --> - <dependency> - <groupId>pl.waw.ipipan.zil.summ</groupId> - <artifactId>nicolas-common</artifactId> - </dependency> - - <!-- third party --> - <dependency> - <groupId>org.apache.commons</groupId> - <artifactId>commons-csv</artifactId> - </dependency> - <dependency> - <groupId>commons-io</groupId> - <artifactId>commons-io</artifactId> - </dependency> - <dependency> - <groupId>org.apache.commons</groupId> - <artifactId>commons-lang3</artifactId> - </dependency> - - <!-- logging --> - <dependency> - <groupId>org.slf4j</groupId> - <artifactId>slf4j-api</artifactId> - </dependency> - - <!-- test --> - <dependency> - <groupId>junit</groupId> - <artifactId>junit</artifactId> - </dependency> - </dependencies> - -</project> \ No newline at end of file diff --git a/pom.xml b/pom.xml index bbdbd9b..81c53ae 100644 --- a/pom.xml +++ b/pom.xml @@ -15,7 +15,6 @@ <module>nicolas-cli</module> <module>nicolas-model</module> <module>nicolas-train</module> - <module>nicolas-zero</module> <module>nicolas-common</module> </modules>