diff --git a/.gitignore b/.gitignore index 28f546a..ed83e22 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,4 @@ target/ hs_err_pid* .idea -*.iml \ No newline at end of file +*.iml diff --git a/nicolas-common/pom.xml b/nicolas-common/pom.xml index 6dbb4fe..62a9c6c 100644 --- a/nicolas-common/pom.xml +++ b/nicolas-common/pom.xml @@ -27,6 +27,10 @@ <groupId>nz.ac.waikato.cms.weka</groupId> <artifactId>weka-dev</artifactId> </dependency> + <dependency> + <groupId>commons-io</groupId> + <artifactId>commons-io</artifactId> + </dependency> <!-- logging --> <dependency> diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java index 603608d..4d2ab97 100644 --- a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java @@ -2,26 +2,21 @@ package pl.waw.ipipan.zil.summ.nicolas.common; import com.google.common.base.Charsets; import com.google.common.collect.ImmutableList; -import weka.classifiers.Classifier; -import weka.classifiers.functions.SMO; -import weka.classifiers.meta.AdaBoostM1; -import weka.classifiers.meta.AttributeSelectedClassifier; -import weka.classifiers.rules.JRip; -import weka.classifiers.trees.J48; -import weka.classifiers.trees.RandomForest; import java.nio.charset.Charset; public class Constants { - public static final String MENTIONS_MODEL_PATH = "mentions_model.bin"; - public static final String SENTENCES_MODEL_PATH = "sentences_model.bin"; - public static final String ZERO_MODEL_PATH = "zeros_model.bin"; + private static final String ROOT_PATH = "/pl/waw/ipipan/zil/summ/nicolas/"; - public static final String MENTIONS_DATASET_PATH = "mentions_train.arff"; - public static final String SENTENCES_DATASET_PATH = "sentences_train.arff"; - public static final String ZERO_DATASET_PATH = "zeros_train.arff"; + private static final String MODELS_PATH = ROOT_PATH + "models/"; + public static final String MENTION_MODEL_RESOURCE_PATH = MODELS_PATH + "mention_model.bin"; + public static final String SENTENCE_MODEL_RESOURCE_PATH = MODELS_PATH + "sentence_model.bin"; + public static final String ZERO_MODEL_RESOURCE_PATH = MODELS_PATH + "zero_model.bin"; + + private static final String RESOURCES_PATH = ROOT_PATH + "resources/"; + public static final String FREQUENT_BASES_RESOURCE_PATH = RESOURCES_PATH + "frequent_bases.txt"; public static final Charset ENCODING = Charsets.UTF_8; @@ -30,24 +25,4 @@ public class Constants { private Constants() { } - public static Classifier getMentionClassifier() { - RandomForest classifier = new RandomForest(); - classifier.setNumIterations(250); - classifier.setSeed(0); - classifier.setNumExecutionSlots(8); - return classifier; - } - - public static Classifier getSentencesClassifier() { - RandomForest classifier = new RandomForest(); - classifier.setNumIterations(10); - classifier.setSeed(0); - classifier.setNumExecutionSlots(8); - return classifier; - } - - public static Classifier getZerosClassifier() { - Classifier classifier = new J48(); - return classifier; - } } diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java index 4c2b173..5524abc 100644 --- a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java @@ -3,6 +3,7 @@ package pl.waw.ipipan.zil.summ.nicolas.common; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; +import org.apache.commons.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; @@ -24,6 +25,47 @@ public class Utils { private static final String DATASET_NAME = "Dataset"; + private Utils() { + } + + public static Classifier loadModelFromResource(String modelResourcePath) throws IOException { + LOG.info("Loading classifier from path: {}...", modelResourcePath); + try (InputStream stream = Utils.class.getResourceAsStream(modelResourcePath)) { + if (stream == null) { + throw new IOException("Model not found at: " + modelResourcePath); + } + try (ObjectInputStream ois = new ObjectInputStream(stream)) { + Classifier classifier = (Classifier) ois.readObject(); + LOG.info("Done. Loaded classifier: {}", classifier.getClass().getSimpleName()); + return classifier; + } catch (ClassNotFoundException e) { + LOG.error("Error loading serialized classifier, class not found.", e); + throw new IOException(e); + } + } + } + + public static TText loadThriftTextFromResource(String textResourcePath) throws IOException { + try (InputStream stream = Utils.class.getResourceAsStream(textResourcePath)) { + if (stream == null) { + throw new IOException("Resource not found at: " + textResourcePath); + } + try (VersionIgnoringObjectInputStream ois = new VersionIgnoringObjectInputStream(stream)) { + return (TText) ois.readObject(); + } catch (ClassNotFoundException e) { + LOG.error("Error reading serialized thrift text file, class not found.", e); + throw new IOException(e); + } + } + } + + public static List<String> loadLinesFromResource(String resourcePath) throws IOException { + try (InputStream stream = Utils.class.getResourceAsStream(resourcePath)) { + return IOUtils.readLines(stream, Constants.ENCODING); + } + } + + @SuppressWarnings("squid:S1319") //weka requires explicit ArrayList public static Instances createNewInstances(ArrayList<Attribute> attributesList) { Instances instances = new Instances(DATASET_NAME, attributesList, 0); instances.setClassIndex(0); diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/VersionIgnoringObjectInputStream.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/VersionIgnoringObjectInputStream.java index 9d03cd8..fbbb2a9 100644 --- a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/VersionIgnoringObjectInputStream.java +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/VersionIgnoringObjectInputStream.java @@ -8,10 +8,12 @@ import java.io.ObjectStreamClass; public class VersionIgnoringObjectInputStream extends ObjectInputStream { - public VersionIgnoringObjectInputStream(InputStream in) throws IOException { + VersionIgnoringObjectInputStream(InputStream in) throws IOException { super(in); } + @Override + @SuppressWarnings("squid:S1166") protected ObjectStreamClass readClassDescriptor() throws IOException, ClassNotFoundException { ObjectStreamClass resultClassDescriptor = super.readClassDescriptor(); // initially streams descriptor Class localClass; // the class in the local JVM that this descriptor represents. diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureExtractor.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureExtractor.java index 3c80046..8b1a6a9 100644 --- a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureExtractor.java +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureExtractor.java @@ -17,6 +17,7 @@ public class FeatureExtractor { private final Set<String> normalizedAttributes = Sets.newHashSet(); + @SuppressWarnings("squid:S1319") //weka requires explicit ArrayList public ArrayList<Attribute> getAttributesList() { return Lists.newArrayList(sortedAttributes); } @@ -46,15 +47,14 @@ public class FeatureExtractor { protected void fillSortedAttributes(String scoreAttName) { sortedAttributes.addAll(name2attribute.values()); sortedAttributes.remove(getAttributeByName(scoreAttName)); - Collections.sort(sortedAttributes, (o1, o2) -> name2attribute.inverse().get(o1).compareTo(name2attribute.inverse().get(o2))); + sortedAttributes.sort(Comparator.comparing(name2attribute.inverse()::get)); sortedAttributes.add(0, getAttributeByName(scoreAttName)); } protected <T> void addNormalizedAttributeValues(Map<T, Map<Attribute, Double>> entity2attributes) { Map<Attribute, Double> attribute2max = Maps.newHashMap(); Map<Attribute, Double> attribute2min = Maps.newHashMap(); - for (T entity : entity2attributes.keySet()) { - Map<Attribute, Double> entityAttributes = entity2attributes.get(entity); + for (Map<Attribute, Double> entityAttributes : entity2attributes.values()) { for (String attributeName : normalizedAttributes) { Attribute attribute = getAttributeByName(attributeName); Double value = entityAttributes.get(attribute); @@ -66,8 +66,7 @@ public class FeatureExtractor { attribute2min.compute(attribute, (k, v) -> Math.min(v, value)); } } - for (T mention : entity2attributes.keySet()) { - Map<Attribute, Double> entityAttributes = entity2attributes.get(mention); + for (Map<Attribute, Double> entityAttributes : entity2attributes.values()) { for (Attribute attribute : attribute2max.keySet()) { Attribute normalizedAttribute = getAttributeByName(name2attribute.inverse().get(attribute) + "_normalized"); entityAttributes.put(normalizedAttribute, diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java index 6d7d07f..55f3941 100644 --- a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java @@ -174,11 +174,11 @@ public class FeatureHelper { } public boolean isNested(TMention mention) { - return mentions.stream().anyMatch(m -> m.getChildIds().containsAll(mention.getChildIds())); + return mentions.stream().anyMatch(m -> !m.equals(mention) && m.getChildIds().containsAll(mention.getChildIds())); } public boolean isNesting(TMention mention) { - return mentions.stream().anyMatch(m -> mention.getChildIds().containsAll(m.getChildIds())); + return mentions.stream().anyMatch(m -> !m.equals(mention) && mention.getChildIds().containsAll(m.getChildIds())); } public Set<TCoreference> getClusters() { diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/Interpretation.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/Interpretation.java index 3ed81d8..bfff430 100644 --- a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/Interpretation.java +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/Interpretation.java @@ -33,6 +33,7 @@ public class Interpretation { person = split[3]; break; case "siebie": + case "prep": casee = split[0]; break; case "fin": @@ -47,9 +48,6 @@ public class Interpretation { number = split[0]; gender = split[1]; break; - case "prep": - casee = split[0]; - break; default: break; } diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/TrainModel.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/TrainModel.java deleted file mode 100644 index 1372c06..0000000 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/TrainModel.java +++ /dev/null @@ -1,47 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.mention; - -import org.apache.commons.lang3.time.StopWatch; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; -import weka.classifiers.Classifier; -import weka.core.Instances; -import weka.core.converters.ArffLoader; - -import java.io.File; -import java.io.FileOutputStream; -import java.io.ObjectOutputStream; - - -public class TrainModel { - private static final Logger LOG = LoggerFactory.getLogger(TrainModel.class); - - public static void main(String[] args) throws Exception { - - ArffLoader loader = new ArffLoader(); - loader.setFile(new File(Constants.MENTIONS_DATASET_PATH)); - Instances instances = loader.getDataSet(); - instances.setClassIndex(0); - LOG.info(instances.size() + " instances loaded."); - LOG.info(instances.numAttributes() + " attributes for each instance."); - - StopWatch watch = new StopWatch(); - watch.start(); - - Classifier classifier = Constants.getMentionClassifier(); - - LOG.info("Building classifier..."); - classifier.buildClassifier(instances); - LOG.info("...done."); - - try (ObjectOutputStream oos = new ObjectOutputStream( - new FileOutputStream(Constants.MENTIONS_MODEL_PATH))) { - oos.writeObject(classifier); - } - - watch.stop(); - LOG.info("Elapsed time: " + watch); - - LOG.info(classifier.toString()); - } -} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java deleted file mode 100644 index 1dece76..0000000 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java +++ /dev/null @@ -1,30 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.mention.test; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; -import pl.waw.ipipan.zil.summ.nicolas.eval.EvalUtils; -import weka.core.Instances; -import weka.core.converters.ArffLoader; - -import java.io.File; - - -public class Crossvalidate { - - private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); - - private Crossvalidate() { - } - - public static void main(String[] args) throws Exception { - ArffLoader loader = new ArffLoader(); - loader.setFile(new File(Constants.MENTIONS_DATASET_PATH)); - Instances instances = loader.getDataSet(); - instances.setClassIndex(0); - LOG.info(instances.size() + " instances loaded."); - LOG.info(instances.numAttributes() + " attributes for each instance."); - - EvalUtils.crossvalidateClassification(instances); - } -} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java deleted file mode 100644 index 8b60024..0000000 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java +++ /dev/null @@ -1,52 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.mention.test; - -import org.apache.commons.lang3.time.StopWatch; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; -import weka.classifiers.Classifier; -import weka.classifiers.evaluation.Evaluation; -import weka.core.Instances; -import weka.core.converters.ArffLoader; - -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.ObjectInputStream; - - -public class Validate { - private static final Logger LOG = LoggerFactory.getLogger(Validate.class); - - public static void main(String[] args) throws Exception { - - ArffLoader loader = new ArffLoader(); - loader.setFile(new File(Constants.MENTIONS_DATASET_PATH)); - Instances instances = loader.getDataSet(); - instances.setClassIndex(0); - LOG.info(instances.size() + " instances loaded."); - LOG.info(instances.numAttributes() + " attributes for each instance."); - - Classifier classifier = loadClassifier(); - - StopWatch watch = new StopWatch(); - watch.start(); - - Evaluation eval = new Evaluation(instances); - eval.evaluateModel(classifier, instances); - - LOG.info(eval.toSummaryString()); - - watch.stop(); - LOG.info("Elapsed time: " + watch); - } - - private static Classifier loadClassifier() throws IOException, ClassNotFoundException { - LOG.info("Loading classifier..."); - try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(Constants.MENTIONS_MODEL_PATH))) { - Classifier classifier = (Classifier) ois.readObject(); - LOG.info("Done. " + classifier.toString()); - return classifier; - } - } -} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/TrainModel.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/TrainModel.java deleted file mode 100644 index 8b3741c..0000000 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/TrainModel.java +++ /dev/null @@ -1,47 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.sentence; - -import org.apache.commons.lang3.time.StopWatch; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; -import weka.classifiers.Classifier; -import weka.core.Instances; -import weka.core.converters.ArffLoader; - -import java.io.File; -import java.io.FileOutputStream; -import java.io.ObjectOutputStream; - - -public class TrainModel { - private static final Logger LOG = LoggerFactory.getLogger(TrainModel.class); - - public static void main(String[] args) throws Exception { - - ArffLoader loader = new ArffLoader(); - loader.setFile(new File(Constants.SENTENCES_DATASET_PATH)); - Instances instances = loader.getDataSet(); - instances.setClassIndex(0); - LOG.info(instances.size() + " instances loaded."); - LOG.info(instances.numAttributes() + " attributes for each instance."); - - StopWatch watch = new StopWatch(); - watch.start(); - - Classifier classifier = Constants.getSentencesClassifier(); - - LOG.info("Building classifier..."); - classifier.buildClassifier(instances); - LOG.info("...done."); - - try (ObjectOutputStream oos = new ObjectOutputStream( - new FileOutputStream(Constants.SENTENCES_MODEL_PATH))) { - oos.writeObject(classifier); - } - - watch.stop(); - LOG.info("Elapsed time: " + watch); - - LOG.info(classifier.toString()); - } -} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java deleted file mode 100644 index 09cc621..0000000 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java +++ /dev/null @@ -1,31 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.sentence.test; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; -import pl.waw.ipipan.zil.summ.nicolas.eval.EvalUtils; -import weka.core.Instances; -import weka.core.converters.ArffLoader; - -import java.io.File; - - -public class Crossvalidate { - - private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); - - private Crossvalidate() { - } - - public static void main(String[] args) throws Exception { - - ArffLoader loader = new ArffLoader(); - loader.setFile(new File(Constants.SENTENCES_DATASET_PATH)); - Instances instances = loader.getDataSet(); - instances.setClassIndex(0); - LOG.info(instances.size() + " instances loaded."); - LOG.info(instances.numAttributes() + " attributes for each instance."); - - EvalUtils.crossvalidateRegression(instances); - } -} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/test/Crossvalidate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/test/Crossvalidate.java deleted file mode 100644 index 252ae6e..0000000 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/test/Crossvalidate.java +++ /dev/null @@ -1,31 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.zero.test; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; -import pl.waw.ipipan.zil.summ.nicolas.eval.EvalUtils; -import weka.core.Instances; -import weka.core.converters.ArffLoader; - -import java.io.File; - - -public class Crossvalidate { - - private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); - - private Crossvalidate() { - } - - public static void main(String[] args) throws Exception { - - ArffLoader loader = new ArffLoader(); - loader.setFile(new File(Constants.ZERO_DATASET_PATH)); - Instances instances = loader.getDataSet(); - instances.setClassIndex(0); - LOG.info(instances.size() + " instances loaded."); - LOG.info(instances.numAttributes() + " attributes for each instance."); - - EvalUtils.crossvalidateClassification(instances); - } -} diff --git a/nicolas-core/pom.xml b/nicolas-lib/pom.xml index 4557c3e..5f80d90 100644 --- a/nicolas-core/pom.xml +++ b/nicolas-lib/pom.xml @@ -9,7 +9,7 @@ <version>1.0-SNAPSHOT</version> </parent> - <artifactId>nicolas</artifactId> + <artifactId>nicolas-lib</artifactId> <dependencies> <!-- project --> diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java index 8003c5a..3b5b55a 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java @@ -11,6 +11,7 @@ import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor; import pl.waw.ipipan.zil.summ.nicolas.mention.MentionModel; import pl.waw.ipipan.zil.summ.nicolas.sentence.SentenceFeatureExtractor; import pl.waw.ipipan.zil.summ.nicolas.sentence.SentenceModel; +import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroFeatureExtractor; import weka.classifiers.Classifier; import java.io.IOException; @@ -20,22 +21,27 @@ import static java.util.stream.Collectors.toList; public class Nicolas { - private final Classifier sentenceClassifier; - private final Classifier mentionClassifier; - private final MentionFeatureExtractor featureExtractor; + private final Classifier mentionModel; + private final Classifier sentenceModel; + private final Classifier zeroModel; + + private final MentionFeatureExtractor mentionFeatureExtractor; private final SentenceFeatureExtractor sentenceFeatureExtractor; + private final ZeroFeatureExtractor zeroFeatureExtractor; public Nicolas() throws IOException, ClassNotFoundException { - mentionClassifier = Utils.loadClassifier(Constants.MENTIONS_MODEL_PATH); - featureExtractor = new MentionFeatureExtractor(); + mentionModel = Utils.loadModelFromResource(Constants.MENTION_MODEL_RESOURCE_PATH); + sentenceModel = Utils.loadModelFromResource(Constants.SENTENCE_MODEL_RESOURCE_PATH); + zeroModel = Utils.loadModelFromResource(Constants.ZERO_MODEL_RESOURCE_PATH); - sentenceClassifier = Utils.loadClassifier(Constants.SENTENCES_MODEL_PATH); + mentionFeatureExtractor = new MentionFeatureExtractor(); sentenceFeatureExtractor = new SentenceFeatureExtractor(); + zeroFeatureExtractor = new ZeroFeatureExtractor(); } public String summarizeThrift(TText text, int targetTokenCount) throws Exception { Set<TMention> goodMentions - = MentionModel.detectGoodMentions(mentionClassifier, featureExtractor, text); + = MentionModel.detectGoodMentions(mentionModel, mentionFeatureExtractor, text); return calculateSummary(text, goodMentions, targetTokenCount); } @@ -52,10 +58,10 @@ public class Nicolas { private List<TSentence> selectSummarySentences(TText thrifted, Set<TMention> goodMentions, int targetSize) throws Exception { List<TSentence> sents = thrifted.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList()); - Map<TSentence, Double> sentence2score = SentenceModel.scoreSentences(thrifted, goodMentions, sentenceClassifier, sentenceFeatureExtractor); + Map<TSentence, Double> sentence2score = SentenceModel.scoreSentences(thrifted, goodMentions, sentenceModel, sentenceFeatureExtractor); List<TSentence> sortedSents = Lists.newArrayList(sents); - Collections.sort(sortedSents, Comparator.comparing(sentence2score::get).reversed()); + sortedSents.sort(Comparator.comparing(sentence2score::get).reversed()); int size = 0; Random r = new Random(1); diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/ThriftUtils.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/ThriftUtils.java index 785fad7..9b56c74 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/ThriftUtils.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/ThriftUtils.java @@ -1,22 +1,17 @@ package pl.waw.ipipan.zil.summ.nicolas; -import com.google.common.base.Charsets; import com.google.common.collect.Maps; -import com.google.common.io.Files; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor; -import pl.waw.ipipan.zil.summ.nicolas.mention.MentionScorer; import pl.waw.ipipan.zil.summ.nicolas.sentence.SentenceFeatureExtractor; import weka.core.Attribute; import weka.core.DenseInstance; import weka.core.Instance; -import java.io.File; -import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Set; @@ -30,16 +25,6 @@ public class ThriftUtils { private ThriftUtils() { } - public static Set<TMention> loadGoldGoodMentions(String id, TText text, boolean dev) throws IOException { - String optimalSummary = Files.toString(new File("src/main/resources/optimal_summaries/" + (dev ? "dev" : "test") + "/" + id + "_theoretic_ub_rouge_1.txt"), Charsets.UTF_8); - - MentionScorer scorer = new MentionScorer(); - Map<TMention, Double> mention2score = scorer.calculateMentionScores(optimalSummary, text); - - mention2score.keySet().removeIf(tMention -> mention2score.get(tMention) != 1.0); - return mention2score.keySet(); - } - public static Map<TMention, Instance> extractInstancesFromMentions(TText preprocessedText, MentionFeatureExtractor featureExtractor) { List<TSentence> sentences = preprocessedText.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList()); Map<TMention, Map<Attribute, Double>> mention2features = featureExtractor.calculateFeatures(preprocessedText); diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel2.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel.java index 4554ccc..43c34c7 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel2.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel.java @@ -26,18 +26,18 @@ import java.util.*; import static java.util.stream.Collectors.toList; -public class ApplyModel2 { +public class ApplyModel { - private static final Logger LOG = LoggerFactory.getLogger(ApplyModel2.class); + private static final Logger LOG = LoggerFactory.getLogger(ApplyModel.class); private static final String TEST_PREPROCESSED_DATA_PATH = "corpora/preprocessed_full_texts/test"; private static final String TARGET_DIR = "corpora/summaries"; public static void main(String[] args) throws Exception { - Classifier mentionClassifier = Utils.loadClassifier(Constants.MENTIONS_MODEL_PATH); + Classifier mentionClassifier = Utils.loadClassifier(Constants.MENTION_MODEL_RESOURCE_PATH); MentionFeatureExtractor featureExtractor = new MentionFeatureExtractor(); - Classifier sentenceClassifier = Utils.loadClassifier(Constants.SENTENCES_MODEL_PATH); + Classifier sentenceClassifier = Utils.loadClassifier(Constants.SENTENCE_MODEL_RESOURCE_PATH); SentenceFeatureExtractor sentenceFeatureExtractor = new SentenceFeatureExtractor(); ZeroSubjectInjector zeroSubjectInjector = new ZeroSubjectInjector(); @@ -102,7 +102,7 @@ public class ApplyModel2 { } List<TSentence> sortedSents = Lists.newArrayList(sents); - Collections.sort(sortedSents, Comparator.comparing(sentence2score::get).reversed()); + sortedSents.sort(Comparator.comparing(sentence2score::get).reversed()); int size = 0; Random r = new Random(1); diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java index ec671aa..cada060 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java @@ -1,26 +1,27 @@ package pl.waw.ipipan.zil.summ.nicolas.mention; -import com.google.common.collect.*; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import pl.waw.ipipan.zil.multiservice.thrift.types.*; import pl.waw.ipipan.zil.summ.nicolas.common.Constants; +import pl.waw.ipipan.zil.summ.nicolas.common.Utils; import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureExtractor; import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; import pl.waw.ipipan.zil.summ.nicolas.common.features.Interpretation; import weka.core.Attribute; -import java.io.File; import java.io.IOException; -import java.nio.file.Files; -import java.util.*; +import java.util.List; +import java.util.Map; import java.util.stream.Collectors; -import java.util.stream.Stream; public class MentionFeatureExtractor extends FeatureExtractor { - private final List<String> frequentBases = Lists.newArrayList(); + private final List<String> frequentBases; - public MentionFeatureExtractor() { + public MentionFeatureExtractor() throws IOException { + frequentBases = loadFrequentBases(); //coref addNumericAttributeNormalized("chain_length"); @@ -70,7 +71,6 @@ public class MentionFeatureExtractor extends FeatureExtractor { addBinaryAttribute(prefix + "_sent_ends_with_questionmark"); // frequent bases - loadFrequentBases(); for (String base : frequentBases) { addBinaryAttribute(prefix + "_" + encodeBase(base)); } @@ -80,17 +80,12 @@ public class MentionFeatureExtractor extends FeatureExtractor { fillSortedAttributes("score"); } - private String encodeBase(String base) { - return "base_equal_" + base.replaceAll(" ", "_").replaceAll("\"", "Q"); + private List<String> loadFrequentBases() throws IOException { + return Utils.loadLinesFromResource(Constants.FREQUENT_BASES_RESOURCE_PATH).stream().map(String::trim).sorted().distinct().collect(Collectors.toList()); } - private void loadFrequentBases() { - try { - Stream<String> lines = Files.lines(new File("frequent_bases.txt").toPath()); - this.frequentBases.addAll(lines.map(String::trim).collect(Collectors.toList())); - } catch (IOException e) { - e.printStackTrace(); - } + private String encodeBase(String base) { + return "base_equal_" + base.replaceAll(" ", "_").replaceAll("\"", "Q"); } public Map<TMention, Map<Attribute, Double>> calculateFeatures(TText preprocessedText) { @@ -123,8 +118,6 @@ public class MentionFeatureExtractor extends FeatureExtractor { attribute2value.put(getAttributeByName("text_par_count"), (double) pars.size()); attribute2value.put(getAttributeByName("text_mention_count"), (double) helper.getMentions().size()); attribute2value.put(getAttributeByName("text_cluster_count"), (double) helper.getClusters().size()); - - assert (attribute2value.size() == getAttributesList().size()); } addNormalizedAttributeValues(result); diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionModel.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionModel.java index 3f65c48..3f65c48 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionModel.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionModel.java diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java index 3da019e..c8db783 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java @@ -87,7 +87,6 @@ public class SentenceFeatureExtractor extends FeatureExtractor { feature2value.put(getAttributeByName("score"), weka.core.Utils.missingValue()); feature2value.remove(null); - assert (feature2value.size() == getAttributesList().size()); sentence2features.put(sentence, feature2value); diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceModel.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceModel.java index c9a43d0..c9a43d0 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceModel.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceModel.java diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java index f862b31..f862b31 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java diff --git a/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/InstanceCreator.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/InstanceCreator.java new file mode 100644 index 0000000..8873735 --- /dev/null +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/InstanceCreator.java @@ -0,0 +1,31 @@ +package pl.waw.ipipan.zil.summ.nicolas.zero; + +import com.google.common.collect.Maps; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import weka.core.Attribute; +import weka.core.DenseInstance; +import weka.core.Instance; + +import java.util.List; +import java.util.Map; + +public class InstanceCreator { + + private InstanceCreator() { + } + + public static Map<ZeroSubjectCandidate, Instance> extractInstancesFromZeroCandidates(List<ZeroSubjectCandidate> candidates, TText text, ZeroFeatureExtractor featureExtractor) { + Map<ZeroSubjectCandidate, Map<Attribute, Double>> candidate2features = featureExtractor.calculateFeatures(candidates, text); + Map<ZeroSubjectCandidate, Instance> candidate2instance = Maps.newHashMap(); + for (Map.Entry<ZeroSubjectCandidate, Map<Attribute, Double>> entry : candidate2features.entrySet()) { + Instance instance = new DenseInstance(featureExtractor.getAttributesList().size()); + Map<Attribute, Double> sentenceFeatures = entry.getValue(); + for (Attribute attribute : featureExtractor.getAttributesList()) { + instance.setValue(attribute, sentenceFeatures.get(attribute)); + } + candidate2instance.put(entry.getKey(), instance); + } + return candidate2instance; + } + +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java index 8111368..d57879d 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java @@ -4,6 +4,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; +import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; import pl.waw.ipipan.zil.multiservice.thrift.types.TToken; import pl.waw.ipipan.zil.summ.nicolas.common.Constants; @@ -18,18 +19,56 @@ import java.util.Set; public class ZeroFeatureExtractor extends FeatureExtractor { + private static final String SCORE = "score"; + + private static final String ANTECEDENT_PREFIX = "antecedent"; + private static final String CANDIDATE_PREFIX = "candidate"; + + private static final String SENTENCE_ENDS_WITH_QUESTION_MARK = "_sentence_ends_with_question_mark"; + private static final String IS_NAMED = "_is_named"; + private static final String TOKEN_COUNT = "_token_count"; + private static final String FIRST_TOKEN_INDEX_IN_SENT = "_first_token_index_in_sent"; + private static final String INDEX_IN_SENT = "_index_in_sent"; + private static final String PREV_TOKEN_POS = "_prev_token_pos"; + private static final String NEXT_TOKEN_POS = "_next_token_pos"; + private static final String IS_NESTING = "_is_nesting"; + private static final String IS_NESTED = "_is_nested"; + private static final String SENTENCE_MENTION_COUNT = "_sentence_mention_count"; + private static final String SENTENCE_TOKEN_LENGTH = "_sentence_token_length"; + private static final String IS_PAN_OR_PANI = "_is_pan_or_pani"; + + // private static final Set<String> PREV_TOKEN_LEMMAS = Sets.newHashSet( +// "zespół", "tylko", "gdy", ".", ":", "też", "kandydat", "do", "dziś", "bo", "by", "z", "a", "jednak", "jak", "który", "ale", "czy", "i", "się", "rok", "-", "\"", "to", "być", "że", ","); + private static final Set<String> PREV_TOKEN_LEMMAS = Sets.newHashSet("to", "z", "do", "o", "czyli", "nie", "\"", "też", "jak", "czy"); + + private static final Set<String> NEXT_TOKEN_LEMMAS = Sets.newHashSet(); +// private static final Set<String> NEXT_TOKEN_LEMMAS = Sets.newHashSet( +// "mówić", "ii", "twierdzić", "już", "(", "budzić", "stanowić", "powinien", "do", "stać", "musieć", "stanąć", "móc", "o", "chcieć", "się", "-", "zostać", ":", "?", "i", "na", "z", "mieć", "\"", "to", "w", "nie", "być", ".", ","); + + private static final String PREV_TOKEN_LEMMA = "_prev_token_lemma_equal_"; + private static final String NEXT_TOKEN_LEMMA = "_next_token_lemma_equal_"; + public ZeroFeatureExtractor() { - for (String prefix : new String[]{"antecedent", "candidate"}) { - addNumericAttribute(prefix + "_index_in_sent"); - addNumericAttribute(prefix + "_first_token_index_in_sent"); - addNumericAttribute(prefix + "_token_count"); - addBinaryAttribute(prefix + "_is_named"); - addNumericAttribute(prefix + "_sentence_mention_count"); - addNominalAttribute(prefix + "_next_token_pos", Constants.POS_TAGS); - addNominalAttribute(prefix + "_prev_token_pos", Constants.POS_TAGS); - addBinaryAttribute(prefix + "_is_nested"); - addBinaryAttribute(prefix + "_is_nesting"); + for (String prefix : new String[]{ANTECEDENT_PREFIX, CANDIDATE_PREFIX}) { + addNumericAttribute(prefix + INDEX_IN_SENT); + addNumericAttribute(prefix + FIRST_TOKEN_INDEX_IN_SENT); + addNumericAttribute(prefix + TOKEN_COUNT); + addBinaryAttribute(prefix + IS_NAMED); + addBinaryAttribute(prefix + IS_PAN_OR_PANI); + addNominalAttribute(prefix + NEXT_TOKEN_POS, Constants.POS_TAGS); + addNominalAttribute(prefix + PREV_TOKEN_POS, Constants.POS_TAGS); + for (String prevLemma : PREV_TOKEN_LEMMAS) { + addBinaryAttribute(prefix + PREV_TOKEN_LEMMA + prevLemma); + } + for (String nextLemma : NEXT_TOKEN_LEMMAS) { + addBinaryAttribute(prefix + NEXT_TOKEN_LEMMA + nextLemma); + } + addBinaryAttribute(prefix + IS_NESTED); + addBinaryAttribute(prefix + IS_NESTING); + addNumericAttribute(prefix + SENTENCE_MENTION_COUNT); + addNumericAttribute(prefix + SENTENCE_TOKEN_LENGTH); + addBinaryAttribute(prefix + SENTENCE_ENDS_WITH_QUESTION_MARK); } addNumericAttribute("chain_length"); @@ -43,8 +82,8 @@ public class ZeroFeatureExtractor extends FeatureExtractor { addNumericAttribute("pair_sent_distance"); addNumericAttribute("pair_par_distance"); - addNominalAttribute("score", Lists.newArrayList("bad", "good")); - fillSortedAttributes("score"); + addNominalAttribute(SCORE, Lists.newArrayList("bad", "good")); + fillSortedAttributes(SCORE); } public Map<ZeroSubjectCandidate, Map<Attribute, Double>> calculateFeatures(List<ZeroSubjectCandidate> candidates, TText text) { @@ -62,13 +101,13 @@ public class ZeroFeatureExtractor extends FeatureExtractor { private Map<Attribute, Double> calculateFeatures(ZeroSubjectCandidate candidate, FeatureHelper helper) { Map<Attribute, Double> candidateFeatures = Maps.newHashMap(); - candidateFeatures.put(getAttributeByName("score"), weka.core.Utils.missingValue()); + candidateFeatures.put(getAttributeByName(SCORE), weka.core.Utils.missingValue()); TMention mention = candidate.getZeroCandidateMention(); TMention antecedent = candidate.getPreviousSentence().getMentions().stream().filter(ante -> helper.getCoreferentMentions(mention).contains(ante)).findFirst().get(); - addMentionFeatures(helper, candidateFeatures, mention, "candidate"); - addMentionFeatures(helper, candidateFeatures, antecedent, "antecedent"); + addMentionFeatures(helper, candidateFeatures, mention, CANDIDATE_PREFIX); + addMentionFeatures(helper, candidateFeatures, antecedent, ANTECEDENT_PREFIX); candidateFeatures.put(getAttributeByName("pair_equal_orth"), toBinary(helper.getMentionOrth(mention).equals(helper.getMentionOrth(antecedent)))); candidateFeatures.put(getAttributeByName("pair_equal_base"), toBinary(helper.getMentionBase(mention).equalsIgnoreCase(helper.getMentionBase(antecedent)))); @@ -98,28 +137,41 @@ public class ZeroFeatureExtractor extends FeatureExtractor { } private void addMentionFeatures(FeatureHelper helper, Map<Attribute, Double> candidateFeatures, TMention mention, String attributePrefix) { - candidateFeatures.put(getAttributeByName(attributePrefix + "_index_in_sent"), (double) helper.getMentionIndexInSent(mention)); - candidateFeatures.put(getAttributeByName(attributePrefix + "_first_token_index_in_sent"), (double) helper.getMentionFirstTokenIndex(mention)); + candidateFeatures.put(getAttributeByName(attributePrefix + INDEX_IN_SENT), (double) helper.getMentionIndexInSent(mention)); + candidateFeatures.put(getAttributeByName(attributePrefix + FIRST_TOKEN_INDEX_IN_SENT), (double) helper.getMentionFirstTokenIndex(mention)); - candidateFeatures.put(getAttributeByName(attributePrefix + "_token_count"), (double) mention.getChildIdsSize()); - candidateFeatures.put(getAttributeByName(attributePrefix + "_is_named"), toBinary(helper.isMentionNamedEntity(mention))); - candidateFeatures.put(getAttributeByName(attributePrefix + "_sentence_mention_count"), (double) helper.getMentionSentence(mention).getMentions().size()); + candidateFeatures.put(getAttributeByName(attributePrefix + TOKEN_COUNT), (double) mention.getChildIdsSize()); + candidateFeatures.put(getAttributeByName(attributePrefix + IS_NAMED), toBinary(helper.isMentionNamedEntity(mention))); + candidateFeatures.put(getAttributeByName(attributePrefix + IS_PAN_OR_PANI), toBinary(helper.getMentionBase(mention).matches("(pan)|(pani)"))); TToken nextToken = helper.getTokenAfterMention(mention); - addNominalAttributeValue(nextToken == null ? "end" : nextToken.getChosenInterpretation().getCtag(), candidateFeatures, attributePrefix + "_next_token_pos"); + addNominalAttributeValue(nextToken == null ? "end" : nextToken.getChosenInterpretation().getCtag(), candidateFeatures, attributePrefix + NEXT_TOKEN_POS); + String nextTokenLemma = nextToken == null ? "" : nextToken.getChosenInterpretation().getBase(); + for (String nextLemma : NEXT_TOKEN_LEMMAS) { + candidateFeatures.put(getAttributeByName(attributePrefix + NEXT_TOKEN_LEMMA + nextLemma), toBinary(nextTokenLemma.equalsIgnoreCase(nextLemma))); + } + TToken prevToken = helper.getTokenBeforeMention(mention); - addNominalAttributeValue(prevToken == null ? "end" : prevToken.getChosenInterpretation().getCtag(), candidateFeatures, attributePrefix + "_prev_token_pos"); + addNominalAttributeValue(prevToken == null ? "end" : prevToken.getChosenInterpretation().getCtag(), candidateFeatures, attributePrefix + PREV_TOKEN_POS); + String prevTokenLemma = prevToken == null ? "" : prevToken.getChosenInterpretation().getBase(); + for (String prevLemma : PREV_TOKEN_LEMMAS) { + candidateFeatures.put(getAttributeByName(attributePrefix + PREV_TOKEN_LEMMA + prevLemma), toBinary(prevTokenLemma.equalsIgnoreCase(prevLemma))); + } - candidateFeatures.put(getAttributeByName(attributePrefix + "_is_nested"), toBinary(helper.isNested(mention))); - candidateFeatures.put(getAttributeByName(attributePrefix + "_is_nesting"), toBinary(helper.isNesting(mention))); + candidateFeatures.put(getAttributeByName(attributePrefix + IS_NESTED), toBinary(helper.isNested(mention))); + candidateFeatures.put(getAttributeByName(attributePrefix + IS_NESTING), toBinary(helper.isNesting(mention))); + TSentence mentionSentence = helper.getMentionSentence(mention); + candidateFeatures.put(getAttributeByName(attributePrefix + SENTENCE_MENTION_COUNT), (double) mentionSentence.getMentions().size()); + candidateFeatures.put(getAttributeByName(attributePrefix + SENTENCE_TOKEN_LENGTH), (double) mentionSentence.getTokens().size()); + candidateFeatures.put(getAttributeByName(attributePrefix + SENTENCE_ENDS_WITH_QUESTION_MARK), toBinary(mentionSentence.getTokens().get(mentionSentence.getTokensSize() - 1).getOrth().equals("?"))); } private void addNominalAttributeValue(String value, Map<Attribute, Double> attribute2value, String attributeName) { Attribute att = getAttributeByName(attributeName); int index = att.indexOfValue(value); if (index == -1) - LOG.warn(value + " not found for attribute " + attributeName); + LOG.warn(value + "not found for attribute " + attributeName); attribute2value.put(att, (double) (index == -1 ? att.indexOfValue("other") : index)); } } diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectCandidate.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectCandidate.java index 6d0a76f..6d0a76f 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectCandidate.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectCandidate.java diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java index 5da90a5..239aff9 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java @@ -8,8 +8,8 @@ import pl.waw.ipipan.zil.summ.nicolas.common.Utils; import weka.classifiers.Classifier; import weka.core.Instance; import weka.core.Instances; +import weka.core.SerializationHelper; -import java.io.IOException; import java.util.List; import java.util.Map; import java.util.Set; @@ -21,8 +21,8 @@ public class ZeroSubjectInjector { private final Classifier classifier; private final Instances instances; - public ZeroSubjectInjector() throws IOException, ClassNotFoundException { - classifier = Utils.loadClassifier(Constants.ZERO_MODEL_PATH); + public ZeroSubjectInjector() throws Exception { + classifier = (Classifier) SerializationHelper.read(Constants.ZERO_MODEL_RESOURCE_PATH); featureExtractor = new ZeroFeatureExtractor(); instances = Utils.createNewInstances(featureExtractor.getAttributesList()); } @@ -31,7 +31,7 @@ public class ZeroSubjectInjector { Set<String> summarySentenceIds = selectedSentences.stream().map(TSentence::getId).collect(Collectors.toSet()); List<ZeroSubjectCandidate> zeroSubjectCandidates = CandidateFinder.findZeroSubjectCandidates(text, summarySentenceIds); Map<ZeroSubjectCandidate, Instance> candidate2instance = - PrepareTrainingData.extractInstancesFromZeroCandidates(zeroSubjectCandidates, text, featureExtractor); + InstanceCreator.extractInstancesFromZeroCandidates(zeroSubjectCandidates, text, featureExtractor); Set<String> result = Sets.newHashSet(); for (Map.Entry<ZeroSubjectCandidate, Instance> entry : candidate2instance.entrySet()) { diff --git a/nicolas-lib/src/test/java/pl/waw/ipipan/zil/summ/nicolas/NicolasTest.java b/nicolas-lib/src/test/java/pl/waw/ipipan/zil/summ/nicolas/NicolasTest.java new file mode 100644 index 0000000..9dae6f4 --- /dev/null +++ b/nicolas-lib/src/test/java/pl/waw/ipipan/zil/summ/nicolas/NicolasTest.java @@ -0,0 +1,30 @@ +package pl.waw.ipipan.zil.summ.nicolas; + +import org.junit.BeforeClass; +import org.junit.Test; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.common.Utils; + +import static org.junit.Assert.assertTrue; + +public class NicolasTest { + + private static final String SAMPLE_THRIFT_TEXT_RESOURCE_PATH = "/pl/waw/ipipan/zil/summ/nicolas/sample_serialized_text.thrift"; + + private static Nicolas nicolas; + + @BeforeClass + public static void shouldLoadModels() throws Exception { + nicolas = new Nicolas(); + } + + @Test + public void shouldSummarizeThriftText() throws Exception { + TText thriftText = Utils.loadThriftTextFromResource(SAMPLE_THRIFT_TEXT_RESOURCE_PATH); + String summary = nicolas.summarizeThrift(thriftText, 5); + int summaryTokensCount = Utils.tokenizeOnWhitespace(summary).size(); + assertTrue(summaryTokensCount > 0); + assertTrue(summaryTokensCount < 10); + } + +} \ No newline at end of file diff --git a/nicolas-core/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java b/nicolas-lib/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java index 4ab4ee2..992ff2d 100644 --- a/nicolas-core/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java +++ b/nicolas-lib/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java @@ -18,7 +18,7 @@ import static org.junit.Assert.assertEquals; public class CandidateFinderTest { - private static final String SAMPLE_TEXT_PATH = "/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin"; + private static final String SAMPLE_TEXT_PATH = "/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.thrift"; private static final String SAMPLE_TEXT_SUMMARY_IDS_PATH = "/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt"; @Test diff --git a/nicolas-core/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin b/nicolas-lib/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/sample_serialized_text.thrift index e30b245..e30b245 100644 --- a/nicolas-core/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin +++ b/nicolas-lib/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/sample_serialized_text.thrift diff --git a/nicolas-lib/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.thrift b/nicolas-lib/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.thrift new file mode 100644 index 0000000..e30b245 --- /dev/null +++ b/nicolas-lib/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.thrift diff --git a/nicolas-core/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt b/nicolas-lib/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt index 10ac642..10ac642 100644 --- a/nicolas-core/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt +++ b/nicolas-lib/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt diff --git a/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/models/.gitignore b/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/models/.gitignore new file mode 100644 index 0000000..f3ac583 --- /dev/null +++ b/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/models/.gitignore @@ -0,0 +1 @@ +*.bin \ No newline at end of file diff --git a/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/models/README.md b/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/models/README.md new file mode 100644 index 0000000..d8e0bae --- /dev/null +++ b/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/models/README.md @@ -0,0 +1 @@ +To generate models in this folder, use nicolas-trainer module. \ No newline at end of file diff --git a/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/frequent_bases.txt b/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/resources/frequent_bases.txt index 973881a..973881a 100644 --- a/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/frequent_bases.txt +++ b/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/resources/frequent_bases.txt diff --git a/nicolas-train/pom.xml b/nicolas-train/pom.xml index 62ae3a7..6d71d47 100644 --- a/nicolas-train/pom.xml +++ b/nicolas-train/pom.xml @@ -12,6 +12,16 @@ <artifactId>nicolas-train</artifactId> <dependencies> + <!-- project --> + <dependency> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <artifactId>nicolas-common</artifactId> + </dependency> + <dependency> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <artifactId>nicolas-lib</artifactId> + </dependency> + <!-- internal --> <dependency> <groupId>pl.waw.ipipan.zil.summ</groupId> @@ -22,10 +32,28 @@ <artifactId>utils</artifactId> </dependency> + <!-- third party --> + <dependency> + <groupId>nz.ac.waikato.cms.weka</groupId> + <artifactId>weka-dev</artifactId> + </dependency> + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-lang3</artifactId> + </dependency> + <dependency> + <groupId>net.lingala.zip4j</groupId> + <artifactId>zip4j</artifactId> + </dependency> + <!-- logging --> <dependency> <groupId>org.slf4j</groupId> <artifactId>slf4j-api</artifactId> </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-simple</artifactId> + </dependency> </dependencies> </project> \ No newline at end of file diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/DownloadAndPreprocessCorpus.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/DownloadAndPreprocessCorpus.java new file mode 100644 index 0000000..439a33b --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/DownloadAndPreprocessCorpus.java @@ -0,0 +1,60 @@ +package pl.waw.ipipan.zil.summ.nicolas.train; + +import net.lingala.zip4j.core.ZipFile; +import org.apache.commons.io.FileUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pl.waw.ipipan.zil.summ.nicolas.train.multiservice.NLPProcess; + +import java.io.File; +import java.net.URL; + +public class DownloadAndPreprocessCorpus { + + private static final Logger LOG = LoggerFactory.getLogger(DownloadAndPreprocessCorpus.class); + + private static final String WORKING_DIR = "data"; + private static final String CORPUS_DOWNLOAD_URL = "http://zil.ipipan.waw.pl/PolishSummariesCorpus?action=AttachFile&do=get&target=PSC_1.0.zip"; + + private DownloadAndPreprocessCorpus() { + } + + public static void main(String[] args) throws Exception { + File workDir = createFolder(WORKING_DIR); + + File corpusFile = new File(workDir, "corpus.zip"); + if (!corpusFile.exists()) { + LOG.info("Downloading corpus file..."); + FileUtils.copyURLToFile(new URL(CORPUS_DOWNLOAD_URL), corpusFile); + LOG.info("done."); + } else { + LOG.info("Corpus file already downloaded."); + } + + File extractedCorpusDir = new File(workDir, "corpus"); + if (extractedCorpusDir.exists()) { + LOG.info("Corpus file already extracted."); + } else { + ZipFile zipFile = new ZipFile(corpusFile); + zipFile.extractAll(extractedCorpusDir.getPath()); + LOG.info("Extracted corpus file."); + } + + File pscDir = new File(extractedCorpusDir, "PSC_1.0"); + File dataDir = new File(pscDir, "data"); + + File preprocessed = new File(WORKING_DIR, "preprocessed"); + createFolder(preprocessed.getPath()); + NLPProcess.main(new String[]{dataDir.getPath(), preprocessed.getPath()}); + } + + private static File createFolder(String path) { + File folder = new File(path); + if (folder.mkdir()) { + LOG.info("Created directory at: {}.", path); + } else { + LOG.info("Directory already present at: {}.", path); + } + return folder; + } +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/TrainAllModels.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/TrainAllModels.java new file mode 100644 index 0000000..b736a93 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/TrainAllModels.java @@ -0,0 +1,17 @@ +package pl.waw.ipipan.zil.summ.nicolas.train; + +import pl.waw.ipipan.zil.summ.nicolas.train.model.mention.TrainMentionModel; +import pl.waw.ipipan.zil.summ.nicolas.train.model.sentence.TrainSentenceModel; +import pl.waw.ipipan.zil.summ.nicolas.train.model.zero.TrainZeroModel; + +public class TrainAllModels { + + private TrainAllModels() { + } + + public static void main(String[] args) throws Exception { + TrainMentionModel.main(args); + TrainSentenceModel.main(args); + TrainZeroModel.main(args); + } +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Trainer.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Trainer.java deleted file mode 100644 index c4b4d7c..0000000 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Trainer.java +++ /dev/null @@ -1,8 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.train; - -public class Trainer { - - public static void main(String[] args) { - - } -} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/ModelConstants.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/ModelConstants.java new file mode 100644 index 0000000..a7087d3 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/ModelConstants.java @@ -0,0 +1,43 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.model.common; + +import weka.classifiers.Classifier; +import weka.classifiers.trees.RandomForest; + +public class ModelConstants { + + public static final String MENTION_DATASET_PATH = "mentions_train.arff"; + public static final String SENTENCE_DATASET_PATH = "sentences_train.arff"; + public static final String ZERO_DATASET_PATH = "zeros_train.arff"; + + private static final int NUM_ITERATIONS = 16; + private static final int NUM_EXECUTION_SLOTS = 8; + private static final int SEED = 0; + + private ModelConstants() { + } + + public static Classifier getMentionClassifier() { + RandomForest classifier = new RandomForest(); + classifier.setNumIterations(NUM_ITERATIONS); + classifier.setSeed(SEED); + classifier.setNumExecutionSlots(NUM_EXECUTION_SLOTS); + return classifier; + } + + public static Classifier getSentenceClassifier() { + RandomForest classifier = new RandomForest(); + classifier.setNumIterations(16); + classifier.setSeed(0); + classifier.setNumExecutionSlots(8); + return classifier; + } + + public static Classifier getZeroClassifier() { + RandomForest classifier = new RandomForest(); + classifier.setNumIterations(16); + classifier.setSeed(0); + classifier.setNumExecutionSlots(8); + return classifier; + } + +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/TrainModel.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/TrainModelCommon.java index 77c5a30..9a0ae09 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/TrainModel.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/TrainModelCommon.java @@ -1,9 +1,9 @@ -package pl.waw.ipipan.zil.summ.nicolas.zero; +package pl.waw.ipipan.zil.summ.nicolas.train.model.common; import org.apache.commons.lang3.time.StopWatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; +import pl.waw.ipipan.zil.summ.nicolas.train.model.zero.TrainZeroModel; import weka.classifiers.Classifier; import weka.core.Instances; import weka.core.converters.ArffLoader; @@ -11,41 +11,43 @@ import weka.core.converters.ArffLoader; import java.io.File; import java.io.FileOutputStream; import java.io.ObjectOutputStream; +import java.util.logging.LogManager; +@SuppressWarnings("squid:S2118") +public class TrainModelCommon { -public class TrainModel { + private static final Logger LOG = LoggerFactory.getLogger(TrainZeroModel.class); - private static final Logger LOG = LoggerFactory.getLogger(TrainModel.class); + private static final String TARGET_MODEL_DIR = "nicolas-model/src/main/resources"; - private TrainModel() { + private TrainModelCommon() { } - public static void main(String[] args) throws Exception { + public static void trainAndSaveModel(String datasetPath, Classifier classifier, String targetPath) throws Exception { + LogManager.getLogManager().reset(); // disable WEKA logging ArffLoader loader = new ArffLoader(); - loader.setFile(new File(Constants.ZERO_DATASET_PATH)); + loader.setFile(new File(datasetPath)); Instances instances = loader.getDataSet(); instances.setClassIndex(0); - LOG.info(instances.size() + " instances loaded."); - LOG.info(instances.numAttributes() + " attributes for each instance."); + LOG.info("{} instances loaded.", instances.size()); + LOG.info("{} attributes for each instance.", instances.numAttributes()); StopWatch watch = new StopWatch(); watch.start(); - Classifier classifier = Constants.getZerosClassifier(); - LOG.info("Building classifier..."); classifier.buildClassifier(instances); - LOG.info("...done."); + LOG.info("...done. Build classifier: {}", classifier); + String target = TARGET_MODEL_DIR + targetPath; + LOG.info("Saving classifier at: {}", target); try (ObjectOutputStream oos = new ObjectOutputStream( - new FileOutputStream(Constants.ZERO_MODEL_PATH))) { + new FileOutputStream(target))) { oos.writeObject(classifier); } watch.stop(); - LOG.info("Elapsed time: " + watch); - - LOG.info(classifier.toString()); + LOG.info("Elapsed time: {}", watch); } } diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionScorer.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/MentionScorer.java index 9180ac4..29eaa5f 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionScorer.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/MentionScorer.java @@ -1,4 +1,4 @@ -package pl.waw.ipipan.zil.summ.nicolas.mention; +package pl.waw.ipipan.zil.summ.nicolas.train.model.mention; import com.google.common.collect.HashMultiset; import com.google.common.collect.Maps; @@ -14,7 +14,6 @@ import java.util.stream.Collectors; public class MentionScorer { - public Map<TMention, Double> calculateMentionScores(String optimalSummary, TText text) { Multiset<String> tokenCounts = HashMultiset.create(Utils.tokenize(optimalSummary.toLowerCase())); @@ -39,20 +38,4 @@ public class MentionScorer { } return mention2score; } - - private static Map<TMention, Double> booleanTokenInclusion(Map<TMention, String> mention2Orth, Multiset<String> tokenCounts) { - Map<TMention, Double> mention2score = Maps.newHashMap(); - for (Map.Entry<TMention, String> entry : mention2Orth.entrySet()) { - TMention mention = entry.getKey(); - String mentionOrth = mention2Orth.get(mention); - int present = 0; - for (String token : Utils.tokenize(mentionOrth)) { - if (tokenCounts.contains(token.toLowerCase())) { - present++; - } - } - mention2score.putIfAbsent(mention, ((present * 2) >= Utils.tokenize(mentionOrth).size()) ? 1.0 : 0.0); - } - return mention2score; - } } diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/PrepareTrainingData.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/PrepareTrainingData.java index 13f606a..7a6f6b5 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/PrepareTrainingData.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/PrepareTrainingData.java @@ -1,4 +1,4 @@ -package pl.waw.ipipan.zil.summ.nicolas.mention; +package pl.waw.ipipan.zil.summ.nicolas.train.model.mention; import com.google.common.base.Charsets; import com.google.common.collect.Maps; @@ -7,9 +7,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.ThriftUtils; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.common.Utils; +import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; import weka.core.Instance; import weka.core.Instances; import weka.core.converters.ArffSaver; @@ -23,8 +25,11 @@ public class PrepareTrainingData { private static final Logger LOG = LoggerFactory.getLogger(PrepareTrainingData.class); - public static final String PREPROCESSED_FULL_TEXTS_DIR_PATH = "src/main/resources/preprocessed_full_texts/dev"; - public static final String OPTIMAL_SUMMARIES_DIR_PATH = "src/main/resources/optimal_summaries/dev"; + private static final String PREPROCESSED_FULL_TEXTS_DIR_PATH = "src/main/resources/preprocessed_full_texts/dev"; + private static final String OPTIMAL_SUMMARIES_DIR_PATH = "src/main/resources/optimal_summaries/dev"; + + private PrepareTrainingData() { + } public static void main(String[] args) throws IOException { @@ -37,19 +42,20 @@ public class PrepareTrainingData { Instances instances = Utils.createNewInstances(featureExtractor.getAttributesList()); int i = 1; - for (String textId : id2preprocessedText.keySet()) { + for (Map.Entry<String, TText> entry : id2preprocessedText.entrySet()) { LOG.info(i++ + "/" + id2preprocessedText.size()); - TText preprocessedText = id2preprocessedText.get(textId); - String optimalSummary = id2optimalSummary.get(textId); + String id = entry.getKey(); + TText preprocessedText = entry.getValue(); + String optimalSummary = id2optimalSummary.get(id); if (optimalSummary == null) continue; Map<TMention, Double> mention2score = mentionScorer.calculateMentionScores(optimalSummary, preprocessedText); Map<TMention, Instance> mention2instance = ThriftUtils.extractInstancesFromMentions(preprocessedText, featureExtractor); - for (Map.Entry<TMention, Instance> entry : mention2instance.entrySet()) { - TMention mention = entry.getKey(); - Instance instance = entry.getValue(); + for (Map.Entry<TMention, Instance> entry2 : mention2instance.entrySet()) { + TMention mention = entry2.getKey(); + Instance instance = entry2.getValue(); instance.setDataset(instances); instance.setClassValue(mention2score.get(mention)); instances.add(instance); @@ -61,7 +67,7 @@ public class PrepareTrainingData { private static void saveInstancesToFile(Instances instances) throws IOException { ArffSaver saver = new ArffSaver(); saver.setInstances(instances); - saver.setFile(new File(Constants.MENTIONS_DATASET_PATH)); + saver.setFile(new File(ModelConstants.MENTION_DATASET_PATH)); saver.writeBatch(); } diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/TrainMentionModel.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/TrainMentionModel.java new file mode 100644 index 0000000..c9104c5 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/TrainMentionModel.java @@ -0,0 +1,20 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.model.mention; + +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; +import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; +import pl.waw.ipipan.zil.summ.nicolas.train.model.common.TrainModelCommon; +import weka.classifiers.Classifier; + +public class TrainMentionModel { + + private TrainMentionModel() { + } + + public static void main(String[] args) throws Exception { + Classifier classifier = ModelConstants.getMentionClassifier(); + String datasetPath = ModelConstants.MENTION_DATASET_PATH; + String targetPath = Constants.MENTION_MODEL_RESOURCE_PATH; + TrainModelCommon.trainAndSaveModel(datasetPath, classifier, targetPath); + } + +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/PrepareTrainingData.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/PrepareTrainingData.java index 31fa380..a892620 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/PrepareTrainingData.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/PrepareTrainingData.java @@ -1,4 +1,4 @@ -package pl.waw.ipipan.zil.summ.nicolas.sentence; +package pl.waw.ipipan.zil.summ.nicolas.train.model.sentence; import com.google.common.base.Charsets; import com.google.common.collect.Maps; @@ -8,11 +8,13 @@ import org.slf4j.LoggerFactory; import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.ThriftUtils; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.common.Utils; import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor; import pl.waw.ipipan.zil.summ.nicolas.mention.MentionModel; +import pl.waw.ipipan.zil.summ.nicolas.sentence.SentenceFeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; import weka.classifiers.Classifier; import weka.core.Instance; import weka.core.Instances; @@ -31,6 +33,9 @@ public class PrepareTrainingData { private static final String PREPROCESSED_FULL_TEXTS_DIR_PATH = "src/main/resources/preprocessed_full_texts/dev"; private static final String OPTIMAL_SUMMARIES_DIR_PATH = "src/main/resources/optimal_summaries/dev"; + private PrepareTrainingData() { + } + public static void main(String[] args) throws Exception { Map<String, TText> id2preprocessedText = Utils.loadPreprocessedTexts(PREPROCESSED_FULL_TEXTS_DIR_PATH); @@ -41,7 +46,7 @@ public class PrepareTrainingData { Instances instances = Utils.createNewInstances(featureExtractor.getAttributesList()); - Classifier classifier = Utils.loadClassifier(Constants.MENTIONS_MODEL_PATH); + Classifier classifier = Utils.loadClassifier(Constants.MENTION_MODEL_RESOURCE_PATH); MentionFeatureExtractor mentionFeatureExtractor = new MentionFeatureExtractor(); int i = 1; @@ -74,7 +79,7 @@ public class PrepareTrainingData { private static void saveInstancesToFile(Instances instances) throws IOException { ArffSaver saver = new ArffSaver(); saver.setInstances(instances); - saver.setFile(new File(Constants.SENTENCES_DATASET_PATH)); + saver.setFile(new File(ModelConstants.SENTENCE_DATASET_PATH)); saver.writeBatch(); } diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceScorer.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/SentenceScorer.java index e53ffa7..ef985a5 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceScorer.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/SentenceScorer.java @@ -1,4 +1,4 @@ -package pl.waw.ipipan.zil.summ.nicolas.sentence; +package pl.waw.ipipan.zil.summ.nicolas.train.model.sentence; import com.google.common.collect.HashMultiset; import com.google.common.collect.Maps; diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/TrainSentenceModel.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/TrainSentenceModel.java new file mode 100644 index 0000000..b79e6ca --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/TrainSentenceModel.java @@ -0,0 +1,20 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.model.sentence; + +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; +import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; +import pl.waw.ipipan.zil.summ.nicolas.train.model.common.TrainModelCommon; +import weka.classifiers.Classifier; + +public class TrainSentenceModel { + + private TrainSentenceModel() { + } + + public static void main(String[] args) throws Exception { + Classifier classifier = ModelConstants.getSentenceClassifier(); + String datasetPath = ModelConstants.SENTENCE_DATASET_PATH; + String targetPath = Constants.SENTENCE_MODEL_RESOURCE_PATH; + TrainModelCommon.trainAndSaveModel(datasetPath, classifier, targetPath); + } + +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/PrepareTrainingData.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/PrepareTrainingData.java index 38fb018..4cb918f 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/PrepareTrainingData.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/PrepareTrainingData.java @@ -1,4 +1,4 @@ -package pl.waw.ipipan.zil.summ.nicolas.zero; +package pl.waw.ipipan.zil.summ.nicolas.train.model.zero; import com.google.common.collect.Maps; import com.google.common.collect.Sets; @@ -6,11 +6,13 @@ import org.apache.commons.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.common.Utils; import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; -import weka.core.Attribute; -import weka.core.DenseInstance; +import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; +import pl.waw.ipipan.zil.summ.nicolas.zero.CandidateFinder; +import pl.waw.ipipan.zil.summ.nicolas.zero.InstanceCreator; +import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroFeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectCandidate; import weka.core.Instance; import weka.core.Instances; import weka.core.converters.ArffSaver; @@ -54,7 +56,7 @@ public class PrepareTrainingData { FeatureHelper featureHelper = new FeatureHelper(text); List<ZeroSubjectCandidate> zeroSubjectCandidates = CandidateFinder.findZeroSubjectCandidates(text, sentenceIds); - Map<ZeroSubjectCandidate, Instance> candidate2instance = extractInstancesFromZeroCandidates(zeroSubjectCandidates, text, featureExtractor); + Map<ZeroSubjectCandidate, Instance> candidate2instance = InstanceCreator.extractInstancesFromZeroCandidates(zeroSubjectCandidates, text, featureExtractor); for (Map.Entry<ZeroSubjectCandidate, Instance> entry2 : candidate2instance.entrySet()) { boolean good = zeroScorer.isValidCandidate(entry2.getKey(), featureHelper); @@ -68,24 +70,11 @@ public class PrepareTrainingData { saveInstancesToFile(instances); } - public static Map<ZeroSubjectCandidate, Instance> extractInstancesFromZeroCandidates(List<ZeroSubjectCandidate> candidates, TText text, ZeroFeatureExtractor featureExtractor) { - Map<ZeroSubjectCandidate, Map<Attribute, Double>> candidate2features = featureExtractor.calculateFeatures(candidates, text); - Map<ZeroSubjectCandidate, Instance> candidate2instance = Maps.newHashMap(); - for (Map.Entry<ZeroSubjectCandidate, Map<Attribute, Double>> entry : candidate2features.entrySet()) { - Instance instance = new DenseInstance(featureExtractor.getAttributesList().size()); - Map<Attribute, Double> sentenceFeatures = entry.getValue(); - for (Attribute attribute : featureExtractor.getAttributesList()) { - instance.setValue(attribute, sentenceFeatures.get(attribute)); - } - candidate2instance.put(entry.getKey(), instance); - } - return candidate2instance; - } private static void saveInstancesToFile(Instances instances) throws IOException { ArffSaver saver = new ArffSaver(); saver.setInstances(instances); - saver.setFile(new File(Constants.ZERO_DATASET_PATH)); + saver.setFile(new File(ModelConstants.ZERO_DATASET_PATH)); saver.writeBatch(); } diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/TrainZeroModel.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/TrainZeroModel.java new file mode 100644 index 0000000..7770a29 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/TrainZeroModel.java @@ -0,0 +1,20 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.model.zero; + +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; +import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; +import pl.waw.ipipan.zil.summ.nicolas.train.model.common.TrainModelCommon; +import weka.classifiers.Classifier; + +public class TrainZeroModel { + + private TrainZeroModel() { + } + + public static void main(String[] args) throws Exception { + Classifier classifier = ModelConstants.getZeroClassifier(); + String datasetPath = ModelConstants.ZERO_DATASET_PATH; + String targetPath = Constants.ZERO_MODEL_RESOURCE_PATH; + TrainModelCommon.trainAndSaveModel(datasetPath, classifier, targetPath); + } + +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroScorer.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/ZeroScorer.java index f34183b..495ca21 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroScorer.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/ZeroScorer.java @@ -1,4 +1,4 @@ -package pl.waw.ipipan.zil.summ.nicolas.zero; +package pl.waw.ipipan.zil.summ.nicolas.train.model.zero; import com.google.common.collect.Maps; import org.apache.commons.csv.CSVFormat; @@ -7,6 +7,7 @@ import org.apache.commons.csv.CSVRecord; import org.apache.commons.csv.QuoteMode; import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; +import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectCandidate; import java.io.IOException; import java.io.InputStream; diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/multiservice/NLPProcess.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/multiservice/NLPProcess.java index bef8a7c..2922942 100644 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/multiservice/NLPProcess.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/multiservice/NLPProcess.java @@ -24,6 +24,9 @@ public class NLPProcess { private static final MultiserviceProxy MSPROXY = new MultiserviceProxy(HOST, PORT); + private static final String CORPUS_FILE_SUFFIX = ".xml"; + private static final String OUTPUT_FILE_SUFFIX = ".thrift"; + private NLPProcess() { } @@ -34,23 +37,27 @@ public class NLPProcess { } File corpusDir = new File(args[0]); if (!corpusDir.isDirectory()) { - LOG.error("Corpus directory does not exist: " + corpusDir); + LOG.error("Corpus directory does not exist: {}", corpusDir); return; } File targetDir = new File(args[1]); if (!targetDir.isDirectory()) { - LOG.error("Target directory does not exist: " + targetDir); + LOG.error("Target directory does not exist: {}", targetDir); return; } int ok = 0; int err = 0; - File[] files = corpusDir.listFiles(f -> f.getName().endsWith(".xml")); + File[] files = corpusDir.listFiles(f -> f.getName().endsWith(CORPUS_FILE_SUFFIX)); + if (files == null || files.length == 0) { + LOG.error("No corpus files found at: {}", corpusDir); + return; + } Arrays.sort(files); for (File file : files) { try { Text text = PSC_IO.readText(file); - File targetFile = new File(targetDir, file.getName().replaceFirst(".xml$", ".bin")); + File targetFile = new File(targetDir, file.getName().replaceFirst(CORPUS_FILE_SUFFIX + "$", OUTPUT_FILE_SUFFIX)); annotateNLP(text, targetFile); ok++; } catch (Exception e) { @@ -58,8 +65,8 @@ public class NLPProcess { LOG.error("Problem with text in " + file + ", " + e); } } - LOG.info(ok + " texts processed successfully."); - LOG.info(err + " texts with errors."); + LOG.info("{} texts processed successfully.", ok); + LOG.info("{} texts with errors.", err); } private static void annotateNLP(Text text, File targetFile) throws Exception { @@ -77,8 +84,8 @@ public class NLPProcess { } public static void serialize(TText ttext, File targetFile) throws IOException { - try (FileOutputStream fout = new FileOutputStream(targetFile); - ObjectOutputStream oos = new ObjectOutputStream(fout)) { + try (FileOutputStream fileOutputStream = new FileOutputStream(targetFile); + ObjectOutputStream oos = new ObjectOutputStream(fileOutputStream)) { oos.writeObject(ttext); } } diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/EvalUtils.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/search/CrossvalidateCommon.java index d0f79fc..b0239df 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/EvalUtils.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/search/CrossvalidateCommon.java @@ -1,4 +1,4 @@ -package pl.waw.ipipan.zil.summ.nicolas.eval; +package pl.waw.ipipan.zil.summ.nicolas.train.search; import org.apache.commons.lang3.time.StopWatch; import org.apache.commons.lang3.tuple.Pair; @@ -14,6 +14,7 @@ import weka.classifiers.functions.SimpleLogistic; import weka.classifiers.lazy.IBk; import weka.classifiers.lazy.KStar; import weka.classifiers.lazy.LWL; +import weka.classifiers.meta.AttributeSelectedClassifier; import weka.classifiers.rules.DecisionTable; import weka.classifiers.rules.JRip; import weka.classifiers.rules.PART; @@ -23,21 +24,49 @@ import weka.classifiers.trees.J48; import weka.classifiers.trees.LMT; import weka.classifiers.trees.RandomForest; import weka.core.Instances; +import weka.core.converters.ArffLoader; +import java.io.File; +import java.io.IOException; import java.util.Arrays; import java.util.Comparator; import java.util.Optional; import java.util.Random; +import java.util.logging.LogManager; -public class EvalUtils { - private static final Logger LOG = LoggerFactory.getLogger(EvalUtils.class); - public static final int NUM_FOLDS = 10; +class CrossvalidateCommon { - private EvalUtils() { + private static final Logger LOG = LoggerFactory.getLogger(CrossvalidateCommon.class); + + private static final int NUM_FOLDS = 10; + + private CrossvalidateCommon() { + } + + static void crossvalidateClassifiers(String datasetPath) throws IOException { + Instances instances = loadInstances(datasetPath); + crossvalidateClassification(instances); + } + + static void crossvalidateRegressors(String datasetPath) throws IOException { + Instances instances = loadInstances(datasetPath); + crossvalidateRegression(instances); } - public static void crossvalidateClassification(Instances instances) throws Exception { + private static Instances loadInstances(String datasetPath) throws IOException { + LogManager.getLogManager().reset(); // disable WEKA logging + + ArffLoader loader = new ArffLoader(); + loader.setFile(new File(datasetPath)); + Instances instances = loader.getDataSet(); + instances.setClassIndex(0); + LOG.info("{} instances loaded.", instances.size()); + LOG.info("{} attributes for each instance.", instances.numAttributes()); + return instances; + } + + private static void crossvalidateClassification(Instances instances) throws IOException { StopWatch watch = new StopWatch(); watch.start(); @@ -45,52 +74,58 @@ public class EvalUtils { new Logistic(), new ZeroR(), new SimpleLogistic(), new BayesNet(), new NaiveBayes(), new KStar(), new IBk(), new LWL(), - new DecisionTable(), new JRip(), new PART()}).parallel().map(cls -> { - Evaluation eval = null; + new DecisionTable(), new JRip(), new PART(), + createAttributeSelectedClassifier()}).parallel().map(cls -> { + String name = cls.getClass().getSimpleName(); + double acc = 0; + Evaluation eval; try { eval = new Evaluation(instances); eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(1)); } catch (Exception e) { - e.printStackTrace(); + LOG.error("Error evaluating model", e); + return Pair.of(0.0, name); } - double acc = eval.correct() / eval.numInstances(); - String name = cls.getClass().getSimpleName(); + acc = eval.correct() / eval.numInstances(); LOG.info(name + " : " + acc); - return Pair.of(acc, name); }).max(Comparator.comparingDouble(Pair::getLeft)); LOG.info("#########"); LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); watch.stop(); - LOG.info("Elapsed time: " + watch); + LOG.info("Elapsed time: {}", watch); + } + + + private static Classifier createAttributeSelectedClassifier() { + AttributeSelectedClassifier attributeSelectedClassifier = new AttributeSelectedClassifier(); + attributeSelectedClassifier.setClassifier(new LMT()); + return attributeSelectedClassifier; } - public static void crossvalidateRegression(Instances instances) { + private static void crossvalidateRegression(Instances instances) { StopWatch watch = new StopWatch(); watch.start(); Optional<Pair<Double, String>> max = Arrays.stream(new Classifier[]{ new RandomForest(), new LinearRegression(), new KStar()}).parallel().map(cls -> { - Evaluation eval = null; double acc = 0; + String name = cls.getClass().getSimpleName(); try { - eval = new Evaluation(instances); + Evaluation eval = new Evaluation(instances); eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(1)); acc = eval.correlationCoefficient(); - } catch (Exception e) { - e.printStackTrace(); + LOG.error("Error evaluating model", e); } - String name = cls.getClass().getSimpleName(); LOG.info(name + " : " + acc); - return Pair.of(acc, name); }).max(Comparator.comparingDouble(Pair::getLeft)); LOG.info("#########"); LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); watch.stop(); - LOG.info("Elapsed time: " + watch); + LOG.info("Elapsed time: {}", watch); } -} \ No newline at end of file +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/search/CrossvalidateMention.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/search/CrossvalidateMention.java new file mode 100644 index 0000000..4d25877 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/search/CrossvalidateMention.java @@ -0,0 +1,14 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.search; + +import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; + + +public class CrossvalidateMention { + + private CrossvalidateMention() { + } + + public static void main(String[] args) throws Exception { + CrossvalidateCommon.crossvalidateClassifiers(ModelConstants.MENTION_DATASET_PATH); + } +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/search/CrossvalidateSentence.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/search/CrossvalidateSentence.java new file mode 100644 index 0000000..66003e0 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/search/CrossvalidateSentence.java @@ -0,0 +1,14 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.search; + +import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; + + +public class CrossvalidateSentence { + + private CrossvalidateSentence() { + } + + public static void main(String[] args) throws Exception { + CrossvalidateCommon.crossvalidateRegressors(ModelConstants.SENTENCE_DATASET_PATH); + } +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/search/CrossvalidateZero.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/search/CrossvalidateZero.java new file mode 100644 index 0000000..f7f1276 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/search/CrossvalidateZero.java @@ -0,0 +1,14 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.search; + +import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; + + +public class CrossvalidateZero { + + private CrossvalidateZero() { + } + + public static void main(String[] args) throws Exception { + CrossvalidateCommon.crossvalidateClassifiers(ModelConstants.ZERO_DATASET_PATH); + } +} diff --git a/nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/zero/dev_ids.txt b/nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/zero/dev_ids.txt new file mode 100644 index 0000000..6b0ff86 --- /dev/null +++ b/nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/zero/dev_ids.txt @@ -0,0 +1,415 @@ +199704210011 +199704210013 +199704250031 +199704260017 +199801030156 +199801100009 +199801150038 +199801150133 +199801170001 +199801170129 +199801170130 +199801200002 +199801200132 +199801210007 +199801220030 +199801220127 +199801230001 +199801230095 +199801240116 +199801240123 +199801260113 +199801270108 +199801280128 +199801290020 +199801310032 +199802040201 +199901180149 +199901190049 +199901230088 +199901250006 +199901250008 +199901250111 +199901250113 +199901300064 +199901300098 +199902240123 +199906220027 +199906220037 +199906220038 +199906220056 +199906220065 +199906230040 +199906230052 +199906240040 +199906240088 +199906250007 +199906250091 +199906260015 +199906260018 +199906260038 +199907030016 +199907030018 +199907030042 +199907030059 +199907050032 +199907050040 +199907050047 +199907050071 +199907270095 +199907270137 +199907270145 +199909210045 +199909250054 +199909300064 +199909300065 +199909300066 +199910020049 +199910020050 +199910090047 +199910090049 +199910090051 +199910110055 +199910110057 +199910210058 +199910210059 +199910270041 +199910280054 +199910280055 +199910280057 +199910300026 +199911030039 +199911030040 +199911030041 +199911060031 +199911060042 +199911060043 +199911080054 +199911080055 +199911080056 +199911100061 +199911100062 +199911100063 +199911130036 +199911130037 +199911130038 +199911180042 +199911180043 +199911180044 +199911220059 +199911220061 +199911220066 +199911230041 +199911240035 +199911240037 +199911240038 +199911250055 +199911250057 +199912020059 +199912090045 +199912090047 +199912090061 +199912110041 +199912110042 +199912130055 +199912130057 +199912170065 +199912180052 +199912210018 +199912210037 +199912210040 +199912220045 +199912220046 +199912220047 +199912230058 +199912230059 +199912230097 +199912280028 +199912280044 +199912280045 +199912310085 +199912310087 +200001030047 +200001030106 +200001040030 +200001040031 +200001060052 +200001060053 +200001060055 +200001070062 +200001070066 +200001080040 +200001080041 +200001140061 +200001140064 +200001170049 +200001170051 +200001170052 +200001170053 +200001180040 +200001200056 +200001220023 +200001220118 +200001240016 +200001290042 +200001310048 +200001310049 +200001310050 +200001310054 +200002090042 +200002090043 +200002120045 +200002120046 +200002160046 +200002160047 +200002250063 +200002250065 +200002250066 +200002290044 +200002290045 +200002290046 +200002290047 +200002290048 +200003010058 +200003010059 +200003060054 +200003060055 +200003060057 +200003110047 +200003110048 +200003110049 +200003210044 +200003210045 +200004120021 +200004120022 +200004120023 +200004150048 +200004150049 +200004150050 +200004170026 +200004170065 +200004220044 +200004220045 +200004220046 +200004220047 +200004220048 +200005060030 +200005150055 +200005150059 +200005300045 +200005300047 +200005300048 +200006010065 +200006010066 +200006010067 +200006050056 +200006050057 +200006050058 +200006050059 +200006050061 +200006050068 +200006070056 +200006080033 +200006120031 +200006130055 +200006130057 +200006130059 +200006260069 +200006260071 +200006270059 +200007120068 +200007120070 +200007120072 +200007170026 +200007180051 +200007240034 +200007270050 +200007280033 +200008040071 +200008040073 +200008250077 +200008250079 +200008260055 +200008310046 +200010120066 +200010120074 +200010130063 +200010140048 +200010140049 +200010160039 +200010160048 +200010160049 +200010180059 +200010180063 +200010190066 +200010190068 +200011210063 +200011210064 +200011210066 +200012050066 +200012050067 +200012050068 +200012050069 +200012050070 +200012050071 +200012080134 +200012080137 +200012110069 +200012110070 +200012110071 +200012110075 +200012120028 +200012120068 +200012120072 +200012130056 +200012130100 +200012130102 +200012130103 +200012140095 +200012140096 +200012140097 +200012140098 +200012140099 +200012140100 +200012150076 +200012160048 +200012160049 +200012180083 +200012180084 +200012180088 +200012230028 +200012230045 +200012230046 +200012230047 +200012230048 +200012230050 +200012270055 +200012270056 +200101020059 +200101020062 +200101020063 +200101020075 +200101130048 +200101130050 +200101130051 +200101130055 +200101150043 +200101150045 +200101180050 +200101180051 +200101180052 +200101200048 +200101220047 +200101220053 +200102070011 +200102070016 +200102120034 +200102120057 +200102130014 +200102150001 +200102150014 +200102160011 +200102190016 +200102220001 +200102220013 +200102270041 +200102270062 +200102280169 +200103010049 +200103060022 +200103060032 +200103060057 +200103080026 +200103080030 +200103080036 +200103100019 +200103100021 +200103100058 +200103100062 +200103130008 +200103130023 +200103130069 +200103200066 +200103200080 +200103270069 +200103310092 +200104020007 +200104050011 +200104100021 +200104100023 +200104170015 +200104170040 +200104170055 +200104170057 +200104190039 +200104190066 +200104230031 +200104230069 +200104260051 +200104260053 +200104300213 +200104300215 +200104300217 +200105020092 +200105050042 +200105050043 +200105050046 +200105050048 +200105070017 +200105140050 +200105140052 +200105220096 +200105290074 +200105290075 +200106120068 +200106120069 +200106180051 +200106180053 +200106200064 +200106220086 +200106220087 +200106220088 +200106220090 +200106250050 +200107120071 +200107120073 +200107210129 +200107240070 +200107250080 +200108060051 +200108060155 +200108060156 +200108060157 +200108070038 +200108160040 +200108180123 +200108200033 +200108210066 +200108210074 +200108270077 +200108280064 +200109060061 +200109130091 +200109250092 +200109260097 +200109270116 +200110020075 +200110150056 +200110150062 +200110200070 +200110200071 +200110220068 +200111080086 +200111140055 +200111210078 +200111240060 +200112040031 +200112040077 +200112050063 +200112100041 +200112190067 +200201280011 +200201290029 +200202280078 +200203280057 +200203290107 diff --git a/nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/zero/test_ids.txt b/nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/zero/test_ids.txt new file mode 100644 index 0000000..d0c556d --- /dev/null +++ b/nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/zero/test_ids.txt @@ -0,0 +1,154 @@ +199704210012 +199704210042 +199704220007 +199704220018 +199704220021 +199704220044 +199704230006 +199704230014 +199704230029 +199704230043 +199704240008 +199704240019 +199704240020 +199704240021 +199704250018 +199704250022 +199704260014 +199704260015 +199704260016 +199704280023 +199704280025 +199704280027 +199704280031 +199704300031 +199704300042 +199704300046 +199801020010 +199801020031 +199801020035 +199801020070 +199801020076 +199801020079 +199801030068 +199801030090 +199801030091 +199801030129 +199801030148 +199801030158 +199801050023 +199801050059 +199801130087 +199801130129 +199801140182 +199801160119 +199801200106 +199801220140 +199801240061 +199801240096 +199801260047 +199801260070 +199801270055 +199801270110 +199801280123 +199801280158 +199801280159 +199801280241 +199801290022 +199801310003 +199801310037 +199802030127 +199802040159 +199802040182 +199802040202 +199805220133 +199808280158 +199901190073 +199901190115 +199901250112 +199901250117 +199901270103 +199901270120 +199901270122 +199901290095 +199901300101 +199902240095 +199906220029 +199906230024 +199906240084 +199906260027 +199907050045 +199907050076 +199907140166 +199907200002 +199907270004 +199908260001 +199909090036 +199909250018 +199909270029 +199910020027 +199910020029 +199910270011 +199911060044 +199911100038 +199911100064 +199911200030 +199911220063 +199912020060 +199912180026 +199912180034 +199912220030 +199912280024 +199912280046 +199912300021 +199912300029 +200001030029 +200001030053 +200001060034 +200001100035 +200001100046 +200001170029 +200001170033 +200001170060 +200001290045 +200002220027 +200002240034 +200002250031 +200003060062 +200003110050 +200004280047 +200004290022 +200006050119 +200006260079 +200006290045 +200007150033 +200008040076 +200008220042 +200008220046 +200010130049 +200010160054 +200012130034 +200012140084 +200012290046 +200104040019 +200106050035 +200108180109 +200108300032 +200111120045 +200111150042 +200111150047 +200111200036 +200111270049 +200112030055 +200112280057 +200201220038 +200201220050 +200202020036 +200202200032 +200202210054 +200202270044 +200203010070 +200203190026 +200203260050 +200203280017 +200203290078 diff --git a/nicolas-core/src/main/resources/zeros.tsv b/nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/zero/zeros.tsv index b9bcdca..b9bcdca 100644 --- a/nicolas-core/src/main/resources/zeros.tsv +++ b/nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/zero/zeros.tsv diff --git a/nicolas-train/src/test/java/pl/waw/ipipan/zil/summ/nicolas/train/multiservice/NLPProcessTest.java b/nicolas-train/src/test/java/pl/waw/ipipan/zil/summ/nicolas/train/multiservice/NLPProcessIT.java index 018c352..d66b72a 100644 --- a/nicolas-train/src/test/java/pl/waw/ipipan/zil/summ/nicolas/train/multiservice/NLPProcessTest.java +++ b/nicolas-train/src/test/java/pl/waw/ipipan/zil/summ/nicolas/train/multiservice/NLPProcessIT.java @@ -1,17 +1,31 @@ package pl.waw.ipipan.zil.summ.nicolas.train.multiservice; +import com.google.common.collect.Lists; +import org.junit.ClassRule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; import java.io.File; +import java.util.List; +import java.util.stream.Collectors; + +import static junit.framework.TestCase.assertEquals; + +public class NLPProcessIT { + + @ClassRule + public static TemporaryFolder TEMPORARY_FOLDER = new TemporaryFolder(); -public class NLPProcessTest { @Test public void shouldProcessSampleText() throws Exception { String text = "Ala ma kota. Ala ma też psa."; TText processed = NLPProcess.annotate(text); - processed.getParagraphs().stream().flatMap(p->p.getSentences().stream()).forEach(s->System.out.println(s.getId())); - File targetFile = new File("sample_serialized_text.bin"); + List<String> ids = processed.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).map(TSentence::getId).collect(Collectors.toList()); + assertEquals(Lists.newArrayList("s-2.1", "s-2.2"), ids); + + File targetFile = TEMPORARY_FOLDER.newFile(); NLPProcess.serialize(processed, targetFile); } } \ No newline at end of file diff --git a/pom.xml b/pom.xml index 81c53ae..115e398 100644 --- a/pom.xml +++ b/pom.xml @@ -11,7 +11,7 @@ <packaging>pom</packaging> <modules> - <module>nicolas-core</module> + <module>nicolas-lib</module> <module>nicolas-cli</module> <module>nicolas-model</module> <module>nicolas-train</module> @@ -26,12 +26,13 @@ <utils.version>1.0</utils.version> <commons-csv.version>1.4</commons-csv.version> - <guava.version>19.0</guava.version> - <weka-dev.version>3.9.0</weka-dev.version> + <guava.version>20.0</guava.version> + <weka-dev.version>3.9.1</weka-dev.version> <commons-lang3.version>3.5</commons-lang3.version> <commons-io.version>2.5</commons-io.version> - <slf4j-api.version>1.7.12</slf4j-api.version> + <slf4j-api.version>1.7.22</slf4j-api.version> <junit.version>4.12</junit.version> + <zip4j.version>1.3.2</zip4j.version> </properties> <prerequisites> @@ -65,6 +66,16 @@ <artifactId>nicolas-zero</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <artifactId>nicolas-lib</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <artifactId>nicolas-train</artifactId> + <version>${project.version}</version> + </dependency> <!-- internal --> <dependency> @@ -93,6 +104,12 @@ <groupId>nz.ac.waikato.cms.weka</groupId> <artifactId>weka-dev</artifactId> <version>${weka-dev.version}</version> + <exclusions> + <exclusion> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-simple</artifactId> + </exclusion> + </exclusions> </dependency> <dependency> <groupId>org.apache.commons</groupId> @@ -104,6 +121,11 @@ <artifactId>commons-io</artifactId> <version>${commons-io.version}</version> </dependency> + <dependency> + <groupId>net.lingala.zip4j</groupId> + <artifactId>zip4j</artifactId> + <version>${zip4j.version}</version> + </dependency> <!-- logging --> <dependency> @@ -111,6 +133,11 @@ <artifactId>slf4j-api</artifactId> <version>${slf4j-api.version}</version> </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-simple</artifactId> + <version>${slf4j-api.version}</version> + </dependency> <!-- test --> <dependency>