diff --git a/src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java b/src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java index da976ef..4ac8233 100644 --- a/src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java +++ b/src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java @@ -6,20 +6,14 @@ import java.nio.file.Files; import java.nio.file.Paths; import java.util.HashMap; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Set; -import java.util.TreeSet; import java.util.stream.Collectors; -import org.apache.commons.math3.distribution.NormalDistribution; -import org.apache.commons.math3.stat.StatUtils; -import org.apache.commons.math3.stat.descriptive.SummaryStatistics; -import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; -import org.apache.commons.math3.stat.inference.TTest; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import pl.waw.ipipan.zil.summ.eval.output.OutputHelper; import pl.waw.ipipan.zil.summ.eval.rouge.RougeN; public class Main { @@ -33,7 +27,7 @@ public class Main { public static void main(String[] args) { if (args.length != 2) { - LOG.error("Wrong number of arguments!"); + LOG.error("Wrong number of arguments! Try: Main goldDir sysDir"); return; } @@ -53,6 +47,21 @@ public class Main { } LOG.info(id2goldSummaryFiles.size() + " gold and sys text(s) loaded."); + ResultAccumulator ra; + try { + ra = countResults(id2goldSummaryFiles, id2sysSummaryFiles); + } catch (Exception e) { + LOG.error("Error counting results: " + e); + return; + } + + OutputHelper.checkNormality(ra, P_VALUE); + OutputHelper.printResults(ra, LATEX); + OutputHelper.printSignificance(ra, P_VALUE, LATEX); + } + + private static ResultAccumulator countResults(Map<String, Set<File>> id2goldSummaryFiles, + Map<String, Set<File>> id2sysSummaryFiles) throws Exception { ResultAccumulator ra = new ResultAccumulator(); int totalComparedTexts = 0; @@ -63,132 +72,27 @@ public class Main { Set<String> goldSummaries = goldSummaryFiles.stream().map(f -> loadSummaryFromFile(f)) .collect(Collectors.toSet()); if (goldSummaries.contains(null)) { - LOG.error("Empty gold summary for id: " + id); - return; + throw new Exception("Empty gold summary for id: " + id); } for (File sysSummaryFile : sysSummaryFiles) { String systemSummary = loadSummaryFromFile(sysSummaryFile); String systemId = getSystemIdFromFilename(sysSummaryFile); if (systemSummary == null) { - LOG.error("Empty system summary for id: " + id); - return; + throw new Exception("Empty system summary for id: " + id); } for (int i : ROUGE_N) { double score = RougeN.score(i, systemSummary, goldSummaries); - ra.logScore("ROUGE_" + i, systemId, score); + ra.logScore("ROUGE_" + i, systemId, score, id); double maxScore = RougeN.scoreMax(i, systemSummary, goldSummaries); - ra.logScore("ROUGE_M_" + i, systemId, maxScore); + ra.logScore("ROUGE_M_" + i, systemId, maxScore, id); } } totalComparedTexts++; } LOG.info(totalComparedTexts + " - total compared texts."); - - checkNormality(ra); - printResults(ra, LATEX); - printSignificance(ra, LATEX); - } - - private static void checkNormality(ResultAccumulator ra) { - int normal = 0; - int all = 0; - for (String systemId : ra.getSystemIds()) - for (String metricId : ra.getMetricIds()) { - all++; - double[] standardized = StatUtils.normalize(ra.getSystemResults(systemId, metricId)); - double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1), - standardized); - if (p > P_VALUE) - normal++; - else - LOG.info("System " + systemId + " for " + metricId + " rejected K-S normality test with p-value = " - + p); - } - - LOG.info(normal + " normal of all " + all); - } - - private static void printResults(ResultAccumulator ra, boolean latex) { - System.out.println("###################### Results #############################"); - - if (latex) - System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); - - System.out.print("System"); - for (String metricId : ra.getMetricIds()) - System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); - if (latex) - System.out.print("\\\\\\hline\\hline"); - System.out.println(); - - for (String systemId : ra.getSystemIds()) { - System.out.print(systemId); - for (String metricId : ra.getMetricIds()) { - SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId); - - System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(), - r.getStandardDeviation())); - } - if (latex) - System.out.print("\\\\\\hline"); - System.out.println(); - } - if (latex) - System.out.println("\\end{tabular}"); - } - - private static void printSignificance(ResultAccumulator ra, boolean latex) { - System.out.println("###################### Stat. sign. #############################"); - int i = 1; - Map<String, Integer> sys2id = new HashMap<>(); - for (String systemId : ra.getSystemIds()) - sys2id.put(systemId, i++); - if (latex) - System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); - System.out.print("System"); - for (String metricId : ra.getMetricIds()) - System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); - if (latex) - System.out.print("\\\\\\hline\\hline"); - System.out.println(); - - for (String systemId : ra.getSystemIds()) { - System.out.print(sys2id.get(systemId) + ". " + systemId); - for (String metricId : ra.getMetricIds()) { - Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId); - List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted() - .collect(Collectors.toList()); - - System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids))); - } - if (latex) - System.out.print("\\\\\\hline"); - System.out.println(); - } - if (latex) - System.out.println("\\end{tabular}"); - } - - private static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId) { - TreeSet<String> signifWorseSystems = new TreeSet<>(); - for (String systemId2 : ra.getSystemIds()) - if (!systemId.equals(systemId2)) { - - SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId); - SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId); - - if (scores1.getMean() <= scores2.getMean()) - continue; - - double p = new TTest().tTest(scores1, scores2) / 2; - - if (p < P_VALUE) - signifWorseSystems.add(systemId2); - - } - return signifWorseSystems; + return ra; } private static String getSystemIdFromFilename(File sysSummaryFile) { diff --git a/src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java b/src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java index 0b37628..08d71b1 100644 --- a/src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java +++ b/src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java @@ -7,20 +7,19 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.math3.stat.descriptive.SummaryStatistics; -import com.google.common.collect.Lists; +import com.google.common.collect.Maps; public class ResultAccumulator { private Map<String, CumulativeResult> systemId2result = new HashMap<>(); private Set<String> metricIds = new HashSet<>(); - public void logScore(String metricId, String systemId, double score) { + public void logScore(String metricId, String systemId, double score, String textId) { systemId2result.putIfAbsent(systemId, new CumulativeResult()); CumulativeResult cr = systemId2result.get(systemId); - cr.addScore(metricId, score); + cr.addScore(metricId, score, textId); metricIds.add(metricId); } @@ -32,7 +31,7 @@ public class ResultAccumulator { return metricIds.stream().sorted().collect(Collectors.toList()); } - public double[] getSystemResults(String systemId, String metricId) { + public Map<String, Double> getSystemResults(String systemId, String metricId) { return systemId2result.get(systemId).getResults(metricId); } @@ -42,17 +41,17 @@ public class ResultAccumulator { private class CumulativeResult { Map<String, SummaryStatistics> metricId2scores = new HashMap<>(); - Map<String, List<Double>> metricId2allScores = new HashMap<>(); + Map<String, Map<String, Double>> metricId2textId2scores = new HashMap<>(); - public void addScore(String metricId, double score) { + public void addScore(String metricId, double score, String textId) { metricId2scores.putIfAbsent(metricId, new SummaryStatistics()); metricId2scores.get(metricId).addValue(score); - metricId2allScores.putIfAbsent(metricId, Lists.newArrayList()); - metricId2allScores.get(metricId).add(score); + metricId2textId2scores.putIfAbsent(metricId, Maps.newHashMap()); + metricId2textId2scores.get(metricId).put(textId, score); } - public double[] getResults(String metricId) { - return ArrayUtils.toPrimitive(metricId2allScores.get(metricId).toArray(new Double[0])); + public Map<String, Double> getResults(String metricId) { + return metricId2textId2scores.get(metricId); } public SummaryStatistics getSummaryStatistics(String metricId) { diff --git a/src/main/java/pl/waw/ipipan/zil/summ/eval/output/OutputHelper.java b/src/main/java/pl/waw/ipipan/zil/summ/eval/output/OutputHelper.java new file mode 100644 index 0000000..726af9d --- /dev/null +++ b/src/main/java/pl/waw/ipipan/zil/summ/eval/output/OutputHelper.java @@ -0,0 +1,134 @@ +package pl.waw.ipipan.zil.summ.eval.output; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeSet; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.math3.distribution.NormalDistribution; +import org.apache.commons.math3.stat.StatUtils; +import org.apache.commons.math3.stat.descriptive.SummaryStatistics; +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; +import org.apache.commons.math3.stat.inference.TTest; + +import pl.waw.ipipan.zil.summ.eval.ResultAccumulator; + +public class OutputHelper { + + public static void checkNormality(ResultAccumulator ra, double pValue) { + int normal = 0; + int all = 0; + for (String systemId : ra.getSystemIds()) + for (String metricId : ra.getMetricIds()) { + all++; + Map<String, Double> systemResults = ra.getSystemResults(systemId, metricId); + double[] values = ArrayUtils.toPrimitive(systemResults.values().toArray(new Double[0])); + double[] standardized = StatUtils.normalize(values); + if (standardized.length < 2) { + System.out.println("Skipping normality check because at least 2 samples are required."); + return; + } + double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1), + standardized); + if (p > pValue) + normal++; + else + System.out.println("System " + systemId + " for " + metricId + + " rejected K-S normality test with p-value = " + p); + } + + System.out.println(normal + " normal of all " + all); + } + + public static void printResults(ResultAccumulator ra, boolean latex) { + System.out.println("###################### Results #############################"); + + if (latex) + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); + + System.out.print("System"); + for (String metricId : ra.getMetricIds()) + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); + if (latex) + System.out.print("\\\\\\hline\\hline"); + System.out.println(); + + for (String systemId : ra.getSystemIds()) { + System.out.print(systemId); + for (String metricId : ra.getMetricIds()) { + SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId); + + System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(), + r.getStandardDeviation())); + } + if (latex) + System.out.print("\\\\\\hline"); + System.out.println(); + } + if (latex) + System.out.println("\\end{tabular}"); + } + + public static void printSignificance(ResultAccumulator ra, double pValue, boolean latex) { + if (ra.getSystemResults(ra.getSystemIds().iterator().next(), ra.getMetricIds().iterator().next()).values() + .size() < 2) { + System.out.println("Skipping significant difference check because at least 2 samples are required."); + return; + } + + System.out.println("###################### Stat. sign. #############################"); + int i = 1; + Map<String, Integer> sys2id = new HashMap<>(); + for (String systemId : ra.getSystemIds()) + sys2id.put(systemId, i++); + if (latex) + System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}"); + System.out.print("System"); + for (String metricId : ra.getMetricIds()) + System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-")); + if (latex) + System.out.print("\\\\\\hline\\hline"); + System.out.println(); + + for (String systemId : ra.getSystemIds()) { + System.out.print(sys2id.get(systemId) + ". " + systemId); + + for (String metricId : ra.getMetricIds()) { + Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId, pValue); + List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted() + .collect(Collectors.toList()); + + System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids))); + } + if (latex) + System.out.print("\\\\\\hline"); + System.out.println(); + } + if (latex) + System.out.println("\\end{tabular}"); + } + + public static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId, + double pValue) { + TreeSet<String> signifWorseSystems = new TreeSet<>(); + for (String systemId2 : ra.getSystemIds()) + if (!systemId.equals(systemId2)) { + + SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId); + SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId); + + if (scores1.getMean() <= scores2.getMean()) + continue; + + double p = new TTest().tTest(scores1, scores2) / 2; + + if (p < pValue) + signifWorseSystems.add(systemId2); + } + return signifWorseSystems; + } + +}