From e1126cdba70bd5287871ebbe89e9ae6635bb5a01 Mon Sep 17 00:00:00 2001 From: Mateusz Kopeć <m.kopec@ipipan.waw.pl> Date: Sat, 8 Oct 2016 19:45:58 +0200 Subject: [PATCH] rough draft --- .gitignore | 18 ++++++++++++++++++ nicolas-cli/pom.xml | 14 ++++++++++++++ nicolas-core/pom.xml | 22 ++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java | 33 +++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java | 11 +++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Utils.java | 200 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel2.java | 120 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureExtractor.java | 83 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureHelper.java | 187 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/Interpretation.java | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java | 199 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionModel.java | 37 +++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionScorer.java | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/PrepareTrainingData.java | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/TrainModel.java | 47 +++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java | 44 ++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/PrepareTrainingData.java | 91 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java | 103 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceScorer.java | 32 ++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/TrainModel.java | 47 +++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java | 41 +++++++++++++++++++++++++++++++++++++++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/Zero.java | 128 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-model/pom.xml | 14 ++++++++++++++ nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/frequent_bases.txt | 237 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-train/pom.xml | 14 ++++++++++++++ nicolas-zero/pom.xml | 14 ++++++++++++++ pom.xml | 101 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 28 files changed, 2105 insertions(+), 0 deletions(-) create mode 100644 .gitignore create mode 100644 nicolas-cli/pom.xml create mode 100644 nicolas-core/pom.xml create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Utils.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel2.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureExtractor.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureHelper.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/Interpretation.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionModel.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionScorer.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/PrepareTrainingData.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/TrainModel.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/PrepareTrainingData.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceScorer.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/TrainModel.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java create mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/Zero.java create mode 100644 nicolas-model/pom.xml create mode 100644 nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/frequent_bases.txt create mode 100644 nicolas-train/pom.xml create mode 100644 nicolas-zero/pom.xml create mode 100644 pom.xml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..28f546a --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +# Created by .ignore support plugin (hsz.mobi) +### Java template +*. +target/ + +# Mobile Tools for Java (J2ME) +.mtj.tmp/ + +# Package Files # +*.jar +*.war +*.ear + +# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml +hs_err_pid* + +.idea +*.iml \ No newline at end of file diff --git a/nicolas-cli/pom.xml b/nicolas-cli/pom.xml new file mode 100644 index 0000000..e65a5b6 --- /dev/null +++ b/nicolas-cli/pom.xml @@ -0,0 +1,14 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <artifactId>nicolas-container</artifactId> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <version>1.0-SNAPSHOT</version> + </parent> + + <artifactId>nicolas-cli</artifactId> + +</project> \ No newline at end of file diff --git a/nicolas-core/pom.xml b/nicolas-core/pom.xml new file mode 100644 index 0000000..63fc157 --- /dev/null +++ b/nicolas-core/pom.xml @@ -0,0 +1,22 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <artifactId>nicolas-container</artifactId> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <version>1.0-SNAPSHOT</version> + </parent> + + <artifactId>nicolas</artifactId> + + <dependencies> + <dependency> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <artifactId>nicolas-model</artifactId> + <version>${project.version}</version> + <scope>runtime</scope> + </dependency> + </dependencies> +</project> \ No newline at end of file diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java new file mode 100644 index 0000000..f4a6ed6 --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java @@ -0,0 +1,33 @@ +package pl.waw.ipipan.zil.summ.nicolas; + +import weka.classifiers.Classifier; +import weka.classifiers.trees.RandomForest; + + +public class Constants { + + public static final String MENTIONS_MODEL_PATH = "mentions_model.bin"; + public static final String SENTENCES_MODEL_PATH = "sentences_model.bin"; + public static final String MENTIONS_DATASET_PATH = "mentions_train.arff"; + public static final String SENTENCES_DATASET_PATH = "sentences_train.arff"; + + private Constants() { + } + + public static Classifier getClassifier() { + RandomForest classifier = new RandomForest(); + classifier.setNumIterations(250); + classifier.setSeed(0); + classifier.setNumExecutionSlots(8); + return classifier; + } + + + public static Classifier getSentencesClassifier() { + RandomForest classifier = new RandomForest(); + classifier.setNumIterations(250); + classifier.setSeed(0); + classifier.setNumExecutionSlots(8); + return classifier; + } +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java new file mode 100644 index 0000000..b137fe9 --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java @@ -0,0 +1,11 @@ +package pl.waw.ipipan.zil.summ.nicolas; + +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; + +public class Nicolas { + + public String summarizeThrift(TText text, int targetTokenCount) { + return "test nicolas"; + } + +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Utils.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Utils.java new file mode 100644 index 0000000..6b0ff0a --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Utils.java @@ -0,0 +1,200 @@ +package pl.waw.ipipan.zil.summ.nicolas; + +import com.google.common.base.Charsets; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import com.google.common.io.Files; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; +import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.multiservice.thrift.types.TToken; +import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.mention.MentionScorer; +import pl.waw.ipipan.zil.summ.nicolas.sentence.SentenceFeatureExtractor; +import weka.classifiers.Classifier; +import weka.core.Attribute; +import weka.core.DenseInstance; +import weka.core.Instance; +import weka.core.Instances; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.util.*; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.stream.Collectors.toList; + +public class Utils { + + private static final Logger LOG = LoggerFactory.getLogger(Utils.class); + + private static final String DATASET_NAME = "Dataset"; + + public static Map<TMention, Instance> extractInstancesFromMentions(TText preprocessedText, MentionFeatureExtractor featureExtractor) { + List<TSentence> sentences = preprocessedText.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList()); + Map<TMention, Map<Attribute, Double>> mention2features = featureExtractor.calculateFeatures(preprocessedText); + + LOG.info("Extracting " + featureExtractor.getAttributesList().size() + " features of each mention."); + Map<TMention, Instance> mention2instance = Maps.newHashMap(); + for (TMention tMention : sentences.stream().flatMap(s -> s.getMentions().stream()).collect(toList())) { + Instance instance = new DenseInstance(featureExtractor.getAttributesList().size()); + Map<Attribute, Double> mentionFeatures = mention2features.get(tMention); + for (Attribute attribute : featureExtractor.getAttributesList()) { + instance.setValue(attribute, mentionFeatures.get(attribute)); + } + mention2instance.put(tMention, instance); + } + return mention2instance; + } + + public static Map<TSentence, Instance> extractInstancesFromSentences(TText preprocessedText, SentenceFeatureExtractor featureExtractor, Set<TMention> goodMentions) { + List<TSentence> sentences = preprocessedText.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList()); + Map<TSentence, Map<Attribute, Double>> sentence2features = featureExtractor.calculateFeatures(preprocessedText, goodMentions); + + LOG.info("Extracting " + featureExtractor.getAttributesList().size() + " features of each sentence."); + Map<TSentence, Instance> sentence2instance = Maps.newHashMap(); + for (TSentence sentence : sentences) { + Instance instance = new DenseInstance(featureExtractor.getAttributesList().size()); + Map<Attribute, Double> sentenceFeatures = sentence2features.get(sentence); + for (Attribute attribute : featureExtractor.getAttributesList()) { + instance.setValue(attribute, sentenceFeatures.get(attribute)); + } + sentence2instance.put(sentence, instance); + } + return sentence2instance; + } + + public static Instances createNewInstances(ArrayList<Attribute> attributesList) { + Instances instances = new Instances(DATASET_NAME, attributesList, 0); + instances.setClassIndex(0); + return instances; + } + + public static Classifier loadClassifier(String path) throws IOException, ClassNotFoundException { + LOG.info("Loading classifier..."); + try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(path))) { + Classifier classifier = (Classifier) ois.readObject(); + LOG.info("Done. " + classifier.toString()); + return classifier; + } + } + + public static Map<String, TText> loadPreprocessedTexts(String path) { + Map<String, TText> id2text = Maps.newHashMap(); + for (File processedFullTextFile : new File(path).listFiles()) { + TText processedFullText = loadThrifted(processedFullTextFile); + id2text.put(processedFullTextFile.getName().split("\\.")[0], processedFullText); + } + LOG.info(id2text.size() + " preprocessed texts found."); + return id2text; + } + + + public static TText loadThrifted(File originalFile) { + try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(originalFile))) { + return (TText) ois.readObject(); + } catch (ClassNotFoundException | IOException e) { + LOG.error("Error reading serialized file: " + e); + return null; + } + } + + public static List<String> tokenize(String text) { + return Arrays.asList(text.split("[^\\p{L}0-9]+")); + } + + public static List<String> tokenizeOnWhitespace(String text) { + return Arrays.asList(text.split(" +")); + } + + public static Map<TMention, String> loadMention2HeadOrth(List<TSentence> sents) { + Map<TMention, String> mention2orth = Maps.newHashMap(); + for (TSentence s : sents) { + Map<String, String> tokId2orth = s.getTokens().stream().collect(Collectors.toMap(TToken::getId, TToken::getOrth)); + Map<String, Boolean> tokId2nps = s.getTokens().stream().collect(Collectors.toMap(TToken::getId, TToken::isNoPrecedingSpace)); + + for (TMention m : s.getMentions()) { + StringBuffer mentionOrth = new StringBuffer(); + for (String tokId : m.getHeadIds()) { + if (!tokId2nps.get(tokId)) + mentionOrth.append(" "); + mentionOrth.append(tokId2orth.get(tokId)); + } + mention2orth.put(m, mentionOrth.toString().trim()); + } + } + return mention2orth; + } + + private static final Collection<String> STOPWORDS = Sets.newHashSet(); + + static { + STOPWORDS.addAll(Lists.newArrayList("i", "się", "to", "co")); + } + + public static Map<TMention, String> loadMention2Orth(List<TSentence> sents) { + Map<TMention, String> mention2orth = Maps.newHashMap(); + for (TSentence s : sents) { + Map<String, TToken> tokId2tok = s.getTokens().stream().collect(Collectors.toMap(TToken::getId, Function.identity())); + + for (TMention m : s.getMentions()) { + StringBuffer mentionOrth = new StringBuffer(); + for (String tokId : m.getChildIds()) { + TToken token = tokId2tok.get(tokId); + if (STOPWORDS.contains(token.getChosenInterpretation().getBase().toLowerCase())) { + continue; + } + + if (!token.isNoPrecedingSpace()) + mentionOrth.append(" "); + mentionOrth.append(token.getOrth()); + } + mention2orth.put(m, mentionOrth.toString().trim()); + } + } + return mention2orth; + } + + public static Map<TMention, String> loadMention2Base(List<TSentence> sents) { + Map<TMention, String> mention2base = Maps.newHashMap(); + for (TSentence s : sents) { + Map<String, String> tokId2base = s.getTokens().stream().collect(Collectors.toMap(tok -> tok.getId(), tok -> tok.getChosenInterpretation().getBase())); + + for (TMention m : s.getMentions()) { + StringBuilder mentionBase = new StringBuilder(); + for (String tokId : m.getChildIds()) { + mentionBase.append(" "); + mentionBase.append(tokId2base.get(tokId)); + } + mention2base.put(m, mentionBase.toString().toLowerCase().trim()); + } + } + return mention2base; + } + + public static String loadSentence2Orth(TSentence sentence) { + StringBuilder sb = new StringBuilder(); + for (TToken token : sentence.getTokens()) { + if (!token.isNoPrecedingSpace()) + sb.append(" "); + sb.append(token.getOrth()); + } + return sb.toString().trim(); + } + + public static Set<TMention> loadGoldGoodMentions(String id, TText text, boolean dev) throws IOException { + String optimalSummary = Files.toString(new File("src/main/resources/optimal_summaries/" + (dev ? "dev" : "test") + "/" + id + "_theoretic_ub_rouge_1.txt"), Charsets.UTF_8); + + MentionScorer scorer = new MentionScorer(); + Map<TMention, Double> mention2score = scorer.calculateMentionScores(optimalSummary, text); + + mention2score.keySet().removeIf(tMention -> mention2score.get(tMention) != 1.0); + return mention2score.keySet(); + } +} \ No newline at end of file diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel2.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel2.java new file mode 100644 index 0000000..f687d4a --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel2.java @@ -0,0 +1,120 @@ +package pl.waw.ipipan.zil.summ.nicolas.apply; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; +import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.Constants; +import pl.waw.ipipan.zil.summ.nicolas.Utils; +import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.mention.MentionModel; +import pl.waw.ipipan.zil.summ.nicolas.sentence.SentenceFeatureExtractor; +import weka.classifiers.Classifier; +import weka.core.Instance; +import weka.core.Instances; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.util.*; + +import static java.util.stream.Collectors.toList; + +public class ApplyModel2 { + + private static final Logger LOG = LoggerFactory.getLogger(ApplyModel2.class); + + private static final String TEST_PREPROCESSED_DATA_PATH = "src/main/resources/preprocessed_full_texts/test"; + private static final String TARGET_DIR = "summaries"; + + public static void main(String[] args) throws Exception { + Classifier mentionClassifier = Utils.loadClassifier(Constants.MENTIONS_MODEL_PATH); + MentionFeatureExtractor featureExtractor = new MentionFeatureExtractor(); + + Classifier sentenceClassifier = Utils.loadClassifier(Constants.SENTENCES_MODEL_PATH); + SentenceFeatureExtractor sentenceFeatureExtractor = new SentenceFeatureExtractor(); + + Map<String, TText> id2preprocessedText = Utils.loadPreprocessedTexts(TEST_PREPROCESSED_DATA_PATH); + int i = 1; + double avgSize = 0; + for (Map.Entry<String, TText> entry : id2preprocessedText.entrySet()) { + TText text = entry.getValue(); + + Set<TMention> goodMentions + = MentionModel.detectGoodMentions(mentionClassifier, featureExtractor, text); + + int targetSize = calculateTargetSize(text); + String summary = calculateSummary(text, goodMentions, targetSize, sentenceClassifier, sentenceFeatureExtractor); + int size = Utils.tokenize(summary).size(); + avgSize += size; + try (BufferedWriter bw = new BufferedWriter(new FileWriter(new File(TARGET_DIR, entry.getKey() + "_emily3.txt")))) { + bw.append(summary); + } + + LOG.info(i++ + "/" + id2preprocessedText.size() + " id: " + entry.getKey()); + } + + LOG.info("Avg size:" + avgSize / id2preprocessedText.size()); + } + + private static int calculateTargetSize(TText text) { + List<TSentence> sents = text.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList()); + StringBuffer body = new StringBuffer(); + for (TSentence sent : sents) + body.append(Utils.loadSentence2Orth(sent) + " "); + int tokenCount = Utils.tokenizeOnWhitespace(body.toString().trim()).size(); + return (int) (0.2 * tokenCount); + } + + private static String calculateSummary(TText thrifted, Set<TMention> goodMentions, int targetSize, Classifier sentenceClassifier, SentenceFeatureExtractor sentenceFeatureExtractor) throws Exception { + List<TSentence> selectedSentences = selectSummarySentences(thrifted, goodMentions, targetSize, sentenceClassifier, sentenceFeatureExtractor); + + StringBuffer sb = new StringBuffer(); + for (TSentence sent : selectedSentences) { + sb.append(" " + Utils.loadSentence2Orth(sent)); + } + return sb.toString().trim(); + } + + private static List<TSentence> selectSummarySentences(TText thrifted, Set<TMention> goodMentions, int targetSize, Classifier sentenceClassifier, SentenceFeatureExtractor sentenceFeatureExtractor) throws Exception { + + List<TSentence> sents = thrifted.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(toList()); + + Instances instances = Utils.createNewInstances(sentenceFeatureExtractor.getAttributesList()); + Map<TSentence, Instance> sentence2instance = Utils.extractInstancesFromSentences(thrifted, sentenceFeatureExtractor, goodMentions); + + Map<TSentence, Double> sentence2score = Maps.newHashMap(); + for (Map.Entry<TSentence, Instance> entry : sentence2instance.entrySet()) { + Instance instance = entry.getValue(); + instance.setDataset(instances); + double score = sentenceClassifier.classifyInstance(instance); + sentence2score.put(entry.getKey(), score); + } + + List<TSentence> sortedSents = Lists.newArrayList(sents); + Collections.sort(sortedSents, Comparator.comparing(sentence2score::get).reversed()); + + int size = 0; + Random r = new Random(1); + Set<TSentence> summary = Sets.newHashSet(); + for (TSentence sent : sortedSents) { + size += Utils.tokenizeOnWhitespace(Utils.loadSentence2Orth(sent)).size(); + if (r.nextDouble() > 0.4 && size > targetSize) + break; + summary.add(sent); + if (size > targetSize) + break; + } + List<TSentence> selectedSentences = Lists.newArrayList(); + for (TSentence sent : sents) { + if (summary.contains(sent)) + selectedSentences.add(sent); + } + return selectedSentences; + } + +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureExtractor.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureExtractor.java new file mode 100644 index 0000000..39de47f --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureExtractor.java @@ -0,0 +1,83 @@ +package pl.waw.ipipan.zil.summ.nicolas.features; + +import com.google.common.collect.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import weka.core.Attribute; + +import java.util.*; + +public class FeatureExtractor { + + protected static final Logger LOG = LoggerFactory.getLogger(FeatureExtractor.class); + + private final List<Attribute> sortedAttributes = Lists.newArrayList(); + + private final BiMap<String, Attribute> name2attribute = HashBiMap.create(); + + private final Set<String> normalizedAttributes = Sets.newHashSet(); + + public ArrayList<Attribute> getAttributesList() { + return Lists.newArrayList(sortedAttributes); + } + + protected Attribute getAttributeByName(String name) { + return name2attribute.get(name); + } + + protected void addNumericAttribute(String attributeName) { + name2attribute.put(attributeName, new Attribute(attributeName)); + } + + protected void addBinaryAttribute(String attributeName) { + name2attribute.put(attributeName, new Attribute(attributeName, Lists.newArrayList("f", "t"))); + } + + protected void addNominalAttribute(String attributeName, List<String> values) { + name2attribute.put(attributeName, new Attribute(attributeName, values)); + } + + protected void addNumericAttributeNormalized(String attributeName) { + addNumericAttribute(attributeName); + addNumericAttribute(attributeName + "_normalized"); + normalizedAttributes.add(attributeName); + } + + protected void fillSortedAttributes(String scoreAttName) { + sortedAttributes.addAll(name2attribute.values()); + sortedAttributes.remove(getAttributeByName(scoreAttName)); + Collections.sort(sortedAttributes, (o1, o2) -> name2attribute.inverse().get(o1).compareTo(name2attribute.inverse().get(o2))); + sortedAttributes.add(0, getAttributeByName(scoreAttName)); + } + + protected <T> void addNormalizedAttributeValues(Map<T, Map<Attribute, Double>> entity2attributes) { + Map<Attribute, Double> attribute2max = Maps.newHashMap(); + Map<Attribute, Double> attribute2min = Maps.newHashMap(); + for (T entity : entity2attributes.keySet()) { + Map<Attribute, Double> entityAttributes = entity2attributes.get(entity); + for (String attributeName : normalizedAttributes) { + Attribute attribute = getAttributeByName(attributeName); + Double value = entityAttributes.get(attribute); + + attribute2max.putIfAbsent(attribute, Double.MIN_VALUE); + attribute2max.compute(attribute, (k, v) -> Math.max(v, value)); + + attribute2min.putIfAbsent(attribute, Double.MAX_VALUE); + attribute2min.compute(attribute, (k, v) -> Math.min(v, value)); + } + } + for (T mention : entity2attributes.keySet()) { + Map<Attribute, Double> entityAttributes = entity2attributes.get(mention); + for (Attribute attribute : attribute2max.keySet()) { + Attribute normalizedAttribute = getAttributeByName(name2attribute.inverse().get(attribute) + "_normalized"); + entityAttributes.put(normalizedAttribute, + (entityAttributes.get(attribute) - attribute2min.get(attribute)) + / (attribute2max.get(attribute) - attribute2min.get(attribute))); + } + } + } + + protected double toBinary(boolean bool) { + return bool ? 1.0 : 0.0; + } +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureHelper.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureHelper.java new file mode 100644 index 0000000..4dc2446 --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureHelper.java @@ -0,0 +1,187 @@ +package pl.waw.ipipan.zil.summ.nicolas.features; + +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import pl.waw.ipipan.zil.multiservice.thrift.types.*; +import pl.waw.ipipan.zil.summ.nicolas.Utils; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toMap; + +/** + * Created by me2 on 04.04.16. + */ +public class FeatureHelper { + + private final List<TMention> mentions; + private final Map<String, TMention> mentionId2mention; + private final Map<TCoreference, List<TMention>> coref2mentions = Maps.newHashMap(); + private final Map<TMention, TCoreference> mention2coref = Maps.newHashMap(); + private final Map<TMention, TSentence> mention2sent = Maps.newHashMap(); + private final Map<TMention, TParagraph> mention2par = Maps.newHashMap(); + private final Map<TMention, String> mention2Orth = Maps.newHashMap(); + private final Map<TMention, String> mention2Base = Maps.newHashMap(); + private final Map<TMention, TToken> mention2head = Maps.newHashMap(); + private final Set<TMention> mentionsInNamedEntities = Sets.newHashSet(); + + private final Map<TMention, Integer> mention2Index = Maps.newHashMap(); + private final Map<TSentence, Integer> sent2Index = Maps.newHashMap(); + private final Map<TParagraph, Integer> par2Index = Maps.newHashMap(); + private final Map<TSentence, Integer> sent2IndexInPar = Maps.newHashMap(); + private final Map<TMention, Integer> mention2indexInPar = Maps.newHashMap(); + private final Map<TMention, Integer> mention2indexInSent = Maps.newHashMap(); + + + public FeatureHelper(TText preprocessedText) { + mentions = preprocessedText.getParagraphs().stream() + .flatMap(p -> p.getSentences().stream()) + .flatMap(s -> s.getMentions().stream()).collect(Collectors.toList()); + + mentionId2mention = mentions.stream().collect(Collectors.toMap(TMention::getId, Function.identity())); + + for (TCoreference coref : preprocessedText.getCoreferences()) { + List<TMention> ments = coref.getMentionIds().stream().map(mentionId2mention::get).collect(toList()); + for (TMention m : ments) { + mention2coref.put(m, coref); + } + coref2mentions.put(coref, ments); + } + + int parIdx = 0; + int sentIdx = 0; + int mentionIdx = 0; + for (TParagraph par : preprocessedText.getParagraphs()) { + Map<TMention, String> m2o = Utils.loadMention2Orth(par.getSentences()); + mention2Orth.putAll(m2o); + Map<TMention, String> m2b = Utils.loadMention2Base(par.getSentences()); + mention2Base.putAll(m2b); + + int sentIdxInPar = 0; + int mentionIdxInPar = 0; + for (TSentence sent : par.getSentences()) { + + Map<String, TToken> tokenId2token = sent.getTokens().stream().collect(toMap(TToken::getId, Function.identity())); + + Map<String, Set<TNamedEntity>> tokenId2namedEntities = Maps.newHashMap(); + for (TNamedEntity namedEntity : sent.getNames()) { + for (String childId : namedEntity.getChildIds()) { + tokenId2namedEntities.putIfAbsent(childId, Sets.newHashSet()); + tokenId2namedEntities.get(childId).add(namedEntity); + } + } + + int mentionIdxInSent = 0; + for (TMention mention : sent.getMentions()) { + mention2sent.put(mention, sent); + mention2par.put(mention, par); + mention2Index.put(mention, mentionIdx++); + mention2indexInSent.put(mention, mentionIdxInSent++); + mention2indexInPar.put(mention, mentionIdxInPar++); + + String firstHeadTokenId = mention.getHeadIds().iterator().next(); + mention2head.put(mention, tokenId2token.get(firstHeadTokenId)); + if (tokenId2namedEntities.containsKey(firstHeadTokenId)) + mentionsInNamedEntities.add(mention); + } + sent2Index.put(sent, sentIdx++); + sent2IndexInPar.put(sent, sentIdxInPar++); + } + + par2Index.put(par, parIdx++); + } + } + + public List<TMention> getMentions() { + return mentions; + } + + public int getMentionIndexInChain(TMention mention) { + return coref2mentions.get(mention2coref.get(mention)).indexOf(mention); + } + + public int getChainLength(TMention mention) { + return coref2mentions.get(mention2coref.get(mention)).size(); + } + + public String getSentenceLastTokenOrth(TSentence sent) { + return sent.getTokens().get(sent.getTokensSize() - 1).getOrth(); + } + + public String getMentionOrth(TMention mention) { + return mention2Orth.get(mention); + } + + public String getMentionBase(TMention mention) { + return mention2Base.get(mention); + } + + public int getMentionIndex(TMention mention) { + return mention2Index.get(mention); + } + + public int getMentionIndexInSent(TMention mention) { + return mention2indexInSent.get(mention); + } + + public int getMentionIndexInPar(TMention mention) { + return mention2indexInPar.get(mention); + } + + public int getParIndex(TParagraph paragraph) { + return par2Index.get(paragraph); + } + + public int getSentIndex(TSentence sent) { + return sent2Index.get(sent); + } + + public int getSentIndexInPar(TSentence sent) { + return sent2IndexInPar.get(sent); + } + + public TParagraph getMentionParagraph(TMention mention) { + return mention2par.get(mention); + } + + public TSentence getMentionSentence(TMention mention) { + return mention2sent.get(mention); + } + + public TMention getFirstChainMention(TMention mention) { + return mentionId2mention.get(mention2coref.get(mention).getMentionIdsIterator().next()); + } + + public TToken getMentionHeadToken(TMention mention) { + return mention2head.get(mention); + } + + public boolean isMentionNamedEntity(TMention mention) { + return mentionsInNamedEntities.contains(mention); + } + + public boolean isNested(TMention mention) { + return mentions.stream().anyMatch(m -> m.getChildIds().containsAll(mention.getChildIds())); + } + + public boolean isNesting(TMention mention) { + return mentions.stream().anyMatch(m -> mention.getChildIds().containsAll(m.getChildIds())); + } + + public Set<TCoreference> getClusters() { + return coref2mentions.keySet(); + } + + public Set<TMention> getCoreferentMentions(TMention tMention) { + return getMentionCluster(tMention).getMentionIds().stream().map(this.mentionId2mention::get).collect(Collectors.toSet()); + } + + public TCoreference getMentionCluster(TMention tMention) { + return this.mention2coref.get(tMention); + } +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/Interpretation.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/Interpretation.java new file mode 100644 index 0000000..11d0cda --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/Interpretation.java @@ -0,0 +1,77 @@ +package pl.waw.ipipan.zil.summ.nicolas.features; + +import pl.waw.ipipan.zil.multiservice.thrift.types.TInterpretation; + + +public class Interpretation { + private String ctag = "null"; + private String casee = "null"; + private String gender = "null"; + private String number = "null"; + private String person = "null"; + + public Interpretation(TInterpretation chosenInterpretation) { + ctag = chosenInterpretation.getCtag(); + String[] split = chosenInterpretation.getMsd().split(":"); + switch (ctag) { + case "ger": + case "subst": + case "pact": + case "ppas": + case "num": + case "numcol": + case "adj": + number = split[0]; + casee = split[1]; + gender = split[2]; + break; + case "ppron12": + case "ppron3": + number = split[0]; + casee = split[1]; + gender = split[2]; + person = split[3]; + break; + case "siebie": + casee = split[0]; + break; + case "fin": + case "bedzie": + case "aglt": + case "impt": + number = split[0]; + person = split[1]; + break; + case "praet": + case "winien": + number = split[0]; + gender = split[1]; + break; + case "prep": + casee = split[0]; + break; + default: + break; + } + } + + public String getCase() { + return casee; + } + + public String getGender() { + return gender; + } + + public String getNumber() { + return number; + } + + public String getPerson() { + return person; + } + + public String getCtag() { + return ctag; + } +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java new file mode 100644 index 0000000..f6ccab9 --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java @@ -0,0 +1,199 @@ +package pl.waw.ipipan.zil.summ.nicolas.mention; + +import com.google.common.collect.*; +import pl.waw.ipipan.zil.multiservice.thrift.types.*; +import pl.waw.ipipan.zil.summ.nicolas.features.FeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.features.FeatureHelper; +import pl.waw.ipipan.zil.summ.nicolas.features.Interpretation; +import weka.core.Attribute; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class MentionFeatureExtractor extends FeatureExtractor { + + private final List<String> frequentBases = Lists.newArrayList(); + + public MentionFeatureExtractor() { + + //coref + addNumericAttributeNormalized("chain_length"); + + // text characteristics + addNumericAttribute("text_token_count"); + addNumericAttribute("text_sent_count"); + addNumericAttribute("text_par_count"); + addNumericAttribute("text_mention_count"); + addNumericAttribute("text_cluster_count"); + + //mention characteristics + for (String prefix : Lists.newArrayList("mention", "chain_first_mention")) { + // mention characteristics + addNumericAttributeNormalized(prefix + "_index"); + addNumericAttributeNormalized(prefix + "_index_in_sent"); + addNumericAttributeNormalized(prefix + "_index_in_par"); + addNumericAttributeNormalized(prefix + "_index_in_chain"); + addBinaryAttribute(prefix + "_capitalized"); + addBinaryAttribute(prefix + "_all_caps"); + addNumericAttributeNormalized(prefix + "_char_count"); + addNumericAttributeNormalized(prefix + "_token_count"); + addBinaryAttribute(prefix + "_is_zero"); + addBinaryAttribute(prefix + "_is_named"); + addBinaryAttribute(prefix + "_is_pronoun"); + addNominalAttribute(prefix + "_ctag", Lists.newArrayList("other", "null", "impt", "subst", "aglt", "ppron3", "ger", "praet", "fin", "num", "interp", "siebie", "brev", "interj", "ppron12", "adj", "burk", "pcon", "bedzie", "adv", "prep", "depr", "xxx", "winien", "conj", "qub", "adja", "ppas", "comp", "pact")); + addNominalAttribute(prefix + "_person", Lists.newArrayList("other", "null", "pri", "sec", "ter")); + addNominalAttribute(prefix + "_case", Lists.newArrayList("other", "null", "nom", "acc", "dat", "gen", "loc", "inst", "voc")); + addNominalAttribute(prefix + "_number", Lists.newArrayList("other", "null", "sg", "pl")); + addNominalAttribute(prefix + "_gender", Lists.newArrayList("other", "null", "f", "m1", "m2", "m3", "n")); + + // relation to other + addBinaryAttribute(prefix + "_is_nested"); + addBinaryAttribute(prefix + "_is_nesting"); + + // par characteristics + addNumericAttributeNormalized(prefix + "_par_idx"); + addNumericAttributeNormalized(prefix + "_par_token_count"); + addNumericAttributeNormalized(prefix + "_par_sent_count"); + + // sent characteristics + addNumericAttributeNormalized(prefix + "_sent_token_count"); + addNumericAttributeNormalized(prefix + "_sent_mention_count"); + addNumericAttributeNormalized(prefix + "_sent_idx"); + addNumericAttributeNormalized(prefix + "_sent_idx_in_par"); + addBinaryAttribute(prefix + "_sent_ends_with_dot"); + addBinaryAttribute(prefix + "_sent_ends_with_questionmark"); + + // frequent bases + loadFrequentBases(); + for (String base : frequentBases) { + addBinaryAttribute(prefix + "_" + encodeBase(base)); + } + } + + addNominalAttribute("score", Lists.newArrayList("bad", "good")); + fillSortedAttributes("score"); + } + + private String encodeBase(String base) { + return "base_equal_" + base.replaceAll(" ", "_").replaceAll("\"", "Q"); + } + + private void loadFrequentBases() { + try { + Stream<String> lines = Files.lines(new File("frequent_bases.txt").toPath()); + this.frequentBases.addAll(lines.map(String::trim).collect(Collectors.toList())); + } catch (IOException e) { + e.printStackTrace(); + } + } + + public Map<TMention, Map<Attribute, Double>> calculateFeatures(TText preprocessedText) { + Map<TMention, Map<Attribute, Double>> result = Maps.newHashMap(); + + FeatureHelper helper = new FeatureHelper(preprocessedText); + + addScoreFeature(result, helper.getMentions()); + + for (TMention mention : helper.getMentions()) { + Map<Attribute, Double> attribute2value = result.get(mention); + + //mention + addMentionAttributes(helper, mention, attribute2value, "mention"); + + //first chain mention + TMention firstChainMention = helper.getFirstChainMention(mention); + addMentionAttributes(helper, firstChainMention, attribute2value, "chain_first_mention"); + + //coref + attribute2value.put(getAttributeByName("chain_length"), (double) helper.getChainLength(mention)); + + //text + List<TParagraph> pars = preprocessedText.getParagraphs(); + List<TSentence> sents = pars.stream().flatMap(p -> p.getSentences().stream()).collect(Collectors.toList()); + List<TToken> tokens = sents.stream().flatMap(s -> s.getTokens().stream()).collect(Collectors.toList()); + attribute2value.put(getAttributeByName("text_char_count"), tokens.stream().mapToDouble(t -> t.getOrth().length()).sum()); + attribute2value.put(getAttributeByName("text_token_count"), (double) tokens.size()); + attribute2value.put(getAttributeByName("text_sent_count"), (double) sents.size()); + attribute2value.put(getAttributeByName("text_par_count"), (double) pars.size()); + attribute2value.put(getAttributeByName("text_mention_count"), (double) helper.getMentions().size()); + attribute2value.put(getAttributeByName("text_cluster_count"), (double) helper.getClusters().size()); + + assert (attribute2value.size() == getAttributesList().size()); + } + addNormalizedAttributeValues(result); + + return result; + } + + private void addMentionAttributes(FeatureHelper helper, TMention mention, Map<Attribute, Double> attribute2value, String attributePrefix) { + // mention characteristics + attribute2value.put(getAttributeByName(attributePrefix + "_index"), (double) helper.getMentionIndex(mention)); + attribute2value.put(getAttributeByName(attributePrefix + "_index_in_sent"), (double) helper.getMentionIndexInSent(mention)); + attribute2value.put(getAttributeByName(attributePrefix + "_index_in_par"), (double) helper.getMentionIndexInPar(mention)); + attribute2value.put(getAttributeByName(attributePrefix + "_index_in_chain"), (double) helper.getMentionIndexInChain(mention)); + attribute2value.put(getAttributeByName(attributePrefix + "_token_count"), (double) mention.getChildIdsSize()); + attribute2value.put(getAttributeByName(attributePrefix + "_is_zero"), toBinary(mention.isZeroSubject())); + attribute2value.put(getAttributeByName(attributePrefix + "_is_pronoun"), toBinary(helper.getMentionHeadToken(mention).getChosenInterpretation().getCtag().matches("ppron.*"))); + attribute2value.put(getAttributeByName(attributePrefix + "_is_named"), toBinary(helper.isMentionNamedEntity(mention))); + + Interpretation interp = new Interpretation(helper.getMentionHeadToken(mention).getChosenInterpretation()); + addNominalAttributeValue(interp.getCtag(), attribute2value, attributePrefix + "_ctag"); + addNominalAttributeValue(interp.getPerson(), attribute2value, attributePrefix + "_person"); + addNominalAttributeValue(interp.getNumber(), attribute2value, attributePrefix + "_number"); + addNominalAttributeValue(interp.getGender(), attribute2value, attributePrefix + "_gender"); + addNominalAttributeValue(interp.getCase(), attribute2value, attributePrefix + "_case"); + + // relation to other mentions + attribute2value.put(getAttributeByName(attributePrefix + "_is_nested"), toBinary(helper.isNested(mention))); + attribute2value.put(getAttributeByName(attributePrefix + "_is_nesting"), toBinary(helper.isNesting(mention))); + + String orth = helper.getMentionOrth(mention); + attribute2value.put(getAttributeByName(attributePrefix + "_capitalized"), toBinary(orth.length() != 0 && orth.substring(0, 1).toUpperCase().equals(orth.substring(0, 1)))); + attribute2value.put(getAttributeByName(attributePrefix + "_all_caps"), toBinary(orth.toUpperCase().equals(orth))); + attribute2value.put(getAttributeByName(attributePrefix + "_char_count"), (double) orth.length()); + + // par characteristics + TParagraph mentionParagraph = helper.getMentionParagraph(mention); + attribute2value.put(getAttributeByName(attributePrefix + "_par_idx"), (double) helper.getParIndex(mentionParagraph)); + attribute2value.put(getAttributeByName(attributePrefix + "_par_token_count"), mentionParagraph.getSentences().stream().map(s -> s.getTokens().size()).mapToDouble(s -> s).sum()); + attribute2value.put(getAttributeByName(attributePrefix + "_par_sent_count"), (double) mentionParagraph.getSentences().size()); + + // sent characteristics + TSentence mentionSentence = helper.getMentionSentence(mention); + attribute2value.put(getAttributeByName(attributePrefix + "_sent_token_count"), (double) mentionSentence.getTokensSize()); + attribute2value.put(getAttributeByName(attributePrefix + "_sent_mention_count"), (double) mentionSentence.getMentions().size()); + attribute2value.put(getAttributeByName(attributePrefix + "_sent_idx"), (double) helper.getSentIndex(mentionSentence)); + attribute2value.put(getAttributeByName(attributePrefix + "_sent_idx_in_par"), (double) helper.getSentIndexInPar(mentionSentence)); + attribute2value.put(getAttributeByName(attributePrefix + "_sent_ends_with_dot"), toBinary(helper.getSentenceLastTokenOrth(mentionSentence).equals("."))); + attribute2value.put(getAttributeByName(attributePrefix + "_sent_ends_with_questionmark"), toBinary(helper.getSentenceLastTokenOrth(mentionSentence).equals("?"))); + + // frequent bases + String mentionBase = helper.getMentionBase(mention); + for (String base : frequentBases) { + attribute2value.put(getAttributeByName(attributePrefix + "_" + encodeBase(base)), toBinary(mentionBase.equals(base))); + } + } + + private void addNominalAttributeValue(String value, Map<Attribute, Double> attribute2value, String attributeName) { + Attribute att = getAttributeByName(attributeName); + int index = att.indexOfValue(value); + if (index == -1) + LOG.warn(value + " not found for attribute " + attributeName); + attribute2value.put(att, (double) (index == -1 ? att.indexOfValue("other") : index)); + } + + + private void addScoreFeature(Map<TMention, Map<Attribute, Double>> result, List<TMention> mentions) { + for (TMention m : mentions) { + Map<Attribute, Double> map = Maps.newHashMap(); + map.put(getAttributeByName("score"), weka.core.Utils.missingValue()); + result.put(m, map); + } + } + +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionModel.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionModel.java new file mode 100644 index 0000000..7e85be6 --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionModel.java @@ -0,0 +1,37 @@ +package pl.waw.ipipan.zil.summ.nicolas.mention; + +import com.google.common.collect.Sets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.Utils; +import weka.classifiers.Classifier; +import weka.core.Instance; +import weka.core.Instances; + +import java.util.Map; +import java.util.Set; + +public class MentionModel { + + private static final Logger LOG = LoggerFactory.getLogger(MentionModel.class); + + public static Set<TMention> detectGoodMentions(Classifier classifier, MentionFeatureExtractor featureExtractor, TText text) throws Exception { + Set<TMention> goodMentions = Sets.newHashSet(); + + Instances instances = Utils.createNewInstances(featureExtractor.getAttributesList()); + Map<TMention, Instance> mention2instance = Utils.extractInstancesFromMentions(text, featureExtractor); + for (Map.Entry<TMention, Instance> entry : mention2instance.entrySet()) { + Instance instance = entry.getValue(); + instance.setDataset(instances); + instance.setClassMissing(); + boolean good = classifier.classifyInstance(instance) > 0.5; + if (good) + goodMentions.add(entry.getKey()); + } + LOG.info("\t" + goodMentions.size() + "\t" + mention2instance.size()); + return goodMentions; + } + +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionScorer.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionScorer.java new file mode 100644 index 0000000..a16edec --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionScorer.java @@ -0,0 +1,59 @@ +package pl.waw.ipipan.zil.summ.nicolas.mention; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Maps; +import com.google.common.collect.Multiset; +import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; +import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.Utils; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class MentionScorer { + + + public Map<TMention, Double> calculateMentionScores(String optimalSummary, TText text) { + Multiset<String> tokenCounts = HashMultiset.create(Utils.tokenize(optimalSummary.toLowerCase())); + + List<TSentence> sentences = text.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(Collectors.toList()); + Map<TMention, String> mention2Orth = Utils.loadMention2Orth(sentences); + + return booleanTokenIntersection(mention2Orth, tokenCounts); + } + + private static Map<TMention, Double> booleanTokenIntersection(Map<TMention, String> mention2Orth, Multiset<String> tokenCounts) { + Map<TMention, Double> mention2score = Maps.newHashMap(); + for (Map.Entry<TMention, String> entry : mention2Orth.entrySet()) { + TMention mention = entry.getKey(); + String mentionOrth = mention2Orth.get(mention); + for (String token : Utils.tokenize(mentionOrth)) { + if (tokenCounts.contains(token.toLowerCase())) { + mention2score.put(mention, 1.0); + break; + } + } + mention2score.putIfAbsent(mention, 0.0); + } + return mention2score; + } + + private static Map<TMention, Double> booleanTokenInclusion(Map<TMention, String> mention2Orth, Multiset<String> tokenCounts) { + Map<TMention, Double> mention2score = Maps.newHashMap(); + for (Map.Entry<TMention, String> entry : mention2Orth.entrySet()) { + TMention mention = entry.getKey(); + String mentionOrth = mention2Orth.get(mention); + int present = 0; + for (String token : Utils.tokenize(mentionOrth)) { + if (tokenCounts.contains(token.toLowerCase())) { + present++; + } + } + mention2score.putIfAbsent(mention, ((present * 2) >= Utils.tokenize(mentionOrth).size()) ? 1.0 : 0.0); + } + return mention2score; + } +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/PrepareTrainingData.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/PrepareTrainingData.java new file mode 100644 index 0000000..7c84f89 --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/PrepareTrainingData.java @@ -0,0 +1,78 @@ +package pl.waw.ipipan.zil.summ.nicolas.mention; + +import com.google.common.base.Charsets; +import com.google.common.collect.Maps; +import com.google.common.io.Files; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.Constants; +import pl.waw.ipipan.zil.summ.nicolas.Utils; +import weka.core.Instance; +import weka.core.Instances; +import weka.core.converters.ArffSaver; + +import java.io.File; +import java.io.IOException; +import java.util.Map; + + +public class PrepareTrainingData { + + private static final Logger LOG = LogManager.getLogger(PrepareTrainingData.class); + + public static final String PREPROCESSED_FULL_TEXTS_DIR_PATH = "src/main/resources/preprocessed_full_texts/dev"; + public static final String OPTIMAL_SUMMARIES_DIR_PATH = "src/main/resources/optimal_summaries/dev"; + + public static void main(String[] args) throws IOException { + + Map<String, TText> id2preprocessedText = Utils.loadPreprocessedTexts(PREPROCESSED_FULL_TEXTS_DIR_PATH); + Map<String, String> id2optimalSummary = loadOptimalSummaries(); + + MentionScorer mentionScorer = new MentionScorer(); + MentionFeatureExtractor featureExtractor = new MentionFeatureExtractor(); + + Instances instances = Utils.createNewInstances(featureExtractor.getAttributesList()); + + int i = 1; + for (String textId : id2preprocessedText.keySet()) { + LOG.info(i++ + "/" + id2preprocessedText.size()); + + TText preprocessedText = id2preprocessedText.get(textId); + String optimalSummary = id2optimalSummary.get(textId); + if (optimalSummary == null) + continue; + Map<TMention, Double> mention2score = mentionScorer.calculateMentionScores(optimalSummary, preprocessedText); + + Map<TMention, Instance> mention2instance = Utils.extractInstancesFromMentions(preprocessedText, featureExtractor); + for (Map.Entry<TMention, Instance> entry : mention2instance.entrySet()) { + TMention mention = entry.getKey(); + Instance instance = entry.getValue(); + instance.setDataset(instances); + instance.setClassValue(mention2score.get(mention)); + instances.add(instance); + } + } + saveInstancesToFile(instances); + } + + private static void saveInstancesToFile(Instances instances) throws IOException { + ArffSaver saver = new ArffSaver(); + saver.setInstances(instances); + saver.setFile(new File(Constants.MENTIONS_DATASET_PATH)); + saver.writeBatch(); + } + + private static Map<String, String> loadOptimalSummaries() throws IOException { + Map<String, String> id2optimalSummary = Maps.newHashMap(); + for (File optimalSummaryFile : new File(OPTIMAL_SUMMARIES_DIR_PATH).listFiles()) { + String optimalSummary = Files.toString(optimalSummaryFile, Charsets.UTF_8); + id2optimalSummary.put(optimalSummaryFile.getName().split("_")[0], optimalSummary); + } + LOG.info(id2optimalSummary.size() + " optimal summaries found."); + return id2optimalSummary; + } + + +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/TrainModel.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/TrainModel.java new file mode 100644 index 0000000..e26b543 --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/TrainModel.java @@ -0,0 +1,47 @@ +package pl.waw.ipipan.zil.summ.nicolas.mention; + +import org.apache.commons.lang3.time.StopWatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pl.waw.ipipan.zil.summ.nicolas.Constants; +import weka.classifiers.Classifier; +import weka.core.Instances; +import weka.core.converters.ArffLoader; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.ObjectOutputStream; + + +public class TrainModel { + private static final Logger LOG = LoggerFactory.getLogger(TrainModel.class); + + public static void main(String[] args) throws Exception { + + ArffLoader loader = new ArffLoader(); + loader.setFile(new File(Constants.MENTIONS_DATASET_PATH)); + Instances instances = loader.getDataSet(); + instances.setClassIndex(0); + LOG.info(instances.size() + " instances loaded."); + LOG.info(instances.numAttributes() + " attributes for each instance."); + + StopWatch watch = new StopWatch(); + watch.start(); + + Classifier classifier = Constants.getClassifier(); + + LOG.info("Building classifier..."); + classifier.buildClassifier(instances); + LOG.info("...done."); + + try (ObjectOutputStream oos = new ObjectOutputStream( + new FileOutputStream(Constants.MENTIONS_MODEL_PATH))) { + oos.writeObject(classifier); + } + + watch.stop(); + LOG.info("Elapsed time: " + watch); + + LOG.info(classifier.toString()); + } +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java new file mode 100644 index 0000000..db2147d --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java @@ -0,0 +1,44 @@ +package pl.waw.ipipan.zil.summ.nicolas.mention.test; + +import org.apache.commons.lang3.time.StopWatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pl.waw.ipipan.zil.summ.nicolas.Constants; +import weka.classifiers.Classifier; +import weka.classifiers.evaluation.Evaluation; +import weka.core.Instances; +import weka.core.converters.ArffLoader; + +import java.io.File; +import java.util.Random; + + +public class Crossvalidate { + + private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); + + public static void main(String[] args) throws Exception { + + ArffLoader loader = new ArffLoader(); + loader.setFile(new File(Constants.MENTIONS_DATASET_PATH)); + Instances instances = loader.getDataSet(); + instances.setClassIndex(0); + LOG.info(instances.size() + " instances loaded."); + LOG.info(instances.numAttributes() + " attributes for each instance."); + +// while (instances.size() > 10000) +// instances.remove(instances.size() - 1); + + StopWatch watch = new StopWatch(); + watch.start(); + + Classifier tree = Constants.getClassifier(); + + Evaluation eval = new Evaluation(instances); + eval.crossValidateModel(tree, instances, 10, new Random(1)); + LOG.info(eval.toSummaryString()); + + watch.stop(); + LOG.info("Elapsed time: " + watch); + } +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java new file mode 100644 index 0000000..0fc9685 --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java @@ -0,0 +1,54 @@ +package pl.waw.ipipan.zil.summ.nicolas.mention.test; + +import org.apache.commons.lang3.time.StopWatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pl.waw.ipipan.zil.summ.nicolas.Constants; +import weka.classifiers.Classifier; +import weka.classifiers.evaluation.Evaluation; +import weka.core.Instances; +import weka.core.converters.ArffLoader; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.ObjectInputStream; + +/** + * Created by me2 on 05.04.16. + */ +public class Validate { + private static final Logger LOG = LoggerFactory.getLogger(Validate.class); + + public static void main(String[] args) throws Exception { + + ArffLoader loader = new ArffLoader(); + loader.setFile(new File(Constants.MENTIONS_DATASET_PATH)); + Instances instances = loader.getDataSet(); + instances.setClassIndex(0); + LOG.info(instances.size() + " instances loaded."); + LOG.info(instances.numAttributes() + " attributes for each instance."); + + Classifier classifier = loadClassifier(); + + StopWatch watch = new StopWatch(); + watch.start(); + + Evaluation eval = new Evaluation(instances); + eval.evaluateModel(classifier, instances); + + LOG.info(eval.toSummaryString()); + + watch.stop(); + LOG.info("Elapsed time: " + watch); + } + + private static Classifier loadClassifier() throws IOException, ClassNotFoundException { + LOG.info("Loading classifier..."); + try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(Constants.MENTIONS_MODEL_PATH))) { + Classifier classifier = (Classifier) ois.readObject(); + LOG.info("Done. " + classifier.toString()); + return classifier; + } + } +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/PrepareTrainingData.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/PrepareTrainingData.java new file mode 100644 index 0000000..b7e2219 --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/PrepareTrainingData.java @@ -0,0 +1,91 @@ +package pl.waw.ipipan.zil.summ.nicolas.sentence; + +import com.google.common.base.Charsets; +import com.google.common.collect.Maps; +import com.google.common.io.Files; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; +import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.Constants; +import pl.waw.ipipan.zil.summ.nicolas.Utils; +import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.mention.MentionModel; +import weka.classifiers.Classifier; +import weka.core.Instance; +import weka.core.Instances; +import weka.core.converters.ArffSaver; + +import java.io.File; +import java.io.IOException; +import java.util.Map; +import java.util.Set; + + +public class PrepareTrainingData { + + private static final Logger LOG = LogManager.getLogger(PrepareTrainingData.class); + + private static final String PREPROCESSED_FULL_TEXTS_DIR_PATH = "src/main/resources/preprocessed_full_texts/dev"; + private static final String OPTIMAL_SUMMARIES_DIR_PATH = "src/main/resources/optimal_summaries/dev"; + + public static void main(String[] args) throws Exception { + + Map<String, TText> id2preprocessedText = Utils.loadPreprocessedTexts(PREPROCESSED_FULL_TEXTS_DIR_PATH); + Map<String, String> id2optimalSummary = loadOptimalSummaries(); + + SentenceScorer sentenceScorer = new SentenceScorer(); + SentenceFeatureExtractor featureExtractor = new SentenceFeatureExtractor(); + + Instances instances = Utils.createNewInstances(featureExtractor.getAttributesList()); + + Classifier classifier = Utils.loadClassifier(Constants.MENTIONS_MODEL_PATH); + MentionFeatureExtractor mentionFeatureExtractor = new MentionFeatureExtractor(); + + int i = 1; + for (String textId : id2preprocessedText.keySet()) { + LOG.info(i++ + "/" + id2preprocessedText.size()); + + TText preprocessedText = id2preprocessedText.get(textId); + String optimalSummary = id2optimalSummary.get(textId); + if (optimalSummary == null) + continue; + Map<TSentence, Double> sentence2score = sentenceScorer.calculateSentenceScores(optimalSummary, preprocessedText); + + Set<TMention> goodMentions + = MentionModel.detectGoodMentions(classifier, mentionFeatureExtractor, preprocessedText); +// Set<TMention> goodMentions +// = Utils.loadGoldGoodMentions(textId, preprocessedText, true); + + Map<TSentence, Instance> sentence2instance = Utils.extractInstancesFromSentences(preprocessedText, featureExtractor, goodMentions); + for (Map.Entry<TSentence, Instance> entry : sentence2instance.entrySet()) { + TSentence sentence = entry.getKey(); + Instance instance = entry.getValue(); + instance.setDataset(instances); + instance.setClassValue(sentence2score.get(sentence)); + instances.add(instance); + } + } + saveInstancesToFile(instances); + } + + private static void saveInstancesToFile(Instances instances) throws IOException { + ArffSaver saver = new ArffSaver(); + saver.setInstances(instances); + saver.setFile(new File(Constants.SENTENCES_DATASET_PATH)); + saver.writeBatch(); + } + + private static Map<String, String> loadOptimalSummaries() throws IOException { + Map<String, String> id2optimalSummary = Maps.newHashMap(); + for (File optimalSummaryFile : new File(OPTIMAL_SUMMARIES_DIR_PATH).listFiles()) { + String optimalSummary = Files.toString(optimalSummaryFile, Charsets.UTF_8); + id2optimalSummary.put(optimalSummaryFile.getName().split("_")[0], optimalSummary); + } + LOG.info(id2optimalSummary.size() + " optimal summaries found."); + return id2optimalSummary; + } + + +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java new file mode 100644 index 0000000..ce045af --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java @@ -0,0 +1,103 @@ +package pl.waw.ipipan.zil.summ.nicolas.sentence; + +import com.google.common.collect.Maps; +import pl.waw.ipipan.zil.multiservice.thrift.types.*; +import pl.waw.ipipan.zil.summ.nicolas.features.FeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.features.FeatureHelper; +import weka.core.Attribute; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +public class SentenceFeatureExtractor extends FeatureExtractor { + + public SentenceFeatureExtractor() { + + addNumericAttributeNormalized("sent_mention_cluster_count"); + addNumericAttributeNormalized("sent_good_mention_cluster_count"); + addNumericAttributeNormalized("sent_good_mention_cluster_good_count"); + addNumericAttributeNormalized("sent_cluster_count"); + addNumericAttributeNormalized("sent_good_cluster_count"); + addNumericAttributeNormalized("sent_mention_count"); + addNumericAttributeNormalized("sent_good_mention_count"); + + addNumericAttributeNormalized("sent_token_length"); + addNumericAttributeNormalized("sent_idx"); + addNumericAttributeNormalized("sent_idx_in_par"); + addBinaryAttribute("sent_ends_with_dot"); + addBinaryAttribute("sent_ends_with_questionmark"); + + addNumericAttributeNormalized("par_idx"); + addNumericAttributeNormalized("par_token_count"); + addNumericAttributeNormalized("par_sent_count"); + + addNumericAttribute("text_token_count"); + addNumericAttribute("text_sent_count"); + addNumericAttribute("text_par_count"); + addNumericAttribute("text_mention_count"); + addNumericAttribute("text_cluster_count"); + + addNumericAttribute("score"); + fillSortedAttributes("score"); + } + + public Map<TSentence, Map<Attribute, Double>> calculateFeatures(TText preprocessedText, Set<TMention> goodMentions) { + + int sentenceIdx = 0; + int parIdx = 0; + + FeatureHelper helper = new FeatureHelper(preprocessedText); + List<TParagraph> pars = preprocessedText.getParagraphs(); + List<TSentence> sents = pars.stream().flatMap(p -> p.getSentences().stream()).collect(Collectors.toList()); + List<TToken> tokens = sents.stream().flatMap(s -> s.getTokens().stream()).collect(Collectors.toList()); + + Map<TSentence, Map<Attribute, Double>> sentence2features = Maps.newLinkedHashMap(); + for (TParagraph paragraph : preprocessedText.getParagraphs()) { + int sentenceIdxInPar = 0; + for (TSentence sentence : paragraph.getSentences()) { + Map<Attribute, Double> feature2value = Maps.newHashMap(); + + feature2value.put(getAttributeByName("sent_mention_cluster_count"), sentence.getMentions().stream().mapToDouble(helper::getChainLength).sum()); + feature2value.put(getAttributeByName("sent_good_mention_cluster_count"), sentence.getMentions().stream().filter(goodMentions::contains).mapToDouble(helper::getChainLength).sum()); + feature2value.put(getAttributeByName("sent_good_mention_cluster_good_count"), (double) sentence.getMentions().stream().filter(goodMentions::contains).flatMap(m -> helper.getCoreferentMentions(m).stream()).filter(goodMentions::contains).count()); + feature2value.put(getAttributeByName("sent_cluster_count"), (double) sentence.getMentions().stream().map(helper::getMentionCluster).collect(Collectors.toSet()).size()); + feature2value.put(getAttributeByName("sent_good_cluster_count"), (double) sentence.getMentions().stream().filter(goodMentions::contains).map(helper::getMentionCluster).collect(Collectors.toSet()).size()); + feature2value.put(getAttributeByName("sent_mention_count"), (double) sentence.getMentions().size()); + feature2value.put(getAttributeByName("sent_good_mention_count"), (double) sentence.getMentions().stream().filter(goodMentions::contains).count()); + + feature2value.put(getAttributeByName("sent_token_length"), (double) sentence.getTokens().size()); + feature2value.put(getAttributeByName("sent_idx_in_par"), (double) sentenceIdxInPar); + feature2value.put(getAttributeByName("sent_idx"), (double) sentenceIdx); + feature2value.put(getAttributeByName("sent_ends_with_dot"), toBinary(helper.getSentenceLastTokenOrth(sentence).equals("."))); + feature2value.put(getAttributeByName("sent_ends_with_questionmark"), toBinary(helper.getSentenceLastTokenOrth(sentence).equals("?"))); + + feature2value.put(getAttributeByName("par_idx"), (double) parIdx); + feature2value.put(getAttributeByName("par_token_count"), paragraph.getSentences().stream().map(s -> s.getTokens().size()).mapToDouble(s -> s).sum()); + feature2value.put(getAttributeByName("par_sent_count"), (double) paragraph.getSentences().size()); + + feature2value.put(getAttributeByName("text_char_count"), tokens.stream().mapToDouble(t -> t.getOrth().length()).sum()); + feature2value.put(getAttributeByName("text_token_count"), (double) tokens.size()); + feature2value.put(getAttributeByName("text_sent_count"), (double) sents.size()); + feature2value.put(getAttributeByName("text_par_count"), (double) pars.size()); + feature2value.put(getAttributeByName("text_mention_count"), (double) helper.getMentions().size()); + feature2value.put(getAttributeByName("text_cluster_count"), (double) helper.getClusters().size()); + + feature2value.put(getAttributeByName("score"), weka.core.Utils.missingValue()); + + feature2value.remove(null); + assert (feature2value.size() == getAttributesList().size()); + + sentence2features.put(sentence, feature2value); + + sentenceIdx++; + sentenceIdxInPar++; + } + parIdx++; + } + addNormalizedAttributeValues(sentence2features); + + return sentence2features; + } +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceScorer.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceScorer.java new file mode 100644 index 0000000..f96ea34 --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceScorer.java @@ -0,0 +1,32 @@ +package pl.waw.ipipan.zil.summ.nicolas.sentence; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Maps; +import com.google.common.collect.Multiset; +import pl.waw.ipipan.zil.multiservice.thrift.types.TParagraph; +import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.Utils; + +import java.util.List; +import java.util.Map; + +public class SentenceScorer { + public Map<TSentence, Double> calculateSentenceScores(String optimalSummary, TText preprocessedText) { + Multiset<String> tokenCounts = HashMultiset.create(Utils.tokenize(optimalSummary.toLowerCase())); + + Map<TSentence, Double> sentence2score = Maps.newHashMap(); + for (TParagraph paragraph : preprocessedText.getParagraphs()) + for (TSentence sentence : paragraph.getSentences()) { + double score = 0.0; + + String orth = Utils.loadSentence2Orth(sentence); + List<String> tokens = Utils.tokenize(orth); + for (String token : tokens) { + score += tokenCounts.contains(token.toLowerCase()) ? 1.0 : 0.0; + } + sentence2score.put(sentence, score / tokens.size()); + } + return sentence2score; + } +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/TrainModel.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/TrainModel.java new file mode 100644 index 0000000..71a4dec --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/TrainModel.java @@ -0,0 +1,47 @@ +package pl.waw.ipipan.zil.summ.nicolas.sentence; + +import org.apache.commons.lang3.time.StopWatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pl.waw.ipipan.zil.summ.nicolas.Constants; +import weka.classifiers.Classifier; +import weka.core.Instances; +import weka.core.converters.ArffLoader; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.ObjectOutputStream; + + +public class TrainModel { + private static final Logger LOG = LoggerFactory.getLogger(TrainModel.class); + + public static void main(String[] args) throws Exception { + + ArffLoader loader = new ArffLoader(); + loader.setFile(new File(Constants.SENTENCES_DATASET_PATH)); + Instances instances = loader.getDataSet(); + instances.setClassIndex(0); + LOG.info(instances.size() + " instances loaded."); + LOG.info(instances.numAttributes() + " attributes for each instance."); + + StopWatch watch = new StopWatch(); + watch.start(); + + Classifier classifier = Constants.getSentencesClassifier(); + + LOG.info("Building classifier..."); + classifier.buildClassifier(instances); + LOG.info("...done."); + + try (ObjectOutputStream oos = new ObjectOutputStream( + new FileOutputStream(Constants.SENTENCES_MODEL_PATH))) { + oos.writeObject(classifier); + } + + watch.stop(); + LOG.info("Elapsed time: " + watch); + + LOG.info(classifier.toString()); + } +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java new file mode 100644 index 0000000..a46f64e --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java @@ -0,0 +1,41 @@ +package pl.waw.ipipan.zil.summ.nicolas.sentence.test; + +import org.apache.commons.lang3.time.StopWatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pl.waw.ipipan.zil.summ.nicolas.Constants; +import weka.classifiers.Classifier; +import weka.classifiers.evaluation.Evaluation; +import weka.core.Instances; +import weka.core.converters.ArffLoader; + +import java.io.File; +import java.util.Random; + + +public class Crossvalidate { + + private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); + + public static void main(String[] args) throws Exception { + + ArffLoader loader = new ArffLoader(); + loader.setFile(new File(Constants.SENTENCES_DATASET_PATH)); + Instances instances = loader.getDataSet(); + instances.setClassIndex(0); + LOG.info(instances.size() + " instances loaded."); + LOG.info(instances.numAttributes() + " attributes for each instance."); + + StopWatch watch = new StopWatch(); + watch.start(); + + Classifier tree = Constants.getSentencesClassifier(); + + Evaluation eval = new Evaluation(instances); + eval.crossValidateModel(tree, instances, 10, new Random(1)); + LOG.info(eval.toSummaryString()); + + watch.stop(); + LOG.info("Elapsed time: " + watch); + } +} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/Zero.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/Zero.java new file mode 100644 index 0000000..cb3c5ef --- /dev/null +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/Zero.java @@ -0,0 +1,128 @@ +package pl.waw.ipipan.zil.summ.nicolas.zero; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVPrinter; +import org.apache.commons.csv.QuoteMode; +import org.apache.commons.io.IOUtils; +import pl.waw.ipipan.zil.multiservice.thrift.types.*; +import pl.waw.ipipan.zil.summ.nicolas.Utils; + +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Created by me2 on 26.07.16. + */ +public class Zero { + + private static final String IDS_PATH = "summaries_dev"; + private static final String THRIFTED_PATH = "src/main/resources/preprocessed_full_texts/dev/"; + + public static void main(String[] args) throws IOException { + + Map<String, TText> id2preprocessedText = Utils.loadPreprocessedTexts(THRIFTED_PATH); + Map<String, List<String>> id2sentIds = loadSentenceIds(IDS_PATH); + + int mentionCount = 0; + int mentionInNom = 0; + int mentionInNomSequential = 0; + + List<List<Object>> rows = Lists.newArrayList(); + for (Map.Entry<String, TText> entry : id2preprocessedText.entrySet()) { + String textId = entry.getKey(); +// System.out.println(id); + + TText text = entry.getValue(); + List<String> sentenceIds = id2sentIds.get(textId); +// System.out.println(sentenceIds); + + Map<String, Set<String>> mentionId2Cluster = Maps.newHashMap(); + for (TCoreference coreference : text.getCoreferences()) { + for (String mentionId : coreference.getMentionIds()) { + mentionId2Cluster.put(mentionId, Sets.newHashSet(coreference.getMentionIds())); + } + } + + Set<String> prevSentenceNominativeMentionIds = Sets.newHashSet(); + TSentence prevSentence = null; + for (TParagraph p : text.getParagraphs()) { + Map<TMention, String> tMentionStringMap = Utils.loadMention2Orth(p.getSentences()); + + for (TSentence sentence : p.getSentences()) { + if (!sentenceIds.contains(sentence.getId())) + continue; + Set<String> currentSentenceNominativeMentionIds = Sets.newHashSet(); + + Map<String, TToken> tokenId2Token = Maps.newHashMap(); + for (TToken t : sentence.getTokens()) + tokenId2Token.put(t.getId(), t); + + for (TMention mention : sentence.getMentions()) { + mentionCount++; + + for (String tokenId : mention.getHeadIds()) { + TInterpretation interp = tokenId2Token.get(tokenId).getChosenInterpretation(); + if (isInNominative(interp)) { + mentionInNom++; + + currentSentenceNominativeMentionIds.add(mention.getId()); + if (mentionId2Cluster.get(mention.getId()).stream().anyMatch(prevSentenceNominativeMentionIds::contains)) { + mentionInNomSequential++; + System.out.println(tMentionStringMap.get(mention) + + "\n\t" + Utils.loadSentence2Orth(prevSentence) + + "\n\t" + Utils.loadSentence2Orth(sentence)); + + List<Object> row = Lists.newArrayList(); + row.add("C"); + row.add(textId); + row.add(tMentionStringMap.get(mention)); + row.add(Utils.loadSentence2Orth(prevSentence)); + row.add(Utils.loadSentence2Orth(sentence)); + rows.add(row); + } + break; + } + } + } + + prevSentence = sentence; + prevSentenceNominativeMentionIds = currentSentenceNominativeMentionIds; + } + } + } + + System.out.println(mentionCount + " mentions"); + System.out.println(mentionInNom + " mention in nom"); + System.out.println(mentionInNomSequential + " mention in nom with previous in nom"); + + try (CSVPrinter csvPrinter = new CSVPrinter(new FileWriter("zeros.tsv"), CSVFormat.DEFAULT.withDelimiter('\t').withEscape('\\').withQuoteMode(QuoteMode.NONE).withQuote('"'))) { + for (List<Object> row : rows) { + csvPrinter.printRecord(row); + } + } + + } + + private static boolean isInNominative(TInterpretation interp) { + return interp.getCtag().equals("subst") && Arrays.stream(interp.getMsd().split(":")).anyMatch(t -> t.equals("nom")); + } + + private static Map<String, List<String>> loadSentenceIds(String idsPath) throws IOException { + Map<String, List<String>> result = Maps.newHashMap(); + for (File f : new File(idsPath).listFiles()) { + String id = f.getName().split("_")[0]; + List<String> sentenceIds = IOUtils.readLines(new FileReader(f)); + result.put(id, sentenceIds); + } + return result; + } +} diff --git a/nicolas-model/pom.xml b/nicolas-model/pom.xml new file mode 100644 index 0000000..6d7c3ac --- /dev/null +++ b/nicolas-model/pom.xml @@ -0,0 +1,14 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <artifactId>nicolas-container</artifactId> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <version>1.0-SNAPSHOT</version> + </parent> + + <artifactId>nicolas-model</artifactId> + +</project> \ No newline at end of file diff --git a/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/frequent_bases.txt b/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/frequent_bases.txt new file mode 100644 index 0000000..973881a --- /dev/null +++ b/nicolas-model/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/frequent_bases.txt @@ -0,0 +1,237 @@ +on +to +co +rok +być +wszystko +polska +człowiek +sobie +raz +my +mieć +czas +państwo +praca +osoba +sprawa +ja +kraj +pieniądz +nikt +kto +przykład +nic +koniec +rząd +prawo +życie +miejsce +móc +fot +problem +władza +miesiąc +rzecz +stan +świat +wszyscy +mówić +rozmowa +coś +sytuacja +powód +początek +wiedzieć +dzień +uwaga +strona +udział +in +musieć +polityk +ktoś +ogół +polityka +chcieć +walka +zmiana +decyzja +ciąg +m . +pan +szansa +polak +przypadek +większość +pytanie +wzgląd +warszawa +proca +pomoc +prezydent +społeczeństwo +wynik +dziecko +prawda +związek +gospodarka +część +wojna +tydzień +granica +głos +przyszłość +autor +wybory +rynek +cel +ustawa +uważać +ten rok +droga +dom +rys +myśleć +firma +zasada +fakt +kolej +nadzieja +dolar +wraz +miasto +rozwój +ten sposób +europa +temat +siła +rodzina +minister +historia +wpływ +współpraca +środek +informacja +procent +wniosek +unia europejski +niemcy +podstawa +reforma +partia +interes +ten sprawa +kandydat +sukces +sposób +wątpliwość +złoty +sld +pracownik +stanowisko +dyskusja +telewizja +pewność +odpowiedź +rzeczywistość +program +cena +działanie +system +unia +ręka +odpowiedzialność +środowisko +solidarność +demokracja +maić +ramy +badanie +media +wartość +wybór +głowa +zostać +usa +pracować +porozumienie +widzieć +zdanie +akcja +wolność +spotkanie +przeszłość +stosunek +okazja +prowadzić +zachód +kobieta +obywatel +sąd +ubiegły rok +dziennikarz +kultura +grupa +opinia publiczny +obrona +bezpieczeństwo +opinia +rzeczpospolita +dokument +racja +szkoła +góra +warunek +organizacja +oko +godzina +tysiąc +ten czas +możliwość +błąd +ziemia +parlament +ten pora +chwila +naród +konflikt +działalność +sejm +powrót +premier +działać +rada +zdrowie +wiek +dodatek +poziom +widzenie +żyć +powiedzieć +inwestycja +rosja +niemiec +samochód +skutek +punkt +rola +mieszkaniec +wyborca +koszt +budżet +szef +styczeń +instytucja +pełnia +ulica +aws +ochrona +dostęp +zagrożenie +zgoda +ue +" rzeczpospolita " +liczba +wieś +połowa \ No newline at end of file diff --git a/nicolas-train/pom.xml b/nicolas-train/pom.xml new file mode 100644 index 0000000..0773393 --- /dev/null +++ b/nicolas-train/pom.xml @@ -0,0 +1,14 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <artifactId>nicolas-container</artifactId> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <version>1.0-SNAPSHOT</version> + </parent> + + <artifactId>nicolas-train</artifactId> + +</project> \ No newline at end of file diff --git a/nicolas-zero/pom.xml b/nicolas-zero/pom.xml new file mode 100644 index 0000000..26bf7dd --- /dev/null +++ b/nicolas-zero/pom.xml @@ -0,0 +1,14 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <artifactId>nicolas-container</artifactId> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <version>1.0-SNAPSHOT</version> + </parent> + + <artifactId>nicolas-zero</artifactId> + +</project> \ No newline at end of file diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..4af7327 --- /dev/null +++ b/pom.xml @@ -0,0 +1,101 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + + <groupId>pl.waw.ipipan.zil.summ</groupId> + <artifactId>nicolas-container</artifactId> + <packaging>pom</packaging> + <version>1.0-SNAPSHOT</version> + + <modules> + <module>nicolas-core</module> + <module>nicolas-cli</module> + <module>nicolas-model</module> + <module>nicolas-train</module> + <module>nicolas-zero</module> + </modules> + + <properties> + <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> + <java.version.build>1.8</java.version.build> + </properties> + + <prerequisites> + <maven>3.0.5</maven> + </prerequisites> + + <developers> + <developer> + <name>Mateusz Kopeć</name> + <organization>ICS PAS</organization> + <email>m.kopec@ipipan.waw.pl</email> + </developer> + </developers> + + <dependencies> + <dependency> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <artifactId>pscapi</artifactId> + <version>1.0-SNAPSHOT</version> + </dependency> + <dependency> + <groupId>pl.waw.ipipan.zil.multiservice</groupId> + <artifactId>utils</artifactId> + <version>1.0-SNAPSHOT</version> + </dependency> + + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-csv</artifactId> + <version>1.3</version> + </dependency> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + <version>19.0</version> + </dependency> + <dependency> + <groupId>nz.ac.waikato.cms.weka</groupId> + <artifactId>weka-dev</artifactId> + <version>3.9.0</version> + </dependency> + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-lang3</artifactId> + <version>3.4</version> + </dependency> + <dependency> + <groupId>commons-io</groupId> + <artifactId>commons-io</artifactId> + <version>2.5</version> + </dependency> + </dependencies> + + + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-compiler-plugin</artifactId> + <version>3.1</version> + <configuration> + <source>${java.version.build}</source> + <target>${java.version.build}</target> + </configuration> + </plugin> + </plugins> + </build> + + <distributionManagement> + <repository> + <id>deployment</id> + <url>http://maven.nlp.ipipan.waw.pl/content/repositories/releases/</url> + </repository> + <snapshotRepository> + <id>deployment</id> + <url>http://maven.nlp.ipipan.waw.pl/content/repositories/snapshots/</url> + </snapshotRepository> + </distributionManagement> +</project> \ No newline at end of file -- libgit2 0.22.2