Commit c6d4dcbe7d1026f1268c05672b892611a729a291

Authored by Mateusz Kopeć
1 parent 0b7a08f9

bugfixes and refactor

src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java
@@ -6,20 +6,14 @@ import java.nio.file.Files; @@ -6,20 +6,14 @@ import java.nio.file.Files;
6 import java.nio.file.Paths; 6 import java.nio.file.Paths;
7 import java.util.HashMap; 7 import java.util.HashMap;
8 import java.util.HashSet; 8 import java.util.HashSet;
9 -import java.util.List;  
10 import java.util.Map; 9 import java.util.Map;
11 import java.util.Set; 10 import java.util.Set;
12 -import java.util.TreeSet;  
13 import java.util.stream.Collectors; 11 import java.util.stream.Collectors;
14 12
15 -import org.apache.commons.math3.distribution.NormalDistribution;  
16 -import org.apache.commons.math3.stat.StatUtils;  
17 -import org.apache.commons.math3.stat.descriptive.SummaryStatistics;  
18 -import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;  
19 -import org.apache.commons.math3.stat.inference.TTest;  
20 import org.apache.logging.log4j.LogManager; 13 import org.apache.logging.log4j.LogManager;
21 import org.apache.logging.log4j.Logger; 14 import org.apache.logging.log4j.Logger;
22 15
  16 +import pl.waw.ipipan.zil.summ.eval.output.OutputHelper;
23 import pl.waw.ipipan.zil.summ.eval.rouge.RougeN; 17 import pl.waw.ipipan.zil.summ.eval.rouge.RougeN;
24 18
25 public class Main { 19 public class Main {
@@ -33,7 +27,7 @@ public class Main { @@ -33,7 +27,7 @@ public class Main {
33 27
34 public static void main(String[] args) { 28 public static void main(String[] args) {
35 if (args.length != 2) { 29 if (args.length != 2) {
36 - LOG.error("Wrong number of arguments!"); 30 + LOG.error("Wrong number of arguments! Try: Main goldDir sysDir");
37 return; 31 return;
38 } 32 }
39 33
@@ -53,6 +47,21 @@ public class Main { @@ -53,6 +47,21 @@ public class Main {
53 } 47 }
54 LOG.info(id2goldSummaryFiles.size() + " gold and sys text(s) loaded."); 48 LOG.info(id2goldSummaryFiles.size() + " gold and sys text(s) loaded.");
55 49
  50 + ResultAccumulator ra;
  51 + try {
  52 + ra = countResults(id2goldSummaryFiles, id2sysSummaryFiles);
  53 + } catch (Exception e) {
  54 + LOG.error("Error counting results: " + e);
  55 + return;
  56 + }
  57 +
  58 + OutputHelper.checkNormality(ra, P_VALUE);
  59 + OutputHelper.printResults(ra, LATEX);
  60 + OutputHelper.printSignificance(ra, P_VALUE, LATEX);
  61 + }
  62 +
  63 + private static ResultAccumulator countResults(Map<String, Set<File>> id2goldSummaryFiles,
  64 + Map<String, Set<File>> id2sysSummaryFiles) throws Exception {
56 ResultAccumulator ra = new ResultAccumulator(); 65 ResultAccumulator ra = new ResultAccumulator();
57 int totalComparedTexts = 0; 66 int totalComparedTexts = 0;
58 67
@@ -63,132 +72,27 @@ public class Main { @@ -63,132 +72,27 @@ public class Main {
63 Set<String> goldSummaries = goldSummaryFiles.stream().map(f -> loadSummaryFromFile(f)) 72 Set<String> goldSummaries = goldSummaryFiles.stream().map(f -> loadSummaryFromFile(f))
64 .collect(Collectors.toSet()); 73 .collect(Collectors.toSet());
65 if (goldSummaries.contains(null)) { 74 if (goldSummaries.contains(null)) {
66 - LOG.error("Empty gold summary for id: " + id);  
67 - return; 75 + throw new Exception("Empty gold summary for id: " + id);
68 } 76 }
69 for (File sysSummaryFile : sysSummaryFiles) { 77 for (File sysSummaryFile : sysSummaryFiles) {
70 String systemSummary = loadSummaryFromFile(sysSummaryFile); 78 String systemSummary = loadSummaryFromFile(sysSummaryFile);
71 String systemId = getSystemIdFromFilename(sysSummaryFile); 79 String systemId = getSystemIdFromFilename(sysSummaryFile);
72 if (systemSummary == null) { 80 if (systemSummary == null) {
73 - LOG.error("Empty system summary for id: " + id);  
74 - return; 81 + throw new Exception("Empty system summary for id: " + id);
75 } 82 }
76 for (int i : ROUGE_N) { 83 for (int i : ROUGE_N) {
77 double score = RougeN.score(i, systemSummary, goldSummaries); 84 double score = RougeN.score(i, systemSummary, goldSummaries);
78 - ra.logScore("ROUGE_" + i, systemId, score); 85 + ra.logScore("ROUGE_" + i, systemId, score, id);
79 86
80 double maxScore = RougeN.scoreMax(i, systemSummary, goldSummaries); 87 double maxScore = RougeN.scoreMax(i, systemSummary, goldSummaries);
81 - ra.logScore("ROUGE_M_" + i, systemId, maxScore); 88 + ra.logScore("ROUGE_M_" + i, systemId, maxScore, id);
82 } 89 }
83 } 90 }
84 totalComparedTexts++; 91 totalComparedTexts++;
85 } 92 }
86 93
87 LOG.info(totalComparedTexts + " - total compared texts."); 94 LOG.info(totalComparedTexts + " - total compared texts.");
88 -  
89 - checkNormality(ra);  
90 - printResults(ra, LATEX);  
91 - printSignificance(ra, LATEX);  
92 - }  
93 -  
94 - private static void checkNormality(ResultAccumulator ra) {  
95 - int normal = 0;  
96 - int all = 0;  
97 - for (String systemId : ra.getSystemIds())  
98 - for (String metricId : ra.getMetricIds()) {  
99 - all++;  
100 - double[] standardized = StatUtils.normalize(ra.getSystemResults(systemId, metricId));  
101 - double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1),  
102 - standardized);  
103 - if (p > P_VALUE)  
104 - normal++;  
105 - else  
106 - LOG.info("System " + systemId + " for " + metricId + " rejected K-S normality test with p-value = "  
107 - + p);  
108 - }  
109 -  
110 - LOG.info(normal + " normal of all " + all);  
111 - }  
112 -  
113 - private static void printResults(ResultAccumulator ra, boolean latex) {  
114 - System.out.println("###################### Results #############################");  
115 -  
116 - if (latex)  
117 - System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");  
118 -  
119 - System.out.print("System");  
120 - for (String metricId : ra.getMetricIds())  
121 - System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));  
122 - if (latex)  
123 - System.out.print("\\\\\\hline\\hline");  
124 - System.out.println();  
125 -  
126 - for (String systemId : ra.getSystemIds()) {  
127 - System.out.print(systemId);  
128 - for (String metricId : ra.getMetricIds()) {  
129 - SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId);  
130 -  
131 - System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(),  
132 - r.getStandardDeviation()));  
133 - }  
134 - if (latex)  
135 - System.out.print("\\\\\\hline");  
136 - System.out.println();  
137 - }  
138 - if (latex)  
139 - System.out.println("\\end{tabular}");  
140 - }  
141 -  
142 - private static void printSignificance(ResultAccumulator ra, boolean latex) {  
143 - System.out.println("###################### Stat. sign. #############################");  
144 - int i = 1;  
145 - Map<String, Integer> sys2id = new HashMap<>();  
146 - for (String systemId : ra.getSystemIds())  
147 - sys2id.put(systemId, i++);  
148 - if (latex)  
149 - System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");  
150 - System.out.print("System");  
151 - for (String metricId : ra.getMetricIds())  
152 - System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));  
153 - if (latex)  
154 - System.out.print("\\\\\\hline\\hline");  
155 - System.out.println();  
156 -  
157 - for (String systemId : ra.getSystemIds()) {  
158 - System.out.print(sys2id.get(systemId) + ". " + systemId);  
159 - for (String metricId : ra.getMetricIds()) {  
160 - Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId);  
161 - List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted()  
162 - .collect(Collectors.toList());  
163 -  
164 - System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids)));  
165 - }  
166 - if (latex)  
167 - System.out.print("\\\\\\hline");  
168 - System.out.println();  
169 - }  
170 - if (latex)  
171 - System.out.println("\\end{tabular}");  
172 - }  
173 -  
174 - private static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId) {  
175 - TreeSet<String> signifWorseSystems = new TreeSet<>();  
176 - for (String systemId2 : ra.getSystemIds())  
177 - if (!systemId.equals(systemId2)) {  
178 -  
179 - SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId);  
180 - SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId);  
181 -  
182 - if (scores1.getMean() <= scores2.getMean())  
183 - continue;  
184 -  
185 - double p = new TTest().tTest(scores1, scores2) / 2;  
186 -  
187 - if (p < P_VALUE)  
188 - signifWorseSystems.add(systemId2);  
189 -  
190 - }  
191 - return signifWorseSystems; 95 + return ra;
192 } 96 }
193 97
194 private static String getSystemIdFromFilename(File sysSummaryFile) { 98 private static String getSystemIdFromFilename(File sysSummaryFile) {
src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java
@@ -7,20 +7,19 @@ import java.util.Map; @@ -7,20 +7,19 @@ import java.util.Map;
7 import java.util.Set; 7 import java.util.Set;
8 import java.util.stream.Collectors; 8 import java.util.stream.Collectors;
9 9
10 -import org.apache.commons.lang3.ArrayUtils;  
11 import org.apache.commons.math3.stat.descriptive.SummaryStatistics; 10 import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
12 11
13 -import com.google.common.collect.Lists; 12 +import com.google.common.collect.Maps;
14 13
15 public class ResultAccumulator { 14 public class ResultAccumulator {
16 15
17 private Map<String, CumulativeResult> systemId2result = new HashMap<>(); 16 private Map<String, CumulativeResult> systemId2result = new HashMap<>();
18 private Set<String> metricIds = new HashSet<>(); 17 private Set<String> metricIds = new HashSet<>();
19 18
20 - public void logScore(String metricId, String systemId, double score) { 19 + public void logScore(String metricId, String systemId, double score, String textId) {
21 systemId2result.putIfAbsent(systemId, new CumulativeResult()); 20 systemId2result.putIfAbsent(systemId, new CumulativeResult());
22 CumulativeResult cr = systemId2result.get(systemId); 21 CumulativeResult cr = systemId2result.get(systemId);
23 - cr.addScore(metricId, score); 22 + cr.addScore(metricId, score, textId);
24 metricIds.add(metricId); 23 metricIds.add(metricId);
25 } 24 }
26 25
@@ -32,7 +31,7 @@ public class ResultAccumulator { @@ -32,7 +31,7 @@ public class ResultAccumulator {
32 return metricIds.stream().sorted().collect(Collectors.toList()); 31 return metricIds.stream().sorted().collect(Collectors.toList());
33 } 32 }
34 33
35 - public double[] getSystemResults(String systemId, String metricId) { 34 + public Map<String, Double> getSystemResults(String systemId, String metricId) {
36 return systemId2result.get(systemId).getResults(metricId); 35 return systemId2result.get(systemId).getResults(metricId);
37 } 36 }
38 37
@@ -42,17 +41,17 @@ public class ResultAccumulator { @@ -42,17 +41,17 @@ public class ResultAccumulator {
42 41
43 private class CumulativeResult { 42 private class CumulativeResult {
44 Map<String, SummaryStatistics> metricId2scores = new HashMap<>(); 43 Map<String, SummaryStatistics> metricId2scores = new HashMap<>();
45 - Map<String, List<Double>> metricId2allScores = new HashMap<>(); 44 + Map<String, Map<String, Double>> metricId2textId2scores = new HashMap<>();
46 45
47 - public void addScore(String metricId, double score) { 46 + public void addScore(String metricId, double score, String textId) {
48 metricId2scores.putIfAbsent(metricId, new SummaryStatistics()); 47 metricId2scores.putIfAbsent(metricId, new SummaryStatistics());
49 metricId2scores.get(metricId).addValue(score); 48 metricId2scores.get(metricId).addValue(score);
50 - metricId2allScores.putIfAbsent(metricId, Lists.newArrayList());  
51 - metricId2allScores.get(metricId).add(score); 49 + metricId2textId2scores.putIfAbsent(metricId, Maps.newHashMap());
  50 + metricId2textId2scores.get(metricId).put(textId, score);
52 } 51 }
53 52
54 - public double[] getResults(String metricId) {  
55 - return ArrayUtils.toPrimitive(metricId2allScores.get(metricId).toArray(new Double[0])); 53 + public Map<String, Double> getResults(String metricId) {
  54 + return metricId2textId2scores.get(metricId);
56 } 55 }
57 56
58 public SummaryStatistics getSummaryStatistics(String metricId) { 57 public SummaryStatistics getSummaryStatistics(String metricId) {
src/main/java/pl/waw/ipipan/zil/summ/eval/output/OutputHelper.java 0 → 100644
  1 +package pl.waw.ipipan.zil.summ.eval.output;
  2 +
  3 +import java.util.HashMap;
  4 +import java.util.List;
  5 +import java.util.Map;
  6 +import java.util.Set;
  7 +import java.util.TreeSet;
  8 +import java.util.stream.Collectors;
  9 +
  10 +import org.apache.commons.lang3.ArrayUtils;
  11 +import org.apache.commons.math3.distribution.NormalDistribution;
  12 +import org.apache.commons.math3.stat.StatUtils;
  13 +import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
  14 +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
  15 +import org.apache.commons.math3.stat.inference.TTest;
  16 +
  17 +import pl.waw.ipipan.zil.summ.eval.ResultAccumulator;
  18 +
  19 +public class OutputHelper {
  20 +
  21 + public static void checkNormality(ResultAccumulator ra, double pValue) {
  22 + int normal = 0;
  23 + int all = 0;
  24 + for (String systemId : ra.getSystemIds())
  25 + for (String metricId : ra.getMetricIds()) {
  26 + all++;
  27 + Map<String, Double> systemResults = ra.getSystemResults(systemId, metricId);
  28 + double[] values = ArrayUtils.toPrimitive(systemResults.values().toArray(new Double[0]));
  29 + double[] standardized = StatUtils.normalize(values);
  30 + if (standardized.length < 2) {
  31 + System.out.println("Skipping normality check because at least 2 samples are required.");
  32 + return;
  33 + }
  34 + double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1),
  35 + standardized);
  36 + if (p > pValue)
  37 + normal++;
  38 + else
  39 + System.out.println("System " + systemId + " for " + metricId
  40 + + " rejected K-S normality test with p-value = " + p);
  41 + }
  42 +
  43 + System.out.println(normal + " normal of all " + all);
  44 + }
  45 +
  46 + public static void printResults(ResultAccumulator ra, boolean latex) {
  47 + System.out.println("###################### Results #############################");
  48 +
  49 + if (latex)
  50 + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");
  51 +
  52 + System.out.print("System");
  53 + for (String metricId : ra.getMetricIds())
  54 + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));
  55 + if (latex)
  56 + System.out.print("\\\\\\hline\\hline");
  57 + System.out.println();
  58 +
  59 + for (String systemId : ra.getSystemIds()) {
  60 + System.out.print(systemId);
  61 + for (String metricId : ra.getMetricIds()) {
  62 + SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId);
  63 +
  64 + System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(),
  65 + r.getStandardDeviation()));
  66 + }
  67 + if (latex)
  68 + System.out.print("\\\\\\hline");
  69 + System.out.println();
  70 + }
  71 + if (latex)
  72 + System.out.println("\\end{tabular}");
  73 + }
  74 +
  75 + public static void printSignificance(ResultAccumulator ra, double pValue, boolean latex) {
  76 + if (ra.getSystemResults(ra.getSystemIds().iterator().next(), ra.getMetricIds().iterator().next()).values()
  77 + .size() < 2) {
  78 + System.out.println("Skipping significant difference check because at least 2 samples are required.");
  79 + return;
  80 + }
  81 +
  82 + System.out.println("###################### Stat. sign. #############################");
  83 + int i = 1;
  84 + Map<String, Integer> sys2id = new HashMap<>();
  85 + for (String systemId : ra.getSystemIds())
  86 + sys2id.put(systemId, i++);
  87 + if (latex)
  88 + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");
  89 + System.out.print("System");
  90 + for (String metricId : ra.getMetricIds())
  91 + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));
  92 + if (latex)
  93 + System.out.print("\\\\\\hline\\hline");
  94 + System.out.println();
  95 +
  96 + for (String systemId : ra.getSystemIds()) {
  97 + System.out.print(sys2id.get(systemId) + ". " + systemId);
  98 +
  99 + for (String metricId : ra.getMetricIds()) {
  100 + Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId, pValue);
  101 + List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted()
  102 + .collect(Collectors.toList());
  103 +
  104 + System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids)));
  105 + }
  106 + if (latex)
  107 + System.out.print("\\\\\\hline");
  108 + System.out.println();
  109 + }
  110 + if (latex)
  111 + System.out.println("\\end{tabular}");
  112 + }
  113 +
  114 + public static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId,
  115 + double pValue) {
  116 + TreeSet<String> signifWorseSystems = new TreeSet<>();
  117 + for (String systemId2 : ra.getSystemIds())
  118 + if (!systemId.equals(systemId2)) {
  119 +
  120 + SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId);
  121 + SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId);
  122 +
  123 + if (scores1.getMean() <= scores2.getMean())
  124 + continue;
  125 +
  126 + double p = new TTest().tTest(scores1, scores2) / 2;
  127 +
  128 + if (p < pValue)
  129 + signifWorseSystems.add(systemId2);
  130 + }
  131 + return signifWorseSystems;
  132 + }
  133 +
  134 +}