Commit c6d4dcbe7d1026f1268c05672b892611a729a291

Authored by Mateusz Kopeć
1 parent 0b7a08f9

bugfixes and refactor

src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java
... ... @@ -6,20 +6,14 @@ import java.nio.file.Files;
6 6 import java.nio.file.Paths;
7 7 import java.util.HashMap;
8 8 import java.util.HashSet;
9   -import java.util.List;
10 9 import java.util.Map;
11 10 import java.util.Set;
12   -import java.util.TreeSet;
13 11 import java.util.stream.Collectors;
14 12  
15   -import org.apache.commons.math3.distribution.NormalDistribution;
16   -import org.apache.commons.math3.stat.StatUtils;
17   -import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
18   -import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
19   -import org.apache.commons.math3.stat.inference.TTest;
20 13 import org.apache.logging.log4j.LogManager;
21 14 import org.apache.logging.log4j.Logger;
22 15  
  16 +import pl.waw.ipipan.zil.summ.eval.output.OutputHelper;
23 17 import pl.waw.ipipan.zil.summ.eval.rouge.RougeN;
24 18  
25 19 public class Main {
... ... @@ -33,7 +27,7 @@ public class Main {
33 27  
34 28 public static void main(String[] args) {
35 29 if (args.length != 2) {
36   - LOG.error("Wrong number of arguments!");
  30 + LOG.error("Wrong number of arguments! Try: Main goldDir sysDir");
37 31 return;
38 32 }
39 33  
... ... @@ -53,6 +47,21 @@ public class Main {
53 47 }
54 48 LOG.info(id2goldSummaryFiles.size() + " gold and sys text(s) loaded.");
55 49  
  50 + ResultAccumulator ra;
  51 + try {
  52 + ra = countResults(id2goldSummaryFiles, id2sysSummaryFiles);
  53 + } catch (Exception e) {
  54 + LOG.error("Error counting results: " + e);
  55 + return;
  56 + }
  57 +
  58 + OutputHelper.checkNormality(ra, P_VALUE);
  59 + OutputHelper.printResults(ra, LATEX);
  60 + OutputHelper.printSignificance(ra, P_VALUE, LATEX);
  61 + }
  62 +
  63 + private static ResultAccumulator countResults(Map<String, Set<File>> id2goldSummaryFiles,
  64 + Map<String, Set<File>> id2sysSummaryFiles) throws Exception {
56 65 ResultAccumulator ra = new ResultAccumulator();
57 66 int totalComparedTexts = 0;
58 67  
... ... @@ -63,132 +72,27 @@ public class Main {
63 72 Set<String> goldSummaries = goldSummaryFiles.stream().map(f -> loadSummaryFromFile(f))
64 73 .collect(Collectors.toSet());
65 74 if (goldSummaries.contains(null)) {
66   - LOG.error("Empty gold summary for id: " + id);
67   - return;
  75 + throw new Exception("Empty gold summary for id: " + id);
68 76 }
69 77 for (File sysSummaryFile : sysSummaryFiles) {
70 78 String systemSummary = loadSummaryFromFile(sysSummaryFile);
71 79 String systemId = getSystemIdFromFilename(sysSummaryFile);
72 80 if (systemSummary == null) {
73   - LOG.error("Empty system summary for id: " + id);
74   - return;
  81 + throw new Exception("Empty system summary for id: " + id);
75 82 }
76 83 for (int i : ROUGE_N) {
77 84 double score = RougeN.score(i, systemSummary, goldSummaries);
78   - ra.logScore("ROUGE_" + i, systemId, score);
  85 + ra.logScore("ROUGE_" + i, systemId, score, id);
79 86  
80 87 double maxScore = RougeN.scoreMax(i, systemSummary, goldSummaries);
81   - ra.logScore("ROUGE_M_" + i, systemId, maxScore);
  88 + ra.logScore("ROUGE_M_" + i, systemId, maxScore, id);
82 89 }
83 90 }
84 91 totalComparedTexts++;
85 92 }
86 93  
87 94 LOG.info(totalComparedTexts + " - total compared texts.");
88   -
89   - checkNormality(ra);
90   - printResults(ra, LATEX);
91   - printSignificance(ra, LATEX);
92   - }
93   -
94   - private static void checkNormality(ResultAccumulator ra) {
95   - int normal = 0;
96   - int all = 0;
97   - for (String systemId : ra.getSystemIds())
98   - for (String metricId : ra.getMetricIds()) {
99   - all++;
100   - double[] standardized = StatUtils.normalize(ra.getSystemResults(systemId, metricId));
101   - double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1),
102   - standardized);
103   - if (p > P_VALUE)
104   - normal++;
105   - else
106   - LOG.info("System " + systemId + " for " + metricId + " rejected K-S normality test with p-value = "
107   - + p);
108   - }
109   -
110   - LOG.info(normal + " normal of all " + all);
111   - }
112   -
113   - private static void printResults(ResultAccumulator ra, boolean latex) {
114   - System.out.println("###################### Results #############################");
115   -
116   - if (latex)
117   - System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");
118   -
119   - System.out.print("System");
120   - for (String metricId : ra.getMetricIds())
121   - System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));
122   - if (latex)
123   - System.out.print("\\\\\\hline\\hline");
124   - System.out.println();
125   -
126   - for (String systemId : ra.getSystemIds()) {
127   - System.out.print(systemId);
128   - for (String metricId : ra.getMetricIds()) {
129   - SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId);
130   -
131   - System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(),
132   - r.getStandardDeviation()));
133   - }
134   - if (latex)
135   - System.out.print("\\\\\\hline");
136   - System.out.println();
137   - }
138   - if (latex)
139   - System.out.println("\\end{tabular}");
140   - }
141   -
142   - private static void printSignificance(ResultAccumulator ra, boolean latex) {
143   - System.out.println("###################### Stat. sign. #############################");
144   - int i = 1;
145   - Map<String, Integer> sys2id = new HashMap<>();
146   - for (String systemId : ra.getSystemIds())
147   - sys2id.put(systemId, i++);
148   - if (latex)
149   - System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");
150   - System.out.print("System");
151   - for (String metricId : ra.getMetricIds())
152   - System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));
153   - if (latex)
154   - System.out.print("\\\\\\hline\\hline");
155   - System.out.println();
156   -
157   - for (String systemId : ra.getSystemIds()) {
158   - System.out.print(sys2id.get(systemId) + ". " + systemId);
159   - for (String metricId : ra.getMetricIds()) {
160   - Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId);
161   - List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted()
162   - .collect(Collectors.toList());
163   -
164   - System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids)));
165   - }
166   - if (latex)
167   - System.out.print("\\\\\\hline");
168   - System.out.println();
169   - }
170   - if (latex)
171   - System.out.println("\\end{tabular}");
172   - }
173   -
174   - private static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId) {
175   - TreeSet<String> signifWorseSystems = new TreeSet<>();
176   - for (String systemId2 : ra.getSystemIds())
177   - if (!systemId.equals(systemId2)) {
178   -
179   - SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId);
180   - SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId);
181   -
182   - if (scores1.getMean() <= scores2.getMean())
183   - continue;
184   -
185   - double p = new TTest().tTest(scores1, scores2) / 2;
186   -
187   - if (p < P_VALUE)
188   - signifWorseSystems.add(systemId2);
189   -
190   - }
191   - return signifWorseSystems;
  95 + return ra;
192 96 }
193 97  
194 98 private static String getSystemIdFromFilename(File sysSummaryFile) {
... ...
src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java
... ... @@ -7,20 +7,19 @@ import java.util.Map;
7 7 import java.util.Set;
8 8 import java.util.stream.Collectors;
9 9  
10   -import org.apache.commons.lang3.ArrayUtils;
11 10 import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
12 11  
13   -import com.google.common.collect.Lists;
  12 +import com.google.common.collect.Maps;
14 13  
15 14 public class ResultAccumulator {
16 15  
17 16 private Map<String, CumulativeResult> systemId2result = new HashMap<>();
18 17 private Set<String> metricIds = new HashSet<>();
19 18  
20   - public void logScore(String metricId, String systemId, double score) {
  19 + public void logScore(String metricId, String systemId, double score, String textId) {
21 20 systemId2result.putIfAbsent(systemId, new CumulativeResult());
22 21 CumulativeResult cr = systemId2result.get(systemId);
23   - cr.addScore(metricId, score);
  22 + cr.addScore(metricId, score, textId);
24 23 metricIds.add(metricId);
25 24 }
26 25  
... ... @@ -32,7 +31,7 @@ public class ResultAccumulator {
32 31 return metricIds.stream().sorted().collect(Collectors.toList());
33 32 }
34 33  
35   - public double[] getSystemResults(String systemId, String metricId) {
  34 + public Map<String, Double> getSystemResults(String systemId, String metricId) {
36 35 return systemId2result.get(systemId).getResults(metricId);
37 36 }
38 37  
... ... @@ -42,17 +41,17 @@ public class ResultAccumulator {
42 41  
43 42 private class CumulativeResult {
44 43 Map<String, SummaryStatistics> metricId2scores = new HashMap<>();
45   - Map<String, List<Double>> metricId2allScores = new HashMap<>();
  44 + Map<String, Map<String, Double>> metricId2textId2scores = new HashMap<>();
46 45  
47   - public void addScore(String metricId, double score) {
  46 + public void addScore(String metricId, double score, String textId) {
48 47 metricId2scores.putIfAbsent(metricId, new SummaryStatistics());
49 48 metricId2scores.get(metricId).addValue(score);
50   - metricId2allScores.putIfAbsent(metricId, Lists.newArrayList());
51   - metricId2allScores.get(metricId).add(score);
  49 + metricId2textId2scores.putIfAbsent(metricId, Maps.newHashMap());
  50 + metricId2textId2scores.get(metricId).put(textId, score);
52 51 }
53 52  
54   - public double[] getResults(String metricId) {
55   - return ArrayUtils.toPrimitive(metricId2allScores.get(metricId).toArray(new Double[0]));
  53 + public Map<String, Double> getResults(String metricId) {
  54 + return metricId2textId2scores.get(metricId);
56 55 }
57 56  
58 57 public SummaryStatistics getSummaryStatistics(String metricId) {
... ...
src/main/java/pl/waw/ipipan/zil/summ/eval/output/OutputHelper.java 0 → 100644
  1 +package pl.waw.ipipan.zil.summ.eval.output;
  2 +
  3 +import java.util.HashMap;
  4 +import java.util.List;
  5 +import java.util.Map;
  6 +import java.util.Set;
  7 +import java.util.TreeSet;
  8 +import java.util.stream.Collectors;
  9 +
  10 +import org.apache.commons.lang3.ArrayUtils;
  11 +import org.apache.commons.math3.distribution.NormalDistribution;
  12 +import org.apache.commons.math3.stat.StatUtils;
  13 +import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
  14 +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
  15 +import org.apache.commons.math3.stat.inference.TTest;
  16 +
  17 +import pl.waw.ipipan.zil.summ.eval.ResultAccumulator;
  18 +
  19 +public class OutputHelper {
  20 +
  21 + public static void checkNormality(ResultAccumulator ra, double pValue) {
  22 + int normal = 0;
  23 + int all = 0;
  24 + for (String systemId : ra.getSystemIds())
  25 + for (String metricId : ra.getMetricIds()) {
  26 + all++;
  27 + Map<String, Double> systemResults = ra.getSystemResults(systemId, metricId);
  28 + double[] values = ArrayUtils.toPrimitive(systemResults.values().toArray(new Double[0]));
  29 + double[] standardized = StatUtils.normalize(values);
  30 + if (standardized.length < 2) {
  31 + System.out.println("Skipping normality check because at least 2 samples are required.");
  32 + return;
  33 + }
  34 + double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1),
  35 + standardized);
  36 + if (p > pValue)
  37 + normal++;
  38 + else
  39 + System.out.println("System " + systemId + " for " + metricId
  40 + + " rejected K-S normality test with p-value = " + p);
  41 + }
  42 +
  43 + System.out.println(normal + " normal of all " + all);
  44 + }
  45 +
  46 + public static void printResults(ResultAccumulator ra, boolean latex) {
  47 + System.out.println("###################### Results #############################");
  48 +
  49 + if (latex)
  50 + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");
  51 +
  52 + System.out.print("System");
  53 + for (String metricId : ra.getMetricIds())
  54 + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));
  55 + if (latex)
  56 + System.out.print("\\\\\\hline\\hline");
  57 + System.out.println();
  58 +
  59 + for (String systemId : ra.getSystemIds()) {
  60 + System.out.print(systemId);
  61 + for (String metricId : ra.getMetricIds()) {
  62 + SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId);
  63 +
  64 + System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(),
  65 + r.getStandardDeviation()));
  66 + }
  67 + if (latex)
  68 + System.out.print("\\\\\\hline");
  69 + System.out.println();
  70 + }
  71 + if (latex)
  72 + System.out.println("\\end{tabular}");
  73 + }
  74 +
  75 + public static void printSignificance(ResultAccumulator ra, double pValue, boolean latex) {
  76 + if (ra.getSystemResults(ra.getSystemIds().iterator().next(), ra.getMetricIds().iterator().next()).values()
  77 + .size() < 2) {
  78 + System.out.println("Skipping significant difference check because at least 2 samples are required.");
  79 + return;
  80 + }
  81 +
  82 + System.out.println("###################### Stat. sign. #############################");
  83 + int i = 1;
  84 + Map<String, Integer> sys2id = new HashMap<>();
  85 + for (String systemId : ra.getSystemIds())
  86 + sys2id.put(systemId, i++);
  87 + if (latex)
  88 + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");
  89 + System.out.print("System");
  90 + for (String metricId : ra.getMetricIds())
  91 + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));
  92 + if (latex)
  93 + System.out.print("\\\\\\hline\\hline");
  94 + System.out.println();
  95 +
  96 + for (String systemId : ra.getSystemIds()) {
  97 + System.out.print(sys2id.get(systemId) + ". " + systemId);
  98 +
  99 + for (String metricId : ra.getMetricIds()) {
  100 + Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId, pValue);
  101 + List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted()
  102 + .collect(Collectors.toList());
  103 +
  104 + System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids)));
  105 + }
  106 + if (latex)
  107 + System.out.print("\\\\\\hline");
  108 + System.out.println();
  109 + }
  110 + if (latex)
  111 + System.out.println("\\end{tabular}");
  112 + }
  113 +
  114 + public static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId,
  115 + double pValue) {
  116 + TreeSet<String> signifWorseSystems = new TreeSet<>();
  117 + for (String systemId2 : ra.getSystemIds())
  118 + if (!systemId.equals(systemId2)) {
  119 +
  120 + SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId);
  121 + SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId);
  122 +
  123 + if (scores1.getMean() <= scores2.getMean())
  124 + continue;
  125 +
  126 + double p = new TTest().tTest(scores1, scores2) / 2;
  127 +
  128 + if (p < pValue)
  129 + signifWorseSystems.add(systemId2);
  130 + }
  131 + return signifWorseSystems;
  132 + }
  133 +
  134 +}
... ...