From dcb5f3d5e5a615839b70cff10410ab54d9ea2f38 Mon Sep 17 00:00:00 2001 From: Mateusz Kopeć <m.kopec@ipipan.waw.pl> Date: Wed, 12 Apr 2017 10:44:51 +0200 Subject: [PATCH] almost finished :) --- nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java | 7 ++++++- nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/InstanceUtils.java | 8 ++++---- nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/ResourceUtils.java | 4 ---- nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/TextUtils.java | 10 +++++++++- nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/resources/frequent_bases.txt | 2 +- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/PathConstants.java | 2 ++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/search/Crossvalidate.java | 43 +++++++++++++++++++++++++++++++++---------- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Main.java | 3 --- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/MentionScorer.java | 26 +++++++++----------------- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/SentenceScorer.java | 2 +- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/Settings.java | 11 ++++++----- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/CreateOptimalSummaries.java | 20 ++++++++++++++++++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/ExtractMostFrequentMentions.java | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/TrainAllModels.java | 2 -- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractMostFrequentMentions.java | 78 ------------------------------------------------------------------------------ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractStopwords.java | 11 ----------- pom.xml | 45 ++++++++++++++++++++++++++++++++------------- 17 files changed, 186 insertions(+), 151 deletions(-) create mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/ExtractMostFrequentMentions.java delete mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractMostFrequentMentions.java delete mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractStopwords.java diff --git a/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java index 401e396..4cc3cdc 100644 --- a/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java @@ -1,6 +1,7 @@ package pl.waw.ipipan.zil.summ.nicolas; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; @@ -8,6 +9,11 @@ import java.nio.charset.StandardCharsets; public class Constants { public static final ImmutableList<String> POS_TAGS = ImmutableList.of("end", "other", "null", "impt", "imps", "inf", "pred", "subst", "aglt", "ppron3", "ger", "praet", "fin", "num", "interp", "siebie", "brev", "interj", "ppron12", "adj", "burk", "pcon", "bedzie", "adv", "prep", "depr", "xxx", "winien", "conj", "qub", "adja", "ppas", "comp", "pact"); + + public static final ImmutableSet<String> STOP_POS_TAGS = + ImmutableSet.of("brev", "conj", "prep", "interp", "qub", "interj", "siebie", "xxx", "inf", "comp", + "bedzie", "aglt"); + public static final Charset ENCODING = StandardCharsets.UTF_8; private static final String ROOT_PATH = "/pl/waw/ipipan/zil/summ/nicolas/"; @@ -19,7 +25,6 @@ public class Constants { private static final String RESOURCES_PATH = ROOT_PATH + "resources/"; public static final String FREQUENT_BASES_RESOURCE_PATH = RESOURCES_PATH + "frequent_bases.txt"; - public static final String STOPWORDS_PATH = RESOURCES_PATH + "stopwords.txt"; private Constants() { } diff --git a/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/InstanceUtils.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/InstanceUtils.java index a7780f1..5c5a220 100644 --- a/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/InstanceUtils.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/InstanceUtils.java @@ -35,7 +35,7 @@ public class InstanceUtils { List<TSentence> sentences = preprocessedText.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList()); Map<TMention, Map<Attribute, Double>> mention2features = featureExtractor.calculateFeatures(preprocessedText); - LOG.info("Extracting {} features of each mention.", featureExtractor.getAttributesList().size()); + LOG.debug("Extracting {} features of each mention.", featureExtractor.getAttributesList().size()); Map<TMention, Instance> mention2instance = Maps.newHashMap(); for (TMention tMention : sentences.stream().flatMap(s -> s.getMentions().stream()).collect(toList())) { Instance instance = new DenseInstance(featureExtractor.getAttributesList().size()); @@ -45,7 +45,7 @@ public class InstanceUtils { } mention2instance.put(tMention, instance); } - LOG.info("Extracted features of {} mentions.", mention2instance.size()); + LOG.debug("Extracted features of {} mentions.", mention2instance.size()); return mention2instance; } @@ -53,7 +53,7 @@ public class InstanceUtils { List<TSentence> sentences = preprocessedText.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList()); Map<TSentence, Map<Attribute, Double>> sentence2features = featureExtractor.calculateFeatures(preprocessedText, goodMentions); - LOG.info("Extracting {} features of each sentence.", featureExtractor.getAttributesList().size()); + LOG.debug("Extracting {} features of each sentence.", featureExtractor.getAttributesList().size()); Map<TSentence, Instance> sentence2instance = Maps.newHashMap(); for (TSentence sentence : sentences) { Instance instance = new DenseInstance(featureExtractor.getAttributesList().size()); @@ -63,7 +63,7 @@ public class InstanceUtils { } sentence2instance.put(sentence, instance); } - LOG.info("Extracted features of {} sentences.", sentence2instance.size()); + LOG.debug("Extracted features of {} sentences.", sentence2instance.size()); return sentence2instance; } diff --git a/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/ResourceUtils.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/ResourceUtils.java index acdf7d2..5476e17 100644 --- a/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/ResourceUtils.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/ResourceUtils.java @@ -24,10 +24,6 @@ public class ResourceUtils { return loadUniqueLowercaseSortedNonemptyLinesFromResource(Constants.FREQUENT_BASES_RESOURCE_PATH); } - public static List<String> loadStopwords() throws IOException { - return loadUniqueLowercaseSortedNonemptyLinesFromResource(Constants.STOPWORDS_PATH); - } - public static Classifier loadModelFromResource(String modelResourcePath) throws IOException { LOG.info("Loading classifier from path: {}...", modelResourcePath); try (InputStream stream = ResourceUtils.class.getResourceAsStream(modelResourcePath)) { diff --git a/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/TextUtils.java b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/TextUtils.java index d561a70..711d445 100644 --- a/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/TextUtils.java +++ b/nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/TextUtils.java @@ -3,10 +3,12 @@ package pl.waw.ipipan.zil.summ.nicolas.utils; import com.google.common.collect.Sets; import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; import pl.waw.ipipan.zil.multiservice.thrift.types.TToken; +import pl.waw.ipipan.zil.summ.nicolas.Constants; import java.util.Arrays; import java.util.List; import java.util.Set; +import java.util.stream.Collectors; public class TextUtils { @@ -29,7 +31,6 @@ public class TextUtils { StringBuilder sb = new StringBuilder(); for (TToken token : sentence.getTokens()) { if (tokenIdsToSkip.contains(token.getId())) { - System.out.println("Skipping " + token.getOrth() + " in sentence: " + loadSentence2Orth(sentence)); continue; } if (!token.isNoPrecedingSpace()) @@ -38,4 +39,11 @@ public class TextUtils { } return sb.toString().trim(); } + + public static String loadSentence2OrthExcludingStoptags(TSentence sentence) { + Set<String> tokenIdsToSkip = sentence.getTokens().stream() + .filter(token -> Constants.STOP_POS_TAGS.contains(token.getChosenInterpretation().getCtag())) + .map(TToken::getId).collect(Collectors.toSet()); + return loadSentence2Orth(sentence, tokenIdsToSkip); + } } diff --git a/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/resources/frequent_bases.txt b/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/resources/frequent_bases.txt index 973881a..9281ed0 100644 --- a/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/resources/frequent_bases.txt +++ b/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/resources/frequent_bases.txt @@ -234,4 +234,4 @@ ue " rzeczpospolita " liczba wieś -połowa \ No newline at end of file +połowa diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/PathConstants.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/PathConstants.java index 9a3a514..f7c0e1d 100644 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/PathConstants.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/PathConstants.java @@ -43,6 +43,8 @@ public class PathConstants { public static final File SUMMARY_LENGTHS_FILE = new File(WORKING_DIR, "summary-lengths.tsv"); + public static final String TARGET_MODEL_DIR = "nicolas-model/src/main/resources"; + private PathConstants() { } diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/search/Crossvalidate.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/search/Crossvalidate.java index 5cba028..27bc63a 100644 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/search/Crossvalidate.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/search/Crossvalidate.java @@ -34,22 +34,33 @@ import java.util.Optional; import java.util.Random; import java.util.logging.LogManager; +import static pl.waw.ipipan.zil.summ.nicolas.PathConstants.*; + class Crossvalidate { private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); - private static final int NUM_FOLDS = 10; + private static final int NUM_FOLDS = 5; + private static final int MAX_INSTANCES = 10000; + + private static final int SEED = 1; private Crossvalidate() { } - static void crossvalidateClassifiers(String datasetPath) throws IOException { + public static void main(String[] args) throws IOException { + crossvalidateClassifiers(MENTION_ARFF.getPath()); + crossvalidateRegressors(SENTENCE_ARFF.getPath()); + crossvalidateClassifiers(ZERO_ARFF.getPath()); + } + + private static void crossvalidateClassifiers(String datasetPath) throws IOException { Instances instances = loadInstances(datasetPath); crossvalidateClassification(instances); } - static void crossvalidateRegressors(String datasetPath) throws IOException { + private static void crossvalidateRegressors(String datasetPath) throws IOException { Instances instances = loadInstances(datasetPath); crossvalidateRegression(instances); } @@ -62,6 +73,9 @@ class Crossvalidate { Instances instances = loader.getDataSet(); instances.setClassIndex(0); LOG.info("{} instances loaded.", instances.size()); + instances.randomize(new Random(SEED)); + while (instances.size() > MAX_INSTANCES) + instances.delete(instances.size() - 1); LOG.info("{} attributes for each instance.", instances.numAttributes()); return instances; } @@ -70,7 +84,8 @@ class Crossvalidate { StopWatch watch = new StopWatch(); watch.start(); - Optional<Pair<Double, String>> max = Arrays.stream(new Classifier[]{new J48(), new RandomForest(), new HoeffdingTree(), new LMT(), + Optional<Pair<Double, String>> max = Arrays.stream(new Classifier[]{ + new J48(), new RandomForest(), new HoeffdingTree(), new LMT(), new Logistic(), new ZeroR(), new SimpleLogistic(), new BayesNet(), new NaiveBayes(), new KStar(), new IBk(), new LWL(), @@ -81,7 +96,7 @@ class Crossvalidate { Evaluation eval; try { eval = new Evaluation(instances); - eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(1)); + eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(SEED)); } catch (Exception e) { LOG.error("Error evaluating model", e); return Pair.of(0.0, name); @@ -90,9 +105,13 @@ class Crossvalidate { LOG.info(name + " : " + acc); return Pair.of(acc, name); }).max(Comparator.comparingDouble(Pair::getLeft)); - LOG.info("#########"); - LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); + LOG.info("#########"); + if (max.isPresent()) { + LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); + } else { + LOG.info("Empty algorithms list"); + } watch.stop(); LOG.info("Elapsed time: {}", watch); } @@ -114,7 +133,7 @@ class Crossvalidate { String name = cls.getClass().getSimpleName(); try { Evaluation eval = new Evaluation(instances); - eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(1)); + eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(SEED)); acc = eval.correlationCoefficient(); } catch (Exception e) { LOG.error("Error evaluating model", e); @@ -122,9 +141,13 @@ class Crossvalidate { LOG.info(name + " : " + acc); return Pair.of(acc, name); }).max(Comparator.comparingDouble(Pair::getLeft)); - LOG.info("#########"); - LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); + LOG.info("#########"); + if (max.isPresent()) { + LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); + } else { + LOG.info("Empty algorithms list"); + } watch.stop(); LOG.info("Elapsed time: {}", watch); } diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Main.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Main.java index f418514..3c021fc 100644 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Main.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Main.java @@ -1,8 +1,6 @@ package pl.waw.ipipan.zil.summ.nicolas.train; import pl.waw.ipipan.zil.summ.nicolas.train.pipeline.*; -import pl.waw.ipipan.zil.summ.nicolas.train.resources.ExtractMostFrequentMentions; -import pl.waw.ipipan.zil.summ.nicolas.train.resources.ExtractStopwords; public class Main { @@ -14,7 +12,6 @@ public class Main { DownloadTrainingResources.main(args); ExtractGoldSummaries.main(args); CreateOptimalSummaries.main(args); - ExtractStopwords.main(args); ExtractMostFrequentMentions.main(args); PrepareTrainingData.main(args); TrainAllModels.main(args); diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/MentionScorer.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/MentionScorer.java index aec39ae..aa341fc 100644 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/MentionScorer.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/MentionScorer.java @@ -7,23 +7,16 @@ import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; import pl.waw.ipipan.zil.multiservice.thrift.types.TToken; -import pl.waw.ipipan.zil.summ.nicolas.utils.ResourceUtils; import pl.waw.ipipan.zil.summ.nicolas.utils.TextUtils; -import java.io.IOException; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; -public class MentionScorer { - - private final Set<String> STOPWORDS; +import static pl.waw.ipipan.zil.summ.nicolas.Constants.STOP_POS_TAGS; - public MentionScorer() throws IOException { - STOPWORDS = ResourceUtils.loadStopwords().stream().collect(Collectors.toSet()); - } +public class MentionScorer { public Map<TMention, Double> calculateMentionScores(String optimalSummary, TText text) { Multiset<String> tokenCounts = HashMultiset.create(TextUtils.tokenize(optimalSummary.toLowerCase())); @@ -34,24 +27,23 @@ public class MentionScorer { return booleanTokenIntersection(mention2Orth, tokenCounts); } - private Map<TMention, String> loadMention2OrthExcludingStopwords(List<TSentence> sents) { + private Map<TMention, String> loadMention2OrthExcludingStopwords(List<TSentence> sentences) { Map<TMention, String> mention2orth = Maps.newHashMap(); - for (TSentence s : sents) { - Map<String, TToken> tokId2tok = s.getTokens().stream().collect(Collectors.toMap(TToken::getId, Function.identity())); + for (TSentence sentence : sentences) { + Map<String, TToken> tokId2tok = sentence.getTokens().stream().collect(Collectors.toMap(TToken::getId, Function.identity())); - for (TMention m : s.getMentions()) { + for (TMention mention : sentence.getMentions()) { StringBuilder mentionOrth = new StringBuilder(); - for (String tokId : m.getChildIds()) { + for (String tokId : mention.getChildIds()) { TToken token = tokId2tok.get(tokId); - if (STOPWORDS.contains(token.getChosenInterpretation().getBase().toLowerCase())) { + if (STOP_POS_TAGS.contains(token.getChosenInterpretation().getCtag())) continue; - } if (!token.isNoPrecedingSpace()) mentionOrth.append(" "); mentionOrth.append(token.getOrth()); } - mention2orth.put(m, mentionOrth.toString().trim()); + mention2orth.put(mention, mentionOrth.toString().trim()); } } return mention2orth; diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/SentenceScorer.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/SentenceScorer.java index dcdc297..3259c3a 100644 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/SentenceScorer.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/SentenceScorer.java @@ -21,7 +21,7 @@ public class SentenceScorer { for (TSentence sentence : paragraph.getSentences()) { double score = 0.0; - String orth = TextUtils.loadSentence2Orth(sentence); + String orth = TextUtils.loadSentence2OrthExcludingStoptags(sentence); List<String> tokens = TextUtils.tokenize(orth); for (String token : tokens) { score += tokenCounts.contains(token.toLowerCase()) ? 1.0 : 0.0; diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/Settings.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/Settings.java index d73c945..923fb2b 100644 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/Settings.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/Settings.java @@ -1,11 +1,13 @@ package pl.waw.ipipan.zil.summ.nicolas.train.model; import weka.classifiers.Classifier; +import weka.classifiers.meta.AttributeSelectedClassifier; +import weka.classifiers.trees.LMT; import weka.classifiers.trees.RandomForest; public class Settings { - private static final int NUM_ITERATIONS = 20; + private static final int NUM_ITERATIONS = 100; private static final int NUM_EXECUTION_SLOTS = 8; private static final int SEED = 0; @@ -29,10 +31,9 @@ public class Settings { } public static Classifier getZeroClassifier() { - RandomForest classifier = new RandomForest(); - classifier.setNumIterations(NUM_ITERATIONS); - classifier.setSeed(SEED); - classifier.setNumExecutionSlots(NUM_EXECUTION_SLOTS); + AttributeSelectedClassifier classifier = new AttributeSelectedClassifier(); + LMT subClassifier = new LMT(); + classifier.setClassifier(subClassifier); return classifier; } diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/CreateOptimalSummaries.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/CreateOptimalSummaries.java index 8dde011..0f91f1e 100644 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/CreateOptimalSummaries.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/CreateOptimalSummaries.java @@ -59,14 +59,34 @@ public class CreateOptimalSummaries { int summaryWordCount = 0; StringBuilder summary = new StringBuilder(); + List<Integer> bestNgramCounts = null; while (averageGoldWordCount >= summaryWordCount) { + bestNgramCounts = getBestNgramCounts(ngram2counts); List<String> ngram = pickBestNgram(ngram2counts); summary.append(" ").append(String.join(" ", ngram)); summaryWordCount += ngram.size(); } + + for (Map.Entry<List<String>, List<Integer>> entry : ngram2counts.entrySet()) { + if (entry.getValue().equals(bestNgramCounts)) { + List<String> ngram = entry.getKey(); + summary.append(" ").append(String.join(" ", ngram)); + } + } + return summary.toString().trim(); } + private static List<Integer> getBestNgramCounts(Map<List<String>, List<Integer>> ngram2counts) { + Optional<List<String>> optional = ngram2counts.keySet().stream() + .sorted(Comparator.comparing((List<String> ngram) -> ngram2counts.get(ngram).size()).reversed()).findFirst(); + if (!optional.isPresent()) { + throw new IllegalArgumentException("No more ngrams to pick!"); + } + List<String> optimalNgram = optional.get(); + return ngram2counts.get(optimalNgram); + } + private static List<String> pickBestNgram(Map<List<String>, List<Integer>> ngram2counts) { Optional<List<String>> optional = ngram2counts.keySet().stream() .sorted(Comparator.comparing((List<String> ngram) -> ngram2counts.get(ngram).size()).reversed()).findFirst(); diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/ExtractMostFrequentMentions.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/ExtractMostFrequentMentions.java new file mode 100644 index 0000000..44d8bb7 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/ExtractMostFrequentMentions.java @@ -0,0 +1,63 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; + +import com.google.common.collect.*; +import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.Constants; +import pl.waw.ipipan.zil.summ.nicolas.CorpusHelper; +import pl.waw.ipipan.zil.summ.nicolas.PathConstants; +import pl.waw.ipipan.zil.summ.nicolas.features.FeatureHelper; +import pl.waw.ipipan.zil.summ.nicolas.utils.thrift.ThriftUtils; + +import javax.xml.bind.JAXBException; +import java.io.BufferedWriter; +import java.io.FileWriter; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static pl.waw.ipipan.zil.summ.nicolas.PathConstants.PREPROCESSED_CORPUS_DIR; + +public class ExtractMostFrequentMentions { + + private static final int MIN_MENTION_DOCUMENT_COUNT = 50; + + private ExtractMostFrequentMentions() { + } + + public static void main(String[] args) throws IOException, JAXBException { + List<String> mostFrequentMentionBases = getMostFrequentMentionBases(); + try (BufferedWriter bw = new BufferedWriter(new FileWriter(PathConstants.TARGET_MODEL_DIR + Constants.FREQUENT_BASES_RESOURCE_PATH))) { + for (String base : mostFrequentMentionBases) { + bw.write(base + "\n"); + } + } + } + + private static List<String> getMostFrequentMentionBases() throws IOException { + Set<String> trainTextIds = CorpusHelper.loadTrainTextIds(); + Map<String, TText> id2preprocessedText = ThriftUtils.loadThriftTextsFromFolder(PREPROCESSED_CORPUS_DIR, trainTextIds::contains); + + Multiset<String> mentionBases = HashMultiset.create(); + + for (TText text : id2preprocessedText.values()) { + FeatureHelper featureHelper = new FeatureHelper(text); + Set<String> textMentionBases = Sets.newHashSet(); + for (TMention mention : featureHelper.getMentions()) { + String mentionBase = featureHelper.getMentionBase(mention); + textMentionBases.add(mentionBase); + } + mentionBases.addAll(textMentionBases); + } + + ImmutableMultiset<String> sorted = Multisets.copyHighestCountFirst(mentionBases); + List<String> mostFrequentMentions = Lists.newArrayList(); + for (Multiset.Entry<String> entry : sorted.entrySet()) { + if (entry.getCount() < MIN_MENTION_DOCUMENT_COUNT) + break; + mostFrequentMentions.add(entry.getElement()); + } + return mostFrequentMentions; + } +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/TrainAllModels.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/TrainAllModels.java index d186dcc..e6d9588 100644 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/TrainAllModels.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/TrainAllModels.java @@ -20,8 +20,6 @@ public class TrainAllModels { private static final Logger LOG = LoggerFactory.getLogger(TrainAllModels.class); - private static final String TARGET_MODEL_DIR = "nicolas-model/src/main/resources"; - private TrainAllModels() { } diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractMostFrequentMentions.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractMostFrequentMentions.java deleted file mode 100644 index 349888c..0000000 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractMostFrequentMentions.java +++ /dev/null @@ -1,78 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.train.resources; - -import com.google.common.collect.*; -import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; -import pl.waw.ipipan.zil.multiservice.thrift.types.TText; -import pl.waw.ipipan.zil.summ.nicolas.utils.thrift.ThriftUtils; -import pl.waw.ipipan.zil.summ.pscapi.io.PSC_IO; -import pl.waw.ipipan.zil.summ.pscapi.xml.Text; - -import javax.xml.bind.JAXBException; -import java.io.File; -import java.io.IOException; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -public class ExtractMostFrequentMentions { - - public static final String GOLD_DATA_PATH = "/home/me2/Dropbox/3_nauka/3_doktorat/3_korpus_streszczen/dist/src/data/"; - - public static final String THRIFTED_PREFIX = "/home/me2/Desktop/thrifted_texts/thrifted_all/"; - public static final String THRIFTED_SUFFIX = "/original"; - - public static void main(String[] args) throws IOException, JAXBException { - - Set<String> devIds = Sets.newHashSet(); - - File goldDir = new File(GOLD_DATA_PATH); - for (File file : goldDir.listFiles()) { - Text goldText = PSC_IO.readText(file); - if (goldText.getSummaries().getSummary().stream().anyMatch(s -> s.getType().equals("abstract"))) - continue; - - devIds.add(file.getName().replace(".xml", "")); - } - - - System.out.println(devIds.size()); - - Multiset<String> mentionCounts = HashMultiset.create(); - for (String id : devIds) { - Set<String> distinctTextMentions = Sets.newHashSet(); - File input = new File(THRIFTED_PREFIX + id + THRIFTED_SUFFIX); - TText thrifted = ThriftUtils.loadThriftTextFromFile(input); - List<TSentence> sents = thrifted.getParagraphs().stream() - .flatMap(p -> p.getSentences().stream()).collect(Collectors.toList()); - - Map<String, String> tokenId2base = Maps.newHashMap(); - sents.stream() - .flatMap(s -> s.getTokens().stream()) - .forEach(token -> tokenId2base.put(token.getId(), token.getChosenInterpretation().getBase())); - - sents.stream().flatMap(s -> s.getMentions().stream()).forEach(m -> { - StringBuffer sb = new StringBuffer(); - for (String tokId : m.getChildIds()) { - sb.append(tokenId2base.get(tokId) + " "); - } - distinctTextMentions.add(sb.toString().trim().toLowerCase()); - }); - - mentionCounts.addAll(distinctTextMentions); - } - - System.out.println(mentionCounts.elementSet().size()); - List<String> sorted = Lists.newArrayList(); - sorted.addAll(mentionCounts.elementSet()); - sorted.sort(Comparator.comparing(mentionCounts::count).reversed()); - int i = 0; - for (String mention : sorted) { - if (mentionCounts.count(mention) < 50) - break; - System.out.println(mention); - } - - } -} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractStopwords.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractStopwords.java deleted file mode 100644 index 31db95c..0000000 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractStopwords.java +++ /dev/null @@ -1,11 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.train.resources; - -public class ExtractStopwords { - - private ExtractStopwords() { - } - - public static void main(String[] args) { - - } -} diff --git a/pom.xml b/pom.xml index 0d91777..4bc9dcb 100644 --- a/pom.xml +++ b/pom.xml @@ -35,9 +35,23 @@ <slf4j-api.version>1.7.24</slf4j-api.version> <junit.version>4.12</junit.version> <zip4j.version>1.3.2</zip4j.version> - <mockito-core.version>2.7.17</mockito-core.version> + <mockito-core.version>2.7.22</mockito-core.version> <jcommander.version>1.64</jcommander.version> <libthrift.version>0.9.0</libthrift.version> + + <jacoco-maven-plugin.version>0.7.8</jacoco-maven-plugin.version> + <maven-compiler-plugin.version>3.5.1</maven-compiler-plugin.version> + <maven-site-plugin.version>3.5.1</maven-site-plugin.version> + <maven-dependency-plugin.version>3.0.0</maven-dependency-plugin.version> + <maven-jar-plugin.version>3.0.2</maven-jar-plugin.version> + <maven-resources-plugin.version>3.0.1</maven-resources-plugin.version> + <maven-clean-plugin.version>3.0.0</maven-clean-plugin.version> + <maven-install-plugin.version>2.5.2</maven-install-plugin.version> + <maven-deploy-plugin.version>2.8.2</maven-deploy-plugin.version> + <maven-assembly-plugin.version>2.6</maven-assembly-plugin.version> + <maven-project-info-reports-plugin.version>2.9</maven-project-info-reports-plugin.version> + <maven-surefire-plugin.version>2.19.1</maven-surefire-plugin.version> + <maven-failsafe-plugin.version>2.19.1</maven-failsafe-plugin.version> </properties> <prerequisites> @@ -183,47 +197,47 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-dependency-plugin</artifactId> - <version>3.0.0</version> + <version>${maven-dependency-plugin.version}</version> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-jar-plugin</artifactId> - <version>3.0.2</version> + <version>${maven-jar-plugin.version}</version> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-resources-plugin</artifactId> - <version>3.0.1</version> + <version>${maven-resources-plugin.version}</version> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-clean-plugin</artifactId> - <version>3.0.0</version> + <version>${maven-clean-plugin.version}</version> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-site-plugin</artifactId> - <version>3.5.1</version> + <version>${maven-site-plugin.version}</version> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-install-plugin</artifactId> - <version>2.5.2</version> + <version>${maven-install-plugin.version}</version> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-deploy-plugin</artifactId> - <version>2.8.2</version> + <version>${maven-deploy-plugin.version}</version> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-assembly-plugin</artifactId> - <version>2.6</version> + <version>${maven-assembly-plugin.version}</version> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> - <version>3.5.1</version> + <version>${maven-compiler-plugin.version}</version> <configuration> <source>${java.version.build}</source> <target>${java.version.build}</target> @@ -231,8 +245,13 @@ </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-project-info-reports-plugin</artifactId> + <version>${maven-project-info-reports-plugin.version}</version> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-surefire-plugin</artifactId> - <version>2.19.1</version> + <version>${maven-surefire-plugin.version}</version> <configuration> <!-- Sets the VM argument line used when unit tests are run. --> <argLine>${surefireArgLine}</argLine> @@ -247,7 +266,7 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-failsafe-plugin</artifactId> - <version>2.19.1</version> + <version>${maven-failsafe-plugin.version}</version> <executions> <execution> <id>integration-test</id> @@ -276,7 +295,7 @@ <plugin> <groupId>org.jacoco</groupId> <artifactId>jacoco-maven-plugin</artifactId> - <version>0.7.8</version> + <version>${jacoco-maven-plugin.version}</version> <executions> <!-- Prepares the property pointing to the JaCoCo runtime agent which -- libgit2 0.22.2