Commit 161751402aa52d5f0d9f404629bf35770cc26c1f

Authored by Mateusz Kopeć
1 parent e8e1d7c9

1.0 version

... ... @@ -39,6 +39,11 @@
39 39 <artifactId>commons-math3</artifactId>
40 40 <version>3.5</version>
41 41 </dependency>
  42 + <dependency>
  43 + <groupId>org.apache.commons</groupId>
  44 + <artifactId>commons-lang3</artifactId>
  45 + <version>3.4</version>
  46 + </dependency>
42 47 </dependencies>
43 48  
44 49 <build>
... ...
src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java
... ... @@ -6,12 +6,17 @@ import java.nio.file.Files;
6 6 import java.nio.file.Paths;
7 7 import java.util.HashMap;
8 8 import java.util.HashSet;
  9 +import java.util.List;
9 10 import java.util.Map;
10 11 import java.util.Set;
11 12 import java.util.TreeSet;
12 13 import java.util.stream.Collectors;
13 14  
14   -import org.apache.commons.math3.distribution.TDistribution;
  15 +import org.apache.commons.math3.distribution.NormalDistribution;
  16 +import org.apache.commons.math3.stat.StatUtils;
  17 +import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
  18 +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
  19 +import org.apache.commons.math3.stat.inference.TTest;
15 20 import org.apache.logging.log4j.LogManager;
16 21 import org.apache.logging.log4j.Logger;
17 22  
... ... @@ -19,6 +24,7 @@ import pl.waw.ipipan.zil.summ.eval.rouge.RougeN;
19 24  
20 25 public class Main {
21 26  
  27 + private static final boolean LATEX = false;
22 28 private static final double P_VALUE = 0.05;
23 29  
24 30 private static final Logger LOG = LogManager.getLogger(Main.class);
... ... @@ -48,7 +54,6 @@ public class Main {
48 54 LOG.info(id2goldSummaryFiles.size() + " gold and sys text(s) loaded.");
49 55  
50 56 ResultAccumulator ra = new ResultAccumulator();
51   - int totalGoldSummaries = 0;
52 57 int totalComparedTexts = 0;
53 58  
54 59 for (String id : id2goldSummaryFiles.keySet()) {
... ... @@ -71,51 +76,116 @@ public class Main {
71 76 for (int i : ROUGE_N) {
72 77 double score = RougeN.score(i, systemSummary, goldSummaries);
73 78 ra.logScore("ROUGE_" + i, systemId, score);
  79 +
  80 + double maxScore = RougeN.scoreMax(i, systemSummary, goldSummaries);
  81 + ra.logScore("ROUGE_M_" + i, systemId, maxScore);
74 82 }
75 83 }
76   -
77   - totalGoldSummaries += goldSummaries.size();
78 84 totalComparedTexts++;
79 85 }
80 86  
81 87 LOG.info(totalComparedTexts + " - total compared texts.");
82   - LOG.info(totalGoldSummaries * 1.0 / totalComparedTexts + " - average gold summaries per text.");
83 88  
84   - printResults(ra);
  89 + checkNormality(ra);
  90 + printResults(ra, LATEX);
  91 + printSignificance(ra, LATEX);
  92 + }
  93 +
  94 + private static void checkNormality(ResultAccumulator ra) {
  95 + int normal = 0;
  96 + int all = 0;
  97 + for (String systemId : ra.getSystemIds())
  98 + for (String metricId : ra.getMetricIds()) {
  99 + all++;
  100 + double[] standardized = StatUtils.normalize(ra.getSystemResults(systemId, metricId));
  101 + double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1),
  102 + standardized);
  103 + if (p > P_VALUE)
  104 + normal++;
  105 + else
  106 + LOG.info("System " + systemId + " for " + metricId + " rejected K-S normality test with p-value = "
  107 + + p);
  108 + }
  109 +
  110 + LOG.info(normal + " normal of all " + all);
  111 + }
  112 +
  113 + private static void printResults(ResultAccumulator ra, boolean latex) {
  114 + System.out.println("###################### Results #############################");
  115 +
  116 + if (latex)
  117 + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");
  118 +
  119 + System.out.print("System");
  120 + for (String metricId : ra.getMetricIds())
  121 + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));
  122 + if (latex)
  123 + System.out.print("\\\\\\hline\\hline");
  124 + System.out.println();
  125 +
  126 + for (String systemId : ra.getSystemIds()) {
  127 + System.out.print(systemId);
  128 + for (String metricId : ra.getMetricIds()) {
  129 + SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId);
  130 +
  131 + System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(),
  132 + r.getStandardDeviation()));
  133 + }
  134 + if (latex)
  135 + System.out.print("\\\\\\hline");
  136 + System.out.println();
  137 + }
  138 + if (latex)
  139 + System.out.println("\\end{tabular}");
85 140 }
86 141  
87   - private static void printResults(ResultAccumulator ra) {
  142 + private static void printSignificance(ResultAccumulator ra, boolean latex) {
  143 + System.out.println("###################### Stat. sign. #############################");
  144 + int i = 1;
  145 + Map<String, Integer> sys2id = new HashMap<>();
  146 + for (String systemId : ra.getSystemIds())
  147 + sys2id.put(systemId, i++);
  148 + if (latex)
  149 + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");
  150 + System.out.print("System");
  151 + for (String metricId : ra.getMetricIds())
  152 + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));
  153 + if (latex)
  154 + System.out.print("\\\\\\hline\\hline");
  155 + System.out.println();
  156 +
88 157 for (String systemId : ra.getSystemIds()) {
89   - System.out.println("############## " + systemId);
  158 + System.out.print(sys2id.get(systemId) + ". " + systemId);
90 159 for (String metricId : ra.getMetricIds()) {
91 160 Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId);
92   - Result r = ra.getSystemResult(systemId, metricId);
  161 + List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted()
  162 + .collect(Collectors.toList());
93 163  
94   - System.out.print(String.format("\t%s\t%.2f\t(+/- %.2f), signif. better than: %s", metricId, r.getMean(),
95   - r.getSD(), String.join(", ", signifWorseSystems)));
96   - System.out.println();
  164 + System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids)));
97 165 }
  166 + if (latex)
  167 + System.out.print("\\\\\\hline");
98 168 System.out.println();
99 169 }
  170 + if (latex)
  171 + System.out.println("\\end{tabular}");
100 172 }
101 173  
102 174 private static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId) {
103 175 TreeSet<String> signifWorseSystems = new TreeSet<>();
104 176 for (String systemId2 : ra.getSystemIds())
105 177 if (!systemId.equals(systemId2)) {
106   - Result r1 = ra.getSystemResult(systemId, metricId);
107   - Result r2 = ra.getSystemResult(systemId2, metricId);
108 178  
109   - double meanDifference = r1.getMean() - r2.getMean();
110   - double mse = (Math.pow(r1.getSD(), 2) + Math.pow(r2.getSD(), 2)) / 2;
111   - int degreesOfFreedom = (r1.getN() - 1) + (r2.getN() - 1);
112   - double t = meanDifference / Math.sqrt(2 * mse / degreesOfFreedom);
  179 + SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId);
  180 + SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId);
  181 +
  182 + if (scores1.getMean() <= scores2.getMean())
  183 + continue;
113 184  
114   - TDistribution tDistribution = new TDistribution(degreesOfFreedom);
115   - double p = 1 - tDistribution.cumulativeProbability(t);
  185 + double p = new TTest().tTest(scores1, scores2) / 2;
116 186  
117 187 if (p < P_VALUE)
118   - signifWorseSystems.add(String.format("%s (%.4f)", systemId2, p));
  188 + signifWorseSystems.add(systemId2);
119 189  
120 190 }
121 191 return signifWorseSystems;
... ...
src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java
... ... @@ -7,8 +7,11 @@ import java.util.Map;
7 7 import java.util.Set;
8 8 import java.util.stream.Collectors;
9 9  
  10 +import org.apache.commons.lang3.ArrayUtils;
10 11 import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
11 12  
  13 +import com.google.common.collect.Lists;
  14 +
12 15 public class ResultAccumulator {
13 16  
14 17 private Map<String, CumulativeResult> systemId2result = new HashMap<>();
... ... @@ -29,21 +32,32 @@ public class ResultAccumulator {
29 32 return metricIds.stream().sorted().collect(Collectors.toList());
30 33 }
31 34  
32   - public Result getSystemResult(String systemId, String metricId) {
33   - return systemId2result.get(systemId).getResult(metricId);
  35 + public double[] getSystemResults(String systemId, String metricId) {
  36 + return systemId2result.get(systemId).getResults(metricId);
  37 + }
  38 +
  39 + public SummaryStatistics getSystemSummaryStatistics(String systemId, String metricId) {
  40 + return systemId2result.get(systemId).getSummaryStatistics(metricId);
34 41 }
35 42  
36 43 private class CumulativeResult {
37 44 Map<String, SummaryStatistics> metricId2scores = new HashMap<>();
  45 + Map<String, List<Double>> metricId2allScores = new HashMap<>();
38 46  
39 47 public void addScore(String metricId, double score) {
40 48 metricId2scores.putIfAbsent(metricId, new SummaryStatistics());
41 49 metricId2scores.get(metricId).addValue(score);
  50 + metricId2allScores.putIfAbsent(metricId, Lists.newArrayList());
  51 + metricId2allScores.get(metricId).add(score);
  52 + }
  53 +
  54 + public double[] getResults(String metricId) {
  55 + return ArrayUtils.toPrimitive(metricId2allScores.get(metricId).toArray(new Double[0]));
42 56 }
43 57  
44   - public Result getResult(String metricId) {
  58 + public SummaryStatistics getSummaryStatistics(String metricId) {
45 59 SummaryStatistics scores = metricId2scores.get(metricId);
46   - return new Result(scores.getMean(), scores.getStandardDeviation(), (int) scores.getN());
  60 + return scores;
47 61 }
48 62 }
49 63 }
... ...
src/main/java/pl/waw/ipipan/zil/summ/eval/rouge/RougeN.java
... ... @@ -34,6 +34,31 @@ public class RougeN {
34 34 return 1.0 * numerator / denominator;
35 35 }
36 36  
  37 + public static double scoreMax(int n, String systemSummary, Set<String> goldSummaries) {
  38 + Multiset<List<String>> systemNgrams = HashMultiset.create();
  39 + countNgrams(systemNgrams, n, systemSummary);
  40 +
  41 + double maxScore = -1.0;
  42 + for (String goldSummary : goldSummaries) {
  43 + Multiset<List<String>> goldNgrams = HashMultiset.create();
  44 + countNgrams(goldNgrams, n, goldSummary);
  45 +
  46 + int numerator = 0;
  47 + int denominator = 0;
  48 + for (List<String> goldNgram : goldNgrams) {
  49 + int goldCount = goldNgrams.count(goldNgram);
  50 + int systemCount = systemNgrams.count(goldNgram);
  51 + numerator += Math.min(goldCount, systemCount);
  52 + denominator += goldCount;
  53 + }
  54 +
  55 + double summaryScore = 1.0 * numerator / denominator;
  56 + maxScore = Math.max(summaryScore, maxScore);
  57 + }
  58 +
  59 + return maxScore;
  60 + }
  61 +
37 62 protected static void countNgrams(Multiset<List<String>> ngrams, int n, String text) {
38 63 List<String> tokens = tokenize(text).stream().map(String::toLowerCase).collect(Collectors.toList());
39 64 NgramQueue<String> ngram = new NgramQueue<>(n);
... ...