Commit e988534f8511a66d209675f9057b736926851bf6

Authored by Mateusz Kopeć
1 parent 2640b21b

rouge-n

... ... @@ -29,6 +29,11 @@
29 29 <artifactId>junit</artifactId>
30 30 <version>4.8.1</version>
31 31 </dependency>
  32 + <dependency>
  33 + <groupId>com.google.guava</groupId>
  34 + <artifactId>guava</artifactId>
  35 + <version>18.0</version>
  36 + </dependency>
32 37 </dependencies>
33 38 <build>
34 39 <plugins>
... ...
src/main/java/mkopec/summ/eval/Main.java
1 1 package mkopec.summ.eval;
2 2  
3 3 import java.io.File;
  4 +import java.io.IOException;
  5 +import java.nio.file.Files;
  6 +import java.nio.file.Paths;
4 7 import java.util.HashMap;
5 8 import java.util.HashSet;
6 9 import java.util.Map;
7 10 import java.util.Set;
  11 +import java.util.stream.Collectors;
8 12  
9 13 import org.apache.log4j.Logger;
10 14  
  15 +import mkopec.summ.eval.rouge.RougeN;
  16 +
11 17 public class Main {
12 18 private static final Logger LOG = Logger.getLogger(Main.class);
13 19  
... ... @@ -24,27 +30,61 @@ public class Main {
24 30 return;
25 31 }
26 32  
27   - Map<String, Set<File>> id2goldSummaryFiles = new HashMap<>();
28   - Map<String, Set<File>> id2sysSummaryFiles = new HashMap<>();
29   - loadFiles(goldDir, id2goldSummaryFiles);
30   - loadFiles(sysDir, id2sysSummaryFiles);
31   -
  33 + Map<String, Set<File>> id2goldSummaryFiles = loadFiles(goldDir);
  34 + Map<String, Set<File>> id2sysSummaryFiles = loadFiles(sysDir);
32 35 if (!id2goldSummaryFiles.keySet().equals(id2sysSummaryFiles.keySet())) {
33 36 LOG.error("Different text ids in gold and sys dirs!");
34 37 return;
35 38 }
36   - LOG.info(id2goldSummaryFiles.size()+" gold and sys text(s) loaded.");
  39 + LOG.info(id2goldSummaryFiles.size() + " gold and sys text(s) loaded.");
  40 +
  41 + for (String id : id2goldSummaryFiles.keySet()) {
  42 + Set<File> goldSummaryFiles = id2goldSummaryFiles.get(id);
  43 + Set<File> sysSummaryFiles = id2sysSummaryFiles.get(id);
  44 +
  45 + for (File sysSummaryFile : sysSummaryFiles) {
  46 + String systemSummary = loadSummaryFromFile(sysSummaryFile);
  47 + String systemId = getSystemIdFromFilename(sysSummaryFile);
  48 + Set<String> goldSummaries = goldSummaryFiles.stream().map(f -> loadSummaryFromFile(f))
  49 + .collect(Collectors.toSet());
  50 + if (goldSummaries.contains(null) || systemSummary == null) {
  51 + return;
  52 + }
  53 +
  54 + for (int i = 1; i < 4; i++) {
  55 + double score = RougeN.score(i, systemSummary, goldSummaries);
  56 + System.out.println(i + "\t" + systemId + "\t" + score);
  57 + System.out.println();
  58 + }
  59 + }
  60 + }
  61 +
  62 + }
  63 +
  64 + private static String getSystemIdFromFilename(File sysSummaryFile) {
  65 + String name = sysSummaryFile.getName();
  66 + return name.substring(name.lastIndexOf("_") + 1, name.lastIndexOf("."));
  67 + }
  68 +
  69 + private static String loadSummaryFromFile(File sysSummaryFile) {
  70 + try {
  71 + return new String(Files.readAllBytes(Paths.get(sysSummaryFile.getPath())));
  72 + } catch (IOException e) {
  73 + LOG.error("Error reading summary file: " + e);
  74 + return null;
  75 + }
37 76 }
38 77  
39   - private static void loadFiles(File goldDir,
40   - Map<String, Set<File>> id2goldSummaryFiles) {
  78 + private static Map<String, Set<File>> loadFiles(File goldDir) {
  79 + Map<String, Set<File>> result = new HashMap<>();
41 80 for (File goldFile : goldDir.listFiles()) {
42 81 String[] spl = goldFile.getName().split("_");
43 82 String id = spl[0];
44 83  
45   - if (!id2goldSummaryFiles.containsKey(id))
46   - id2goldSummaryFiles.put(id, new HashSet<>());
47   - id2goldSummaryFiles.get(id).add(goldFile);
  84 + if (!result.containsKey(id))
  85 + result.put(id, new HashSet<>());
  86 + result.get(id).add(goldFile);
48 87 }
  88 + return result;
49 89 }
50 90 }
... ...
src/main/java/mkopec/summ/eval/NgramQueue.java 0 → 100644
  1 +package mkopec.summ.eval;
  2 +
  3 +import java.util.LinkedList;
  4 +import java.util.List;
  5 +
  6 +public class NgramQueue<T> extends LinkedList<T> {
  7 +
  8 + private static final long serialVersionUID = -3001965727065327823L;
  9 +
  10 + private int ngramLength;
  11 +
  12 + public NgramQueue(int length) {
  13 + super();
  14 + this.ngramLength = length;
  15 + }
  16 +
  17 + @Override
  18 + public boolean add(T e) {
  19 + if (this.size() == ngramLength)
  20 + this.poll();
  21 + return super.add(e);
  22 + }
  23 +
  24 + public List<T> getNGram(int i) {
  25 + return this.subList(Math.max(0, this.size() - i), this.size());
  26 + }
  27 +}
... ...
src/main/java/mkopec/summ/eval/rouge/RougeN.java 0 → 100644
  1 +package mkopec.summ.eval.rouge;
  2 +
  3 +import java.util.ArrayList;
  4 +import java.util.Arrays;
  5 +import java.util.List;
  6 +import java.util.Set;
  7 +import java.util.stream.Collectors;
  8 +
  9 +import com.google.common.collect.HashMultiset;
  10 +import com.google.common.collect.Multiset;
  11 +
  12 +import mkopec.summ.eval.NgramQueue;
  13 +
  14 +public class RougeN {
  15 +
  16 + public static double score(int n, String systemSummary, Set<String> goldSummaries) {
  17 + Multiset<List<String>> systemNgrams = HashMultiset.create();
  18 + countNgrams(systemNgrams, n, systemSummary);
  19 +
  20 + int numerator = 0;
  21 + int denominator = 0;
  22 + for (String goldSummary : goldSummaries) {
  23 + Multiset<List<String>> goldNgrams = HashMultiset.create();
  24 + countNgrams(goldNgrams, n, goldSummary);
  25 +
  26 + for (List<String> goldNgram : goldNgrams) {
  27 + int goldCount = goldNgrams.count(goldNgram);
  28 + int systemCount = systemNgrams.count(goldNgram);
  29 + numerator += Math.min(goldCount, systemCount);
  30 + denominator += goldCount;
  31 + }
  32 + }
  33 +
  34 + return 1.0 * numerator / denominator;
  35 + }
  36 +
  37 + protected static void countNgrams(Multiset<List<String>> ngrams, int n, String text) {
  38 + List<String> tokens = tokenize(text).stream().map(String::toLowerCase).collect(Collectors.toList());
  39 + NgramQueue<String> ngram = new NgramQueue<>(n);
  40 + boolean ngramFull = false;
  41 + for (String token : tokens) {
  42 + ngram.add(token);
  43 + if (!ngramFull)
  44 + if (ngram.size() == n)
  45 + ngramFull = true;
  46 + else
  47 + continue;
  48 + ngrams.add(new ArrayList<>(ngram));
  49 + }
  50 + }
  51 +
  52 + protected static List<String> tokenize(String text) {
  53 + return Arrays.asList(text.split("[^\\p{L}0-9]+"));
  54 + }
  55 +
  56 +}
... ...
src/test/java/mkopec/summ/eval/rouge/RougeNTest.java 0 → 100644
  1 +package mkopec.summ.eval.rouge;
  2 +
  3 +import java.util.List;
  4 +
  5 +import org.junit.Assert;
  6 +import org.junit.Test;
  7 +
  8 +import com.google.common.collect.HashMultiset;
  9 +import com.google.common.collect.Lists;
  10 +import com.google.common.collect.Multiset;
  11 +
  12 +public class RougeNTest {
  13 +
  14 + @Test
  15 + public void testNgramCounter() {
  16 + String text = "Ala ma kota. Kota ma też Ala. Ma też kota Ala. Też kota Ala.";
  17 + Multiset<List<String>> unigrams = HashMultiset.create();
  18 + RougeN.countNgrams(unigrams, 1, text);
  19 + Assert.assertEquals(4, unigrams.count(Lists.newArrayList("kota")));
  20 + Assert.assertEquals(3, unigrams.count(Lists.newArrayList("ma")));
  21 + Assert.assertEquals(3, unigrams.count(Lists.newArrayList("też")));
  22 +
  23 + Multiset<List<String>> bigrams = HashMultiset.create();
  24 + RougeN.countNgrams(bigrams, 2, text);
  25 + Assert.assertEquals(2, bigrams.count(Lists.newArrayList("ala", "ma")));
  26 + Assert.assertEquals(1, bigrams.count(Lists.newArrayList("kota", "kota")));
  27 + Assert.assertEquals(2, bigrams.count(Lists.newArrayList("ma", "też")));
  28 +
  29 + Multiset<List<String>> trigrams = HashMultiset.create();
  30 + RougeN.countNgrams(trigrams, 3, text);
  31 + Assert.assertEquals(2, trigrams.count(Lists.newArrayList("też", "kota", "ala")));
  32 + Assert.assertEquals(1, trigrams.count(Lists.newArrayList("ala", "ma", "też")));
  33 + Assert.assertEquals(1, trigrams.count(Lists.newArrayList("kota", "kota", "ma")));
  34 +
  35 + Multiset<List<String>> ngrams1 = HashMultiset.create();
  36 + RougeN.countNgrams(ngrams1, RougeN.tokenize(text).size(), text);
  37 + Assert.assertEquals(1, ngrams1.size());
  38 +
  39 + RougeN.countNgrams(ngrams1, RougeN.tokenize(text).size(), text);
  40 + Assert.assertEquals(2, ngrams1.size());
  41 +
  42 + Multiset<List<String>> ngrams2 = HashMultiset.create();
  43 + RougeN.countNgrams(ngrams2, RougeN.tokenize(text).size() + 1, text);
  44 + Assert.assertEquals(0, ngrams2.size());
  45 + }
  46 +
  47 + @Test
  48 + public void testTokenization() {
  49 + String text = "Ala ma kota. Kot ma \"Alę\".\n\n\n I do tego 999 nóg.";
  50 + List<String> tokens = RougeN.tokenize(text);
  51 +
  52 + Assert.assertEquals("Ala", tokens.get(0));
  53 + Assert.assertEquals("kota", tokens.get(2));
  54 + Assert.assertEquals("Alę", tokens.get(5));
  55 + Assert.assertEquals("999", tokens.get(9));
  56 + Assert.assertEquals("nóg", tokens.get(10));
  57 +
  58 + Assert.assertEquals(11, tokens.size());
  59 + }
  60 +}
... ...