Commit e8e1d7c9fba7b03c7edd2b71254d5d11b9815a12

Authored by Mateusz Kopeć
1 parent 595d6aa6

significance tests

src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java
... ... @@ -8,8 +8,10 @@ import java.util.HashMap;
8 8 import java.util.HashSet;
9 9 import java.util.Map;
10 10 import java.util.Set;
  11 +import java.util.TreeSet;
11 12 import java.util.stream.Collectors;
12 13  
  14 +import org.apache.commons.math3.distribution.TDistribution;
13 15 import org.apache.logging.log4j.LogManager;
14 16 import org.apache.logging.log4j.Logger;
15 17  
... ... @@ -17,6 +19,8 @@ import pl.waw.ipipan.zil.summ.eval.rouge.RougeN;
17 19  
18 20 public class Main {
19 21  
  22 + private static final double P_VALUE = 0.05;
  23 +
20 24 private static final Logger LOG = LogManager.getLogger(Main.class);
21 25  
22 26 private static final int[] ROUGE_N = new int[] { 1, 2, 3 };
... ... @@ -84,13 +88,39 @@ public class Main {
84 88 for (String systemId : ra.getSystemIds()) {
85 89 System.out.println("############## " + systemId);
86 90 for (String metricId : ra.getMetricIds()) {
  91 + Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId);
87 92 Result r = ra.getSystemResult(systemId, metricId);
88   - System.out.println(String.format("\t%s\t%.2f\t(+/- %.2f)", metricId, r.getMean(), r.getSD()));
  93 +
  94 + System.out.print(String.format("\t%s\t%.2f\t(+/- %.2f), signif. better than: %s", metricId, r.getMean(),
  95 + r.getSD(), String.join(", ", signifWorseSystems)));
  96 + System.out.println();
89 97 }
90 98 System.out.println();
91 99 }
92 100 }
93 101  
  102 + private static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId) {
  103 + TreeSet<String> signifWorseSystems = new TreeSet<>();
  104 + for (String systemId2 : ra.getSystemIds())
  105 + if (!systemId.equals(systemId2)) {
  106 + Result r1 = ra.getSystemResult(systemId, metricId);
  107 + Result r2 = ra.getSystemResult(systemId2, metricId);
  108 +
  109 + double meanDifference = r1.getMean() - r2.getMean();
  110 + double mse = (Math.pow(r1.getSD(), 2) + Math.pow(r2.getSD(), 2)) / 2;
  111 + int degreesOfFreedom = (r1.getN() - 1) + (r2.getN() - 1);
  112 + double t = meanDifference / Math.sqrt(2 * mse / degreesOfFreedom);
  113 +
  114 + TDistribution tDistribution = new TDistribution(degreesOfFreedom);
  115 + double p = 1 - tDistribution.cumulativeProbability(t);
  116 +
  117 + if (p < P_VALUE)
  118 + signifWorseSystems.add(String.format("%s (%.4f)", systemId2, p));
  119 +
  120 + }
  121 + return signifWorseSystems;
  122 + }
  123 +
94 124 private static String getSystemIdFromFilename(File sysSummaryFile) {
95 125 String name = sysSummaryFile.getName();
96 126 return name.substring(name.lastIndexOf("_") + 1, name.lastIndexOf("."));
... ...
src/main/java/pl/waw/ipipan/zil/summ/eval/Result.java
... ... @@ -4,10 +4,16 @@ public class Result {
4 4  
5 5 private double sd;
6 6 private double mean;
  7 + private int n;
7 8  
8   - public Result(double mean, double sd) {
  9 + public Result(double mean, double sd, int n) {
9 10 this.mean = mean;
10 11 this.sd = sd;
  12 + this.n = n;
  13 + }
  14 +
  15 + public int getN() {
  16 + return n;
11 17 }
12 18  
13 19 public double getMean() {
... ...
src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java
... ... @@ -43,7 +43,7 @@ public class ResultAccumulator {
43 43  
44 44 public Result getResult(String metricId) {
45 45 SummaryStatistics scores = metricId2scores.get(metricId);
46   - return new Result(scores.getMean(), scores.getStandardDeviation());
  46 + return new Result(scores.getMean(), scores.getStandardDeviation(), (int) scores.getN());
47 47 }
48 48 }
49 49 }
... ...