Commit e8e1d7c9fba7b03c7edd2b71254d5d11b9815a12
1 parent
595d6aa6
significance tests
Showing
3 changed files
with
39 additions
and
3 deletions
src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java
... | ... | @@ -8,8 +8,10 @@ import java.util.HashMap; |
8 | 8 | import java.util.HashSet; |
9 | 9 | import java.util.Map; |
10 | 10 | import java.util.Set; |
11 | +import java.util.TreeSet; | |
11 | 12 | import java.util.stream.Collectors; |
12 | 13 | |
14 | +import org.apache.commons.math3.distribution.TDistribution; | |
13 | 15 | import org.apache.logging.log4j.LogManager; |
14 | 16 | import org.apache.logging.log4j.Logger; |
15 | 17 | |
... | ... | @@ -17,6 +19,8 @@ import pl.waw.ipipan.zil.summ.eval.rouge.RougeN; |
17 | 19 | |
18 | 20 | public class Main { |
19 | 21 | |
22 | + private static final double P_VALUE = 0.05; | |
23 | + | |
20 | 24 | private static final Logger LOG = LogManager.getLogger(Main.class); |
21 | 25 | |
22 | 26 | private static final int[] ROUGE_N = new int[] { 1, 2, 3 }; |
... | ... | @@ -84,13 +88,39 @@ public class Main { |
84 | 88 | for (String systemId : ra.getSystemIds()) { |
85 | 89 | System.out.println("############## " + systemId); |
86 | 90 | for (String metricId : ra.getMetricIds()) { |
91 | + Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId); | |
87 | 92 | Result r = ra.getSystemResult(systemId, metricId); |
88 | - System.out.println(String.format("\t%s\t%.2f\t(+/- %.2f)", metricId, r.getMean(), r.getSD())); | |
93 | + | |
94 | + System.out.print(String.format("\t%s\t%.2f\t(+/- %.2f), signif. better than: %s", metricId, r.getMean(), | |
95 | + r.getSD(), String.join(", ", signifWorseSystems))); | |
96 | + System.out.println(); | |
89 | 97 | } |
90 | 98 | System.out.println(); |
91 | 99 | } |
92 | 100 | } |
93 | 101 | |
102 | + private static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId) { | |
103 | + TreeSet<String> signifWorseSystems = new TreeSet<>(); | |
104 | + for (String systemId2 : ra.getSystemIds()) | |
105 | + if (!systemId.equals(systemId2)) { | |
106 | + Result r1 = ra.getSystemResult(systemId, metricId); | |
107 | + Result r2 = ra.getSystemResult(systemId2, metricId); | |
108 | + | |
109 | + double meanDifference = r1.getMean() - r2.getMean(); | |
110 | + double mse = (Math.pow(r1.getSD(), 2) + Math.pow(r2.getSD(), 2)) / 2; | |
111 | + int degreesOfFreedom = (r1.getN() - 1) + (r2.getN() - 1); | |
112 | + double t = meanDifference / Math.sqrt(2 * mse / degreesOfFreedom); | |
113 | + | |
114 | + TDistribution tDistribution = new TDistribution(degreesOfFreedom); | |
115 | + double p = 1 - tDistribution.cumulativeProbability(t); | |
116 | + | |
117 | + if (p < P_VALUE) | |
118 | + signifWorseSystems.add(String.format("%s (%.4f)", systemId2, p)); | |
119 | + | |
120 | + } | |
121 | + return signifWorseSystems; | |
122 | + } | |
123 | + | |
94 | 124 | private static String getSystemIdFromFilename(File sysSummaryFile) { |
95 | 125 | String name = sysSummaryFile.getName(); |
96 | 126 | return name.substring(name.lastIndexOf("_") + 1, name.lastIndexOf(".")); |
... | ... |
src/main/java/pl/waw/ipipan/zil/summ/eval/Result.java
... | ... | @@ -4,10 +4,16 @@ public class Result { |
4 | 4 | |
5 | 5 | private double sd; |
6 | 6 | private double mean; |
7 | + private int n; | |
7 | 8 | |
8 | - public Result(double mean, double sd) { | |
9 | + public Result(double mean, double sd, int n) { | |
9 | 10 | this.mean = mean; |
10 | 11 | this.sd = sd; |
12 | + this.n = n; | |
13 | + } | |
14 | + | |
15 | + public int getN() { | |
16 | + return n; | |
11 | 17 | } |
12 | 18 | |
13 | 19 | public double getMean() { |
... | ... |
src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java
... | ... | @@ -43,7 +43,7 @@ public class ResultAccumulator { |
43 | 43 | |
44 | 44 | public Result getResult(String metricId) { |
45 | 45 | SummaryStatistics scores = metricId2scores.get(metricId); |
46 | - return new Result(scores.getMean(), scores.getStandardDeviation()); | |
46 | + return new Result(scores.getMean(), scores.getStandardDeviation(), (int) scores.getN()); | |
47 | 47 | } |
48 | 48 | } |
49 | 49 | } |
... | ... |