Commit a5aab8fa541c22a199c89f883b329921ed5f7602

Authored by Mateusz Kopeć
1 parent f940218e

implement thrifted API

nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java
1 1 package pl.waw.ipipan.zil.summ.nicolas;
2 2  
  3 +import com.google.common.collect.Lists;
  4 +import com.google.common.collect.Maps;
  5 +import com.google.common.collect.Sets;
  6 +import pl.waw.ipipan.zil.multiservice.thrift.types.TMention;
  7 +import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence;
3 8 import pl.waw.ipipan.zil.multiservice.thrift.types.TText;
  9 +import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor;
  10 +import pl.waw.ipipan.zil.summ.nicolas.mention.MentionModel;
  11 +import pl.waw.ipipan.zil.summ.nicolas.sentence.SentenceFeatureExtractor;
  12 +import weka.classifiers.Classifier;
  13 +import weka.core.Instance;
  14 +import weka.core.Instances;
  15 +
  16 +import java.io.IOException;
  17 +import java.util.*;
  18 +
  19 +import static java.util.stream.Collectors.toList;
4 20  
5 21 public class Nicolas {
6 22  
7   - public String summarizeThrift(TText text, int targetTokenCount) {
8   - return "test nicolas";
  23 + private final Classifier sentenceClassifier;
  24 + private final Classifier mentionClassifier;
  25 + private final MentionFeatureExtractor featureExtractor;
  26 + private final SentenceFeatureExtractor sentenceFeatureExtractor;
  27 +
  28 + public Nicolas() throws IOException, ClassNotFoundException {
  29 + mentionClassifier = Utils.loadClassifier(Constants.MENTIONS_MODEL_PATH);
  30 + featureExtractor = new MentionFeatureExtractor();
  31 +
  32 + sentenceClassifier = Utils.loadClassifier(Constants.SENTENCES_MODEL_PATH);
  33 + sentenceFeatureExtractor = new SentenceFeatureExtractor();
  34 + }
  35 +
  36 + public String summarizeThrift(TText text, int targetTokenCount) throws Exception {
  37 + Set<TMention> goodMentions
  38 + = MentionModel.detectGoodMentions(mentionClassifier, featureExtractor, text);
  39 + return calculateSummary(text, goodMentions, targetTokenCount, sentenceClassifier, sentenceFeatureExtractor);
  40 + }
  41 +
  42 + private static String calculateSummary(TText thrifted, Set<TMention> goodMentions, int targetSize, Classifier sentenceClassifier, SentenceFeatureExtractor sentenceFeatureExtractor) throws Exception {
  43 + List<TSentence> selectedSentences = selectSummarySentences(thrifted, goodMentions, targetSize, sentenceClassifier, sentenceFeatureExtractor);
  44 +
  45 + StringBuilder sb = new StringBuilder();
  46 + for (TSentence sent : selectedSentences) {
  47 + sb.append(" ").append(Utils.loadSentence2Orth(sent));
  48 + }
  49 + return sb.toString().trim();
9 50 }
10 51  
  52 + private static List<TSentence> selectSummarySentences(TText thrifted, Set<TMention> goodMentions, int targetSize, Classifier sentenceClassifier, SentenceFeatureExtractor sentenceFeatureExtractor) throws Exception {
  53 + List<TSentence> sents = thrifted.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList());
  54 +
  55 + Instances instances = Utils.createNewInstances(sentenceFeatureExtractor.getAttributesList());
  56 + Map<TSentence, Instance> sentence2instance = Utils.extractInstancesFromSentences(thrifted, sentenceFeatureExtractor, goodMentions);
  57 +
  58 + Map<TSentence, Double> sentence2score = Maps.newHashMap();
  59 + for (Map.Entry<TSentence, Instance> entry : sentence2instance.entrySet()) {
  60 + Instance instance = entry.getValue();
  61 + instance.setDataset(instances);
  62 + double score = sentenceClassifier.classifyInstance(instance);
  63 + sentence2score.put(entry.getKey(), score);
  64 + }
  65 +
  66 + List<TSentence> sortedSents = Lists.newArrayList(sents);
  67 + Collections.sort(sortedSents, Comparator.comparing(sentence2score::get).reversed());
  68 +
  69 + int size = 0;
  70 + Random r = new Random(1);
  71 + Set<TSentence> summary = Sets.newHashSet();
  72 + for (TSentence sent : sortedSents) {
  73 + size += Utils.tokenizeOnWhitespace(Utils.loadSentence2Orth(sent)).size();
  74 + if (r.nextDouble() > 0.4 && size > targetSize)
  75 + break;
  76 + summary.add(sent);
  77 + if (size > targetSize)
  78 + break;
  79 + }
  80 + List<TSentence> selectedSentences = Lists.newArrayList();
  81 + for (TSentence sent : sents) {
  82 + if (summary.contains(sent))
  83 + selectedSentences.add(sent);
  84 + }
  85 + return selectedSentences;
  86 + }
11 87 }
... ...