Commit 161751402aa52d5f0d9f404629bf35770cc26c1f
1 parent
e8e1d7c9
1.0 version
Showing
4 changed files
with
139 additions
and
25 deletions
pom.xml
... | ... | @@ -39,6 +39,11 @@ |
39 | 39 | <artifactId>commons-math3</artifactId> |
40 | 40 | <version>3.5</version> |
41 | 41 | </dependency> |
42 | + <dependency> | |
43 | + <groupId>org.apache.commons</groupId> | |
44 | + <artifactId>commons-lang3</artifactId> | |
45 | + <version>3.4</version> | |
46 | + </dependency> | |
42 | 47 | </dependencies> |
43 | 48 | |
44 | 49 | <build> |
... | ... |
src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java
... | ... | @@ -6,12 +6,17 @@ import java.nio.file.Files; |
6 | 6 | import java.nio.file.Paths; |
7 | 7 | import java.util.HashMap; |
8 | 8 | import java.util.HashSet; |
9 | +import java.util.List; | |
9 | 10 | import java.util.Map; |
10 | 11 | import java.util.Set; |
11 | 12 | import java.util.TreeSet; |
12 | 13 | import java.util.stream.Collectors; |
13 | 14 | |
14 | -import org.apache.commons.math3.distribution.TDistribution; | |
15 | +import org.apache.commons.math3.distribution.NormalDistribution; | |
16 | +import org.apache.commons.math3.stat.StatUtils; | |
17 | +import org.apache.commons.math3.stat.descriptive.SummaryStatistics; | |
18 | +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; | |
19 | +import org.apache.commons.math3.stat.inference.TTest; | |
15 | 20 | import org.apache.logging.log4j.LogManager; |
16 | 21 | import org.apache.logging.log4j.Logger; |
17 | 22 | |
... | ... | @@ -19,6 +24,7 @@ import pl.waw.ipipan.zil.summ.eval.rouge.RougeN; |
19 | 24 | |
20 | 25 | public class Main { |
21 | 26 | |
27 | + private static final boolean LATEX = false; | |
22 | 28 | private static final double P_VALUE = 0.05; |
23 | 29 | |
24 | 30 | private static final Logger LOG = LogManager.getLogger(Main.class); |
... | ... | @@ -48,7 +54,6 @@ public class Main { |
48 | 54 | LOG.info(id2goldSummaryFiles.size() + " gold and sys text(s) loaded."); |
49 | 55 | |
50 | 56 | ResultAccumulator ra = new ResultAccumulator(); |
51 | - int totalGoldSummaries = 0; | |
52 | 57 | int totalComparedTexts = 0; |
53 | 58 | |
54 | 59 | for (String id : id2goldSummaryFiles.keySet()) { |
... | ... | @@ -71,51 +76,116 @@ public class Main { |
71 | 76 | for (int i : ROUGE_N) { |
72 | 77 | double score = RougeN.score(i, systemSummary, goldSummaries); |
73 | 78 | ra.logScore("ROUGE_" + i, systemId, score); |
79 | + | |
80 | + double maxScore = RougeN.scoreMax(i, systemSummary, goldSummaries); | |
81 | + ra.logScore("ROUGE_M_" + i, systemId, maxScore); | |
74 | 82 | } |
75 | 83 | } |
76 | - | |
77 | - totalGoldSummaries += goldSummaries.size(); | |
78 | 84 | totalComparedTexts++; |
79 | 85 | } |
80 | 86 | |
81 | 87 | LOG.info(totalComparedTexts + " - total compared texts."); |
82 | - LOG.info(totalGoldSummaries * 1.0 / totalComparedTexts + " - average gold summaries per text."); | |
83 | 88 | |
84 | - printResults(ra); | |
89 | + checkNormality(ra); | |
90 | + printResults(ra, LATEX); | |
91 | + printSignificance(ra, LATEX); | |
92 | + } | |
93 | + | |
94 | + private static void checkNormality(ResultAccumulator ra) { | |
95 | + int normal = 0; | |
96 | + int all = 0; | |
97 | + for (String systemId : ra.getSystemIds()) | |
98 | + for (String metricId : ra.getMetricIds()) { | |
99 | + all++; | |
100 | + double[] standardized = StatUtils.normalize(ra.getSystemResults(systemId, metricId)); | |
101 | + double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1), | |
102 | + standardized); | |
103 | + if (p > P_VALUE) | |
104 | + normal++; | |
105 | + else | |
106 | + LOG.info("System " + systemId + " for " + metricId + " rejected K-S normality test with p-value = " | |
107 | + + p); | |
108 | + } | |
109 | + | |
110 | + LOG.info(normal + " normal of all " + all); | |
111 | + } | |
112 | + | |
113 | + private static void printResults(ResultAccumulator ra, boolean latex) { | |
114 | + System.out.println("###################### Results #############################"); | |
115 | + | |
116 | + if (latex) | |
117 | + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); | |
118 | + | |
119 | + System.out.print("System"); | |
120 | + for (String metricId : ra.getMetricIds()) | |
121 | + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); | |
122 | + if (latex) | |
123 | + System.out.print("\\\\\\hline\\hline"); | |
124 | + System.out.println(); | |
125 | + | |
126 | + for (String systemId : ra.getSystemIds()) { | |
127 | + System.out.print(systemId); | |
128 | + for (String metricId : ra.getMetricIds()) { | |
129 | + SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId); | |
130 | + | |
131 | + System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(), | |
132 | + r.getStandardDeviation())); | |
133 | + } | |
134 | + if (latex) | |
135 | + System.out.print("\\\\\\hline"); | |
136 | + System.out.println(); | |
137 | + } | |
138 | + if (latex) | |
139 | + System.out.println("\\end{tabular}"); | |
85 | 140 | } |
86 | 141 | |
87 | - private static void printResults(ResultAccumulator ra) { | |
142 | + private static void printSignificance(ResultAccumulator ra, boolean latex) { | |
143 | + System.out.println("###################### Stat. sign. #############################"); | |
144 | + int i = 1; | |
145 | + Map<String, Integer> sys2id = new HashMap<>(); | |
146 | + for (String systemId : ra.getSystemIds()) | |
147 | + sys2id.put(systemId, i++); | |
148 | + if (latex) | |
149 | + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); | |
150 | + System.out.print("System"); | |
151 | + for (String metricId : ra.getMetricIds()) | |
152 | + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); | |
153 | + if (latex) | |
154 | + System.out.print("\\\\\\hline\\hline"); | |
155 | + System.out.println(); | |
156 | + | |
88 | 157 | for (String systemId : ra.getSystemIds()) { |
89 | - System.out.println("############## " + systemId); | |
158 | + System.out.print(sys2id.get(systemId) + ". " + systemId); | |
90 | 159 | for (String metricId : ra.getMetricIds()) { |
91 | 160 | Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId); |
92 | - Result r = ra.getSystemResult(systemId, metricId); | |
161 | + List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted() | |
162 | + .collect(Collectors.toList()); | |
93 | 163 | |
94 | - System.out.print(String.format("\t%s\t%.2f\t(+/- %.2f), signif. better than: %s", metricId, r.getMean(), | |
95 | - r.getSD(), String.join(", ", signifWorseSystems))); | |
96 | - System.out.println(); | |
164 | + System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids))); | |
97 | 165 | } |
166 | + if (latex) | |
167 | + System.out.print("\\\\\\hline"); | |
98 | 168 | System.out.println(); |
99 | 169 | } |
170 | + if (latex) | |
171 | + System.out.println("\\end{tabular}"); | |
100 | 172 | } |
101 | 173 | |
102 | 174 | private static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId) { |
103 | 175 | TreeSet<String> signifWorseSystems = new TreeSet<>(); |
104 | 176 | for (String systemId2 : ra.getSystemIds()) |
105 | 177 | if (!systemId.equals(systemId2)) { |
106 | - Result r1 = ra.getSystemResult(systemId, metricId); | |
107 | - Result r2 = ra.getSystemResult(systemId2, metricId); | |
108 | 178 | |
109 | - double meanDifference = r1.getMean() - r2.getMean(); | |
110 | - double mse = (Math.pow(r1.getSD(), 2) + Math.pow(r2.getSD(), 2)) / 2; | |
111 | - int degreesOfFreedom = (r1.getN() - 1) + (r2.getN() - 1); | |
112 | - double t = meanDifference / Math.sqrt(2 * mse / degreesOfFreedom); | |
179 | + SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId); | |
180 | + SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId); | |
181 | + | |
182 | + if (scores1.getMean() <= scores2.getMean()) | |
183 | + continue; | |
113 | 184 | |
114 | - TDistribution tDistribution = new TDistribution(degreesOfFreedom); | |
115 | - double p = 1 - tDistribution.cumulativeProbability(t); | |
185 | + double p = new TTest().tTest(scores1, scores2) / 2; | |
116 | 186 | |
117 | 187 | if (p < P_VALUE) |
118 | - signifWorseSystems.add(String.format("%s (%.4f)", systemId2, p)); | |
188 | + signifWorseSystems.add(systemId2); | |
119 | 189 | |
120 | 190 | } |
121 | 191 | return signifWorseSystems; |
... | ... |
src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java
... | ... | @@ -7,8 +7,11 @@ import java.util.Map; |
7 | 7 | import java.util.Set; |
8 | 8 | import java.util.stream.Collectors; |
9 | 9 | |
10 | +import org.apache.commons.lang3.ArrayUtils; | |
10 | 11 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics; |
11 | 12 | |
13 | +import com.google.common.collect.Lists; | |
14 | + | |
12 | 15 | public class ResultAccumulator { |
13 | 16 | |
14 | 17 | private Map<String, CumulativeResult> systemId2result = new HashMap<>(); |
... | ... | @@ -29,21 +32,32 @@ public class ResultAccumulator { |
29 | 32 | return metricIds.stream().sorted().collect(Collectors.toList()); |
30 | 33 | } |
31 | 34 | |
32 | - public Result getSystemResult(String systemId, String metricId) { | |
33 | - return systemId2result.get(systemId).getResult(metricId); | |
35 | + public double[] getSystemResults(String systemId, String metricId) { | |
36 | + return systemId2result.get(systemId).getResults(metricId); | |
37 | + } | |
38 | + | |
39 | + public SummaryStatistics getSystemSummaryStatistics(String systemId, String metricId) { | |
40 | + return systemId2result.get(systemId).getSummaryStatistics(metricId); | |
34 | 41 | } |
35 | 42 | |
36 | 43 | private class CumulativeResult { |
37 | 44 | Map<String, SummaryStatistics> metricId2scores = new HashMap<>(); |
45 | + Map<String, List<Double>> metricId2allScores = new HashMap<>(); | |
38 | 46 | |
39 | 47 | public void addScore(String metricId, double score) { |
40 | 48 | metricId2scores.putIfAbsent(metricId, new SummaryStatistics()); |
41 | 49 | metricId2scores.get(metricId).addValue(score); |
50 | + metricId2allScores.putIfAbsent(metricId, Lists.newArrayList()); | |
51 | + metricId2allScores.get(metricId).add(score); | |
52 | + } | |
53 | + | |
54 | + public double[] getResults(String metricId) { | |
55 | + return ArrayUtils.toPrimitive(metricId2allScores.get(metricId).toArray(new Double[0])); | |
42 | 56 | } |
43 | 57 | |
44 | - public Result getResult(String metricId) { | |
58 | + public SummaryStatistics getSummaryStatistics(String metricId) { | |
45 | 59 | SummaryStatistics scores = metricId2scores.get(metricId); |
46 | - return new Result(scores.getMean(), scores.getStandardDeviation(), (int) scores.getN()); | |
60 | + return scores; | |
47 | 61 | } |
48 | 62 | } |
49 | 63 | } |
... | ... |
src/main/java/pl/waw/ipipan/zil/summ/eval/rouge/RougeN.java
... | ... | @@ -34,6 +34,31 @@ public class RougeN { |
34 | 34 | return 1.0 * numerator / denominator; |
35 | 35 | } |
36 | 36 | |
37 | + public static double scoreMax(int n, String systemSummary, Set<String> goldSummaries) { | |
38 | + Multiset<List<String>> systemNgrams = HashMultiset.create(); | |
39 | + countNgrams(systemNgrams, n, systemSummary); | |
40 | + | |
41 | + double maxScore = -1.0; | |
42 | + for (String goldSummary : goldSummaries) { | |
43 | + Multiset<List<String>> goldNgrams = HashMultiset.create(); | |
44 | + countNgrams(goldNgrams, n, goldSummary); | |
45 | + | |
46 | + int numerator = 0; | |
47 | + int denominator = 0; | |
48 | + for (List<String> goldNgram : goldNgrams) { | |
49 | + int goldCount = goldNgrams.count(goldNgram); | |
50 | + int systemCount = systemNgrams.count(goldNgram); | |
51 | + numerator += Math.min(goldCount, systemCount); | |
52 | + denominator += goldCount; | |
53 | + } | |
54 | + | |
55 | + double summaryScore = 1.0 * numerator / denominator; | |
56 | + maxScore = Math.max(summaryScore, maxScore); | |
57 | + } | |
58 | + | |
59 | + return maxScore; | |
60 | + } | |
61 | + | |
37 | 62 | protected static void countNgrams(Multiset<List<String>> ngrams, int n, String text) { |
38 | 63 | List<String> tokens = tokenize(text).stream().map(String::toLowerCase).collect(Collectors.toList()); |
39 | 64 | NgramQueue<String> ngram = new NgramQueue<>(n); |
... | ... |