Commit dcb5f3d5e5a615839b70cff10410ab54d9ea2f38
1 parent
156b3707
almost finished :)
Showing
17 changed files
with
186 additions
and
151 deletions
nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java
1 | 1 | package pl.waw.ipipan.zil.summ.nicolas; |
2 | 2 | |
3 | 3 | import com.google.common.collect.ImmutableList; |
4 | +import com.google.common.collect.ImmutableSet; | |
4 | 5 | |
5 | 6 | import java.nio.charset.Charset; |
6 | 7 | import java.nio.charset.StandardCharsets; |
... | ... | @@ -8,6 +9,11 @@ import java.nio.charset.StandardCharsets; |
8 | 9 | public class Constants { |
9 | 10 | |
10 | 11 | public static final ImmutableList<String> POS_TAGS = ImmutableList.of("end", "other", "null", "impt", "imps", "inf", "pred", "subst", "aglt", "ppron3", "ger", "praet", "fin", "num", "interp", "siebie", "brev", "interj", "ppron12", "adj", "burk", "pcon", "bedzie", "adv", "prep", "depr", "xxx", "winien", "conj", "qub", "adja", "ppas", "comp", "pact"); |
12 | + | |
13 | + public static final ImmutableSet<String> STOP_POS_TAGS = | |
14 | + ImmutableSet.of("brev", "conj", "prep", "interp", "qub", "interj", "siebie", "xxx", "inf", "comp", | |
15 | + "bedzie", "aglt"); | |
16 | + | |
11 | 17 | public static final Charset ENCODING = StandardCharsets.UTF_8; |
12 | 18 | |
13 | 19 | private static final String ROOT_PATH = "/pl/waw/ipipan/zil/summ/nicolas/"; |
... | ... | @@ -19,7 +25,6 @@ public class Constants { |
19 | 25 | |
20 | 26 | private static final String RESOURCES_PATH = ROOT_PATH + "resources/"; |
21 | 27 | public static final String FREQUENT_BASES_RESOURCE_PATH = RESOURCES_PATH + "frequent_bases.txt"; |
22 | - public static final String STOPWORDS_PATH = RESOURCES_PATH + "stopwords.txt"; | |
23 | 28 | |
24 | 29 | private Constants() { |
25 | 30 | } |
... | ... |
nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/InstanceUtils.java
... | ... | @@ -35,7 +35,7 @@ public class InstanceUtils { |
35 | 35 | List<TSentence> sentences = preprocessedText.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList()); |
36 | 36 | Map<TMention, Map<Attribute, Double>> mention2features = featureExtractor.calculateFeatures(preprocessedText); |
37 | 37 | |
38 | - LOG.info("Extracting {} features of each mention.", featureExtractor.getAttributesList().size()); | |
38 | + LOG.debug("Extracting {} features of each mention.", featureExtractor.getAttributesList().size()); | |
39 | 39 | Map<TMention, Instance> mention2instance = Maps.newHashMap(); |
40 | 40 | for (TMention tMention : sentences.stream().flatMap(s -> s.getMentions().stream()).collect(toList())) { |
41 | 41 | Instance instance = new DenseInstance(featureExtractor.getAttributesList().size()); |
... | ... | @@ -45,7 +45,7 @@ public class InstanceUtils { |
45 | 45 | } |
46 | 46 | mention2instance.put(tMention, instance); |
47 | 47 | } |
48 | - LOG.info("Extracted features of {} mentions.", mention2instance.size()); | |
48 | + LOG.debug("Extracted features of {} mentions.", mention2instance.size()); | |
49 | 49 | return mention2instance; |
50 | 50 | } |
51 | 51 | |
... | ... | @@ -53,7 +53,7 @@ public class InstanceUtils { |
53 | 53 | List<TSentence> sentences = preprocessedText.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList()); |
54 | 54 | Map<TSentence, Map<Attribute, Double>> sentence2features = featureExtractor.calculateFeatures(preprocessedText, goodMentions); |
55 | 55 | |
56 | - LOG.info("Extracting {} features of each sentence.", featureExtractor.getAttributesList().size()); | |
56 | + LOG.debug("Extracting {} features of each sentence.", featureExtractor.getAttributesList().size()); | |
57 | 57 | Map<TSentence, Instance> sentence2instance = Maps.newHashMap(); |
58 | 58 | for (TSentence sentence : sentences) { |
59 | 59 | Instance instance = new DenseInstance(featureExtractor.getAttributesList().size()); |
... | ... | @@ -63,7 +63,7 @@ public class InstanceUtils { |
63 | 63 | } |
64 | 64 | sentence2instance.put(sentence, instance); |
65 | 65 | } |
66 | - LOG.info("Extracted features of {} sentences.", sentence2instance.size()); | |
66 | + LOG.debug("Extracted features of {} sentences.", sentence2instance.size()); | |
67 | 67 | return sentence2instance; |
68 | 68 | } |
69 | 69 | |
... | ... |
nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/ResourceUtils.java
... | ... | @@ -24,10 +24,6 @@ public class ResourceUtils { |
24 | 24 | return loadUniqueLowercaseSortedNonemptyLinesFromResource(Constants.FREQUENT_BASES_RESOURCE_PATH); |
25 | 25 | } |
26 | 26 | |
27 | - public static List<String> loadStopwords() throws IOException { | |
28 | - return loadUniqueLowercaseSortedNonemptyLinesFromResource(Constants.STOPWORDS_PATH); | |
29 | - } | |
30 | - | |
31 | 27 | public static Classifier loadModelFromResource(String modelResourcePath) throws IOException { |
32 | 28 | LOG.info("Loading classifier from path: {}...", modelResourcePath); |
33 | 29 | try (InputStream stream = ResourceUtils.class.getResourceAsStream(modelResourcePath)) { |
... | ... |
nicolas-lib/src/main/java/pl/waw/ipipan/zil/summ/nicolas/utils/TextUtils.java
... | ... | @@ -3,10 +3,12 @@ package pl.waw.ipipan.zil.summ.nicolas.utils; |
3 | 3 | import com.google.common.collect.Sets; |
4 | 4 | import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; |
5 | 5 | import pl.waw.ipipan.zil.multiservice.thrift.types.TToken; |
6 | +import pl.waw.ipipan.zil.summ.nicolas.Constants; | |
6 | 7 | |
7 | 8 | import java.util.Arrays; |
8 | 9 | import java.util.List; |
9 | 10 | import java.util.Set; |
11 | +import java.util.stream.Collectors; | |
10 | 12 | |
11 | 13 | public class TextUtils { |
12 | 14 | |
... | ... | @@ -29,7 +31,6 @@ public class TextUtils { |
29 | 31 | StringBuilder sb = new StringBuilder(); |
30 | 32 | for (TToken token : sentence.getTokens()) { |
31 | 33 | if (tokenIdsToSkip.contains(token.getId())) { |
32 | - System.out.println("Skipping " + token.getOrth() + " in sentence: " + loadSentence2Orth(sentence)); | |
33 | 34 | continue; |
34 | 35 | } |
35 | 36 | if (!token.isNoPrecedingSpace()) |
... | ... | @@ -38,4 +39,11 @@ public class TextUtils { |
38 | 39 | } |
39 | 40 | return sb.toString().trim(); |
40 | 41 | } |
42 | + | |
43 | + public static String loadSentence2OrthExcludingStoptags(TSentence sentence) { | |
44 | + Set<String> tokenIdsToSkip = sentence.getTokens().stream() | |
45 | + .filter(token -> Constants.STOP_POS_TAGS.contains(token.getChosenInterpretation().getCtag())) | |
46 | + .map(TToken::getId).collect(Collectors.toSet()); | |
47 | + return loadSentence2Orth(sentence, tokenIdsToSkip); | |
48 | + } | |
41 | 49 | } |
... | ... |
nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/resources/frequent_bases.txt
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/PathConstants.java
... | ... | @@ -43,6 +43,8 @@ public class PathConstants { |
43 | 43 | |
44 | 44 | public static final File SUMMARY_LENGTHS_FILE = new File(WORKING_DIR, "summary-lengths.tsv"); |
45 | 45 | |
46 | + public static final String TARGET_MODEL_DIR = "nicolas-model/src/main/resources"; | |
47 | + | |
46 | 48 | private PathConstants() { |
47 | 49 | } |
48 | 50 | |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/search/Crossvalidate.java
... | ... | @@ -34,22 +34,33 @@ import java.util.Optional; |
34 | 34 | import java.util.Random; |
35 | 35 | import java.util.logging.LogManager; |
36 | 36 | |
37 | +import static pl.waw.ipipan.zil.summ.nicolas.PathConstants.*; | |
38 | + | |
37 | 39 | |
38 | 40 | class Crossvalidate { |
39 | 41 | |
40 | 42 | private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); |
41 | 43 | |
42 | - private static final int NUM_FOLDS = 10; | |
44 | + private static final int NUM_FOLDS = 5; | |
45 | + private static final int MAX_INSTANCES = 10000; | |
46 | + | |
47 | + private static final int SEED = 1; | |
43 | 48 | |
44 | 49 | private Crossvalidate() { |
45 | 50 | } |
46 | 51 | |
47 | - static void crossvalidateClassifiers(String datasetPath) throws IOException { | |
52 | + public static void main(String[] args) throws IOException { | |
53 | + crossvalidateClassifiers(MENTION_ARFF.getPath()); | |
54 | + crossvalidateRegressors(SENTENCE_ARFF.getPath()); | |
55 | + crossvalidateClassifiers(ZERO_ARFF.getPath()); | |
56 | + } | |
57 | + | |
58 | + private static void crossvalidateClassifiers(String datasetPath) throws IOException { | |
48 | 59 | Instances instances = loadInstances(datasetPath); |
49 | 60 | crossvalidateClassification(instances); |
50 | 61 | } |
51 | 62 | |
52 | - static void crossvalidateRegressors(String datasetPath) throws IOException { | |
63 | + private static void crossvalidateRegressors(String datasetPath) throws IOException { | |
53 | 64 | Instances instances = loadInstances(datasetPath); |
54 | 65 | crossvalidateRegression(instances); |
55 | 66 | } |
... | ... | @@ -62,6 +73,9 @@ class Crossvalidate { |
62 | 73 | Instances instances = loader.getDataSet(); |
63 | 74 | instances.setClassIndex(0); |
64 | 75 | LOG.info("{} instances loaded.", instances.size()); |
76 | + instances.randomize(new Random(SEED)); | |
77 | + while (instances.size() > MAX_INSTANCES) | |
78 | + instances.delete(instances.size() - 1); | |
65 | 79 | LOG.info("{} attributes for each instance.", instances.numAttributes()); |
66 | 80 | return instances; |
67 | 81 | } |
... | ... | @@ -70,7 +84,8 @@ class Crossvalidate { |
70 | 84 | StopWatch watch = new StopWatch(); |
71 | 85 | watch.start(); |
72 | 86 | |
73 | - Optional<Pair<Double, String>> max = Arrays.stream(new Classifier[]{new J48(), new RandomForest(), new HoeffdingTree(), new LMT(), | |
87 | + Optional<Pair<Double, String>> max = Arrays.stream(new Classifier[]{ | |
88 | + new J48(), new RandomForest(), new HoeffdingTree(), new LMT(), | |
74 | 89 | new Logistic(), new ZeroR(), |
75 | 90 | new SimpleLogistic(), new BayesNet(), new NaiveBayes(), |
76 | 91 | new KStar(), new IBk(), new LWL(), |
... | ... | @@ -81,7 +96,7 @@ class Crossvalidate { |
81 | 96 | Evaluation eval; |
82 | 97 | try { |
83 | 98 | eval = new Evaluation(instances); |
84 | - eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(1)); | |
99 | + eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(SEED)); | |
85 | 100 | } catch (Exception e) { |
86 | 101 | LOG.error("Error evaluating model", e); |
87 | 102 | return Pair.of(0.0, name); |
... | ... | @@ -90,9 +105,13 @@ class Crossvalidate { |
90 | 105 | LOG.info(name + " : " + acc); |
91 | 106 | return Pair.of(acc, name); |
92 | 107 | }).max(Comparator.comparingDouble(Pair::getLeft)); |
93 | - LOG.info("#########"); | |
94 | - LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); | |
95 | 108 | |
109 | + LOG.info("#########"); | |
110 | + if (max.isPresent()) { | |
111 | + LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); | |
112 | + } else { | |
113 | + LOG.info("Empty algorithms list"); | |
114 | + } | |
96 | 115 | watch.stop(); |
97 | 116 | LOG.info("Elapsed time: {}", watch); |
98 | 117 | } |
... | ... | @@ -114,7 +133,7 @@ class Crossvalidate { |
114 | 133 | String name = cls.getClass().getSimpleName(); |
115 | 134 | try { |
116 | 135 | Evaluation eval = new Evaluation(instances); |
117 | - eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(1)); | |
136 | + eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(SEED)); | |
118 | 137 | acc = eval.correlationCoefficient(); |
119 | 138 | } catch (Exception e) { |
120 | 139 | LOG.error("Error evaluating model", e); |
... | ... | @@ -122,9 +141,13 @@ class Crossvalidate { |
122 | 141 | LOG.info(name + " : " + acc); |
123 | 142 | return Pair.of(acc, name); |
124 | 143 | }).max(Comparator.comparingDouble(Pair::getLeft)); |
125 | - LOG.info("#########"); | |
126 | - LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); | |
127 | 144 | |
145 | + LOG.info("#########"); | |
146 | + if (max.isPresent()) { | |
147 | + LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); | |
148 | + } else { | |
149 | + LOG.info("Empty algorithms list"); | |
150 | + } | |
128 | 151 | watch.stop(); |
129 | 152 | LOG.info("Elapsed time: {}", watch); |
130 | 153 | } |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Main.java
1 | 1 | package pl.waw.ipipan.zil.summ.nicolas.train; |
2 | 2 | |
3 | 3 | import pl.waw.ipipan.zil.summ.nicolas.train.pipeline.*; |
4 | -import pl.waw.ipipan.zil.summ.nicolas.train.resources.ExtractMostFrequentMentions; | |
5 | -import pl.waw.ipipan.zil.summ.nicolas.train.resources.ExtractStopwords; | |
6 | 4 | |
7 | 5 | public class Main { |
8 | 6 | |
... | ... | @@ -14,7 +12,6 @@ public class Main { |
14 | 12 | DownloadTrainingResources.main(args); |
15 | 13 | ExtractGoldSummaries.main(args); |
16 | 14 | CreateOptimalSummaries.main(args); |
17 | - ExtractStopwords.main(args); | |
18 | 15 | ExtractMostFrequentMentions.main(args); |
19 | 16 | PrepareTrainingData.main(args); |
20 | 17 | TrainAllModels.main(args); |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/MentionScorer.java
... | ... | @@ -7,23 +7,16 @@ import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; |
7 | 7 | import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; |
8 | 8 | import pl.waw.ipipan.zil.multiservice.thrift.types.TText; |
9 | 9 | import pl.waw.ipipan.zil.multiservice.thrift.types.TToken; |
10 | -import pl.waw.ipipan.zil.summ.nicolas.utils.ResourceUtils; | |
11 | 10 | import pl.waw.ipipan.zil.summ.nicolas.utils.TextUtils; |
12 | 11 | |
13 | -import java.io.IOException; | |
14 | 12 | import java.util.List; |
15 | 13 | import java.util.Map; |
16 | -import java.util.Set; | |
17 | 14 | import java.util.function.Function; |
18 | 15 | import java.util.stream.Collectors; |
19 | 16 | |
20 | -public class MentionScorer { | |
21 | - | |
22 | - private final Set<String> STOPWORDS; | |
17 | +import static pl.waw.ipipan.zil.summ.nicolas.Constants.STOP_POS_TAGS; | |
23 | 18 | |
24 | - public MentionScorer() throws IOException { | |
25 | - STOPWORDS = ResourceUtils.loadStopwords().stream().collect(Collectors.toSet()); | |
26 | - } | |
19 | +public class MentionScorer { | |
27 | 20 | |
28 | 21 | public Map<TMention, Double> calculateMentionScores(String optimalSummary, TText text) { |
29 | 22 | Multiset<String> tokenCounts = HashMultiset.create(TextUtils.tokenize(optimalSummary.toLowerCase())); |
... | ... | @@ -34,24 +27,23 @@ public class MentionScorer { |
34 | 27 | return booleanTokenIntersection(mention2Orth, tokenCounts); |
35 | 28 | } |
36 | 29 | |
37 | - private Map<TMention, String> loadMention2OrthExcludingStopwords(List<TSentence> sents) { | |
30 | + private Map<TMention, String> loadMention2OrthExcludingStopwords(List<TSentence> sentences) { | |
38 | 31 | Map<TMention, String> mention2orth = Maps.newHashMap(); |
39 | - for (TSentence s : sents) { | |
40 | - Map<String, TToken> tokId2tok = s.getTokens().stream().collect(Collectors.toMap(TToken::getId, Function.identity())); | |
32 | + for (TSentence sentence : sentences) { | |
33 | + Map<String, TToken> tokId2tok = sentence.getTokens().stream().collect(Collectors.toMap(TToken::getId, Function.identity())); | |
41 | 34 | |
42 | - for (TMention m : s.getMentions()) { | |
35 | + for (TMention mention : sentence.getMentions()) { | |
43 | 36 | StringBuilder mentionOrth = new StringBuilder(); |
44 | - for (String tokId : m.getChildIds()) { | |
37 | + for (String tokId : mention.getChildIds()) { | |
45 | 38 | TToken token = tokId2tok.get(tokId); |
46 | - if (STOPWORDS.contains(token.getChosenInterpretation().getBase().toLowerCase())) { | |
39 | + if (STOP_POS_TAGS.contains(token.getChosenInterpretation().getCtag())) | |
47 | 40 | continue; |
48 | - } | |
49 | 41 | |
50 | 42 | if (!token.isNoPrecedingSpace()) |
51 | 43 | mentionOrth.append(" "); |
52 | 44 | mentionOrth.append(token.getOrth()); |
53 | 45 | } |
54 | - mention2orth.put(m, mentionOrth.toString().trim()); | |
46 | + mention2orth.put(mention, mentionOrth.toString().trim()); | |
55 | 47 | } |
56 | 48 | } |
57 | 49 | return mention2orth; |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/SentenceScorer.java
... | ... | @@ -21,7 +21,7 @@ public class SentenceScorer { |
21 | 21 | for (TSentence sentence : paragraph.getSentences()) { |
22 | 22 | double score = 0.0; |
23 | 23 | |
24 | - String orth = TextUtils.loadSentence2Orth(sentence); | |
24 | + String orth = TextUtils.loadSentence2OrthExcludingStoptags(sentence); | |
25 | 25 | List<String> tokens = TextUtils.tokenize(orth); |
26 | 26 | for (String token : tokens) { |
27 | 27 | score += tokenCounts.contains(token.toLowerCase()) ? 1.0 : 0.0; |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/Settings.java
1 | 1 | package pl.waw.ipipan.zil.summ.nicolas.train.model; |
2 | 2 | |
3 | 3 | import weka.classifiers.Classifier; |
4 | +import weka.classifiers.meta.AttributeSelectedClassifier; | |
5 | +import weka.classifiers.trees.LMT; | |
4 | 6 | import weka.classifiers.trees.RandomForest; |
5 | 7 | |
6 | 8 | public class Settings { |
7 | 9 | |
8 | - private static final int NUM_ITERATIONS = 20; | |
10 | + private static final int NUM_ITERATIONS = 100; | |
9 | 11 | private static final int NUM_EXECUTION_SLOTS = 8; |
10 | 12 | private static final int SEED = 0; |
11 | 13 | |
... | ... | @@ -29,10 +31,9 @@ public class Settings { |
29 | 31 | } |
30 | 32 | |
31 | 33 | public static Classifier getZeroClassifier() { |
32 | - RandomForest classifier = new RandomForest(); | |
33 | - classifier.setNumIterations(NUM_ITERATIONS); | |
34 | - classifier.setSeed(SEED); | |
35 | - classifier.setNumExecutionSlots(NUM_EXECUTION_SLOTS); | |
34 | + AttributeSelectedClassifier classifier = new AttributeSelectedClassifier(); | |
35 | + LMT subClassifier = new LMT(); | |
36 | + classifier.setClassifier(subClassifier); | |
36 | 37 | return classifier; |
37 | 38 | } |
38 | 39 | |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/CreateOptimalSummaries.java
... | ... | @@ -59,14 +59,34 @@ public class CreateOptimalSummaries { |
59 | 59 | |
60 | 60 | int summaryWordCount = 0; |
61 | 61 | StringBuilder summary = new StringBuilder(); |
62 | + List<Integer> bestNgramCounts = null; | |
62 | 63 | while (averageGoldWordCount >= summaryWordCount) { |
64 | + bestNgramCounts = getBestNgramCounts(ngram2counts); | |
63 | 65 | List<String> ngram = pickBestNgram(ngram2counts); |
64 | 66 | summary.append(" ").append(String.join(" ", ngram)); |
65 | 67 | summaryWordCount += ngram.size(); |
66 | 68 | } |
69 | + | |
70 | + for (Map.Entry<List<String>, List<Integer>> entry : ngram2counts.entrySet()) { | |
71 | + if (entry.getValue().equals(bestNgramCounts)) { | |
72 | + List<String> ngram = entry.getKey(); | |
73 | + summary.append(" ").append(String.join(" ", ngram)); | |
74 | + } | |
75 | + } | |
76 | + | |
67 | 77 | return summary.toString().trim(); |
68 | 78 | } |
69 | 79 | |
80 | + private static List<Integer> getBestNgramCounts(Map<List<String>, List<Integer>> ngram2counts) { | |
81 | + Optional<List<String>> optional = ngram2counts.keySet().stream() | |
82 | + .sorted(Comparator.comparing((List<String> ngram) -> ngram2counts.get(ngram).size()).reversed()).findFirst(); | |
83 | + if (!optional.isPresent()) { | |
84 | + throw new IllegalArgumentException("No more ngrams to pick!"); | |
85 | + } | |
86 | + List<String> optimalNgram = optional.get(); | |
87 | + return ngram2counts.get(optimalNgram); | |
88 | + } | |
89 | + | |
70 | 90 | private static List<String> pickBestNgram(Map<List<String>, List<Integer>> ngram2counts) { |
71 | 91 | Optional<List<String>> optional = ngram2counts.keySet().stream() |
72 | 92 | .sorted(Comparator.comparing((List<String> ngram) -> ngram2counts.get(ngram).size()).reversed()).findFirst(); |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/ExtractMostFrequentMentions.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; | |
2 | + | |
3 | +import com.google.common.collect.*; | |
4 | +import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; | |
5 | +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; | |
6 | +import pl.waw.ipipan.zil.summ.nicolas.Constants; | |
7 | +import pl.waw.ipipan.zil.summ.nicolas.CorpusHelper; | |
8 | +import pl.waw.ipipan.zil.summ.nicolas.PathConstants; | |
9 | +import pl.waw.ipipan.zil.summ.nicolas.features.FeatureHelper; | |
10 | +import pl.waw.ipipan.zil.summ.nicolas.utils.thrift.ThriftUtils; | |
11 | + | |
12 | +import javax.xml.bind.JAXBException; | |
13 | +import java.io.BufferedWriter; | |
14 | +import java.io.FileWriter; | |
15 | +import java.io.IOException; | |
16 | +import java.util.List; | |
17 | +import java.util.Map; | |
18 | +import java.util.Set; | |
19 | + | |
20 | +import static pl.waw.ipipan.zil.summ.nicolas.PathConstants.PREPROCESSED_CORPUS_DIR; | |
21 | + | |
22 | +public class ExtractMostFrequentMentions { | |
23 | + | |
24 | + private static final int MIN_MENTION_DOCUMENT_COUNT = 50; | |
25 | + | |
26 | + private ExtractMostFrequentMentions() { | |
27 | + } | |
28 | + | |
29 | + public static void main(String[] args) throws IOException, JAXBException { | |
30 | + List<String> mostFrequentMentionBases = getMostFrequentMentionBases(); | |
31 | + try (BufferedWriter bw = new BufferedWriter(new FileWriter(PathConstants.TARGET_MODEL_DIR + Constants.FREQUENT_BASES_RESOURCE_PATH))) { | |
32 | + for (String base : mostFrequentMentionBases) { | |
33 | + bw.write(base + "\n"); | |
34 | + } | |
35 | + } | |
36 | + } | |
37 | + | |
38 | + private static List<String> getMostFrequentMentionBases() throws IOException { | |
39 | + Set<String> trainTextIds = CorpusHelper.loadTrainTextIds(); | |
40 | + Map<String, TText> id2preprocessedText = ThriftUtils.loadThriftTextsFromFolder(PREPROCESSED_CORPUS_DIR, trainTextIds::contains); | |
41 | + | |
42 | + Multiset<String> mentionBases = HashMultiset.create(); | |
43 | + | |
44 | + for (TText text : id2preprocessedText.values()) { | |
45 | + FeatureHelper featureHelper = new FeatureHelper(text); | |
46 | + Set<String> textMentionBases = Sets.newHashSet(); | |
47 | + for (TMention mention : featureHelper.getMentions()) { | |
48 | + String mentionBase = featureHelper.getMentionBase(mention); | |
49 | + textMentionBases.add(mentionBase); | |
50 | + } | |
51 | + mentionBases.addAll(textMentionBases); | |
52 | + } | |
53 | + | |
54 | + ImmutableMultiset<String> sorted = Multisets.copyHighestCountFirst(mentionBases); | |
55 | + List<String> mostFrequentMentions = Lists.newArrayList(); | |
56 | + for (Multiset.Entry<String> entry : sorted.entrySet()) { | |
57 | + if (entry.getCount() < MIN_MENTION_DOCUMENT_COUNT) | |
58 | + break; | |
59 | + mostFrequentMentions.add(entry.getElement()); | |
60 | + } | |
61 | + return mostFrequentMentions; | |
62 | + } | |
63 | +} | |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/TrainAllModels.java
... | ... | @@ -20,8 +20,6 @@ public class TrainAllModels { |
20 | 20 | |
21 | 21 | private static final Logger LOG = LoggerFactory.getLogger(TrainAllModels.class); |
22 | 22 | |
23 | - private static final String TARGET_MODEL_DIR = "nicolas-model/src/main/resources"; | |
24 | - | |
25 | 23 | private TrainAllModels() { |
26 | 24 | } |
27 | 25 | |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractMostFrequentMentions.java deleted
1 | -package pl.waw.ipipan.zil.summ.nicolas.train.resources; | |
2 | - | |
3 | -import com.google.common.collect.*; | |
4 | -import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; | |
5 | -import pl.waw.ipipan.zil.multiservice.thrift.types.TText; | |
6 | -import pl.waw.ipipan.zil.summ.nicolas.utils.thrift.ThriftUtils; | |
7 | -import pl.waw.ipipan.zil.summ.pscapi.io.PSC_IO; | |
8 | -import pl.waw.ipipan.zil.summ.pscapi.xml.Text; | |
9 | - | |
10 | -import javax.xml.bind.JAXBException; | |
11 | -import java.io.File; | |
12 | -import java.io.IOException; | |
13 | -import java.util.Comparator; | |
14 | -import java.util.List; | |
15 | -import java.util.Map; | |
16 | -import java.util.Set; | |
17 | -import java.util.stream.Collectors; | |
18 | - | |
19 | -public class ExtractMostFrequentMentions { | |
20 | - | |
21 | - public static final String GOLD_DATA_PATH = "/home/me2/Dropbox/3_nauka/3_doktorat/3_korpus_streszczen/dist/src/data/"; | |
22 | - | |
23 | - public static final String THRIFTED_PREFIX = "/home/me2/Desktop/thrifted_texts/thrifted_all/"; | |
24 | - public static final String THRIFTED_SUFFIX = "/original"; | |
25 | - | |
26 | - public static void main(String[] args) throws IOException, JAXBException { | |
27 | - | |
28 | - Set<String> devIds = Sets.newHashSet(); | |
29 | - | |
30 | - File goldDir = new File(GOLD_DATA_PATH); | |
31 | - for (File file : goldDir.listFiles()) { | |
32 | - Text goldText = PSC_IO.readText(file); | |
33 | - if (goldText.getSummaries().getSummary().stream().anyMatch(s -> s.getType().equals("abstract"))) | |
34 | - continue; | |
35 | - | |
36 | - devIds.add(file.getName().replace(".xml", "")); | |
37 | - } | |
38 | - | |
39 | - | |
40 | - System.out.println(devIds.size()); | |
41 | - | |
42 | - Multiset<String> mentionCounts = HashMultiset.create(); | |
43 | - for (String id : devIds) { | |
44 | - Set<String> distinctTextMentions = Sets.newHashSet(); | |
45 | - File input = new File(THRIFTED_PREFIX + id + THRIFTED_SUFFIX); | |
46 | - TText thrifted = ThriftUtils.loadThriftTextFromFile(input); | |
47 | - List<TSentence> sents = thrifted.getParagraphs().stream() | |
48 | - .flatMap(p -> p.getSentences().stream()).collect(Collectors.toList()); | |
49 | - | |
50 | - Map<String, String> tokenId2base = Maps.newHashMap(); | |
51 | - sents.stream() | |
52 | - .flatMap(s -> s.getTokens().stream()) | |
53 | - .forEach(token -> tokenId2base.put(token.getId(), token.getChosenInterpretation().getBase())); | |
54 | - | |
55 | - sents.stream().flatMap(s -> s.getMentions().stream()).forEach(m -> { | |
56 | - StringBuffer sb = new StringBuffer(); | |
57 | - for (String tokId : m.getChildIds()) { | |
58 | - sb.append(tokenId2base.get(tokId) + " "); | |
59 | - } | |
60 | - distinctTextMentions.add(sb.toString().trim().toLowerCase()); | |
61 | - }); | |
62 | - | |
63 | - mentionCounts.addAll(distinctTextMentions); | |
64 | - } | |
65 | - | |
66 | - System.out.println(mentionCounts.elementSet().size()); | |
67 | - List<String> sorted = Lists.newArrayList(); | |
68 | - sorted.addAll(mentionCounts.elementSet()); | |
69 | - sorted.sort(Comparator.comparing(mentionCounts::count).reversed()); | |
70 | - int i = 0; | |
71 | - for (String mention : sorted) { | |
72 | - if (mentionCounts.count(mention) < 50) | |
73 | - break; | |
74 | - System.out.println(mention); | |
75 | - } | |
76 | - | |
77 | - } | |
78 | -} |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/resources/ExtractStopwords.java deleted
pom.xml
... | ... | @@ -35,9 +35,23 @@ |
35 | 35 | <slf4j-api.version>1.7.24</slf4j-api.version> |
36 | 36 | <junit.version>4.12</junit.version> |
37 | 37 | <zip4j.version>1.3.2</zip4j.version> |
38 | - <mockito-core.version>2.7.17</mockito-core.version> | |
38 | + <mockito-core.version>2.7.22</mockito-core.version> | |
39 | 39 | <jcommander.version>1.64</jcommander.version> |
40 | 40 | <libthrift.version>0.9.0</libthrift.version> |
41 | + | |
42 | + <jacoco-maven-plugin.version>0.7.8</jacoco-maven-plugin.version> | |
43 | + <maven-compiler-plugin.version>3.5.1</maven-compiler-plugin.version> | |
44 | + <maven-site-plugin.version>3.5.1</maven-site-plugin.version> | |
45 | + <maven-dependency-plugin.version>3.0.0</maven-dependency-plugin.version> | |
46 | + <maven-jar-plugin.version>3.0.2</maven-jar-plugin.version> | |
47 | + <maven-resources-plugin.version>3.0.1</maven-resources-plugin.version> | |
48 | + <maven-clean-plugin.version>3.0.0</maven-clean-plugin.version> | |
49 | + <maven-install-plugin.version>2.5.2</maven-install-plugin.version> | |
50 | + <maven-deploy-plugin.version>2.8.2</maven-deploy-plugin.version> | |
51 | + <maven-assembly-plugin.version>2.6</maven-assembly-plugin.version> | |
52 | + <maven-project-info-reports-plugin.version>2.9</maven-project-info-reports-plugin.version> | |
53 | + <maven-surefire-plugin.version>2.19.1</maven-surefire-plugin.version> | |
54 | + <maven-failsafe-plugin.version>2.19.1</maven-failsafe-plugin.version> | |
41 | 55 | </properties> |
42 | 56 | |
43 | 57 | <prerequisites> |
... | ... | @@ -183,47 +197,47 @@ |
183 | 197 | <plugin> |
184 | 198 | <groupId>org.apache.maven.plugins</groupId> |
185 | 199 | <artifactId>maven-dependency-plugin</artifactId> |
186 | - <version>3.0.0</version> | |
200 | + <version>${maven-dependency-plugin.version}</version> | |
187 | 201 | </plugin> |
188 | 202 | <plugin> |
189 | 203 | <groupId>org.apache.maven.plugins</groupId> |
190 | 204 | <artifactId>maven-jar-plugin</artifactId> |
191 | - <version>3.0.2</version> | |
205 | + <version>${maven-jar-plugin.version}</version> | |
192 | 206 | </plugin> |
193 | 207 | <plugin> |
194 | 208 | <groupId>org.apache.maven.plugins</groupId> |
195 | 209 | <artifactId>maven-resources-plugin</artifactId> |
196 | - <version>3.0.1</version> | |
210 | + <version>${maven-resources-plugin.version}</version> | |
197 | 211 | </plugin> |
198 | 212 | <plugin> |
199 | 213 | <groupId>org.apache.maven.plugins</groupId> |
200 | 214 | <artifactId>maven-clean-plugin</artifactId> |
201 | - <version>3.0.0</version> | |
215 | + <version>${maven-clean-plugin.version}</version> | |
202 | 216 | </plugin> |
203 | 217 | <plugin> |
204 | 218 | <groupId>org.apache.maven.plugins</groupId> |
205 | 219 | <artifactId>maven-site-plugin</artifactId> |
206 | - <version>3.5.1</version> | |
220 | + <version>${maven-site-plugin.version}</version> | |
207 | 221 | </plugin> |
208 | 222 | <plugin> |
209 | 223 | <groupId>org.apache.maven.plugins</groupId> |
210 | 224 | <artifactId>maven-install-plugin</artifactId> |
211 | - <version>2.5.2</version> | |
225 | + <version>${maven-install-plugin.version}</version> | |
212 | 226 | </plugin> |
213 | 227 | <plugin> |
214 | 228 | <groupId>org.apache.maven.plugins</groupId> |
215 | 229 | <artifactId>maven-deploy-plugin</artifactId> |
216 | - <version>2.8.2</version> | |
230 | + <version>${maven-deploy-plugin.version}</version> | |
217 | 231 | </plugin> |
218 | 232 | <plugin> |
219 | 233 | <groupId>org.apache.maven.plugins</groupId> |
220 | 234 | <artifactId>maven-assembly-plugin</artifactId> |
221 | - <version>2.6</version> | |
235 | + <version>${maven-assembly-plugin.version}</version> | |
222 | 236 | </plugin> |
223 | 237 | <plugin> |
224 | 238 | <groupId>org.apache.maven.plugins</groupId> |
225 | 239 | <artifactId>maven-compiler-plugin</artifactId> |
226 | - <version>3.5.1</version> | |
240 | + <version>${maven-compiler-plugin.version}</version> | |
227 | 241 | <configuration> |
228 | 242 | <source>${java.version.build}</source> |
229 | 243 | <target>${java.version.build}</target> |
... | ... | @@ -231,8 +245,13 @@ |
231 | 245 | </plugin> |
232 | 246 | <plugin> |
233 | 247 | <groupId>org.apache.maven.plugins</groupId> |
248 | + <artifactId>maven-project-info-reports-plugin</artifactId> | |
249 | + <version>${maven-project-info-reports-plugin.version}</version> | |
250 | + </plugin> | |
251 | + <plugin> | |
252 | + <groupId>org.apache.maven.plugins</groupId> | |
234 | 253 | <artifactId>maven-surefire-plugin</artifactId> |
235 | - <version>2.19.1</version> | |
254 | + <version>${maven-surefire-plugin.version}</version> | |
236 | 255 | <configuration> |
237 | 256 | <!-- Sets the VM argument line used when unit tests are run. --> |
238 | 257 | <argLine>${surefireArgLine}</argLine> |
... | ... | @@ -247,7 +266,7 @@ |
247 | 266 | <plugin> |
248 | 267 | <groupId>org.apache.maven.plugins</groupId> |
249 | 268 | <artifactId>maven-failsafe-plugin</artifactId> |
250 | - <version>2.19.1</version> | |
269 | + <version>${maven-failsafe-plugin.version}</version> | |
251 | 270 | <executions> |
252 | 271 | <execution> |
253 | 272 | <id>integration-test</id> |
... | ... | @@ -276,7 +295,7 @@ |
276 | 295 | <plugin> |
277 | 296 | <groupId>org.jacoco</groupId> |
278 | 297 | <artifactId>jacoco-maven-plugin</artifactId> |
279 | - <version>0.7.8</version> | |
298 | + <version>${jacoco-maven-plugin.version}</version> | |
280 | 299 | <executions> |
281 | 300 | <!-- |
282 | 301 | Prepares the property pointing to the JaCoCo runtime agent which |
... | ... |