Commit dcb5f3d5e5a615839b70cff10410ab54d9ea2f38

Authored by Mateusz Kopeć
1 parent 156b3707

almost finished :)

nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java
1 package pl.waw.ipipan.zil.summ.nicolas; 1 package pl.waw.ipipan.zil.summ.nicolas;
2 2
3 import com.google.common.collect.ImmutableList; 3 import com.google.common.collect.ImmutableList;
  4 +import com.google.common.collect.ImmutableSet;
4 5
5 import java.nio.charset.Charset; 6 import java.nio.charset.Charset;
6 import java.nio.charset.StandardCharsets; 7 import java.nio.charset.StandardCharsets;
@@ -8,6 +9,11 @@ import java.nio.charset.StandardCharsets; @@ -8,6 +9,11 @@ import java.nio.charset.StandardCharsets;
8 public class Constants { 9 public class Constants {
9 10
10 public static final ImmutableList<String> POS_TAGS = ImmutableList.of("end", "other", "null", "impt", "imps", "inf", "pred", "subst", "aglt", "ppron3", "ger", "praet", "fin", "num", "interp", "siebie", "brev", "interj", "ppron12", "adj", "burk", "pcon", "bedzie", "adv", "prep", "depr", "xxx", "winien", "conj", "qub", "adja", "ppas", "comp", "pact"); 11 public static final ImmutableList<String> POS_TAGS = ImmutableList.of("end", "other", "null", "impt", "imps", "inf", "pred", "subst", "aglt", "ppron3", "ger", "praet", "fin", "num", "interp", "siebie", "brev", "interj", "ppron12", "adj", "burk", "pcon", "bedzie", "adv", "prep", "depr", "xxx", "winien", "conj", "qub", "adja", "ppas", "comp", "pact");
  12 +
  13 + public static final ImmutableSet<String> STOP_POS_TAGS =
  14 + ImmutableSet.of("brev", "conj", "prep", "interp", "qub", "interj", "siebie", "xxx", "inf", "comp",
  15 + "bedzie", "aglt");
  16 +
11 public static final Charset ENCODING = StandardCharsets.UTF_8; 17 public static final Charset ENCODING = StandardCharsets.UTF_8;
12 18
13 private static final String ROOT_PATH = "/pl/waw/ipipan/zil/summ/nicolas/"; 19 private static final String ROOT_PATH = "/pl/waw/ipipan/zil/summ/nicolas/";
@@ -19,7 +25,6 @@ public class Constants { @@ -19,7 +25,6 @@ public class Constants {
19 25
20 private static final String RESOURCES_PATH = ROOT_PATH + "resources/"; 26 private static final String RESOURCES_PATH = ROOT_PATH + "resources/";
21 public static final String FREQUENT_BASES_RESOURCE_PATH = RESOURCES_PATH + "frequent_bases.txt"; 27 public static final String FREQUENT_BASES_RESOURCE_PATH = RESOURCES_PATH + "frequent_bases.txt";
22 - public static final String STOPWORDS_PATH = RESOURCES_PATH + "stopwords.txt";  
23 28
24 private Constants() { 29 private Constants() {
25 } 30 }
nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/InstanceUtils.java
@@ -35,7 +35,7 @@ public class InstanceUtils { @@ -35,7 +35,7 @@ public class InstanceUtils {
35 List<TSentence> sentences = preprocessedText.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList()); 35 List<TSentence> sentences = preprocessedText.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList());
36 Map<TMention, Map<Attribute, Double>> mention2features = featureExtractor.calculateFeatures(preprocessedText); 36 Map<TMention, Map<Attribute, Double>> mention2features = featureExtractor.calculateFeatures(preprocessedText);
37 37
38 - LOG.info("Extracting {} features of each mention.", featureExtractor.getAttributesList().size()); 38 + LOG.debug("Extracting {} features of each mention.", featureExtractor.getAttributesList().size());
39 Map<TMention, Instance> mention2instance = Maps.newHashMap(); 39 Map<TMention, Instance> mention2instance = Maps.newHashMap();
40 for (TMention tMention : sentences.stream().flatMap(s -> s.getMentions().stream()).collect(toList())) { 40 for (TMention tMention : sentences.stream().flatMap(s -> s.getMentions().stream()).collect(toList())) {
41 Instance instance = new DenseInstance(featureExtractor.getAttributesList().size()); 41 Instance instance = new DenseInstance(featureExtractor.getAttributesList().size());
@@ -45,7 +45,7 @@ public class InstanceUtils { @@ -45,7 +45,7 @@ public class InstanceUtils {
45 } 45 }
46 mention2instance.put(tMention, instance); 46 mention2instance.put(tMention, instance);
47 } 47 }
48 - LOG.info("Extracted features of {} mentions.", mention2instance.size()); 48 + LOG.debug("Extracted features of {} mentions.", mention2instance.size());
49 return mention2instance; 49 return mention2instance;
50 } 50 }
51 51
@@ -53,7 +53,7 @@ public class InstanceUtils { @@ -53,7 +53,7 @@ public class InstanceUtils {
53 List<TSentence> sentences = preprocessedText.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList()); 53 List<TSentence> sentences = preprocessedText.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList());
54 Map<TSentence, Map<Attribute, Double>> sentence2features = featureExtractor.calculateFeatures(preprocessedText, goodMentions); 54 Map<TSentence, Map<Attribute, Double>> sentence2features = featureExtractor.calculateFeatures(preprocessedText, goodMentions);
55 55
56 - LOG.info("Extracting {} features of each sentence.", featureExtractor.getAttributesList().size()); 56 + LOG.debug("Extracting {} features of each sentence.", featureExtractor.getAttributesList().size());
57 Map<TSentence, Instance> sentence2instance = Maps.newHashMap(); 57 Map<TSentence, Instance> sentence2instance = Maps.newHashMap();
58 for (TSentence sentence : sentences) { 58 for (TSentence sentence : sentences) {
59 Instance instance = new DenseInstance(featureExtractor.getAttributesList().size()); 59 Instance instance = new DenseInstance(featureExtractor.getAttributesList().size());
@@ -63,7 +63,7 @@ public class InstanceUtils { @@ -63,7 +63,7 @@ public class InstanceUtils {
63 } 63 }
64 sentence2instance.put(sentence, instance); 64 sentence2instance.put(sentence, instance);
65 } 65 }
66 - LOG.info("Extracted features of {} sentences.", sentence2instance.size()); 66 + LOG.debug("Extracted features of {} sentences.", sentence2instance.size());
67 return sentence2instance; 67 return sentence2instance;
68 } 68 }
69 69
nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/ResourceUtils.java
@@ -24,10 +24,6 @@ public class ResourceUtils { @@ -24,10 +24,6 @@ public class ResourceUtils {
24 return loadUniqueLowercaseSortedNonemptyLinesFromResource(Constants.FREQUENT_BASES_RESOURCE_PATH); 24 return loadUniqueLowercaseSortedNonemptyLinesFromResource(Constants.FREQUENT_BASES_RESOURCE_PATH);
25 } 25 }
26 26
27 - public static List<String> loadStopwords() throws IOException {  
28 - return loadUniqueLowercaseSortedNonemptyLinesFromResource(Constants.STOPWORDS_PATH);  
29 - }  
30 -  
31 public static Classifier loadModelFromResource(String modelResourcePath) throws IOException { 27 public static Classifier loadModelFromResource(String modelResourcePath) throws IOException {
32 LOG.info("Loading classifier from path: {}...", modelResourcePath); 28 LOG.info("Loading classifier from path: {}...", modelResourcePath);
33 try (InputStream stream = ResourceUtils.class.getResourceAsStream(modelResourcePath)) { 29 try (InputStream stream = ResourceUtils.class.getResourceAsStream(modelResourcePath)) {
nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/TextUtils.java
@@ -3,10 +3,12 @@ package pl.waw.ipipan.zil.summ.nicolas.utils; @@ -3,10 +3,12 @@ package pl.waw.ipipan.zil.summ.nicolas.utils;
3 import com.google.common.collect.Sets; 3 import com.google.common.collect.Sets;
4 import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; 4 import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence;
5 import pl.waw.ipipan.zil.multiservice.thrift.types.TToken; 5 import pl.waw.ipipan.zil.multiservice.thrift.types.TToken;
  6 +import pl.waw.ipipan.zil.summ.nicolas.Constants;
6 7
7 import java.util.Arrays; 8 import java.util.Arrays;
8 import java.util.List; 9 import java.util.List;
9 import java.util.Set; 10 import java.util.Set;
  11 +import java.util.stream.Collectors;
10 12
11 public class TextUtils { 13 public class TextUtils {
12 14
@@ -29,7 +31,6 @@ public class TextUtils { @@ -29,7 +31,6 @@ public class TextUtils {
29 StringBuilder sb = new StringBuilder(); 31 StringBuilder sb = new StringBuilder();
30 for (TToken token : sentence.getTokens()) { 32 for (TToken token : sentence.getTokens()) {
31 if (tokenIdsToSkip.contains(token.getId())) { 33 if (tokenIdsToSkip.contains(token.getId())) {
32 - System.out.println("Skipping " + token.getOrth() + " in sentence: " + loadSentence2Orth(sentence));  
33 continue; 34 continue;
34 } 35 }
35 if (!token.isNoPrecedingSpace()) 36 if (!token.isNoPrecedingSpace())
@@ -38,4 +39,11 @@ public class TextUtils { @@ -38,4 +39,11 @@ public class TextUtils {
38 } 39 }
39 return sb.toString().trim(); 40 return sb.toString().trim();
40 } 41 }
  42 +
  43 + public static String loadSentence2OrthExcludingStoptags(TSentence sentence) {
  44 + Set<String> tokenIdsToSkip = sentence.getTokens().stream()
  45 + .filter(token -> Constants.STOP_POS_TAGS.contains(token.getChosenInterpretation().getCtag()))
  46 + .map(TToken::getId).collect(Collectors.toSet());
  47 + return loadSentence2Orth(sentence, tokenIdsToSkip);
  48 + }
41 } 49 }
nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/resources/frequent_bases.txt
@@ -234,4 +234,4 @@ ue @@ -234,4 +234,4 @@ ue
234 " rzeczpospolita " 234 " rzeczpospolita "
235 liczba 235 liczba
236 wieś 236 wieś
237 -połowa  
238 \ No newline at end of file 237 \ No newline at end of file
  238 +połowa
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/PathConstants.java
@@ -43,6 +43,8 @@ public class PathConstants { @@ -43,6 +43,8 @@ public class PathConstants {
43 43
44 public static final File SUMMARY_LENGTHS_FILE = new File(WORKING_DIR, "summary-lengths.tsv"); 44 public static final File SUMMARY_LENGTHS_FILE = new File(WORKING_DIR, "summary-lengths.tsv");
45 45
  46 + public static final String TARGET_MODEL_DIR = "nicolas-model/src/main/resources";
  47 +
46 private PathConstants() { 48 private PathConstants() {
47 } 49 }
48 50
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/search/Crossvalidate.java
@@ -34,22 +34,33 @@ import java.util.Optional; @@ -34,22 +34,33 @@ import java.util.Optional;
34 import java.util.Random; 34 import java.util.Random;
35 import java.util.logging.LogManager; 35 import java.util.logging.LogManager;
36 36
  37 +import static pl.waw.ipipan.zil.summ.nicolas.PathConstants.*;
  38 +
37 39
38 class Crossvalidate { 40 class Crossvalidate {
39 41
40 private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); 42 private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class);
41 43
42 - private static final int NUM_FOLDS = 10; 44 + private static final int NUM_FOLDS = 5;
  45 + private static final int MAX_INSTANCES = 10000;
  46 +
  47 + private static final int SEED = 1;
43 48
44 private Crossvalidate() { 49 private Crossvalidate() {
45 } 50 }
46 51
47 - static void crossvalidateClassifiers(String datasetPath) throws IOException { 52 + public static void main(String[] args) throws IOException {
  53 + crossvalidateClassifiers(MENTION_ARFF.getPath());
  54 + crossvalidateRegressors(SENTENCE_ARFF.getPath());
  55 + crossvalidateClassifiers(ZERO_ARFF.getPath());
  56 + }
  57 +
  58 + private static void crossvalidateClassifiers(String datasetPath) throws IOException {
48 Instances instances = loadInstances(datasetPath); 59 Instances instances = loadInstances(datasetPath);
49 crossvalidateClassification(instances); 60 crossvalidateClassification(instances);
50 } 61 }
51 62
52 - static void crossvalidateRegressors(String datasetPath) throws IOException { 63 + private static void crossvalidateRegressors(String datasetPath) throws IOException {
53 Instances instances = loadInstances(datasetPath); 64 Instances instances = loadInstances(datasetPath);
54 crossvalidateRegression(instances); 65 crossvalidateRegression(instances);
55 } 66 }
@@ -62,6 +73,9 @@ class Crossvalidate { @@ -62,6 +73,9 @@ class Crossvalidate {
62 Instances instances = loader.getDataSet(); 73 Instances instances = loader.getDataSet();
63 instances.setClassIndex(0); 74 instances.setClassIndex(0);
64 LOG.info("{} instances loaded.", instances.size()); 75 LOG.info("{} instances loaded.", instances.size());
  76 + instances.randomize(new Random(SEED));
  77 + while (instances.size() > MAX_INSTANCES)
  78 + instances.delete(instances.size() - 1);
65 LOG.info("{} attributes for each instance.", instances.numAttributes()); 79 LOG.info("{} attributes for each instance.", instances.numAttributes());
66 return instances; 80 return instances;
67 } 81 }
@@ -70,7 +84,8 @@ class Crossvalidate { @@ -70,7 +84,8 @@ class Crossvalidate {
70 StopWatch watch = new StopWatch(); 84 StopWatch watch = new StopWatch();
71 watch.start(); 85 watch.start();
72 86
73 - Optional<Pair<Double, String>> max = Arrays.stream(new Classifier[]{new J48(), new RandomForest(), new HoeffdingTree(), new LMT(), 87 + Optional<Pair<Double, String>> max = Arrays.stream(new Classifier[]{
  88 + new J48(), new RandomForest(), new HoeffdingTree(), new LMT(),
74 new Logistic(), new ZeroR(), 89 new Logistic(), new ZeroR(),
75 new SimpleLogistic(), new BayesNet(), new NaiveBayes(), 90 new SimpleLogistic(), new BayesNet(), new NaiveBayes(),
76 new KStar(), new IBk(), new LWL(), 91 new KStar(), new IBk(), new LWL(),
@@ -81,7 +96,7 @@ class Crossvalidate { @@ -81,7 +96,7 @@ class Crossvalidate {
81 Evaluation eval; 96 Evaluation eval;
82 try { 97 try {
83 eval = new Evaluation(instances); 98 eval = new Evaluation(instances);
84 - eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(1)); 99 + eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(SEED));
85 } catch (Exception e) { 100 } catch (Exception e) {
86 LOG.error("Error evaluating model", e); 101 LOG.error("Error evaluating model", e);
87 return Pair.of(0.0, name); 102 return Pair.of(0.0, name);
@@ -90,9 +105,13 @@ class Crossvalidate { @@ -90,9 +105,13 @@ class Crossvalidate {
90 LOG.info(name + " : " + acc); 105 LOG.info(name + " : " + acc);
91 return Pair.of(acc, name); 106 return Pair.of(acc, name);
92 }).max(Comparator.comparingDouble(Pair::getLeft)); 107 }).max(Comparator.comparingDouble(Pair::getLeft));
93 - LOG.info("#########");  
94 - LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft());  
95 108
  109 + LOG.info("#########");
  110 + if (max.isPresent()) {
  111 + LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft());
  112 + } else {
  113 + LOG.info("Empty algorithms list");
  114 + }
96 watch.stop(); 115 watch.stop();
97 LOG.info("Elapsed time: {}", watch); 116 LOG.info("Elapsed time: {}", watch);
98 } 117 }
@@ -114,7 +133,7 @@ class Crossvalidate { @@ -114,7 +133,7 @@ class Crossvalidate {
114 String name = cls.getClass().getSimpleName(); 133 String name = cls.getClass().getSimpleName();
115 try { 134 try {
116 Evaluation eval = new Evaluation(instances); 135 Evaluation eval = new Evaluation(instances);
117 - eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(1)); 136 + eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(SEED));
118 acc = eval.correlationCoefficient(); 137 acc = eval.correlationCoefficient();
119 } catch (Exception e) { 138 } catch (Exception e) {
120 LOG.error("Error evaluating model", e); 139 LOG.error("Error evaluating model", e);
@@ -122,9 +141,13 @@ class Crossvalidate { @@ -122,9 +141,13 @@ class Crossvalidate {
122 LOG.info(name + " : " + acc); 141 LOG.info(name + " : " + acc);
123 return Pair.of(acc, name); 142 return Pair.of(acc, name);
124 }).max(Comparator.comparingDouble(Pair::getLeft)); 143 }).max(Comparator.comparingDouble(Pair::getLeft));
125 - LOG.info("#########");  
126 - LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft());  
127 144
  145 + LOG.info("#########");
  146 + if (max.isPresent()) {
  147 + LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft());
  148 + } else {
  149 + LOG.info("Empty algorithms list");
  150 + }
128 watch.stop(); 151 watch.stop();
129 LOG.info("Elapsed time: {}", watch); 152 LOG.info("Elapsed time: {}", watch);
130 } 153 }
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Main.java
1 package pl.waw.ipipan.zil.summ.nicolas.train; 1 package pl.waw.ipipan.zil.summ.nicolas.train;
2 2
3 import pl.waw.ipipan.zil.summ.nicolas.train.pipeline.*; 3 import pl.waw.ipipan.zil.summ.nicolas.train.pipeline.*;
4 -import pl.waw.ipipan.zil.summ.nicolas.train.resources.ExtractMostFrequentMentions;  
5 -import pl.waw.ipipan.zil.summ.nicolas.train.resources.ExtractStopwords;  
6 4
7 public class Main { 5 public class Main {
8 6
@@ -14,7 +12,6 @@ public class Main { @@ -14,7 +12,6 @@ public class Main {
14 DownloadTrainingResources.main(args); 12 DownloadTrainingResources.main(args);
15 ExtractGoldSummaries.main(args); 13 ExtractGoldSummaries.main(args);
16 CreateOptimalSummaries.main(args); 14 CreateOptimalSummaries.main(args);
17 - ExtractStopwords.main(args);  
18 ExtractMostFrequentMentions.main(args); 15 ExtractMostFrequentMentions.main(args);
19 PrepareTrainingData.main(args); 16 PrepareTrainingData.main(args);
20 TrainAllModels.main(args); 17 TrainAllModels.main(args);
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/MentionScorer.java
@@ -7,23 +7,16 @@ import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; @@ -7,23 +7,16 @@ import pl.waw.ipipan.zil.multiservice.thrift.types.TMention;
7 import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; 7 import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence;
8 import pl.waw.ipipan.zil.multiservice.thrift.types.TText; 8 import pl.waw.ipipan.zil.multiservice.thrift.types.TText;
9 import pl.waw.ipipan.zil.multiservice.thrift.types.TToken; 9 import pl.waw.ipipan.zil.multiservice.thrift.types.TToken;
10 -import pl.waw.ipipan.zil.summ.nicolas.utils.ResourceUtils;  
11 import pl.waw.ipipan.zil.summ.nicolas.utils.TextUtils; 10 import pl.waw.ipipan.zil.summ.nicolas.utils.TextUtils;
12 11
13 -import java.io.IOException;  
14 import java.util.List; 12 import java.util.List;
15 import java.util.Map; 13 import java.util.Map;
16 -import java.util.Set;  
17 import java.util.function.Function; 14 import java.util.function.Function;
18 import java.util.stream.Collectors; 15 import java.util.stream.Collectors;
19 16
20 -public class MentionScorer {  
21 -  
22 - private final Set<String> STOPWORDS; 17 +import static pl.waw.ipipan.zil.summ.nicolas.Constants.STOP_POS_TAGS;
23 18
24 - public MentionScorer() throws IOException {  
25 - STOPWORDS = ResourceUtils.loadStopwords().stream().collect(Collectors.toSet());  
26 - } 19 +public class MentionScorer {
27 20
28 public Map<TMention, Double> calculateMentionScores(String optimalSummary, TText text) { 21 public Map<TMention, Double> calculateMentionScores(String optimalSummary, TText text) {
29 Multiset<String> tokenCounts = HashMultiset.create(TextUtils.tokenize(optimalSummary.toLowerCase())); 22 Multiset<String> tokenCounts = HashMultiset.create(TextUtils.tokenize(optimalSummary.toLowerCase()));
@@ -34,24 +27,23 @@ public class MentionScorer { @@ -34,24 +27,23 @@ public class MentionScorer {
34 return booleanTokenIntersection(mention2Orth, tokenCounts); 27 return booleanTokenIntersection(mention2Orth, tokenCounts);
35 } 28 }
36 29
37 - private Map<TMention, String> loadMention2OrthExcludingStopwords(List<TSentence> sents) { 30 + private Map<TMention, String> loadMention2OrthExcludingStopwords(List<TSentence> sentences) {
38 Map<TMention, String> mention2orth = Maps.newHashMap(); 31 Map<TMention, String> mention2orth = Maps.newHashMap();
39 - for (TSentence s : sents) {  
40 - Map<String, TToken> tokId2tok = s.getTokens().stream().collect(Collectors.toMap(TToken::getId, Function.identity())); 32 + for (TSentence sentence : sentences) {
  33 + Map<String, TToken> tokId2tok = sentence.getTokens().stream().collect(Collectors.toMap(TToken::getId, Function.identity()));
41 34
42 - for (TMention m : s.getMentions()) { 35 + for (TMention mention : sentence.getMentions()) {
43 StringBuilder mentionOrth = new StringBuilder(); 36 StringBuilder mentionOrth = new StringBuilder();
44 - for (String tokId : m.getChildIds()) { 37 + for (String tokId : mention.getChildIds()) {
45 TToken token = tokId2tok.get(tokId); 38 TToken token = tokId2tok.get(tokId);
46 - if (STOPWORDS.contains(token.getChosenInterpretation().getBase().toLowerCase())) { 39 + if (STOP_POS_TAGS.contains(token.getChosenInterpretation().getCtag()))
47 continue; 40 continue;
48 - }  
49 41
50 if (!token.isNoPrecedingSpace()) 42 if (!token.isNoPrecedingSpace())
51 mentionOrth.append(" "); 43 mentionOrth.append(" ");
52 mentionOrth.append(token.getOrth()); 44 mentionOrth.append(token.getOrth());
53 } 45 }
54 - mention2orth.put(m, mentionOrth.toString().trim()); 46 + mention2orth.put(mention, mentionOrth.toString().trim());
55 } 47 }
56 } 48 }
57 return mention2orth; 49 return mention2orth;
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/SentenceScorer.java
@@ -21,7 +21,7 @@ public class SentenceScorer { @@ -21,7 +21,7 @@ public class SentenceScorer {
21 for (TSentence sentence : paragraph.getSentences()) { 21 for (TSentence sentence : paragraph.getSentences()) {
22 double score = 0.0; 22 double score = 0.0;
23 23
24 - String orth = TextUtils.loadSentence2Orth(sentence); 24 + String orth = TextUtils.loadSentence2OrthExcludingStoptags(sentence);
25 List<String> tokens = TextUtils.tokenize(orth); 25 List<String> tokens = TextUtils.tokenize(orth);
26 for (String token : tokens) { 26 for (String token : tokens) {
27 score += tokenCounts.contains(token.toLowerCase()) ? 1.0 : 0.0; 27 score += tokenCounts.contains(token.toLowerCase()) ? 1.0 : 0.0;
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/Settings.java
1 package pl.waw.ipipan.zil.summ.nicolas.train.model; 1 package pl.waw.ipipan.zil.summ.nicolas.train.model;
2 2
3 import weka.classifiers.Classifier; 3 import weka.classifiers.Classifier;
  4 +import weka.classifiers.meta.AttributeSelectedClassifier;
  5 +import weka.classifiers.trees.LMT;
4 import weka.classifiers.trees.RandomForest; 6 import weka.classifiers.trees.RandomForest;
5 7
6 public class Settings { 8 public class Settings {
7 9
8 - private static final int NUM_ITERATIONS = 20; 10 + private static final int NUM_ITERATIONS = 100;
9 private static final int NUM_EXECUTION_SLOTS = 8; 11 private static final int NUM_EXECUTION_SLOTS = 8;
10 private static final int SEED = 0; 12 private static final int SEED = 0;
11 13
@@ -29,10 +31,9 @@ public class Settings { @@ -29,10 +31,9 @@ public class Settings {
29 } 31 }
30 32
31 public static Classifier getZeroClassifier() { 33 public static Classifier getZeroClassifier() {
32 - RandomForest classifier = new RandomForest();  
33 - classifier.setNumIterations(NUM_ITERATIONS);  
34 - classifier.setSeed(SEED);  
35 - classifier.setNumExecutionSlots(NUM_EXECUTION_SLOTS); 34 + AttributeSelectedClassifier classifier = new AttributeSelectedClassifier();
  35 + LMT subClassifier = new LMT();
  36 + classifier.setClassifier(subClassifier);
36 return classifier; 37 return classifier;
37 } 38 }
38 39
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/CreateOptimalSummaries.java
@@ -59,14 +59,34 @@ public class CreateOptimalSummaries { @@ -59,14 +59,34 @@ public class CreateOptimalSummaries {
59 59
60 int summaryWordCount = 0; 60 int summaryWordCount = 0;
61 StringBuilder summary = new StringBuilder(); 61 StringBuilder summary = new StringBuilder();
  62 + List<Integer> bestNgramCounts = null;
62 while (averageGoldWordCount >= summaryWordCount) { 63 while (averageGoldWordCount >= summaryWordCount) {
  64 + bestNgramCounts = getBestNgramCounts(ngram2counts);
63 List<String> ngram = pickBestNgram(ngram2counts); 65 List<String> ngram = pickBestNgram(ngram2counts);
64 summary.append(" ").append(String.join(" ", ngram)); 66 summary.append(" ").append(String.join(" ", ngram));
65 summaryWordCount += ngram.size(); 67 summaryWordCount += ngram.size();
66 } 68 }
  69 +
  70 + for (Map.Entry<List<String>, List<Integer>> entry : ngram2counts.entrySet()) {
  71 + if (entry.getValue().equals(bestNgramCounts)) {
  72 + List<String> ngram = entry.getKey();
  73 + summary.append(" ").append(String.join(" ", ngram));
  74 + }
  75 + }
  76 +
67 return summary.toString().trim(); 77 return summary.toString().trim();
68 } 78 }
69 79
  80 + private static List<Integer> getBestNgramCounts(Map<List<String>, List<Integer>> ngram2counts) {
  81 + Optional<List<String>> optional = ngram2counts.keySet().stream()
  82 + .sorted(Comparator.comparing((List<String> ngram) -> ngram2counts.get(ngram).size()).reversed()).findFirst();
  83 + if (!optional.isPresent()) {
  84 + throw new IllegalArgumentException("No more ngrams to pick!");
  85 + }
  86 + List<String> optimalNgram = optional.get();
  87 + return ngram2counts.get(optimalNgram);
  88 + }
  89 +
70 private static List<String> pickBestNgram(Map<List<String>, List<Integer>> ngram2counts) { 90 private static List<String> pickBestNgram(Map<List<String>, List<Integer>> ngram2counts) {
71 Optional<List<String>> optional = ngram2counts.keySet().stream() 91 Optional<List<String>> optional = ngram2counts.keySet().stream()
72 .sorted(Comparator.comparing((List<String> ngram) -> ngram2counts.get(ngram).size()).reversed()).findFirst(); 92 .sorted(Comparator.comparing((List<String> ngram) -> ngram2counts.get(ngram).size()).reversed()).findFirst();
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/ExtractMostFrequentMentions.java 0 → 100644
  1 +package pl.waw.ipipan.zil.summ.nicolas.train.pipeline;
  2 +
  3 +import com.google.common.collect.*;
  4 +import pl.waw.ipipan.zil.multiservice.thrift.types.TMention;
  5 +import pl.waw.ipipan.zil.multiservice.thrift.types.TText;
  6 +import pl.waw.ipipan.zil.summ.nicolas.Constants;
  7 +import pl.waw.ipipan.zil.summ.nicolas.CorpusHelper;
  8 +import pl.waw.ipipan.zil.summ.nicolas.PathConstants;
  9 +import pl.waw.ipipan.zil.summ.nicolas.features.FeatureHelper;
  10 +import pl.waw.ipipan.zil.summ.nicolas.utils.thrift.ThriftUtils;
  11 +
  12 +import javax.xml.bind.JAXBException;
  13 +import java.io.BufferedWriter;
  14 +import java.io.FileWriter;
  15 +import java.io.IOException;
  16 +import java.util.List;
  17 +import java.util.Map;
  18 +import java.util.Set;
  19 +
  20 +import static pl.waw.ipipan.zil.summ.nicolas.PathConstants.PREPROCESSED_CORPUS_DIR;
  21 +
  22 +public class ExtractMostFrequentMentions {
  23 +
  24 + private static final int MIN_MENTION_DOCUMENT_COUNT = 50;
  25 +
  26 + private ExtractMostFrequentMentions() {
  27 + }
  28 +
  29 + public static void main(String[] args) throws IOException, JAXBException {
  30 + List<String> mostFrequentMentionBases = getMostFrequentMentionBases();
  31 + try (BufferedWriter bw = new BufferedWriter(new FileWriter(PathConstants.TARGET_MODEL_DIR + Constants.FREQUENT_BASES_RESOURCE_PATH))) {
  32 + for (String base : mostFrequentMentionBases) {
  33 + bw.write(base + "\n");
  34 + }
  35 + }
  36 + }
  37 +
  38 + private static List<String> getMostFrequentMentionBases() throws IOException {
  39 + Set<String> trainTextIds = CorpusHelper.loadTrainTextIds();
  40 + Map<String, TText> id2preprocessedText = ThriftUtils.loadThriftTextsFromFolder(PREPROCESSED_CORPUS_DIR, trainTextIds::contains);
  41 +
  42 + Multiset<String> mentionBases = HashMultiset.create();
  43 +
  44 + for (TText text : id2preprocessedText.values()) {
  45 + FeatureHelper featureHelper = new FeatureHelper(text);
  46 + Set<String> textMentionBases = Sets.newHashSet();
  47 + for (TMention mention : featureHelper.getMentions()) {
  48 + String mentionBase = featureHelper.getMentionBase(mention);
  49 + textMentionBases.add(mentionBase);
  50 + }
  51 + mentionBases.addAll(textMentionBases);
  52 + }
  53 +
  54 + ImmutableMultiset<String> sorted = Multisets.copyHighestCountFirst(mentionBases);
  55 + List<String> mostFrequentMentions = Lists.newArrayList();
  56 + for (Multiset.Entry<String> entry : sorted.entrySet()) {
  57 + if (entry.getCount() < MIN_MENTION_DOCUMENT_COUNT)
  58 + break;
  59 + mostFrequentMentions.add(entry.getElement());
  60 + }
  61 + return mostFrequentMentions;
  62 + }
  63 +}
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/TrainAllModels.java
@@ -20,8 +20,6 @@ public class TrainAllModels { @@ -20,8 +20,6 @@ public class TrainAllModels {
20 20
21 private static final Logger LOG = LoggerFactory.getLogger(TrainAllModels.class); 21 private static final Logger LOG = LoggerFactory.getLogger(TrainAllModels.class);
22 22
23 - private static final String TARGET_MODEL_DIR = "nicolas-model/src/main/resources";  
24 -  
25 private TrainAllModels() { 23 private TrainAllModels() {
26 } 24 }
27 25
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractMostFrequentMentions.java deleted
1 -package pl.waw.ipipan.zil.summ.nicolas.train.resources;  
2 -  
3 -import com.google.common.collect.*;  
4 -import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence;  
5 -import pl.waw.ipipan.zil.multiservice.thrift.types.TText;  
6 -import pl.waw.ipipan.zil.summ.nicolas.utils.thrift.ThriftUtils;  
7 -import pl.waw.ipipan.zil.summ.pscapi.io.PSC_IO;  
8 -import pl.waw.ipipan.zil.summ.pscapi.xml.Text;  
9 -  
10 -import javax.xml.bind.JAXBException;  
11 -import java.io.File;  
12 -import java.io.IOException;  
13 -import java.util.Comparator;  
14 -import java.util.List;  
15 -import java.util.Map;  
16 -import java.util.Set;  
17 -import java.util.stream.Collectors;  
18 -  
19 -public class ExtractMostFrequentMentions {  
20 -  
21 - public static final String GOLD_DATA_PATH = "/home/me2/Dropbox/3_nauka/3_doktorat/3_korpus_streszczen/dist/src/data/";  
22 -  
23 - public static final String THRIFTED_PREFIX = "/home/me2/Desktop/thrifted_texts/thrifted_all/";  
24 - public static final String THRIFTED_SUFFIX = "/original";  
25 -  
26 - public static void main(String[] args) throws IOException, JAXBException {  
27 -  
28 - Set<String> devIds = Sets.newHashSet();  
29 -  
30 - File goldDir = new File(GOLD_DATA_PATH);  
31 - for (File file : goldDir.listFiles()) {  
32 - Text goldText = PSC_IO.readText(file);  
33 - if (goldText.getSummaries().getSummary().stream().anyMatch(s -> s.getType().equals("abstract")))  
34 - continue;  
35 -  
36 - devIds.add(file.getName().replace(".xml", ""));  
37 - }  
38 -  
39 -  
40 - System.out.println(devIds.size());  
41 -  
42 - Multiset<String> mentionCounts = HashMultiset.create();  
43 - for (String id : devIds) {  
44 - Set<String> distinctTextMentions = Sets.newHashSet();  
45 - File input = new File(THRIFTED_PREFIX + id + THRIFTED_SUFFIX);  
46 - TText thrifted = ThriftUtils.loadThriftTextFromFile(input);  
47 - List<TSentence> sents = thrifted.getParagraphs().stream()  
48 - .flatMap(p -> p.getSentences().stream()).collect(Collectors.toList());  
49 -  
50 - Map<String, String> tokenId2base = Maps.newHashMap();  
51 - sents.stream()  
52 - .flatMap(s -> s.getTokens().stream())  
53 - .forEach(token -> tokenId2base.put(token.getId(), token.getChosenInterpretation().getBase()));  
54 -  
55 - sents.stream().flatMap(s -> s.getMentions().stream()).forEach(m -> {  
56 - StringBuffer sb = new StringBuffer();  
57 - for (String tokId : m.getChildIds()) {  
58 - sb.append(tokenId2base.get(tokId) + " ");  
59 - }  
60 - distinctTextMentions.add(sb.toString().trim().toLowerCase());  
61 - });  
62 -  
63 - mentionCounts.addAll(distinctTextMentions);  
64 - }  
65 -  
66 - System.out.println(mentionCounts.elementSet().size());  
67 - List<String> sorted = Lists.newArrayList();  
68 - sorted.addAll(mentionCounts.elementSet());  
69 - sorted.sort(Comparator.comparing(mentionCounts::count).reversed());  
70 - int i = 0;  
71 - for (String mention : sorted) {  
72 - if (mentionCounts.count(mention) < 50)  
73 - break;  
74 - System.out.println(mention);  
75 - }  
76 -  
77 - }  
78 -}  
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractStopwords.java deleted
1 -package pl.waw.ipipan.zil.summ.nicolas.train.resources;  
2 -  
3 -public class ExtractStopwords {  
4 -  
5 - private ExtractStopwords() {  
6 - }  
7 -  
8 - public static void main(String[] args) {  
9 -  
10 - }  
11 -}  
@@ -35,9 +35,23 @@ @@ -35,9 +35,23 @@
35 <slf4j-api.version>1.7.24</slf4j-api.version> 35 <slf4j-api.version>1.7.24</slf4j-api.version>
36 <junit.version>4.12</junit.version> 36 <junit.version>4.12</junit.version>
37 <zip4j.version>1.3.2</zip4j.version> 37 <zip4j.version>1.3.2</zip4j.version>
38 - <mockito-core.version>2.7.17</mockito-core.version> 38 + <mockito-core.version>2.7.22</mockito-core.version>
39 <jcommander.version>1.64</jcommander.version> 39 <jcommander.version>1.64</jcommander.version>
40 <libthrift.version>0.9.0</libthrift.version> 40 <libthrift.version>0.9.0</libthrift.version>
  41 +
  42 + <jacoco-maven-plugin.version>0.7.8</jacoco-maven-plugin.version>
  43 + <maven-compiler-plugin.version>3.5.1</maven-compiler-plugin.version>
  44 + <maven-site-plugin.version>3.5.1</maven-site-plugin.version>
  45 + <maven-dependency-plugin.version>3.0.0</maven-dependency-plugin.version>
  46 + <maven-jar-plugin.version>3.0.2</maven-jar-plugin.version>
  47 + <maven-resources-plugin.version>3.0.1</maven-resources-plugin.version>
  48 + <maven-clean-plugin.version>3.0.0</maven-clean-plugin.version>
  49 + <maven-install-plugin.version>2.5.2</maven-install-plugin.version>
  50 + <maven-deploy-plugin.version>2.8.2</maven-deploy-plugin.version>
  51 + <maven-assembly-plugin.version>2.6</maven-assembly-plugin.version>
  52 + <maven-project-info-reports-plugin.version>2.9</maven-project-info-reports-plugin.version>
  53 + <maven-surefire-plugin.version>2.19.1</maven-surefire-plugin.version>
  54 + <maven-failsafe-plugin.version>2.19.1</maven-failsafe-plugin.version>
41 </properties> 55 </properties>
42 56
43 <prerequisites> 57 <prerequisites>
@@ -183,47 +197,47 @@ @@ -183,47 +197,47 @@
183 <plugin> 197 <plugin>
184 <groupId>org.apache.maven.plugins</groupId> 198 <groupId>org.apache.maven.plugins</groupId>
185 <artifactId>maven-dependency-plugin</artifactId> 199 <artifactId>maven-dependency-plugin</artifactId>
186 - <version>3.0.0</version> 200 + <version>${maven-dependency-plugin.version}</version>
187 </plugin> 201 </plugin>
188 <plugin> 202 <plugin>
189 <groupId>org.apache.maven.plugins</groupId> 203 <groupId>org.apache.maven.plugins</groupId>
190 <artifactId>maven-jar-plugin</artifactId> 204 <artifactId>maven-jar-plugin</artifactId>
191 - <version>3.0.2</version> 205 + <version>${maven-jar-plugin.version}</version>
192 </plugin> 206 </plugin>
193 <plugin> 207 <plugin>
194 <groupId>org.apache.maven.plugins</groupId> 208 <groupId>org.apache.maven.plugins</groupId>
195 <artifactId>maven-resources-plugin</artifactId> 209 <artifactId>maven-resources-plugin</artifactId>
196 - <version>3.0.1</version> 210 + <version>${maven-resources-plugin.version}</version>
197 </plugin> 211 </plugin>
198 <plugin> 212 <plugin>
199 <groupId>org.apache.maven.plugins</groupId> 213 <groupId>org.apache.maven.plugins</groupId>
200 <artifactId>maven-clean-plugin</artifactId> 214 <artifactId>maven-clean-plugin</artifactId>
201 - <version>3.0.0</version> 215 + <version>${maven-clean-plugin.version}</version>
202 </plugin> 216 </plugin>
203 <plugin> 217 <plugin>
204 <groupId>org.apache.maven.plugins</groupId> 218 <groupId>org.apache.maven.plugins</groupId>
205 <artifactId>maven-site-plugin</artifactId> 219 <artifactId>maven-site-plugin</artifactId>
206 - <version>3.5.1</version> 220 + <version>${maven-site-plugin.version}</version>
207 </plugin> 221 </plugin>
208 <plugin> 222 <plugin>
209 <groupId>org.apache.maven.plugins</groupId> 223 <groupId>org.apache.maven.plugins</groupId>
210 <artifactId>maven-install-plugin</artifactId> 224 <artifactId>maven-install-plugin</artifactId>
211 - <version>2.5.2</version> 225 + <version>${maven-install-plugin.version}</version>
212 </plugin> 226 </plugin>
213 <plugin> 227 <plugin>
214 <groupId>org.apache.maven.plugins</groupId> 228 <groupId>org.apache.maven.plugins</groupId>
215 <artifactId>maven-deploy-plugin</artifactId> 229 <artifactId>maven-deploy-plugin</artifactId>
216 - <version>2.8.2</version> 230 + <version>${maven-deploy-plugin.version}</version>
217 </plugin> 231 </plugin>
218 <plugin> 232 <plugin>
219 <groupId>org.apache.maven.plugins</groupId> 233 <groupId>org.apache.maven.plugins</groupId>
220 <artifactId>maven-assembly-plugin</artifactId> 234 <artifactId>maven-assembly-plugin</artifactId>
221 - <version>2.6</version> 235 + <version>${maven-assembly-plugin.version}</version>
222 </plugin> 236 </plugin>
223 <plugin> 237 <plugin>
224 <groupId>org.apache.maven.plugins</groupId> 238 <groupId>org.apache.maven.plugins</groupId>
225 <artifactId>maven-compiler-plugin</artifactId> 239 <artifactId>maven-compiler-plugin</artifactId>
226 - <version>3.5.1</version> 240 + <version>${maven-compiler-plugin.version}</version>
227 <configuration> 241 <configuration>
228 <source>${java.version.build}</source> 242 <source>${java.version.build}</source>
229 <target>${java.version.build}</target> 243 <target>${java.version.build}</target>
@@ -231,8 +245,13 @@ @@ -231,8 +245,13 @@
231 </plugin> 245 </plugin>
232 <plugin> 246 <plugin>
233 <groupId>org.apache.maven.plugins</groupId> 247 <groupId>org.apache.maven.plugins</groupId>
  248 + <artifactId>maven-project-info-reports-plugin</artifactId>
  249 + <version>${maven-project-info-reports-plugin.version}</version>
  250 + </plugin>
  251 + <plugin>
  252 + <groupId>org.apache.maven.plugins</groupId>
234 <artifactId>maven-surefire-plugin</artifactId> 253 <artifactId>maven-surefire-plugin</artifactId>
235 - <version>2.19.1</version> 254 + <version>${maven-surefire-plugin.version}</version>
236 <configuration> 255 <configuration>
237 <!-- Sets the VM argument line used when unit tests are run. --> 256 <!-- Sets the VM argument line used when unit tests are run. -->
238 <argLine>${surefireArgLine}</argLine> 257 <argLine>${surefireArgLine}</argLine>
@@ -247,7 +266,7 @@ @@ -247,7 +266,7 @@
247 <plugin> 266 <plugin>
248 <groupId>org.apache.maven.plugins</groupId> 267 <groupId>org.apache.maven.plugins</groupId>
249 <artifactId>maven-failsafe-plugin</artifactId> 268 <artifactId>maven-failsafe-plugin</artifactId>
250 - <version>2.19.1</version> 269 + <version>${maven-failsafe-plugin.version}</version>
251 <executions> 270 <executions>
252 <execution> 271 <execution>
253 <id>integration-test</id> 272 <id>integration-test</id>
@@ -276,7 +295,7 @@ @@ -276,7 +295,7 @@
276 <plugin> 295 <plugin>
277 <groupId>org.jacoco</groupId> 296 <groupId>org.jacoco</groupId>
278 <artifactId>jacoco-maven-plugin</artifactId> 297 <artifactId>jacoco-maven-plugin</artifactId>
279 - <version>0.7.8</version> 298 + <version>${jacoco-maven-plugin.version}</version>
280 <executions> 299 <executions>
281 <!-- 300 <!--
282 Prepares the property pointing to the JaCoCo runtime agent which 301 Prepares the property pointing to the JaCoCo runtime agent which