Commit c6d4dcbe7d1026f1268c05672b892611a729a291
1 parent
0b7a08f9
bugfixes and refactor
Showing
3 changed files
with
166 additions
and
129 deletions
src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java
@@ -6,20 +6,14 @@ import java.nio.file.Files; | @@ -6,20 +6,14 @@ import java.nio.file.Files; | ||
6 | import java.nio.file.Paths; | 6 | import java.nio.file.Paths; |
7 | import java.util.HashMap; | 7 | import java.util.HashMap; |
8 | import java.util.HashSet; | 8 | import java.util.HashSet; |
9 | -import java.util.List; | ||
10 | import java.util.Map; | 9 | import java.util.Map; |
11 | import java.util.Set; | 10 | import java.util.Set; |
12 | -import java.util.TreeSet; | ||
13 | import java.util.stream.Collectors; | 11 | import java.util.stream.Collectors; |
14 | 12 | ||
15 | -import org.apache.commons.math3.distribution.NormalDistribution; | ||
16 | -import org.apache.commons.math3.stat.StatUtils; | ||
17 | -import org.apache.commons.math3.stat.descriptive.SummaryStatistics; | ||
18 | -import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; | ||
19 | -import org.apache.commons.math3.stat.inference.TTest; | ||
20 | import org.apache.logging.log4j.LogManager; | 13 | import org.apache.logging.log4j.LogManager; |
21 | import org.apache.logging.log4j.Logger; | 14 | import org.apache.logging.log4j.Logger; |
22 | 15 | ||
16 | +import pl.waw.ipipan.zil.summ.eval.output.OutputHelper; | ||
23 | import pl.waw.ipipan.zil.summ.eval.rouge.RougeN; | 17 | import pl.waw.ipipan.zil.summ.eval.rouge.RougeN; |
24 | 18 | ||
25 | public class Main { | 19 | public class Main { |
@@ -33,7 +27,7 @@ public class Main { | @@ -33,7 +27,7 @@ public class Main { | ||
33 | 27 | ||
34 | public static void main(String[] args) { | 28 | public static void main(String[] args) { |
35 | if (args.length != 2) { | 29 | if (args.length != 2) { |
36 | - LOG.error("Wrong number of arguments!"); | 30 | + LOG.error("Wrong number of arguments! Try: Main goldDir sysDir"); |
37 | return; | 31 | return; |
38 | } | 32 | } |
39 | 33 | ||
@@ -53,6 +47,21 @@ public class Main { | @@ -53,6 +47,21 @@ public class Main { | ||
53 | } | 47 | } |
54 | LOG.info(id2goldSummaryFiles.size() + " gold and sys text(s) loaded."); | 48 | LOG.info(id2goldSummaryFiles.size() + " gold and sys text(s) loaded."); |
55 | 49 | ||
50 | + ResultAccumulator ra; | ||
51 | + try { | ||
52 | + ra = countResults(id2goldSummaryFiles, id2sysSummaryFiles); | ||
53 | + } catch (Exception e) { | ||
54 | + LOG.error("Error counting results: " + e); | ||
55 | + return; | ||
56 | + } | ||
57 | + | ||
58 | + OutputHelper.checkNormality(ra, P_VALUE); | ||
59 | + OutputHelper.printResults(ra, LATEX); | ||
60 | + OutputHelper.printSignificance(ra, P_VALUE, LATEX); | ||
61 | + } | ||
62 | + | ||
63 | + private static ResultAccumulator countResults(Map<String, Set<File>> id2goldSummaryFiles, | ||
64 | + Map<String, Set<File>> id2sysSummaryFiles) throws Exception { | ||
56 | ResultAccumulator ra = new ResultAccumulator(); | 65 | ResultAccumulator ra = new ResultAccumulator(); |
57 | int totalComparedTexts = 0; | 66 | int totalComparedTexts = 0; |
58 | 67 | ||
@@ -63,132 +72,27 @@ public class Main { | @@ -63,132 +72,27 @@ public class Main { | ||
63 | Set<String> goldSummaries = goldSummaryFiles.stream().map(f -> loadSummaryFromFile(f)) | 72 | Set<String> goldSummaries = goldSummaryFiles.stream().map(f -> loadSummaryFromFile(f)) |
64 | .collect(Collectors.toSet()); | 73 | .collect(Collectors.toSet()); |
65 | if (goldSummaries.contains(null)) { | 74 | if (goldSummaries.contains(null)) { |
66 | - LOG.error("Empty gold summary for id: " + id); | ||
67 | - return; | 75 | + throw new Exception("Empty gold summary for id: " + id); |
68 | } | 76 | } |
69 | for (File sysSummaryFile : sysSummaryFiles) { | 77 | for (File sysSummaryFile : sysSummaryFiles) { |
70 | String systemSummary = loadSummaryFromFile(sysSummaryFile); | 78 | String systemSummary = loadSummaryFromFile(sysSummaryFile); |
71 | String systemId = getSystemIdFromFilename(sysSummaryFile); | 79 | String systemId = getSystemIdFromFilename(sysSummaryFile); |
72 | if (systemSummary == null) { | 80 | if (systemSummary == null) { |
73 | - LOG.error("Empty system summary for id: " + id); | ||
74 | - return; | 81 | + throw new Exception("Empty system summary for id: " + id); |
75 | } | 82 | } |
76 | for (int i : ROUGE_N) { | 83 | for (int i : ROUGE_N) { |
77 | double score = RougeN.score(i, systemSummary, goldSummaries); | 84 | double score = RougeN.score(i, systemSummary, goldSummaries); |
78 | - ra.logScore("ROUGE_" + i, systemId, score); | 85 | + ra.logScore("ROUGE_" + i, systemId, score, id); |
79 | 86 | ||
80 | double maxScore = RougeN.scoreMax(i, systemSummary, goldSummaries); | 87 | double maxScore = RougeN.scoreMax(i, systemSummary, goldSummaries); |
81 | - ra.logScore("ROUGE_M_" + i, systemId, maxScore); | 88 | + ra.logScore("ROUGE_M_" + i, systemId, maxScore, id); |
82 | } | 89 | } |
83 | } | 90 | } |
84 | totalComparedTexts++; | 91 | totalComparedTexts++; |
85 | } | 92 | } |
86 | 93 | ||
87 | LOG.info(totalComparedTexts + " - total compared texts."); | 94 | LOG.info(totalComparedTexts + " - total compared texts."); |
88 | - | ||
89 | - checkNormality(ra); | ||
90 | - printResults(ra, LATEX); | ||
91 | - printSignificance(ra, LATEX); | ||
92 | - } | ||
93 | - | ||
94 | - private static void checkNormality(ResultAccumulator ra) { | ||
95 | - int normal = 0; | ||
96 | - int all = 0; | ||
97 | - for (String systemId : ra.getSystemIds()) | ||
98 | - for (String metricId : ra.getMetricIds()) { | ||
99 | - all++; | ||
100 | - double[] standardized = StatUtils.normalize(ra.getSystemResults(systemId, metricId)); | ||
101 | - double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1), | ||
102 | - standardized); | ||
103 | - if (p > P_VALUE) | ||
104 | - normal++; | ||
105 | - else | ||
106 | - LOG.info("System " + systemId + " for " + metricId + " rejected K-S normality test with p-value = " | ||
107 | - + p); | ||
108 | - } | ||
109 | - | ||
110 | - LOG.info(normal + " normal of all " + all); | ||
111 | - } | ||
112 | - | ||
113 | - private static void printResults(ResultAccumulator ra, boolean latex) { | ||
114 | - System.out.println("###################### Results #############################"); | ||
115 | - | ||
116 | - if (latex) | ||
117 | - System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); | ||
118 | - | ||
119 | - System.out.print("System"); | ||
120 | - for (String metricId : ra.getMetricIds()) | ||
121 | - System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); | ||
122 | - if (latex) | ||
123 | - System.out.print("\\\\\\hline\\hline"); | ||
124 | - System.out.println(); | ||
125 | - | ||
126 | - for (String systemId : ra.getSystemIds()) { | ||
127 | - System.out.print(systemId); | ||
128 | - for (String metricId : ra.getMetricIds()) { | ||
129 | - SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId); | ||
130 | - | ||
131 | - System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(), | ||
132 | - r.getStandardDeviation())); | ||
133 | - } | ||
134 | - if (latex) | ||
135 | - System.out.print("\\\\\\hline"); | ||
136 | - System.out.println(); | ||
137 | - } | ||
138 | - if (latex) | ||
139 | - System.out.println("\\end{tabular}"); | ||
140 | - } | ||
141 | - | ||
142 | - private static void printSignificance(ResultAccumulator ra, boolean latex) { | ||
143 | - System.out.println("###################### Stat. sign. #############################"); | ||
144 | - int i = 1; | ||
145 | - Map<String, Integer> sys2id = new HashMap<>(); | ||
146 | - for (String systemId : ra.getSystemIds()) | ||
147 | - sys2id.put(systemId, i++); | ||
148 | - if (latex) | ||
149 | - System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); | ||
150 | - System.out.print("System"); | ||
151 | - for (String metricId : ra.getMetricIds()) | ||
152 | - System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); | ||
153 | - if (latex) | ||
154 | - System.out.print("\\\\\\hline\\hline"); | ||
155 | - System.out.println(); | ||
156 | - | ||
157 | - for (String systemId : ra.getSystemIds()) { | ||
158 | - System.out.print(sys2id.get(systemId) + ". " + systemId); | ||
159 | - for (String metricId : ra.getMetricIds()) { | ||
160 | - Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId); | ||
161 | - List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted() | ||
162 | - .collect(Collectors.toList()); | ||
163 | - | ||
164 | - System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids))); | ||
165 | - } | ||
166 | - if (latex) | ||
167 | - System.out.print("\\\\\\hline"); | ||
168 | - System.out.println(); | ||
169 | - } | ||
170 | - if (latex) | ||
171 | - System.out.println("\\end{tabular}"); | ||
172 | - } | ||
173 | - | ||
174 | - private static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId) { | ||
175 | - TreeSet<String> signifWorseSystems = new TreeSet<>(); | ||
176 | - for (String systemId2 : ra.getSystemIds()) | ||
177 | - if (!systemId.equals(systemId2)) { | ||
178 | - | ||
179 | - SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId); | ||
180 | - SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId); | ||
181 | - | ||
182 | - if (scores1.getMean() <= scores2.getMean()) | ||
183 | - continue; | ||
184 | - | ||
185 | - double p = new TTest().tTest(scores1, scores2) / 2; | ||
186 | - | ||
187 | - if (p < P_VALUE) | ||
188 | - signifWorseSystems.add(systemId2); | ||
189 | - | ||
190 | - } | ||
191 | - return signifWorseSystems; | 95 | + return ra; |
192 | } | 96 | } |
193 | 97 | ||
194 | private static String getSystemIdFromFilename(File sysSummaryFile) { | 98 | private static String getSystemIdFromFilename(File sysSummaryFile) { |
src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java
@@ -7,20 +7,19 @@ import java.util.Map; | @@ -7,20 +7,19 @@ import java.util.Map; | ||
7 | import java.util.Set; | 7 | import java.util.Set; |
8 | import java.util.stream.Collectors; | 8 | import java.util.stream.Collectors; |
9 | 9 | ||
10 | -import org.apache.commons.lang3.ArrayUtils; | ||
11 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics; | 10 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics; |
12 | 11 | ||
13 | -import com.google.common.collect.Lists; | 12 | +import com.google.common.collect.Maps; |
14 | 13 | ||
15 | public class ResultAccumulator { | 14 | public class ResultAccumulator { |
16 | 15 | ||
17 | private Map<String, CumulativeResult> systemId2result = new HashMap<>(); | 16 | private Map<String, CumulativeResult> systemId2result = new HashMap<>(); |
18 | private Set<String> metricIds = new HashSet<>(); | 17 | private Set<String> metricIds = new HashSet<>(); |
19 | 18 | ||
20 | - public void logScore(String metricId, String systemId, double score) { | 19 | + public void logScore(String metricId, String systemId, double score, String textId) { |
21 | systemId2result.putIfAbsent(systemId, new CumulativeResult()); | 20 | systemId2result.putIfAbsent(systemId, new CumulativeResult()); |
22 | CumulativeResult cr = systemId2result.get(systemId); | 21 | CumulativeResult cr = systemId2result.get(systemId); |
23 | - cr.addScore(metricId, score); | 22 | + cr.addScore(metricId, score, textId); |
24 | metricIds.add(metricId); | 23 | metricIds.add(metricId); |
25 | } | 24 | } |
26 | 25 | ||
@@ -32,7 +31,7 @@ public class ResultAccumulator { | @@ -32,7 +31,7 @@ public class ResultAccumulator { | ||
32 | return metricIds.stream().sorted().collect(Collectors.toList()); | 31 | return metricIds.stream().sorted().collect(Collectors.toList()); |
33 | } | 32 | } |
34 | 33 | ||
35 | - public double[] getSystemResults(String systemId, String metricId) { | 34 | + public Map<String, Double> getSystemResults(String systemId, String metricId) { |
36 | return systemId2result.get(systemId).getResults(metricId); | 35 | return systemId2result.get(systemId).getResults(metricId); |
37 | } | 36 | } |
38 | 37 | ||
@@ -42,17 +41,17 @@ public class ResultAccumulator { | @@ -42,17 +41,17 @@ public class ResultAccumulator { | ||
42 | 41 | ||
43 | private class CumulativeResult { | 42 | private class CumulativeResult { |
44 | Map<String, SummaryStatistics> metricId2scores = new HashMap<>(); | 43 | Map<String, SummaryStatistics> metricId2scores = new HashMap<>(); |
45 | - Map<String, List<Double>> metricId2allScores = new HashMap<>(); | 44 | + Map<String, Map<String, Double>> metricId2textId2scores = new HashMap<>(); |
46 | 45 | ||
47 | - public void addScore(String metricId, double score) { | 46 | + public void addScore(String metricId, double score, String textId) { |
48 | metricId2scores.putIfAbsent(metricId, new SummaryStatistics()); | 47 | metricId2scores.putIfAbsent(metricId, new SummaryStatistics()); |
49 | metricId2scores.get(metricId).addValue(score); | 48 | metricId2scores.get(metricId).addValue(score); |
50 | - metricId2allScores.putIfAbsent(metricId, Lists.newArrayList()); | ||
51 | - metricId2allScores.get(metricId).add(score); | 49 | + metricId2textId2scores.putIfAbsent(metricId, Maps.newHashMap()); |
50 | + metricId2textId2scores.get(metricId).put(textId, score); | ||
52 | } | 51 | } |
53 | 52 | ||
54 | - public double[] getResults(String metricId) { | ||
55 | - return ArrayUtils.toPrimitive(metricId2allScores.get(metricId).toArray(new Double[0])); | 53 | + public Map<String, Double> getResults(String metricId) { |
54 | + return metricId2textId2scores.get(metricId); | ||
56 | } | 55 | } |
57 | 56 | ||
58 | public SummaryStatistics getSummaryStatistics(String metricId) { | 57 | public SummaryStatistics getSummaryStatistics(String metricId) { |
src/main/java/pl/waw/ipipan/zil/summ/eval/output/OutputHelper.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.eval.output; | ||
2 | + | ||
3 | +import java.util.HashMap; | ||
4 | +import java.util.List; | ||
5 | +import java.util.Map; | ||
6 | +import java.util.Set; | ||
7 | +import java.util.TreeSet; | ||
8 | +import java.util.stream.Collectors; | ||
9 | + | ||
10 | +import org.apache.commons.lang3.ArrayUtils; | ||
11 | +import org.apache.commons.math3.distribution.NormalDistribution; | ||
12 | +import org.apache.commons.math3.stat.StatUtils; | ||
13 | +import org.apache.commons.math3.stat.descriptive.SummaryStatistics; | ||
14 | +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; | ||
15 | +import org.apache.commons.math3.stat.inference.TTest; | ||
16 | + | ||
17 | +import pl.waw.ipipan.zil.summ.eval.ResultAccumulator; | ||
18 | + | ||
19 | +public class OutputHelper { | ||
20 | + | ||
21 | + public static void checkNormality(ResultAccumulator ra, double pValue) { | ||
22 | + int normal = 0; | ||
23 | + int all = 0; | ||
24 | + for (String systemId : ra.getSystemIds()) | ||
25 | + for (String metricId : ra.getMetricIds()) { | ||
26 | + all++; | ||
27 | + Map<String, Double> systemResults = ra.getSystemResults(systemId, metricId); | ||
28 | + double[] values = ArrayUtils.toPrimitive(systemResults.values().toArray(new Double[0])); | ||
29 | + double[] standardized = StatUtils.normalize(values); | ||
30 | + if (standardized.length < 2) { | ||
31 | + System.out.println("Skipping normality check because at least 2 samples are required."); | ||
32 | + return; | ||
33 | + } | ||
34 | + double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1), | ||
35 | + standardized); | ||
36 | + if (p > pValue) | ||
37 | + normal++; | ||
38 | + else | ||
39 | + System.out.println("System " + systemId + " for " + metricId | ||
40 | + + " rejected K-S normality test with p-value = " + p); | ||
41 | + } | ||
42 | + | ||
43 | + System.out.println(normal + " normal of all " + all); | ||
44 | + } | ||
45 | + | ||
46 | + public static void printResults(ResultAccumulator ra, boolean latex) { | ||
47 | + System.out.println("###################### Results #############################"); | ||
48 | + | ||
49 | + if (latex) | ||
50 | + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); | ||
51 | + | ||
52 | + System.out.print("System"); | ||
53 | + for (String metricId : ra.getMetricIds()) | ||
54 | + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); | ||
55 | + if (latex) | ||
56 | + System.out.print("\\\\\\hline\\hline"); | ||
57 | + System.out.println(); | ||
58 | + | ||
59 | + for (String systemId : ra.getSystemIds()) { | ||
60 | + System.out.print(systemId); | ||
61 | + for (String metricId : ra.getMetricIds()) { | ||
62 | + SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId); | ||
63 | + | ||
64 | + System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(), | ||
65 | + r.getStandardDeviation())); | ||
66 | + } | ||
67 | + if (latex) | ||
68 | + System.out.print("\\\\\\hline"); | ||
69 | + System.out.println(); | ||
70 | + } | ||
71 | + if (latex) | ||
72 | + System.out.println("\\end{tabular}"); | ||
73 | + } | ||
74 | + | ||
75 | + public static void printSignificance(ResultAccumulator ra, double pValue, boolean latex) { | ||
76 | + if (ra.getSystemResults(ra.getSystemIds().iterator().next(), ra.getMetricIds().iterator().next()).values() | ||
77 | + .size() < 2) { | ||
78 | + System.out.println("Skipping significant difference check because at least 2 samples are required."); | ||
79 | + return; | ||
80 | + } | ||
81 | + | ||
82 | + System.out.println("###################### Stat. sign. #############################"); | ||
83 | + int i = 1; | ||
84 | + Map<String, Integer> sys2id = new HashMap<>(); | ||
85 | + for (String systemId : ra.getSystemIds()) | ||
86 | + sys2id.put(systemId, i++); | ||
87 | + if (latex) | ||
88 | + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); | ||
89 | + System.out.print("System"); | ||
90 | + for (String metricId : ra.getMetricIds()) | ||
91 | + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); | ||
92 | + if (latex) | ||
93 | + System.out.print("\\\\\\hline\\hline"); | ||
94 | + System.out.println(); | ||
95 | + | ||
96 | + for (String systemId : ra.getSystemIds()) { | ||
97 | + System.out.print(sys2id.get(systemId) + ". " + systemId); | ||
98 | + | ||
99 | + for (String metricId : ra.getMetricIds()) { | ||
100 | + Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId, pValue); | ||
101 | + List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted() | ||
102 | + .collect(Collectors.toList()); | ||
103 | + | ||
104 | + System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids))); | ||
105 | + } | ||
106 | + if (latex) | ||
107 | + System.out.print("\\\\\\hline"); | ||
108 | + System.out.println(); | ||
109 | + } | ||
110 | + if (latex) | ||
111 | + System.out.println("\\end{tabular}"); | ||
112 | + } | ||
113 | + | ||
114 | + public static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId, | ||
115 | + double pValue) { | ||
116 | + TreeSet<String> signifWorseSystems = new TreeSet<>(); | ||
117 | + for (String systemId2 : ra.getSystemIds()) | ||
118 | + if (!systemId.equals(systemId2)) { | ||
119 | + | ||
120 | + SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId); | ||
121 | + SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId); | ||
122 | + | ||
123 | + if (scores1.getMean() <= scores2.getMean()) | ||
124 | + continue; | ||
125 | + | ||
126 | + double p = new TTest().tTest(scores1, scores2) / 2; | ||
127 | + | ||
128 | + if (p < pValue) | ||
129 | + signifWorseSystems.add(systemId2); | ||
130 | + } | ||
131 | + return signifWorseSystems; | ||
132 | + } | ||
133 | + | ||
134 | +} |