Commit c6d4dcbe7d1026f1268c05672b892611a729a291
1 parent
0b7a08f9
bugfixes and refactor
Showing
3 changed files
with
166 additions
and
129 deletions
src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java
... | ... | @@ -6,20 +6,14 @@ import java.nio.file.Files; |
6 | 6 | import java.nio.file.Paths; |
7 | 7 | import java.util.HashMap; |
8 | 8 | import java.util.HashSet; |
9 | -import java.util.List; | |
10 | 9 | import java.util.Map; |
11 | 10 | import java.util.Set; |
12 | -import java.util.TreeSet; | |
13 | 11 | import java.util.stream.Collectors; |
14 | 12 | |
15 | -import org.apache.commons.math3.distribution.NormalDistribution; | |
16 | -import org.apache.commons.math3.stat.StatUtils; | |
17 | -import org.apache.commons.math3.stat.descriptive.SummaryStatistics; | |
18 | -import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; | |
19 | -import org.apache.commons.math3.stat.inference.TTest; | |
20 | 13 | import org.apache.logging.log4j.LogManager; |
21 | 14 | import org.apache.logging.log4j.Logger; |
22 | 15 | |
16 | +import pl.waw.ipipan.zil.summ.eval.output.OutputHelper; | |
23 | 17 | import pl.waw.ipipan.zil.summ.eval.rouge.RougeN; |
24 | 18 | |
25 | 19 | public class Main { |
... | ... | @@ -33,7 +27,7 @@ public class Main { |
33 | 27 | |
34 | 28 | public static void main(String[] args) { |
35 | 29 | if (args.length != 2) { |
36 | - LOG.error("Wrong number of arguments!"); | |
30 | + LOG.error("Wrong number of arguments! Try: Main goldDir sysDir"); | |
37 | 31 | return; |
38 | 32 | } |
39 | 33 | |
... | ... | @@ -53,6 +47,21 @@ public class Main { |
53 | 47 | } |
54 | 48 | LOG.info(id2goldSummaryFiles.size() + " gold and sys text(s) loaded."); |
55 | 49 | |
50 | + ResultAccumulator ra; | |
51 | + try { | |
52 | + ra = countResults(id2goldSummaryFiles, id2sysSummaryFiles); | |
53 | + } catch (Exception e) { | |
54 | + LOG.error("Error counting results: " + e); | |
55 | + return; | |
56 | + } | |
57 | + | |
58 | + OutputHelper.checkNormality(ra, P_VALUE); | |
59 | + OutputHelper.printResults(ra, LATEX); | |
60 | + OutputHelper.printSignificance(ra, P_VALUE, LATEX); | |
61 | + } | |
62 | + | |
63 | + private static ResultAccumulator countResults(Map<String, Set<File>> id2goldSummaryFiles, | |
64 | + Map<String, Set<File>> id2sysSummaryFiles) throws Exception { | |
56 | 65 | ResultAccumulator ra = new ResultAccumulator(); |
57 | 66 | int totalComparedTexts = 0; |
58 | 67 | |
... | ... | @@ -63,132 +72,27 @@ public class Main { |
63 | 72 | Set<String> goldSummaries = goldSummaryFiles.stream().map(f -> loadSummaryFromFile(f)) |
64 | 73 | .collect(Collectors.toSet()); |
65 | 74 | if (goldSummaries.contains(null)) { |
66 | - LOG.error("Empty gold summary for id: " + id); | |
67 | - return; | |
75 | + throw new Exception("Empty gold summary for id: " + id); | |
68 | 76 | } |
69 | 77 | for (File sysSummaryFile : sysSummaryFiles) { |
70 | 78 | String systemSummary = loadSummaryFromFile(sysSummaryFile); |
71 | 79 | String systemId = getSystemIdFromFilename(sysSummaryFile); |
72 | 80 | if (systemSummary == null) { |
73 | - LOG.error("Empty system summary for id: " + id); | |
74 | - return; | |
81 | + throw new Exception("Empty system summary for id: " + id); | |
75 | 82 | } |
76 | 83 | for (int i : ROUGE_N) { |
77 | 84 | double score = RougeN.score(i, systemSummary, goldSummaries); |
78 | - ra.logScore("ROUGE_" + i, systemId, score); | |
85 | + ra.logScore("ROUGE_" + i, systemId, score, id); | |
79 | 86 | |
80 | 87 | double maxScore = RougeN.scoreMax(i, systemSummary, goldSummaries); |
81 | - ra.logScore("ROUGE_M_" + i, systemId, maxScore); | |
88 | + ra.logScore("ROUGE_M_" + i, systemId, maxScore, id); | |
82 | 89 | } |
83 | 90 | } |
84 | 91 | totalComparedTexts++; |
85 | 92 | } |
86 | 93 | |
87 | 94 | LOG.info(totalComparedTexts + " - total compared texts."); |
88 | - | |
89 | - checkNormality(ra); | |
90 | - printResults(ra, LATEX); | |
91 | - printSignificance(ra, LATEX); | |
92 | - } | |
93 | - | |
94 | - private static void checkNormality(ResultAccumulator ra) { | |
95 | - int normal = 0; | |
96 | - int all = 0; | |
97 | - for (String systemId : ra.getSystemIds()) | |
98 | - for (String metricId : ra.getMetricIds()) { | |
99 | - all++; | |
100 | - double[] standardized = StatUtils.normalize(ra.getSystemResults(systemId, metricId)); | |
101 | - double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1), | |
102 | - standardized); | |
103 | - if (p > P_VALUE) | |
104 | - normal++; | |
105 | - else | |
106 | - LOG.info("System " + systemId + " for " + metricId + " rejected K-S normality test with p-value = " | |
107 | - + p); | |
108 | - } | |
109 | - | |
110 | - LOG.info(normal + " normal of all " + all); | |
111 | - } | |
112 | - | |
113 | - private static void printResults(ResultAccumulator ra, boolean latex) { | |
114 | - System.out.println("###################### Results #############################"); | |
115 | - | |
116 | - if (latex) | |
117 | - System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); | |
118 | - | |
119 | - System.out.print("System"); | |
120 | - for (String metricId : ra.getMetricIds()) | |
121 | - System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); | |
122 | - if (latex) | |
123 | - System.out.print("\\\\\\hline\\hline"); | |
124 | - System.out.println(); | |
125 | - | |
126 | - for (String systemId : ra.getSystemIds()) { | |
127 | - System.out.print(systemId); | |
128 | - for (String metricId : ra.getMetricIds()) { | |
129 | - SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId); | |
130 | - | |
131 | - System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(), | |
132 | - r.getStandardDeviation())); | |
133 | - } | |
134 | - if (latex) | |
135 | - System.out.print("\\\\\\hline"); | |
136 | - System.out.println(); | |
137 | - } | |
138 | - if (latex) | |
139 | - System.out.println("\\end{tabular}"); | |
140 | - } | |
141 | - | |
142 | - private static void printSignificance(ResultAccumulator ra, boolean latex) { | |
143 | - System.out.println("###################### Stat. sign. #############################"); | |
144 | - int i = 1; | |
145 | - Map<String, Integer> sys2id = new HashMap<>(); | |
146 | - for (String systemId : ra.getSystemIds()) | |
147 | - sys2id.put(systemId, i++); | |
148 | - if (latex) | |
149 | - System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); | |
150 | - System.out.print("System"); | |
151 | - for (String metricId : ra.getMetricIds()) | |
152 | - System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); | |
153 | - if (latex) | |
154 | - System.out.print("\\\\\\hline\\hline"); | |
155 | - System.out.println(); | |
156 | - | |
157 | - for (String systemId : ra.getSystemIds()) { | |
158 | - System.out.print(sys2id.get(systemId) + ". " + systemId); | |
159 | - for (String metricId : ra.getMetricIds()) { | |
160 | - Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId); | |
161 | - List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted() | |
162 | - .collect(Collectors.toList()); | |
163 | - | |
164 | - System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids))); | |
165 | - } | |
166 | - if (latex) | |
167 | - System.out.print("\\\\\\hline"); | |
168 | - System.out.println(); | |
169 | - } | |
170 | - if (latex) | |
171 | - System.out.println("\\end{tabular}"); | |
172 | - } | |
173 | - | |
174 | - private static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId) { | |
175 | - TreeSet<String> signifWorseSystems = new TreeSet<>(); | |
176 | - for (String systemId2 : ra.getSystemIds()) | |
177 | - if (!systemId.equals(systemId2)) { | |
178 | - | |
179 | - SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId); | |
180 | - SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId); | |
181 | - | |
182 | - if (scores1.getMean() <= scores2.getMean()) | |
183 | - continue; | |
184 | - | |
185 | - double p = new TTest().tTest(scores1, scores2) / 2; | |
186 | - | |
187 | - if (p < P_VALUE) | |
188 | - signifWorseSystems.add(systemId2); | |
189 | - | |
190 | - } | |
191 | - return signifWorseSystems; | |
95 | + return ra; | |
192 | 96 | } |
193 | 97 | |
194 | 98 | private static String getSystemIdFromFilename(File sysSummaryFile) { |
... | ... |
src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java
... | ... | @@ -7,20 +7,19 @@ import java.util.Map; |
7 | 7 | import java.util.Set; |
8 | 8 | import java.util.stream.Collectors; |
9 | 9 | |
10 | -import org.apache.commons.lang3.ArrayUtils; | |
11 | 10 | import org.apache.commons.math3.stat.descriptive.SummaryStatistics; |
12 | 11 | |
13 | -import com.google.common.collect.Lists; | |
12 | +import com.google.common.collect.Maps; | |
14 | 13 | |
15 | 14 | public class ResultAccumulator { |
16 | 15 | |
17 | 16 | private Map<String, CumulativeResult> systemId2result = new HashMap<>(); |
18 | 17 | private Set<String> metricIds = new HashSet<>(); |
19 | 18 | |
20 | - public void logScore(String metricId, String systemId, double score) { | |
19 | + public void logScore(String metricId, String systemId, double score, String textId) { | |
21 | 20 | systemId2result.putIfAbsent(systemId, new CumulativeResult()); |
22 | 21 | CumulativeResult cr = systemId2result.get(systemId); |
23 | - cr.addScore(metricId, score); | |
22 | + cr.addScore(metricId, score, textId); | |
24 | 23 | metricIds.add(metricId); |
25 | 24 | } |
26 | 25 | |
... | ... | @@ -32,7 +31,7 @@ public class ResultAccumulator { |
32 | 31 | return metricIds.stream().sorted().collect(Collectors.toList()); |
33 | 32 | } |
34 | 33 | |
35 | - public double[] getSystemResults(String systemId, String metricId) { | |
34 | + public Map<String, Double> getSystemResults(String systemId, String metricId) { | |
36 | 35 | return systemId2result.get(systemId).getResults(metricId); |
37 | 36 | } |
38 | 37 | |
... | ... | @@ -42,17 +41,17 @@ public class ResultAccumulator { |
42 | 41 | |
43 | 42 | private class CumulativeResult { |
44 | 43 | Map<String, SummaryStatistics> metricId2scores = new HashMap<>(); |
45 | - Map<String, List<Double>> metricId2allScores = new HashMap<>(); | |
44 | + Map<String, Map<String, Double>> metricId2textId2scores = new HashMap<>(); | |
46 | 45 | |
47 | - public void addScore(String metricId, double score) { | |
46 | + public void addScore(String metricId, double score, String textId) { | |
48 | 47 | metricId2scores.putIfAbsent(metricId, new SummaryStatistics()); |
49 | 48 | metricId2scores.get(metricId).addValue(score); |
50 | - metricId2allScores.putIfAbsent(metricId, Lists.newArrayList()); | |
51 | - metricId2allScores.get(metricId).add(score); | |
49 | + metricId2textId2scores.putIfAbsent(metricId, Maps.newHashMap()); | |
50 | + metricId2textId2scores.get(metricId).put(textId, score); | |
52 | 51 | } |
53 | 52 | |
54 | - public double[] getResults(String metricId) { | |
55 | - return ArrayUtils.toPrimitive(metricId2allScores.get(metricId).toArray(new Double[0])); | |
53 | + public Map<String, Double> getResults(String metricId) { | |
54 | + return metricId2textId2scores.get(metricId); | |
56 | 55 | } |
57 | 56 | |
58 | 57 | public SummaryStatistics getSummaryStatistics(String metricId) { |
... | ... |
src/main/java/pl/waw/ipipan/zil/summ/eval/output/OutputHelper.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.eval.output; | |
2 | + | |
3 | +import java.util.HashMap; | |
4 | +import java.util.List; | |
5 | +import java.util.Map; | |
6 | +import java.util.Set; | |
7 | +import java.util.TreeSet; | |
8 | +import java.util.stream.Collectors; | |
9 | + | |
10 | +import org.apache.commons.lang3.ArrayUtils; | |
11 | +import org.apache.commons.math3.distribution.NormalDistribution; | |
12 | +import org.apache.commons.math3.stat.StatUtils; | |
13 | +import org.apache.commons.math3.stat.descriptive.SummaryStatistics; | |
14 | +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; | |
15 | +import org.apache.commons.math3.stat.inference.TTest; | |
16 | + | |
17 | +import pl.waw.ipipan.zil.summ.eval.ResultAccumulator; | |
18 | + | |
19 | +public class OutputHelper { | |
20 | + | |
21 | + public static void checkNormality(ResultAccumulator ra, double pValue) { | |
22 | + int normal = 0; | |
23 | + int all = 0; | |
24 | + for (String systemId : ra.getSystemIds()) | |
25 | + for (String metricId : ra.getMetricIds()) { | |
26 | + all++; | |
27 | + Map<String, Double> systemResults = ra.getSystemResults(systemId, metricId); | |
28 | + double[] values = ArrayUtils.toPrimitive(systemResults.values().toArray(new Double[0])); | |
29 | + double[] standardized = StatUtils.normalize(values); | |
30 | + if (standardized.length < 2) { | |
31 | + System.out.println("Skipping normality check because at least 2 samples are required."); | |
32 | + return; | |
33 | + } | |
34 | + double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1), | |
35 | + standardized); | |
36 | + if (p > pValue) | |
37 | + normal++; | |
38 | + else | |
39 | + System.out.println("System " + systemId + " for " + metricId | |
40 | + + " rejected K-S normality test with p-value = " + p); | |
41 | + } | |
42 | + | |
43 | + System.out.println(normal + " normal of all " + all); | |
44 | + } | |
45 | + | |
46 | + public static void printResults(ResultAccumulator ra, boolean latex) { | |
47 | + System.out.println("###################### Results #############################"); | |
48 | + | |
49 | + if (latex) | |
50 | + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); | |
51 | + | |
52 | + System.out.print("System"); | |
53 | + for (String metricId : ra.getMetricIds()) | |
54 | + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); | |
55 | + if (latex) | |
56 | + System.out.print("\\\\\\hline\\hline"); | |
57 | + System.out.println(); | |
58 | + | |
59 | + for (String systemId : ra.getSystemIds()) { | |
60 | + System.out.print(systemId); | |
61 | + for (String metricId : ra.getMetricIds()) { | |
62 | + SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId); | |
63 | + | |
64 | + System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(), | |
65 | + r.getStandardDeviation())); | |
66 | + } | |
67 | + if (latex) | |
68 | + System.out.print("\\\\\\hline"); | |
69 | + System.out.println(); | |
70 | + } | |
71 | + if (latex) | |
72 | + System.out.println("\\end{tabular}"); | |
73 | + } | |
74 | + | |
75 | + public static void printSignificance(ResultAccumulator ra, double pValue, boolean latex) { | |
76 | + if (ra.getSystemResults(ra.getSystemIds().iterator().next(), ra.getMetricIds().iterator().next()).values() | |
77 | + .size() < 2) { | |
78 | + System.out.println("Skipping significant difference check because at least 2 samples are required."); | |
79 | + return; | |
80 | + } | |
81 | + | |
82 | + System.out.println("###################### Stat. sign. #############################"); | |
83 | + int i = 1; | |
84 | + Map<String, Integer> sys2id = new HashMap<>(); | |
85 | + for (String systemId : ra.getSystemIds()) | |
86 | + sys2id.put(systemId, i++); | |
87 | + if (latex) | |
88 | + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); | |
89 | + System.out.print("System"); | |
90 | + for (String metricId : ra.getMetricIds()) | |
91 | + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); | |
92 | + if (latex) | |
93 | + System.out.print("\\\\\\hline\\hline"); | |
94 | + System.out.println(); | |
95 | + | |
96 | + for (String systemId : ra.getSystemIds()) { | |
97 | + System.out.print(sys2id.get(systemId) + ". " + systemId); | |
98 | + | |
99 | + for (String metricId : ra.getMetricIds()) { | |
100 | + Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId, pValue); | |
101 | + List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted() | |
102 | + .collect(Collectors.toList()); | |
103 | + | |
104 | + System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids))); | |
105 | + } | |
106 | + if (latex) | |
107 | + System.out.print("\\\\\\hline"); | |
108 | + System.out.println(); | |
109 | + } | |
110 | + if (latex) | |
111 | + System.out.println("\\end{tabular}"); | |
112 | + } | |
113 | + | |
114 | + public static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId, | |
115 | + double pValue) { | |
116 | + TreeSet<String> signifWorseSystems = new TreeSet<>(); | |
117 | + for (String systemId2 : ra.getSystemIds()) | |
118 | + if (!systemId.equals(systemId2)) { | |
119 | + | |
120 | + SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId); | |
121 | + SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId); | |
122 | + | |
123 | + if (scores1.getMean() <= scores2.getMean()) | |
124 | + continue; | |
125 | + | |
126 | + double p = new TTest().tTest(scores1, scores2) / 2; | |
127 | + | |
128 | + if (p < pValue) | |
129 | + signifWorseSystems.add(systemId2); | |
130 | + } | |
131 | + return signifWorseSystems; | |
132 | + } | |
133 | + | |
134 | +} | |
... | ... |