Commit 595d6aa646dc3dde01d92d775e7d465b8709f7a6
1 parent
2692e607
calculating mean and sd scores
Showing
5 changed files
with
117 additions
and
9 deletions
pom.xml
... | ... | @@ -34,8 +34,13 @@ |
34 | 34 | <artifactId>guava</artifactId> |
35 | 35 | <version>18.0</version> |
36 | 36 | </dependency> |
37 | + <dependency> | |
38 | + <groupId>org.apache.commons</groupId> | |
39 | + <artifactId>commons-math3</artifactId> | |
40 | + <version>3.5</version> | |
41 | + </dependency> | |
37 | 42 | </dependencies> |
38 | - | |
43 | + | |
39 | 44 | <build> |
40 | 45 | <plugins> |
41 | 46 | <plugin> |
... | ... |
src/main/java/pl/waw/ipipan/zil/summ/eval/Main.java
... | ... | @@ -16,8 +16,11 @@ import org.apache.logging.log4j.Logger; |
16 | 16 | import pl.waw.ipipan.zil.summ.eval.rouge.RougeN; |
17 | 17 | |
18 | 18 | public class Main { |
19 | + | |
19 | 20 | private static final Logger LOG = LogManager.getLogger(Main.class); |
20 | 21 | |
22 | + private static final int[] ROUGE_N = new int[] { 1, 2, 3 }; | |
23 | + | |
21 | 24 | public static void main(String[] args) { |
22 | 25 | if (args.length != 2) { |
23 | 26 | LOG.error("Wrong number of arguments!"); |
... | ... | @@ -35,31 +38,57 @@ public class Main { |
35 | 38 | Map<String, Set<File>> id2sysSummaryFiles = loadFiles(sysDir); |
36 | 39 | if (!id2goldSummaryFiles.keySet().equals(id2sysSummaryFiles.keySet())) { |
37 | 40 | LOG.error("Different text ids in gold and sys dirs!"); |
41 | + LOG.error(id2goldSummaryFiles.size() + " gold ids, " + id2sysSummaryFiles.size() + " sys ids."); | |
38 | 42 | return; |
39 | 43 | } |
40 | 44 | LOG.info(id2goldSummaryFiles.size() + " gold and sys text(s) loaded."); |
41 | 45 | |
46 | + ResultAccumulator ra = new ResultAccumulator(); | |
47 | + int totalGoldSummaries = 0; | |
48 | + int totalComparedTexts = 0; | |
49 | + | |
42 | 50 | for (String id : id2goldSummaryFiles.keySet()) { |
43 | 51 | Set<File> goldSummaryFiles = id2goldSummaryFiles.get(id); |
44 | 52 | Set<File> sysSummaryFiles = id2sysSummaryFiles.get(id); |
45 | 53 | |
54 | + Set<String> goldSummaries = goldSummaryFiles.stream().map(f -> loadSummaryFromFile(f)) | |
55 | + .collect(Collectors.toSet()); | |
56 | + if (goldSummaries.contains(null)) { | |
57 | + LOG.error("Empty gold summary for id: " + id); | |
58 | + return; | |
59 | + } | |
46 | 60 | for (File sysSummaryFile : sysSummaryFiles) { |
47 | 61 | String systemSummary = loadSummaryFromFile(sysSummaryFile); |
48 | 62 | String systemId = getSystemIdFromFilename(sysSummaryFile); |
49 | - Set<String> goldSummaries = goldSummaryFiles.stream().map(f -> loadSummaryFromFile(f)) | |
50 | - .collect(Collectors.toSet()); | |
51 | - if (goldSummaries.contains(null) || systemSummary == null) { | |
63 | + if (systemSummary == null) { | |
64 | + LOG.error("Empty system summary for id: " + id); | |
52 | 65 | return; |
53 | 66 | } |
54 | - | |
55 | - for (int i = 1; i < 4; i++) { | |
67 | + for (int i : ROUGE_N) { | |
56 | 68 | double score = RougeN.score(i, systemSummary, goldSummaries); |
57 | - System.out.println(i + "\t" + systemId + "\t" + score); | |
58 | - System.out.println(); | |
69 | + ra.logScore("ROUGE_" + i, systemId, score); | |
59 | 70 | } |
60 | 71 | } |
72 | + | |
73 | + totalGoldSummaries += goldSummaries.size(); | |
74 | + totalComparedTexts++; | |
61 | 75 | } |
62 | 76 | |
77 | + LOG.info(totalComparedTexts + " - total compared texts."); | |
78 | + LOG.info(totalGoldSummaries * 1.0 / totalComparedTexts + " - average gold summaries per text."); | |
79 | + | |
80 | + printResults(ra); | |
81 | + } | |
82 | + | |
83 | + private static void printResults(ResultAccumulator ra) { | |
84 | + for (String systemId : ra.getSystemIds()) { | |
85 | + System.out.println("############## " + systemId); | |
86 | + for (String metricId : ra.getMetricIds()) { | |
87 | + Result r = ra.getSystemResult(systemId, metricId); | |
88 | + System.out.println(String.format("\t%s\t%.2f\t(+/- %.2f)", metricId, r.getMean(), r.getSD())); | |
89 | + } | |
90 | + System.out.println(); | |
91 | + } | |
63 | 92 | } |
64 | 93 | |
65 | 94 | private static String getSystemIdFromFilename(File sysSummaryFile) { |
... | ... |
src/main/java/pl/waw/ipipan/zil/summ/eval/Result.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.eval; | |
2 | + | |
3 | +public class Result { | |
4 | + | |
5 | + private double sd; | |
6 | + private double mean; | |
7 | + | |
8 | + public Result(double mean, double sd) { | |
9 | + this.mean = mean; | |
10 | + this.sd = sd; | |
11 | + } | |
12 | + | |
13 | + public double getMean() { | |
14 | + return mean; | |
15 | + } | |
16 | + | |
17 | + public double getSD() { | |
18 | + return sd; | |
19 | + } | |
20 | + | |
21 | +} | |
... | ... |
src/main/java/pl/waw/ipipan/zil/summ/eval/ResultAccumulator.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.eval; | |
2 | + | |
3 | +import java.util.HashMap; | |
4 | +import java.util.HashSet; | |
5 | +import java.util.List; | |
6 | +import java.util.Map; | |
7 | +import java.util.Set; | |
8 | +import java.util.stream.Collectors; | |
9 | + | |
10 | +import org.apache.commons.math3.stat.descriptive.SummaryStatistics; | |
11 | + | |
12 | +public class ResultAccumulator { | |
13 | + | |
14 | + private Map<String, CumulativeResult> systemId2result = new HashMap<>(); | |
15 | + private Set<String> metricIds = new HashSet<>(); | |
16 | + | |
17 | + public void logScore(String metricId, String systemId, double score) { | |
18 | + systemId2result.putIfAbsent(systemId, new CumulativeResult()); | |
19 | + CumulativeResult cr = systemId2result.get(systemId); | |
20 | + cr.addScore(metricId, score); | |
21 | + metricIds.add(metricId); | |
22 | + } | |
23 | + | |
24 | + public List<String> getSystemIds() { | |
25 | + return systemId2result.keySet().stream().sorted().collect(Collectors.toList()); | |
26 | + } | |
27 | + | |
28 | + public List<String> getMetricIds() { | |
29 | + return metricIds.stream().sorted().collect(Collectors.toList()); | |
30 | + } | |
31 | + | |
32 | + public Result getSystemResult(String systemId, String metricId) { | |
33 | + return systemId2result.get(systemId).getResult(metricId); | |
34 | + } | |
35 | + | |
36 | + private class CumulativeResult { | |
37 | + Map<String, SummaryStatistics> metricId2scores = new HashMap<>(); | |
38 | + | |
39 | + public void addScore(String metricId, double score) { | |
40 | + metricId2scores.putIfAbsent(metricId, new SummaryStatistics()); | |
41 | + metricId2scores.get(metricId).addValue(score); | |
42 | + } | |
43 | + | |
44 | + public Result getResult(String metricId) { | |
45 | + SummaryStatistics scores = metricId2scores.get(metricId); | |
46 | + return new Result(scores.getMean(), scores.getStandardDeviation()); | |
47 | + } | |
48 | + } | |
49 | +} | |
... | ... |
src/main/resources/log4j2.xml
... | ... | @@ -2,12 +2,16 @@ |
2 | 2 | <Configuration status="WARN"> |
3 | 3 | <Appenders> |
4 | 4 | <Console name="Console" target="SYSTEM_OUT"> |
5 | - <PatternLayout pattern="%d{HH:mm:ss.SSS} [%t] %-5level %logger{36} - %msg%n" /> | |
5 | + <PatternLayout pattern="%d{HH:mm:ss.SSS} [%t] %-5level %logger{3} - %msg%n" /> | |
6 | 6 | </Console> |
7 | 7 | </Appenders> |
8 | 8 | <Loggers> |
9 | 9 | <Root level="error"> |
10 | 10 | <AppenderRef ref="Console" /> |
11 | 11 | </Root> |
12 | + <Logger name="pl.waw.ipipan.zil.summ.eval" level="debug" | |
13 | + additivity="false"> | |
14 | + <AppenderRef ref="Console" /> | |
15 | + </Logger> | |
12 | 16 | </Loggers> |
13 | 17 | </Configuration> |
14 | 18 | \ No newline at end of file |
... | ... |