Commit 595d6aa646dc3dde01d92d775e7d465b8709f7a6

Authored by Mateusz Kopeć
1 parent 2692e607

calculating mean and sd scores

... ... @@ -34,8 +34,13 @@
34 34 <artifactId>guava</artifactId>
35 35 <version>18.0</version>
36 36 </dependency>
  37 + <dependency>
  38 + <groupId>org.apache.commons</groupId>
  39 + <artifactId>commons-math3</artifactId>
  40 + <version>3.5</version>
  41 + </dependency>
37 42 </dependencies>
38   -
  43 +
39 44 <build>
40 45 <plugins>
41 46 <plugin>
... ...
src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java
... ... @@ -16,8 +16,11 @@ import org.apache.logging.log4j.Logger;
16 16 import pl.waw.ipipan.zil.summ.eval.rouge.RougeN;
17 17  
18 18 public class Main {
  19 +
19 20 private static final Logger LOG = LogManager.getLogger(Main.class);
20 21  
  22 + private static final int[] ROUGE_N = new int[] { 1, 2, 3 };
  23 +
21 24 public static void main(String[] args) {
22 25 if (args.length != 2) {
23 26 LOG.error("Wrong number of arguments!");
... ... @@ -35,31 +38,57 @@ public class Main {
35 38 Map<String, Set<File>> id2sysSummaryFiles = loadFiles(sysDir);
36 39 if (!id2goldSummaryFiles.keySet().equals(id2sysSummaryFiles.keySet())) {
37 40 LOG.error("Different text ids in gold and sys dirs!");
  41 + LOG.error(id2goldSummaryFiles.size() + " gold ids, " + id2sysSummaryFiles.size() + " sys ids.");
38 42 return;
39 43 }
40 44 LOG.info(id2goldSummaryFiles.size() + " gold and sys text(s) loaded.");
41 45  
  46 + ResultAccumulator ra = new ResultAccumulator();
  47 + int totalGoldSummaries = 0;
  48 + int totalComparedTexts = 0;
  49 +
42 50 for (String id : id2goldSummaryFiles.keySet()) {
43 51 Set<File> goldSummaryFiles = id2goldSummaryFiles.get(id);
44 52 Set<File> sysSummaryFiles = id2sysSummaryFiles.get(id);
45 53  
  54 + Set<String> goldSummaries = goldSummaryFiles.stream().map(f -> loadSummaryFromFile(f))
  55 + .collect(Collectors.toSet());
  56 + if (goldSummaries.contains(null)) {
  57 + LOG.error("Empty gold summary for id: " + id);
  58 + return;
  59 + }
46 60 for (File sysSummaryFile : sysSummaryFiles) {
47 61 String systemSummary = loadSummaryFromFile(sysSummaryFile);
48 62 String systemId = getSystemIdFromFilename(sysSummaryFile);
49   - Set<String> goldSummaries = goldSummaryFiles.stream().map(f -> loadSummaryFromFile(f))
50   - .collect(Collectors.toSet());
51   - if (goldSummaries.contains(null) || systemSummary == null) {
  63 + if (systemSummary == null) {
  64 + LOG.error("Empty system summary for id: " + id);
52 65 return;
53 66 }
54   -
55   - for (int i = 1; i < 4; i++) {
  67 + for (int i : ROUGE_N) {
56 68 double score = RougeN.score(i, systemSummary, goldSummaries);
57   - System.out.println(i + "\t" + systemId + "\t" + score);
58   - System.out.println();
  69 + ra.logScore("ROUGE_" + i, systemId, score);
59 70 }
60 71 }
  72 +
  73 + totalGoldSummaries += goldSummaries.size();
  74 + totalComparedTexts++;
61 75 }
62 76  
  77 + LOG.info(totalComparedTexts + " - total compared texts.");
  78 + LOG.info(totalGoldSummaries * 1.0 / totalComparedTexts + " - average gold summaries per text.");
  79 +
  80 + printResults(ra);
  81 + }
  82 +
  83 + private static void printResults(ResultAccumulator ra) {
  84 + for (String systemId : ra.getSystemIds()) {
  85 + System.out.println("############## " + systemId);
  86 + for (String metricId : ra.getMetricIds()) {
  87 + Result r = ra.getSystemResult(systemId, metricId);
  88 + System.out.println(String.format("\t%s\t%.2f\t(+/- %.2f)", metricId, r.getMean(), r.getSD()));
  89 + }
  90 + System.out.println();
  91 + }
63 92 }
64 93  
65 94 private static String getSystemIdFromFilename(File sysSummaryFile) {
... ...
src/main/java/pl/waw/ipipan/zil/summ/eval/Result.java 0 → 100644
  1 +package pl.waw.ipipan.zil.summ.eval;
  2 +
  3 +public class Result {
  4 +
  5 + private double sd;
  6 + private double mean;
  7 +
  8 + public Result(double mean, double sd) {
  9 + this.mean = mean;
  10 + this.sd = sd;
  11 + }
  12 +
  13 + public double getMean() {
  14 + return mean;
  15 + }
  16 +
  17 + public double getSD() {
  18 + return sd;
  19 + }
  20 +
  21 +}
... ...
src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java 0 → 100644
  1 +package pl.waw.ipipan.zil.summ.eval;
  2 +
  3 +import java.util.HashMap;
  4 +import java.util.HashSet;
  5 +import java.util.List;
  6 +import java.util.Map;
  7 +import java.util.Set;
  8 +import java.util.stream.Collectors;
  9 +
  10 +import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
  11 +
  12 +public class ResultAccumulator {
  13 +
  14 + private Map<String, CumulativeResult> systemId2result = new HashMap<>();
  15 + private Set<String> metricIds = new HashSet<>();
  16 +
  17 + public void logScore(String metricId, String systemId, double score) {
  18 + systemId2result.putIfAbsent(systemId, new CumulativeResult());
  19 + CumulativeResult cr = systemId2result.get(systemId);
  20 + cr.addScore(metricId, score);
  21 + metricIds.add(metricId);
  22 + }
  23 +
  24 + public List<String> getSystemIds() {
  25 + return systemId2result.keySet().stream().sorted().collect(Collectors.toList());
  26 + }
  27 +
  28 + public List<String> getMetricIds() {
  29 + return metricIds.stream().sorted().collect(Collectors.toList());
  30 + }
  31 +
  32 + public Result getSystemResult(String systemId, String metricId) {
  33 + return systemId2result.get(systemId).getResult(metricId);
  34 + }
  35 +
  36 + private class CumulativeResult {
  37 + Map<String, SummaryStatistics> metricId2scores = new HashMap<>();
  38 +
  39 + public void addScore(String metricId, double score) {
  40 + metricId2scores.putIfAbsent(metricId, new SummaryStatistics());
  41 + metricId2scores.get(metricId).addValue(score);
  42 + }
  43 +
  44 + public Result getResult(String metricId) {
  45 + SummaryStatistics scores = metricId2scores.get(metricId);
  46 + return new Result(scores.getMean(), scores.getStandardDeviation());
  47 + }
  48 + }
  49 +}
... ...
src/main/resources/log4j2.xml
... ... @@ -2,12 +2,16 @@
2 2 <Configuration status="WARN">
3 3 <Appenders>
4 4 <Console name="Console" target="SYSTEM_OUT">
5   - <PatternLayout pattern="%d{HH:mm:ss.SSS} [%t] %-5level %logger{36} - %msg%n" />
  5 + <PatternLayout pattern="%d{HH:mm:ss.SSS} [%t] %-5level %logger{3} - %msg%n" />
6 6 </Console>
7 7 </Appenders>
8 8 <Loggers>
9 9 <Root level="error">
10 10 <AppenderRef ref="Console" />
11 11 </Root>
  12 + <Logger name="pl.waw.ipipan.zil.summ.eval" level="debug"
  13 + additivity="false">
  14 + <AppenderRef ref="Console" />
  15 + </Logger>
12 16 </Loggers>
13 17 </Configuration>
14 18 \ No newline at end of file
... ...