From c6d4dcbe7d1026f1268c05672b892611a729a291 Mon Sep 17 00:00:00 2001
From: Mateusz Kopeć <m.kopec@ipipan.waw.pl>
Date: Sat, 24 Oct 2015 09:39:43 +0200
Subject: [PATCH] bugfixes and refactor

---
 src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java                | 140 ++++++++++++++++++++++----------------------------------------------------------------------------------------------------------------------
 src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java   |  21 ++++++++++-----------
 src/main/java/pl/waw/ipipan/zil/summ/eval/output/OutputHelper.java | 134 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 166 insertions(+), 129 deletions(-)
 create mode 100644 src/main/java/pl/waw/ipipan/zil/summ/eval/output/OutputHelper.java

diff --git a/src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java b/src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java
index da976ef..4ac8233 100644
--- a/src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java
+++ b/src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java
@@ -6,20 +6,14 @@ import java.nio.file.Files;
 import java.nio.file.Paths;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.TreeSet;
 import java.util.stream.Collectors;
 
-import org.apache.commons.math3.distribution.NormalDistribution;
-import org.apache.commons.math3.stat.StatUtils;
-import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
-import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
-import org.apache.commons.math3.stat.inference.TTest;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 
+import pl.waw.ipipan.zil.summ.eval.output.OutputHelper;
 import pl.waw.ipipan.zil.summ.eval.rouge.RougeN;
 
 public class Main {
@@ -33,7 +27,7 @@ public class Main {
 
 	public static void main(String[] args) {
 		if (args.length != 2) {
-			LOG.error("Wrong number of arguments!");
+			LOG.error("Wrong number of arguments! Try: Main goldDir sysDir");
 			return;
 		}
 
@@ -53,6 +47,21 @@ public class Main {
 		}
 		LOG.info(id2goldSummaryFiles.size() + " gold and sys text(s) loaded.");
 
+		ResultAccumulator ra;
+		try {
+			ra = countResults(id2goldSummaryFiles, id2sysSummaryFiles);
+		} catch (Exception e) {
+			LOG.error("Error counting results: " + e);
+			return;
+		}
+
+		OutputHelper.checkNormality(ra, P_VALUE);
+		OutputHelper.printResults(ra, LATEX);
+		OutputHelper.printSignificance(ra, P_VALUE, LATEX);
+	}
+
+	private static ResultAccumulator countResults(Map<String, Set<File>> id2goldSummaryFiles,
+			Map<String, Set<File>> id2sysSummaryFiles) throws Exception {
 		ResultAccumulator ra = new ResultAccumulator();
 		int totalComparedTexts = 0;
 
@@ -63,132 +72,27 @@ public class Main {
 			Set<String> goldSummaries = goldSummaryFiles.stream().map(f -> loadSummaryFromFile(f))
 					.collect(Collectors.toSet());
 			if (goldSummaries.contains(null)) {
-				LOG.error("Empty gold summary for id: " + id);
-				return;
+				throw new Exception("Empty gold summary for id: " + id);
 			}
 			for (File sysSummaryFile : sysSummaryFiles) {
 				String systemSummary = loadSummaryFromFile(sysSummaryFile);
 				String systemId = getSystemIdFromFilename(sysSummaryFile);
 				if (systemSummary == null) {
-					LOG.error("Empty system summary for id: " + id);
-					return;
+					throw new Exception("Empty system summary for id: " + id);
 				}
 				for (int i : ROUGE_N) {
 					double score = RougeN.score(i, systemSummary, goldSummaries);
-					ra.logScore("ROUGE_" + i, systemId, score);
+					ra.logScore("ROUGE_" + i, systemId, score, id);
 
 					double maxScore = RougeN.scoreMax(i, systemSummary, goldSummaries);
-					ra.logScore("ROUGE_M_" + i, systemId, maxScore);
+					ra.logScore("ROUGE_M_" + i, systemId, maxScore, id);
 				}
 			}
 			totalComparedTexts++;
 		}
 
 		LOG.info(totalComparedTexts + " - total compared texts.");
-
-		checkNormality(ra);
-		printResults(ra, LATEX);
-		printSignificance(ra, LATEX);
-	}
-
-	private static void checkNormality(ResultAccumulator ra) {
-		int normal = 0;
-		int all = 0;
-		for (String systemId : ra.getSystemIds())
-			for (String metricId : ra.getMetricIds()) {
-				all++;
-				double[] standardized = StatUtils.normalize(ra.getSystemResults(systemId, metricId));
-				double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1),
-						standardized);
-				if (p > P_VALUE)
-					normal++;
-				else
-					LOG.info("System " + systemId + " for " + metricId + " rejected K-S normality test with p-value = "
-							+ p);
-			}
-
-		LOG.info(normal + " normal of all " + all);
-	}
-
-	private static void printResults(ResultAccumulator ra, boolean latex) {
-		System.out.println("###################### Results #############################");
-
-		if (latex)
-			System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");
-
-		System.out.print("System");
-		for (String metricId : ra.getMetricIds())
-			System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));
-		if (latex)
-			System.out.print("\\\\\\hline\\hline");
-		System.out.println();
-
-		for (String systemId : ra.getSystemIds()) {
-			System.out.print(systemId);
-			for (String metricId : ra.getMetricIds()) {
-				SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId);
-
-				System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(),
-						r.getStandardDeviation()));
-			}
-			if (latex)
-				System.out.print("\\\\\\hline");
-			System.out.println();
-		}
-		if (latex)
-			System.out.println("\\end{tabular}");
-	}
-
-	private static void printSignificance(ResultAccumulator ra, boolean latex) {
-		System.out.println("###################### Stat. sign. #############################");
-		int i = 1;
-		Map<String, Integer> sys2id = new HashMap<>();
-		for (String systemId : ra.getSystemIds())
-			sys2id.put(systemId, i++);
-		if (latex)
-			System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");
-		System.out.print("System");
-		for (String metricId : ra.getMetricIds())
-			System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));
-		if (latex)
-			System.out.print("\\\\\\hline\\hline");
-		System.out.println();
-
-		for (String systemId : ra.getSystemIds()) {
-			System.out.print(sys2id.get(systemId) + ". " + systemId);
-			for (String metricId : ra.getMetricIds()) {
-				Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId);
-				List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted()
-						.collect(Collectors.toList());
-
-				System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids)));
-			}
-			if (latex)
-				System.out.print("\\\\\\hline");
-			System.out.println();
-		}
-		if (latex)
-			System.out.println("\\end{tabular}");
-	}
-
-	private static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId) {
-		TreeSet<String> signifWorseSystems = new TreeSet<>();
-		for (String systemId2 : ra.getSystemIds())
-			if (!systemId.equals(systemId2)) {
-
-				SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId);
-				SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId);
-
-				if (scores1.getMean() <= scores2.getMean())
-					continue;
-
-				double p = new TTest().tTest(scores1, scores2) / 2;
-
-				if (p < P_VALUE)
-					signifWorseSystems.add(systemId2);
-
-			}
-		return signifWorseSystems;
+		return ra;
 	}
 
 	private static String getSystemIdFromFilename(File sysSummaryFile) {
diff --git a/src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java b/src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java
index 0b37628..08d71b1 100644
--- a/src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java
+++ b/src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java
@@ -7,20 +7,19 @@ import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
 
-import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
 
-import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
 
 public class ResultAccumulator {
 
 	private Map<String, CumulativeResult> systemId2result = new HashMap<>();
 	private Set<String> metricIds = new HashSet<>();
 
-	public void logScore(String metricId, String systemId, double score) {
+	public void logScore(String metricId, String systemId, double score, String textId) {
 		systemId2result.putIfAbsent(systemId, new CumulativeResult());
 		CumulativeResult cr = systemId2result.get(systemId);
-		cr.addScore(metricId, score);
+		cr.addScore(metricId, score, textId);
 		metricIds.add(metricId);
 	}
 
@@ -32,7 +31,7 @@ public class ResultAccumulator {
 		return metricIds.stream().sorted().collect(Collectors.toList());
 	}
 
-	public double[] getSystemResults(String systemId, String metricId) {
+	public Map<String, Double> getSystemResults(String systemId, String metricId) {
 		return systemId2result.get(systemId).getResults(metricId);
 	}
 
@@ -42,17 +41,17 @@ public class ResultAccumulator {
 
 	private class CumulativeResult {
 		Map<String, SummaryStatistics> metricId2scores = new HashMap<>();
-		Map<String, List<Double>> metricId2allScores = new HashMap<>();
+		Map<String, Map<String, Double>> metricId2textId2scores = new HashMap<>();
 
-		public void addScore(String metricId, double score) {
+		public void addScore(String metricId, double score, String textId) {
 			metricId2scores.putIfAbsent(metricId, new SummaryStatistics());
 			metricId2scores.get(metricId).addValue(score);
-			metricId2allScores.putIfAbsent(metricId, Lists.newArrayList());
-			metricId2allScores.get(metricId).add(score);
+			metricId2textId2scores.putIfAbsent(metricId, Maps.newHashMap());
+			metricId2textId2scores.get(metricId).put(textId, score);
 		}
 
-		public double[] getResults(String metricId) {
-			return ArrayUtils.toPrimitive(metricId2allScores.get(metricId).toArray(new Double[0]));
+		public Map<String, Double> getResults(String metricId) {
+			return metricId2textId2scores.get(metricId);
 		}
 
 		public SummaryStatistics getSummaryStatistics(String metricId) {
diff --git a/src/main/java/pl/waw/ipipan/zil/summ/eval/output/OutputHelper.java b/src/main/java/pl/waw/ipipan/zil/summ/eval/output/OutputHelper.java
new file mode 100644
index 0000000..726af9d
--- /dev/null
+++ b/src/main/java/pl/waw/ipipan/zil/summ/eval/output/OutputHelper.java
@@ -0,0 +1,134 @@
+package pl.waw.ipipan.zil.summ.eval.output;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeSet;
+import java.util.stream.Collectors;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.math3.distribution.NormalDistribution;
+import org.apache.commons.math3.stat.StatUtils;
+import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
+import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
+import org.apache.commons.math3.stat.inference.TTest;
+
+import pl.waw.ipipan.zil.summ.eval.ResultAccumulator;
+
+public class OutputHelper {
+
+	public static void checkNormality(ResultAccumulator ra, double pValue) {
+		int normal = 0;
+		int all = 0;
+		for (String systemId : ra.getSystemIds())
+			for (String metricId : ra.getMetricIds()) {
+				all++;
+				Map<String, Double> systemResults = ra.getSystemResults(systemId, metricId);
+				double[] values = ArrayUtils.toPrimitive(systemResults.values().toArray(new Double[0]));
+				double[] standardized = StatUtils.normalize(values);
+				if (standardized.length < 2) {
+					System.out.println("Skipping normality check because at least 2 samples are required.");
+					return;
+				}
+				double p = new KolmogorovSmirnovTest().kolmogorovSmirnovTest(new NormalDistribution(0, 1),
+						standardized);
+				if (p > pValue)
+					normal++;
+				else
+					System.out.println("System " + systemId + " for " + metricId
+							+ " rejected K-S normality test with p-value = " + p);
+			}
+
+		System.out.println(normal + " normal of all " + all);
+	}
+
+	public static void printResults(ResultAccumulator ra, boolean latex) {
+		System.out.println("###################### Results #############################");
+
+		if (latex)
+			System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");
+
+		System.out.print("System");
+		for (String metricId : ra.getMetricIds())
+			System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));
+		if (latex)
+			System.out.print("\\\\\\hline\\hline");
+		System.out.println();
+
+		for (String systemId : ra.getSystemIds()) {
+			System.out.print(systemId);
+			for (String metricId : ra.getMetricIds()) {
+				SummaryStatistics r = ra.getSystemSummaryStatistics(systemId, metricId);
+
+				System.out.print(String.format("\t" + (latex ? "&" : "") + " %.3f (%.3f)", r.getMean(),
+						r.getStandardDeviation()));
+			}
+			if (latex)
+				System.out.print("\\\\\\hline");
+			System.out.println();
+		}
+		if (latex)
+			System.out.println("\\end{tabular}");
+	}
+
+	public static void printSignificance(ResultAccumulator ra, double pValue, boolean latex) {
+		if (ra.getSystemResults(ra.getSystemIds().iterator().next(), ra.getMetricIds().iterator().next()).values()
+				.size() < 2) {
+			System.out.println("Skipping significant difference check because at least 2 samples are required.");
+			return;
+		}
+
+		System.out.println("###################### Stat. sign. #############################");
+		int i = 1;
+		Map<String, Integer> sys2id = new HashMap<>();
+		for (String systemId : ra.getSystemIds())
+			sys2id.put(systemId, i++);
+		if (latex)
+			System.out.println("\\begin{tabular}{l|l|l|l|l|l|l}");
+		System.out.print("System");
+		for (String metricId : ra.getMetricIds())
+			System.out.print("\t" + (latex ? "&" : "") + metricId.replaceAll("_", "-"));
+		if (latex)
+			System.out.print("\\\\\\hline\\hline");
+		System.out.println();
+
+		for (String systemId : ra.getSystemIds()) {
+			System.out.print(sys2id.get(systemId) + ". " + systemId);
+
+			for (String metricId : ra.getMetricIds()) {
+				Set<String> signifWorseSystems = findSignifWorseSystems(ra, systemId, metricId, pValue);
+				List<String> ids = signifWorseSystems.stream().map(s -> sys2id.get(s) + "").sorted()
+						.collect(Collectors.toList());
+
+				System.out.print(String.format("\t" + (latex ? "&" : "") + " %s", String.join(", ", ids)));
+			}
+			if (latex)
+				System.out.print("\\\\\\hline");
+			System.out.println();
+		}
+		if (latex)
+			System.out.println("\\end{tabular}");
+	}
+
+	public static Set<String> findSignifWorseSystems(ResultAccumulator ra, String systemId, String metricId,
+			double pValue) {
+		TreeSet<String> signifWorseSystems = new TreeSet<>();
+		for (String systemId2 : ra.getSystemIds())
+			if (!systemId.equals(systemId2)) {
+
+				SummaryStatistics scores1 = ra.getSystemSummaryStatistics(systemId, metricId);
+				SummaryStatistics scores2 = ra.getSystemSummaryStatistics(systemId2, metricId);
+
+				if (scores1.getMean() <= scores2.getMean())
+					continue;
+
+				double p = new TTest().tTest(scores1, scores2) / 2;
+
+				if (p < pValue)
+					signifWorseSystems.add(systemId2);
+			}
+		return signifWorseSystems;
+	}
+
+}
--
libgit2 0.22.2