diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java index b137fe9..c6573ba 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java @@ -1,11 +1,87 @@ package pl.waw.ipipan.zil.summ.nicolas; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; +import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.mention.MentionModel; +import pl.waw.ipipan.zil.summ.nicolas.sentence.SentenceFeatureExtractor; +import weka.classifiers.Classifier; +import weka.core.Instance; +import weka.core.Instances; + +import java.io.IOException; +import java.util.*; + +import static java.util.stream.Collectors.toList; public class Nicolas { - public String summarizeThrift(TText text, int targetTokenCount) { - return "test nicolas"; + private final Classifier sentenceClassifier; + private final Classifier mentionClassifier; + private final MentionFeatureExtractor featureExtractor; + private final SentenceFeatureExtractor sentenceFeatureExtractor; + + public Nicolas() throws IOException, ClassNotFoundException { + mentionClassifier = Utils.loadClassifier(Constants.MENTIONS_MODEL_PATH); + featureExtractor = new MentionFeatureExtractor(); + + sentenceClassifier = Utils.loadClassifier(Constants.SENTENCES_MODEL_PATH); + sentenceFeatureExtractor = new SentenceFeatureExtractor(); + } + + public String summarizeThrift(TText text, int targetTokenCount) throws Exception { + Set<TMention> goodMentions + = MentionModel.detectGoodMentions(mentionClassifier, featureExtractor, text); + return calculateSummary(text, goodMentions, targetTokenCount, sentenceClassifier, sentenceFeatureExtractor); + } + + private static String calculateSummary(TText thrifted, Set<TMention> goodMentions, int targetSize, Classifier sentenceClassifier, SentenceFeatureExtractor sentenceFeatureExtractor) throws Exception { + List<TSentence> selectedSentences = selectSummarySentences(thrifted, goodMentions, targetSize, sentenceClassifier, sentenceFeatureExtractor); + + StringBuilder sb = new StringBuilder(); + for (TSentence sent : selectedSentences) { + sb.append(" ").append(Utils.loadSentence2Orth(sent)); + } + return sb.toString().trim(); } + private static List<TSentence> selectSummarySentences(TText thrifted, Set<TMention> goodMentions, int targetSize, Classifier sentenceClassifier, SentenceFeatureExtractor sentenceFeatureExtractor) throws Exception { + List<TSentence> sents = thrifted.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList()); + + Instances instances = Utils.createNewInstances(sentenceFeatureExtractor.getAttributesList()); + Map<TSentence, Instance> sentence2instance = Utils.extractInstancesFromSentences(thrifted, sentenceFeatureExtractor, goodMentions); + + Map<TSentence, Double> sentence2score = Maps.newHashMap(); + for (Map.Entry<TSentence, Instance> entry : sentence2instance.entrySet()) { + Instance instance = entry.getValue(); + instance.setDataset(instances); + double score = sentenceClassifier.classifyInstance(instance); + sentence2score.put(entry.getKey(), score); + } + + List<TSentence> sortedSents = Lists.newArrayList(sents); + Collections.sort(sortedSents, Comparator.comparing(sentence2score::get).reversed()); + + int size = 0; + Random r = new Random(1); + Set<TSentence> summary = Sets.newHashSet(); + for (TSentence sent : sortedSents) { + size += Utils.tokenizeOnWhitespace(Utils.loadSentence2Orth(sent)).size(); + if (r.nextDouble() > 0.4 && size > targetSize) + break; + summary.add(sent); + if (size > targetSize) + break; + } + List<TSentence> selectedSentences = Lists.newArrayList(); + for (TSentence sent : sents) { + if (summary.contains(sent)) + selectedSentences.add(sent); + } + return selectedSentences; + } }