From 91b27b24a9fb0d6427debc133c923e3188f9a768 Mon Sep 17 00:00:00 2001 From: Mateusz Kopeć <m.kopec@ipipan.waw.pl> Date: Mon, 7 Nov 2016 22:43:13 +0100 Subject: [PATCH] zeros corpus wip --- nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java | 46 ++++++++++++++++++++++++++++++++++++++++++++++ nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java | 12 ++++++++++-- nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/VersionIgnoringObjectInputStream.java | 33 +++++++++++++++++++++++++++++++++ nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureExtractor.java | 83 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java | 203 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/Interpretation.java | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-core/pom.xml | 4 ++++ nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java | 33 --------------------------------- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java | 1 + nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel2.java | 21 +++++++++++++-------- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureExtractor.java | 83 ----------------------------------------------------------------------------------- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureHelper.java | 185 ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/Interpretation.java | 77 ----------------------------------------------------------------------------- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java | 6 +++--- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionScorer.java | 2 +- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/PrepareTrainingData.java | 2 +- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/TrainModel.java | 4 ++-- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java | 4 ++-- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java | 2 +- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/PrepareTrainingData.java | 2 +- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java | 4 ++-- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceScorer.java | 1 + nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/TrainModel.java | 2 +- nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java | 2 +- nicolas-zero/pom.xml | 4 ++++ nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java | 5 ++++- nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/Zero.java | 76 ---------------------------------------------------------------------------- nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java | 46 ++++++++++++++++++++++++++++++++++++++++++++++ nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainModel.java | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainingDataExtractor.java | 97 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/ZeroScorer.java | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-zero/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java | 24 ++++++++---------------- nicolas-zero/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjectorTest.java | 11 ----------- nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin | Bin 0 -> 4103 bytes nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt | 2 ++ pom.xml | 5 +++++ 37 files changed, 822 insertions(+), 507 deletions(-) create mode 100644 nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java create mode 100644 nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/VersionIgnoringObjectInputStream.java create mode 100644 nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureExtractor.java create mode 100644 nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java create mode 100644 nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/Interpretation.java delete mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java delete mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureExtractor.java delete mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureHelper.java delete mode 100644 nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/Interpretation.java delete mode 100644 nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/Zero.java create mode 100644 nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java create mode 100644 nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainModel.java create mode 100644 nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainingDataExtractor.java create mode 100644 nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/ZeroScorer.java delete mode 100644 nicolas-zero/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjectorTest.java create mode 100644 nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin create mode 100644 nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java new file mode 100644 index 0000000..38d2989 --- /dev/null +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java @@ -0,0 +1,46 @@ +package pl.waw.ipipan.zil.summ.nicolas.common; + +import com.google.common.base.Charsets; +import weka.classifiers.Classifier; +import weka.classifiers.functions.Logistic; +import weka.classifiers.trees.RandomForest; + +import java.nio.charset.Charset; + + +public class Constants { + + public static final String MENTIONS_MODEL_PATH = "mentions_model.bin"; + public static final String SENTENCES_MODEL_PATH = "sentences_model.bin"; + public static final String ZERO_MODEL_PATH = "zeros_model.bin"; + + public static final String MENTIONS_DATASET_PATH = "mentions_train.arff"; + public static final String SENTENCES_DATASET_PATH = "sentences_train.arff"; + public static final String ZERO_DATASET_PATH = "zeros_train.arff"; + + public static final Charset ENCODING = Charsets.UTF_8; + + private Constants() { + } + + public static Classifier getMentionClassifier() { + RandomForest classifier = new RandomForest(); + classifier.setNumIterations(250); + classifier.setSeed(0); + classifier.setNumExecutionSlots(8); + return classifier; + } + + public static Classifier getSentencesClassifier() { + RandomForest classifier = new RandomForest(); + classifier.setNumIterations(250); + classifier.setSeed(0); + classifier.setNumExecutionSlots(8); + return classifier; + } + + public static Classifier getZerosClassifier() { + Logistic classifier = new Logistic(); + return classifier; + } +} diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java index b76153d..4c2b173 100644 --- a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java @@ -101,7 +101,7 @@ public class Utils { STOPWORDS.addAll(Lists.newArrayList("i", "się", "to", "co")); } - public static Map<TMention, String> loadMention2Orth(List<TSentence> sents) { + public static Map<TMention, String> loadMention2Orth(List<TSentence> sents, boolean discardStopwords) { Map<TMention, String> mention2orth = Maps.newHashMap(); for (TSentence s : sents) { Map<String, TToken> tokId2tok = s.getTokens().stream().collect(Collectors.toMap(TToken::getId, Function.identity())); @@ -110,7 +110,7 @@ public class Utils { StringBuffer mentionOrth = new StringBuffer(); for (String tokId : m.getChildIds()) { TToken token = tokId2tok.get(tokId); - if (STOPWORDS.contains(token.getChosenInterpretation().getBase().toLowerCase())) { + if (discardStopwords && STOPWORDS.contains(token.getChosenInterpretation().getBase().toLowerCase())) { continue; } @@ -142,8 +142,16 @@ public class Utils { } public static String loadSentence2Orth(TSentence sentence) { + return loadSentence2Orth(sentence, Sets.newHashSet()); + } + + public static String loadSentence2Orth(TSentence sentence, Set<String> tokenIdsToSkip) { StringBuilder sb = new StringBuilder(); for (TToken token : sentence.getTokens()) { + if (tokenIdsToSkip.contains(token.getId())) { + System.out.println("Skipping " + token.getOrth() + " in sentence: " + loadSentence2Orth(sentence)); + continue; + } if (!token.isNoPrecedingSpace()) sb.append(" "); sb.append(token.getOrth()); diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/VersionIgnoringObjectInputStream.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/VersionIgnoringObjectInputStream.java new file mode 100644 index 0000000..9d03cd8 --- /dev/null +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/VersionIgnoringObjectInputStream.java @@ -0,0 +1,33 @@ +package pl.waw.ipipan.zil.summ.nicolas.common; + +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectStreamClass; + + +public class VersionIgnoringObjectInputStream extends ObjectInputStream { + + public VersionIgnoringObjectInputStream(InputStream in) throws IOException { + super(in); + } + + protected ObjectStreamClass readClassDescriptor() throws IOException, ClassNotFoundException { + ObjectStreamClass resultClassDescriptor = super.readClassDescriptor(); // initially streams descriptor + Class localClass; // the class in the local JVM that this descriptor represents. + try { + localClass = Class.forName(resultClassDescriptor.getName()); + } catch (ClassNotFoundException e) { + return resultClassDescriptor; + } + ObjectStreamClass localClassDescriptor = ObjectStreamClass.lookup(localClass); + if (localClassDescriptor != null) { // only if class implements serializable + final long localSUID = localClassDescriptor.getSerialVersionUID(); + final long streamSUID = resultClassDescriptor.getSerialVersionUID(); + if (streamSUID != localSUID) { // check for serialVersionUID mismatch. + resultClassDescriptor = localClassDescriptor; // Use local class descriptor for deserialization + } + } + return resultClassDescriptor; + } +} diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureExtractor.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureExtractor.java new file mode 100644 index 0000000..3c80046 --- /dev/null +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureExtractor.java @@ -0,0 +1,83 @@ +package pl.waw.ipipan.zil.summ.nicolas.common.features; + +import com.google.common.collect.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import weka.core.Attribute; + +import java.util.*; + +public class FeatureExtractor { + + protected static final Logger LOG = LoggerFactory.getLogger(FeatureExtractor.class); + + private final List<Attribute> sortedAttributes = Lists.newArrayList(); + + private final BiMap<String, Attribute> name2attribute = HashBiMap.create(); + + private final Set<String> normalizedAttributes = Sets.newHashSet(); + + public ArrayList<Attribute> getAttributesList() { + return Lists.newArrayList(sortedAttributes); + } + + protected Attribute getAttributeByName(String name) { + return name2attribute.get(name); + } + + protected void addNumericAttribute(String attributeName) { + name2attribute.put(attributeName, new Attribute(attributeName)); + } + + protected void addBinaryAttribute(String attributeName) { + name2attribute.put(attributeName, new Attribute(attributeName, Lists.newArrayList("f", "t"))); + } + + protected void addNominalAttribute(String attributeName, List<String> values) { + name2attribute.put(attributeName, new Attribute(attributeName, values)); + } + + protected void addNumericAttributeNormalized(String attributeName) { + addNumericAttribute(attributeName); + addNumericAttribute(attributeName + "_normalized"); + normalizedAttributes.add(attributeName); + } + + protected void fillSortedAttributes(String scoreAttName) { + sortedAttributes.addAll(name2attribute.values()); + sortedAttributes.remove(getAttributeByName(scoreAttName)); + Collections.sort(sortedAttributes, (o1, o2) -> name2attribute.inverse().get(o1).compareTo(name2attribute.inverse().get(o2))); + sortedAttributes.add(0, getAttributeByName(scoreAttName)); + } + + protected <T> void addNormalizedAttributeValues(Map<T, Map<Attribute, Double>> entity2attributes) { + Map<Attribute, Double> attribute2max = Maps.newHashMap(); + Map<Attribute, Double> attribute2min = Maps.newHashMap(); + for (T entity : entity2attributes.keySet()) { + Map<Attribute, Double> entityAttributes = entity2attributes.get(entity); + for (String attributeName : normalizedAttributes) { + Attribute attribute = getAttributeByName(attributeName); + Double value = entityAttributes.get(attribute); + + attribute2max.putIfAbsent(attribute, Double.MIN_VALUE); + attribute2max.compute(attribute, (k, v) -> Math.max(v, value)); + + attribute2min.putIfAbsent(attribute, Double.MAX_VALUE); + attribute2min.compute(attribute, (k, v) -> Math.min(v, value)); + } + } + for (T mention : entity2attributes.keySet()) { + Map<Attribute, Double> entityAttributes = entity2attributes.get(mention); + for (Attribute attribute : attribute2max.keySet()) { + Attribute normalizedAttribute = getAttributeByName(name2attribute.inverse().get(attribute) + "_normalized"); + entityAttributes.put(normalizedAttribute, + (entityAttributes.get(attribute) - attribute2min.get(attribute)) + / (attribute2max.get(attribute) - attribute2min.get(attribute))); + } + } + } + + protected double toBinary(boolean bool) { + return bool ? 1.0 : 0.0; + } +} diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java new file mode 100644 index 0000000..23d1958 --- /dev/null +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java @@ -0,0 +1,203 @@ +package pl.waw.ipipan.zil.summ.nicolas.common.features; + +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import pl.waw.ipipan.zil.multiservice.thrift.types.*; +import pl.waw.ipipan.zil.summ.nicolas.common.Utils; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toMap; + + +public class FeatureHelper { + + private final TText text; + + private final List<TMention> mentions; + private final Map<String, TMention> mentionId2mention; + private final Map<TCoreference, List<TMention>> coref2mentions = Maps.newHashMap(); + private final Map<TMention, TCoreference> mention2coref = Maps.newHashMap(); + private final Map<TMention, TSentence> mention2sent = Maps.newHashMap(); + private final Map<TMention, TParagraph> mention2par = Maps.newHashMap(); + private final Map<TMention, String> mention2Orth = Maps.newHashMap(); + private final Map<TMention, String> mention2Base = Maps.newHashMap(); + private final Map<TMention, TToken> mention2head = Maps.newHashMap(); + private final Set<TMention> mentionsInNamedEntities = Sets.newHashSet(); + + private final Map<TMention, Integer> mention2Index = Maps.newHashMap(); + private final Map<TSentence, Integer> sent2Index = Maps.newHashMap(); + private final Map<TParagraph, Integer> par2Index = Maps.newHashMap(); + private final Map<TSentence, Integer> sent2IndexInPar = Maps.newHashMap(); + private final Map<TMention, Integer> mention2indexInPar = Maps.newHashMap(); + private final Map<TMention, Integer> mention2indexInSent = Maps.newHashMap(); + + + public FeatureHelper(TText preprocessedText) { + text = preprocessedText; + + mentions = preprocessedText.getParagraphs().stream() + .flatMap(p -> p.getSentences().stream()) + .flatMap(s -> s.getMentions().stream()).collect(Collectors.toList()); + + mentionId2mention = mentions.stream().collect(Collectors.toMap(TMention::getId, Function.identity())); + + for (TCoreference coref : preprocessedText.getCoreferences()) { + List<TMention> ments = coref.getMentionIds().stream().map(mentionId2mention::get).collect(toList()); + for (TMention m : ments) { + mention2coref.put(m, coref); + } + coref2mentions.put(coref, ments); + } + + int parIdx = 0; + int sentIdx = 0; + int mentionIdx = 0; + for (TParagraph par : preprocessedText.getParagraphs()) { + Map<TMention, String> m2o = Utils.loadMention2Orth(par.getSentences(), false); + mention2Orth.putAll(m2o); + Map<TMention, String> m2b = Utils.loadMention2Base(par.getSentences()); + mention2Base.putAll(m2b); + + int sentIdxInPar = 0; + int mentionIdxInPar = 0; + for (TSentence sent : par.getSentences()) { + + Map<String, TToken> tokenId2token = sent.getTokens().stream().collect(toMap(TToken::getId, Function.identity())); + + Map<String, Set<TNamedEntity>> tokenId2namedEntities = Maps.newHashMap(); + for (TNamedEntity namedEntity : sent.getNames()) { + for (String childId : namedEntity.getChildIds()) { + tokenId2namedEntities.putIfAbsent(childId, Sets.newHashSet()); + tokenId2namedEntities.get(childId).add(namedEntity); + } + } + + int mentionIdxInSent = 0; + for (TMention mention : sent.getMentions()) { + mention2sent.put(mention, sent); + mention2par.put(mention, par); + mention2Index.put(mention, mentionIdx++); + mention2indexInSent.put(mention, mentionIdxInSent++); + mention2indexInPar.put(mention, mentionIdxInPar++); + + String firstHeadTokenId = mention.getHeadIds().iterator().next(); + mention2head.put(mention, tokenId2token.get(firstHeadTokenId)); + if (tokenId2namedEntities.containsKey(firstHeadTokenId)) + mentionsInNamedEntities.add(mention); + } + sent2Index.put(sent, sentIdx++); + sent2IndexInPar.put(sent, sentIdxInPar++); + } + + par2Index.put(par, parIdx++); + } + } + + public List<TMention> getMentions() { + return mentions; + } + + public int getMentionIndexInChain(TMention mention) { + return coref2mentions.get(mention2coref.get(mention)).indexOf(mention); + } + + public int getChainLength(TMention mention) { + return coref2mentions.get(mention2coref.get(mention)).size(); + } + + public String getSentenceLastTokenOrth(TSentence sent) { + return sent.getTokens().get(sent.getTokensSize() - 1).getOrth(); + } + + public String getMentionOrth(TMention mention) { + return mention2Orth.get(mention); + } + + public String getMentionBase(TMention mention) { + return mention2Base.get(mention); + } + + public int getMentionIndex(TMention mention) { + return mention2Index.get(mention); + } + + public int getMentionIndexInSent(TMention mention) { + return mention2indexInSent.get(mention); + } + + public int getMentionIndexInPar(TMention mention) { + return mention2indexInPar.get(mention); + } + + public int getParIndex(TParagraph paragraph) { + return par2Index.get(paragraph); + } + + public int getSentIndex(TSentence sent) { + return sent2Index.get(sent); + } + + public int getSentIndexInPar(TSentence sent) { + return sent2IndexInPar.get(sent); + } + + public TParagraph getMentionParagraph(TMention mention) { + return mention2par.get(mention); + } + + public TSentence getMentionSentence(TMention mention) { + return mention2sent.get(mention); + } + + public TMention getFirstChainMention(TMention mention) { + return mentionId2mention.get(mention2coref.get(mention).getMentionIdsIterator().next()); + } + + public TToken getMentionHeadToken(TMention mention) { + return mention2head.get(mention); + } + + public boolean isMentionNamedEntity(TMention mention) { + return mentionsInNamedEntities.contains(mention); + } + + public boolean isNested(TMention mention) { + return mentions.stream().anyMatch(m -> m.getChildIds().containsAll(mention.getChildIds())); + } + + public boolean isNesting(TMention mention) { + return mentions.stream().anyMatch(m -> mention.getChildIds().containsAll(m.getChildIds())); + } + + public Set<TCoreference> getClusters() { + return coref2mentions.keySet(); + } + + public Set<TMention> getCoreferentMentions(TMention tMention) { + return getMentionCluster(tMention).getMentionIds().stream().map(this.mentionId2mention::get).collect(Collectors.toSet()); + } + + public TCoreference getMentionCluster(TMention tMention) { + return this.mention2coref.get(tMention); + } + + public String getSentenceOrth(TSentence sentence) { + StringBuilder sb = new StringBuilder(); + for (TToken token : sentence.getTokens()) { + if (!token.isNoPrecedingSpace()) + sb.append(" "); + sb.append(token.getOrth()); + } + return sb.toString().trim(); + } + + public TText getText() { + return text; + } +} diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/Interpretation.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/Interpretation.java new file mode 100644 index 0000000..3ed81d8 --- /dev/null +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/Interpretation.java @@ -0,0 +1,77 @@ +package pl.waw.ipipan.zil.summ.nicolas.common.features; + +import pl.waw.ipipan.zil.multiservice.thrift.types.TInterpretation; + + +public class Interpretation { + private String ctag = "null"; + private String casee = "null"; + private String gender = "null"; + private String number = "null"; + private String person = "null"; + + public Interpretation(TInterpretation chosenInterpretation) { + ctag = chosenInterpretation.getCtag(); + String[] split = chosenInterpretation.getMsd().split(":"); + switch (ctag) { + case "ger": + case "subst": + case "pact": + case "ppas": + case "num": + case "numcol": + case "adj": + number = split[0]; + casee = split[1]; + gender = split[2]; + break; + case "ppron12": + case "ppron3": + number = split[0]; + casee = split[1]; + gender = split[2]; + person = split[3]; + break; + case "siebie": + casee = split[0]; + break; + case "fin": + case "bedzie": + case "aglt": + case "impt": + number = split[0]; + person = split[1]; + break; + case "praet": + case "winien": + number = split[0]; + gender = split[1]; + break; + case "prep": + casee = split[0]; + break; + default: + break; + } + } + + public String getCase() { + return casee; + } + + public String getGender() { + return gender; + } + + public String getNumber() { + return number; + } + + public String getPerson() { + return person; + } + + public String getCtag() { + return ctag; + } +} diff --git a/nicolas-core/pom.xml b/nicolas-core/pom.xml index 0047276..b291cda 100644 --- a/nicolas-core/pom.xml +++ b/nicolas-core/pom.xml @@ -21,6 +21,10 @@ <groupId>pl.waw.ipipan.zil.summ</groupId> <artifactId>nicolas-model</artifactId> </dependency> + <dependency> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <artifactId>nicolas-zero</artifactId> + </dependency> <dependency> <groupId>pl.waw.ipipan.zil.summ</groupId> diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java deleted file mode 100644 index f4a6ed6..0000000 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Constants.java +++ /dev/null @@ -1,33 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas; - -import weka.classifiers.Classifier; -import weka.classifiers.trees.RandomForest; - - -public class Constants { - - public static final String MENTIONS_MODEL_PATH = "mentions_model.bin"; - public static final String SENTENCES_MODEL_PATH = "sentences_model.bin"; - public static final String MENTIONS_DATASET_PATH = "mentions_train.arff"; - public static final String SENTENCES_DATASET_PATH = "sentences_train.arff"; - - private Constants() { - } - - public static Classifier getClassifier() { - RandomForest classifier = new RandomForest(); - classifier.setNumIterations(250); - classifier.setSeed(0); - classifier.setNumExecutionSlots(8); - return classifier; - } - - - public static Classifier getSentencesClassifier() { - RandomForest classifier = new RandomForest(); - classifier.setNumIterations(250); - classifier.setSeed(0); - classifier.setNumExecutionSlots(8); - return classifier; - } -} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java index 96f3786..e4f86d4 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/Nicolas.java @@ -6,6 +6,7 @@ import com.google.common.collect.Sets; import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.common.Utils; import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor; import pl.waw.ipipan.zil.summ.nicolas.mention.MentionModel; diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel2.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel2.java index 2de5225..4554ccc 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel2.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/apply/ApplyModel2.java @@ -8,12 +8,13 @@ import org.slf4j.LoggerFactory; import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; -import pl.waw.ipipan.zil.summ.nicolas.Constants; import pl.waw.ipipan.zil.summ.nicolas.ThriftUtils; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.common.Utils; import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor; import pl.waw.ipipan.zil.summ.nicolas.mention.MentionModel; import pl.waw.ipipan.zil.summ.nicolas.sentence.SentenceFeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectInjector; import weka.classifiers.Classifier; import weka.core.Instance; import weka.core.Instances; @@ -29,8 +30,8 @@ public class ApplyModel2 { private static final Logger LOG = LoggerFactory.getLogger(ApplyModel2.class); - private static final String TEST_PREPROCESSED_DATA_PATH = "src/main/resources/preprocessed_full_texts/test"; - private static final String TARGET_DIR = "summaries"; + private static final String TEST_PREPROCESSED_DATA_PATH = "corpora/preprocessed_full_texts/test"; + private static final String TARGET_DIR = "corpora/summaries"; public static void main(String[] args) throws Exception { Classifier mentionClassifier = Utils.loadClassifier(Constants.MENTIONS_MODEL_PATH); @@ -39,6 +40,8 @@ public class ApplyModel2 { Classifier sentenceClassifier = Utils.loadClassifier(Constants.SENTENCES_MODEL_PATH); SentenceFeatureExtractor sentenceFeatureExtractor = new SentenceFeatureExtractor(); + ZeroSubjectInjector zeroSubjectInjector = new ZeroSubjectInjector(); + Map<String, TText> id2preprocessedText = Utils.loadPreprocessedTexts(TEST_PREPROCESSED_DATA_PATH); int i = 1; double avgSize = 0; @@ -49,10 +52,10 @@ public class ApplyModel2 { = MentionModel.detectGoodMentions(mentionClassifier, featureExtractor, text); int targetSize = calculateTargetSize(text); - String summary = calculateSummary(text, goodMentions, targetSize, sentenceClassifier, sentenceFeatureExtractor); + String summary = calculateSummary(text, goodMentions, targetSize, sentenceClassifier, sentenceFeatureExtractor, zeroSubjectInjector); int size = Utils.tokenize(summary).size(); avgSize += size; - try (BufferedWriter bw = new BufferedWriter(new FileWriter(new File(TARGET_DIR, entry.getKey() + "_emily3.txt")))) { + try (BufferedWriter bw = new BufferedWriter(new FileWriter(new File(TARGET_DIR, entry.getKey() + "_emily4.txt")))) { bw.append(summary); } @@ -71,12 +74,14 @@ public class ApplyModel2 { return (int) (0.2 * tokenCount); } - private static String calculateSummary(TText thrifted, Set<TMention> goodMentions, int targetSize, Classifier sentenceClassifier, SentenceFeatureExtractor sentenceFeatureExtractor) throws Exception { + private static String calculateSummary(TText thrifted, Set<TMention> goodMentions, int targetSize, Classifier sentenceClassifier, SentenceFeatureExtractor sentenceFeatureExtractor, ZeroSubjectInjector zeroSubjectInjector) throws Exception { List<TSentence> selectedSentences = selectSummarySentences(thrifted, goodMentions, targetSize, sentenceClassifier, sentenceFeatureExtractor); - StringBuffer sb = new StringBuffer(); + Set<String> zeroSubjectTokenIds = zeroSubjectInjector.findZeroSubjectTokenIds(thrifted, selectedSentences); + + StringBuilder sb = new StringBuilder(); for (TSentence sent : selectedSentences) { - sb.append(" " + Utils.loadSentence2Orth(sent)); + sb.append(" " + Utils.loadSentence2Orth(sent, zeroSubjectTokenIds)); } return sb.toString().trim(); } diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureExtractor.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureExtractor.java deleted file mode 100644 index 39de47f..0000000 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureExtractor.java +++ /dev/null @@ -1,83 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.features; - -import com.google.common.collect.*; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import weka.core.Attribute; - -import java.util.*; - -public class FeatureExtractor { - - protected static final Logger LOG = LoggerFactory.getLogger(FeatureExtractor.class); - - private final List<Attribute> sortedAttributes = Lists.newArrayList(); - - private final BiMap<String, Attribute> name2attribute = HashBiMap.create(); - - private final Set<String> normalizedAttributes = Sets.newHashSet(); - - public ArrayList<Attribute> getAttributesList() { - return Lists.newArrayList(sortedAttributes); - } - - protected Attribute getAttributeByName(String name) { - return name2attribute.get(name); - } - - protected void addNumericAttribute(String attributeName) { - name2attribute.put(attributeName, new Attribute(attributeName)); - } - - protected void addBinaryAttribute(String attributeName) { - name2attribute.put(attributeName, new Attribute(attributeName, Lists.newArrayList("f", "t"))); - } - - protected void addNominalAttribute(String attributeName, List<String> values) { - name2attribute.put(attributeName, new Attribute(attributeName, values)); - } - - protected void addNumericAttributeNormalized(String attributeName) { - addNumericAttribute(attributeName); - addNumericAttribute(attributeName + "_normalized"); - normalizedAttributes.add(attributeName); - } - - protected void fillSortedAttributes(String scoreAttName) { - sortedAttributes.addAll(name2attribute.values()); - sortedAttributes.remove(getAttributeByName(scoreAttName)); - Collections.sort(sortedAttributes, (o1, o2) -> name2attribute.inverse().get(o1).compareTo(name2attribute.inverse().get(o2))); - sortedAttributes.add(0, getAttributeByName(scoreAttName)); - } - - protected <T> void addNormalizedAttributeValues(Map<T, Map<Attribute, Double>> entity2attributes) { - Map<Attribute, Double> attribute2max = Maps.newHashMap(); - Map<Attribute, Double> attribute2min = Maps.newHashMap(); - for (T entity : entity2attributes.keySet()) { - Map<Attribute, Double> entityAttributes = entity2attributes.get(entity); - for (String attributeName : normalizedAttributes) { - Attribute attribute = getAttributeByName(attributeName); - Double value = entityAttributes.get(attribute); - - attribute2max.putIfAbsent(attribute, Double.MIN_VALUE); - attribute2max.compute(attribute, (k, v) -> Math.max(v, value)); - - attribute2min.putIfAbsent(attribute, Double.MAX_VALUE); - attribute2min.compute(attribute, (k, v) -> Math.min(v, value)); - } - } - for (T mention : entity2attributes.keySet()) { - Map<Attribute, Double> entityAttributes = entity2attributes.get(mention); - for (Attribute attribute : attribute2max.keySet()) { - Attribute normalizedAttribute = getAttributeByName(name2attribute.inverse().get(attribute) + "_normalized"); - entityAttributes.put(normalizedAttribute, - (entityAttributes.get(attribute) - attribute2min.get(attribute)) - / (attribute2max.get(attribute) - attribute2min.get(attribute))); - } - } - } - - protected double toBinary(boolean bool) { - return bool ? 1.0 : 0.0; - } -} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureHelper.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureHelper.java deleted file mode 100644 index d774b0a..0000000 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/FeatureHelper.java +++ /dev/null @@ -1,185 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.features; - -import com.google.common.collect.Maps; -import com.google.common.collect.Sets; -import pl.waw.ipipan.zil.multiservice.thrift.types.*; -import pl.waw.ipipan.zil.summ.nicolas.common.Utils; - -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.Function; -import java.util.stream.Collectors; - -import static java.util.stream.Collectors.toList; -import static java.util.stream.Collectors.toMap; - - -public class FeatureHelper { - - private final List<TMention> mentions; - private final Map<String, TMention> mentionId2mention; - private final Map<TCoreference, List<TMention>> coref2mentions = Maps.newHashMap(); - private final Map<TMention, TCoreference> mention2coref = Maps.newHashMap(); - private final Map<TMention, TSentence> mention2sent = Maps.newHashMap(); - private final Map<TMention, TParagraph> mention2par = Maps.newHashMap(); - private final Map<TMention, String> mention2Orth = Maps.newHashMap(); - private final Map<TMention, String> mention2Base = Maps.newHashMap(); - private final Map<TMention, TToken> mention2head = Maps.newHashMap(); - private final Set<TMention> mentionsInNamedEntities = Sets.newHashSet(); - - private final Map<TMention, Integer> mention2Index = Maps.newHashMap(); - private final Map<TSentence, Integer> sent2Index = Maps.newHashMap(); - private final Map<TParagraph, Integer> par2Index = Maps.newHashMap(); - private final Map<TSentence, Integer> sent2IndexInPar = Maps.newHashMap(); - private final Map<TMention, Integer> mention2indexInPar = Maps.newHashMap(); - private final Map<TMention, Integer> mention2indexInSent = Maps.newHashMap(); - - - public FeatureHelper(TText preprocessedText) { - mentions = preprocessedText.getParagraphs().stream() - .flatMap(p -> p.getSentences().stream()) - .flatMap(s -> s.getMentions().stream()).collect(Collectors.toList()); - - mentionId2mention = mentions.stream().collect(Collectors.toMap(TMention::getId, Function.identity())); - - for (TCoreference coref : preprocessedText.getCoreferences()) { - List<TMention> ments = coref.getMentionIds().stream().map(mentionId2mention::get).collect(toList()); - for (TMention m : ments) { - mention2coref.put(m, coref); - } - coref2mentions.put(coref, ments); - } - - int parIdx = 0; - int sentIdx = 0; - int mentionIdx = 0; - for (TParagraph par : preprocessedText.getParagraphs()) { - Map<TMention, String> m2o = Utils.loadMention2Orth(par.getSentences()); - mention2Orth.putAll(m2o); - Map<TMention, String> m2b = Utils.loadMention2Base(par.getSentences()); - mention2Base.putAll(m2b); - - int sentIdxInPar = 0; - int mentionIdxInPar = 0; - for (TSentence sent : par.getSentences()) { - - Map<String, TToken> tokenId2token = sent.getTokens().stream().collect(toMap(TToken::getId, Function.identity())); - - Map<String, Set<TNamedEntity>> tokenId2namedEntities = Maps.newHashMap(); - for (TNamedEntity namedEntity : sent.getNames()) { - for (String childId : namedEntity.getChildIds()) { - tokenId2namedEntities.putIfAbsent(childId, Sets.newHashSet()); - tokenId2namedEntities.get(childId).add(namedEntity); - } - } - - int mentionIdxInSent = 0; - for (TMention mention : sent.getMentions()) { - mention2sent.put(mention, sent); - mention2par.put(mention, par); - mention2Index.put(mention, mentionIdx++); - mention2indexInSent.put(mention, mentionIdxInSent++); - mention2indexInPar.put(mention, mentionIdxInPar++); - - String firstHeadTokenId = mention.getHeadIds().iterator().next(); - mention2head.put(mention, tokenId2token.get(firstHeadTokenId)); - if (tokenId2namedEntities.containsKey(firstHeadTokenId)) - mentionsInNamedEntities.add(mention); - } - sent2Index.put(sent, sentIdx++); - sent2IndexInPar.put(sent, sentIdxInPar++); - } - - par2Index.put(par, parIdx++); - } - } - - public List<TMention> getMentions() { - return mentions; - } - - public int getMentionIndexInChain(TMention mention) { - return coref2mentions.get(mention2coref.get(mention)).indexOf(mention); - } - - public int getChainLength(TMention mention) { - return coref2mentions.get(mention2coref.get(mention)).size(); - } - - public String getSentenceLastTokenOrth(TSentence sent) { - return sent.getTokens().get(sent.getTokensSize() - 1).getOrth(); - } - - public String getMentionOrth(TMention mention) { - return mention2Orth.get(mention); - } - - public String getMentionBase(TMention mention) { - return mention2Base.get(mention); - } - - public int getMentionIndex(TMention mention) { - return mention2Index.get(mention); - } - - public int getMentionIndexInSent(TMention mention) { - return mention2indexInSent.get(mention); - } - - public int getMentionIndexInPar(TMention mention) { - return mention2indexInPar.get(mention); - } - - public int getParIndex(TParagraph paragraph) { - return par2Index.get(paragraph); - } - - public int getSentIndex(TSentence sent) { - return sent2Index.get(sent); - } - - public int getSentIndexInPar(TSentence sent) { - return sent2IndexInPar.get(sent); - } - - public TParagraph getMentionParagraph(TMention mention) { - return mention2par.get(mention); - } - - public TSentence getMentionSentence(TMention mention) { - return mention2sent.get(mention); - } - - public TMention getFirstChainMention(TMention mention) { - return mentionId2mention.get(mention2coref.get(mention).getMentionIdsIterator().next()); - } - - public TToken getMentionHeadToken(TMention mention) { - return mention2head.get(mention); - } - - public boolean isMentionNamedEntity(TMention mention) { - return mentionsInNamedEntities.contains(mention); - } - - public boolean isNested(TMention mention) { - return mentions.stream().anyMatch(m -> m.getChildIds().containsAll(mention.getChildIds())); - } - - public boolean isNesting(TMention mention) { - return mentions.stream().anyMatch(m -> mention.getChildIds().containsAll(m.getChildIds())); - } - - public Set<TCoreference> getClusters() { - return coref2mentions.keySet(); - } - - public Set<TMention> getCoreferentMentions(TMention tMention) { - return getMentionCluster(tMention).getMentionIds().stream().map(this.mentionId2mention::get).collect(Collectors.toSet()); - } - - public TCoreference getMentionCluster(TMention tMention) { - return this.mention2coref.get(tMention); - } -} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/Interpretation.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/Interpretation.java deleted file mode 100644 index 11d0cda..0000000 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/features/Interpretation.java +++ /dev/null @@ -1,77 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.features; - -import pl.waw.ipipan.zil.multiservice.thrift.types.TInterpretation; - - -public class Interpretation { - private String ctag = "null"; - private String casee = "null"; - private String gender = "null"; - private String number = "null"; - private String person = "null"; - - public Interpretation(TInterpretation chosenInterpretation) { - ctag = chosenInterpretation.getCtag(); - String[] split = chosenInterpretation.getMsd().split(":"); - switch (ctag) { - case "ger": - case "subst": - case "pact": - case "ppas": - case "num": - case "numcol": - case "adj": - number = split[0]; - casee = split[1]; - gender = split[2]; - break; - case "ppron12": - case "ppron3": - number = split[0]; - casee = split[1]; - gender = split[2]; - person = split[3]; - break; - case "siebie": - casee = split[0]; - break; - case "fin": - case "bedzie": - case "aglt": - case "impt": - number = split[0]; - person = split[1]; - break; - case "praet": - case "winien": - number = split[0]; - gender = split[1]; - break; - case "prep": - casee = split[0]; - break; - default: - break; - } - } - - public String getCase() { - return casee; - } - - public String getGender() { - return gender; - } - - public String getNumber() { - return number; - } - - public String getPerson() { - return person; - } - - public String getCtag() { - return ctag; - } -} diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java index f6ccab9..ad239f9 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java @@ -2,9 +2,9 @@ package pl.waw.ipipan.zil.summ.nicolas.mention; import com.google.common.collect.*; import pl.waw.ipipan.zil.multiservice.thrift.types.*; -import pl.waw.ipipan.zil.summ.nicolas.features.FeatureExtractor; -import pl.waw.ipipan.zil.summ.nicolas.features.FeatureHelper; -import pl.waw.ipipan.zil.summ.nicolas.features.Interpretation; +import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; +import pl.waw.ipipan.zil.summ.nicolas.common.features.Interpretation; import weka.core.Attribute; import java.io.File; diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionScorer.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionScorer.java index 5fa8e7c..9180ac4 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionScorer.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionScorer.java @@ -19,7 +19,7 @@ public class MentionScorer { Multiset<String> tokenCounts = HashMultiset.create(Utils.tokenize(optimalSummary.toLowerCase())); List<TSentence> sentences = text.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(Collectors.toList()); - Map<TMention, String> mention2Orth = Utils.loadMention2Orth(sentences); + Map<TMention, String> mention2Orth = Utils.loadMention2Orth(sentences, true); return booleanTokenIntersection(mention2Orth, tokenCounts); } diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/PrepareTrainingData.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/PrepareTrainingData.java index 3810574..13f606a 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/PrepareTrainingData.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/PrepareTrainingData.java @@ -7,7 +7,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; -import pl.waw.ipipan.zil.summ.nicolas.Constants; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.ThriftUtils; import pl.waw.ipipan.zil.summ.nicolas.common.Utils; import weka.core.Instance; diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/TrainModel.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/TrainModel.java index e26b543..1372c06 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/TrainModel.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/TrainModel.java @@ -3,7 +3,7 @@ package pl.waw.ipipan.zil.summ.nicolas.mention; import org.apache.commons.lang3.time.StopWatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import pl.waw.ipipan.zil.summ.nicolas.Constants; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import weka.classifiers.Classifier; import weka.core.Instances; import weka.core.converters.ArffLoader; @@ -28,7 +28,7 @@ public class TrainModel { StopWatch watch = new StopWatch(); watch.start(); - Classifier classifier = Constants.getClassifier(); + Classifier classifier = Constants.getMentionClassifier(); LOG.info("Building classifier..."); classifier.buildClassifier(instances); diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java index db2147d..a52ef2f 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java @@ -3,7 +3,7 @@ package pl.waw.ipipan.zil.summ.nicolas.mention.test; import org.apache.commons.lang3.time.StopWatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import pl.waw.ipipan.zil.summ.nicolas.Constants; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import weka.classifiers.Classifier; import weka.classifiers.evaluation.Evaluation; import weka.core.Instances; @@ -32,7 +32,7 @@ public class Crossvalidate { StopWatch watch = new StopWatch(); watch.start(); - Classifier tree = Constants.getClassifier(); + Classifier tree = Constants.getMentionClassifier(); Evaluation eval = new Evaluation(instances); eval.crossValidateModel(tree, instances, 10, new Random(1)); diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java index 0fc9685..48a8ccf 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java @@ -3,7 +3,7 @@ package pl.waw.ipipan.zil.summ.nicolas.mention.test; import org.apache.commons.lang3.time.StopWatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import pl.waw.ipipan.zil.summ.nicolas.Constants; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import weka.classifiers.Classifier; import weka.classifiers.evaluation.Evaluation; import weka.core.Instances; diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/PrepareTrainingData.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/PrepareTrainingData.java index f9ab453..31fa380 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/PrepareTrainingData.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/PrepareTrainingData.java @@ -8,7 +8,7 @@ import org.slf4j.LoggerFactory; import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; -import pl.waw.ipipan.zil.summ.nicolas.Constants; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.ThriftUtils; import pl.waw.ipipan.zil.summ.nicolas.common.Utils; import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor; diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java index ce045af..3da019e 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceFeatureExtractor.java @@ -2,8 +2,8 @@ package pl.waw.ipipan.zil.summ.nicolas.sentence; import com.google.common.collect.Maps; import pl.waw.ipipan.zil.multiservice.thrift.types.*; -import pl.waw.ipipan.zil.summ.nicolas.features.FeatureExtractor; -import pl.waw.ipipan.zil.summ.nicolas.features.FeatureHelper; +import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; import weka.core.Attribute; import java.util.List; diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceScorer.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceScorer.java index 0ebb515..e53ffa7 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceScorer.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/SentenceScorer.java @@ -3,6 +3,7 @@ package pl.waw.ipipan.zil.summ.nicolas.sentence; import com.google.common.collect.HashMultiset; import com.google.common.collect.Maps; import com.google.common.collect.Multiset; +import com.google.common.collect.Sets; import pl.waw.ipipan.zil.multiservice.thrift.types.TParagraph; import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/TrainModel.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/TrainModel.java index 71a4dec..8b3741c 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/TrainModel.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/TrainModel.java @@ -3,7 +3,7 @@ package pl.waw.ipipan.zil.summ.nicolas.sentence; import org.apache.commons.lang3.time.StopWatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import pl.waw.ipipan.zil.summ.nicolas.Constants; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import weka.classifiers.Classifier; import weka.core.Instances; import weka.core.converters.ArffLoader; diff --git a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java index a46f64e..457a857 100644 --- a/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java +++ b/nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java @@ -3,7 +3,7 @@ package pl.waw.ipipan.zil.summ.nicolas.sentence.test; import org.apache.commons.lang3.time.StopWatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import pl.waw.ipipan.zil.summ.nicolas.Constants; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import weka.classifiers.Classifier; import weka.classifiers.evaluation.Evaluation; import weka.core.Instances; diff --git a/nicolas-zero/pom.xml b/nicolas-zero/pom.xml index 6f4d656..666e517 100644 --- a/nicolas-zero/pom.xml +++ b/nicolas-zero/pom.xml @@ -27,6 +27,10 @@ <groupId>commons-io</groupId> <artifactId>commons-io</artifactId> </dependency> + <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-lang3</artifactId> + </dependency> <!-- logging --> <dependency> diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java b/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java index cceb8ac..9ce2a4b 100644 --- a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java +++ b/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java @@ -12,7 +12,10 @@ import java.util.Set; public class CandidateFinder { - public List<ZeroSubjectCandidate> findZeroSubjectCandidates(TText text, Set<String> summarySentenceIds) { + private CandidateFinder() { + } + + public static List<ZeroSubjectCandidate> findZeroSubjectCandidates(TText text, Set<String> summarySentenceIds) { List<ZeroSubjectCandidate> candidates = Lists.newArrayList(); Map<String, Set<String>> mentionId2Cluster = Maps.newHashMap(); diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/Zero.java b/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/Zero.java deleted file mode 100644 index 1414f45..0000000 --- a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/Zero.java +++ /dev/null @@ -1,76 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.zero; - -import com.google.common.collect.Lists; -import com.google.common.collect.Maps; -import com.google.common.collect.Sets; -import org.apache.commons.csv.CSVFormat; -import org.apache.commons.csv.CSVPrinter; -import org.apache.commons.csv.QuoteMode; -import org.apache.commons.io.IOUtils; -import pl.waw.ipipan.zil.multiservice.thrift.types.TText; -import pl.waw.ipipan.zil.summ.nicolas.common.ThriftTextHelper; -import pl.waw.ipipan.zil.summ.nicolas.common.Utils; - -import java.io.File; -import java.io.FileReader; -import java.io.FileWriter; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.Set; - -public class Zero { - - private static final String IDS_PATH = "corpora/summaries_dev"; - private static final String THRIFTED_PATH = "corpora/preprocessed_full_texts/dev/"; - - private Zero() { - } - - public static void main(String[] args) throws IOException { - - CandidateFinder candidateFinder = new CandidateFinder(); - - Map<String, TText> id2preprocessedText = Utils.loadPreprocessedTexts(THRIFTED_PATH); - Map<String, Set<String>> id2sentIds = loadSentenceIds(IDS_PATH); - - List<List<Object>> rows = Lists.newArrayList(); - for (Map.Entry<String, TText> entry : id2preprocessedText.entrySet()) { - String textId = entry.getKey(); - - TText text = entry.getValue(); - ThriftTextHelper thriftTextHelper = new ThriftTextHelper(text); - - Set<String> sentenceIds = id2sentIds.get(textId); - - List<ZeroSubjectCandidate> zeroSubjectCandidates = candidateFinder.findZeroSubjectCandidates(text, sentenceIds); - - for (ZeroSubjectCandidate candidate : zeroSubjectCandidates) { - List<Object> row = Lists.newArrayList(); - row.add("C"); - row.add(textId); - row.add(thriftTextHelper.getMentionText(candidate.getZeroCandidateMention())); - row.add(thriftTextHelper.getSentenceText(candidate.getPreviousSentence())); - row.add(thriftTextHelper.getSentenceText(candidate.getSentence())); - rows.add(row); - } - } - - try (CSVPrinter csvPrinter = new CSVPrinter(new FileWriter("zeros.tsv"), CSVFormat.DEFAULT.withDelimiter('\t').withEscape('\\').withQuoteMode(QuoteMode.NONE).withQuote('"'))) { - for (List<Object> row : rows) { - csvPrinter.printRecord(row); - } - } - - } - - private static Map<String, Set<String>> loadSentenceIds(String idsPath) throws IOException { - Map<String, Set<String>> result = Maps.newHashMap(); - for (File f : new File(idsPath).listFiles()) { - String id = f.getName().split("_")[0]; - List<String> sentenceIds = IOUtils.readLines(new FileReader(f)); - result.put(id, Sets.newHashSet(sentenceIds)); - } - return result; - } -} diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java b/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java new file mode 100644 index 0000000..43e1333 --- /dev/null +++ b/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java @@ -0,0 +1,69 @@ +package pl.waw.ipipan.zil.summ.nicolas.zero; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; +import weka.core.Attribute; + +import java.util.List; +import java.util.Map; + + +public class ZeroFeatureExtractor extends FeatureExtractor { + + public ZeroFeatureExtractor() { + + for (String prefix : new String[]{"antecedent", "candidate"}) { + addNumericAttribute(prefix + "_index_in_sent"); + addNumericAttribute(prefix + "_token_count"); + addBinaryAttribute(prefix + "_is_zero"); + addBinaryAttribute(prefix + "_is_pronoun"); + addBinaryAttribute(prefix + "_is_named"); + } + + addBinaryAttribute("pair_equal_orth"); + + addNominalAttribute("score", Lists.newArrayList("bad", "good")); + fillSortedAttributes("score"); + } + + public Map<ZeroSubjectCandidate, Map<Attribute, Double>> calculateFeatures(List<ZeroSubjectCandidate> candidates, TText text) { + Map<ZeroSubjectCandidate, Map<Attribute, Double>> result = Maps.newHashMap(); + + FeatureHelper helper = new FeatureHelper(text); + for (ZeroSubjectCandidate candidate : candidates) { + Map<Attribute, Double> candidateFeatures = calculateFeatures(candidate, helper); + result.put(candidate, candidateFeatures); + } + + return result; + } + + private Map<Attribute, Double> calculateFeatures(ZeroSubjectCandidate candidate, FeatureHelper helper) { + + Map<Attribute, Double> candidateFeatures = Maps.newHashMap(); + candidateFeatures.put(getAttributeByName("score"), weka.core.Utils.missingValue()); + + TMention mention = candidate.getZeroCandidateMention(); + TMention antecedent = candidate.getPreviousSentence().getMentions().stream().filter(ante -> helper.getCoreferentMentions(mention).contains(ante)).findFirst().get(); + + addMentionFeatures(helper, candidateFeatures, mention, "candidate"); + addMentionFeatures(helper, candidateFeatures, antecedent, "antecedent"); + + candidateFeatures.put(getAttributeByName("pair_equal_orth"), toBinary(helper.getMentionOrth(mention).equalsIgnoreCase(helper.getMentionOrth(antecedent)))); + + return candidateFeatures; + } + + private void addMentionFeatures(FeatureHelper helper, Map<Attribute, Double> candidateFeatures, TMention mention, String attributePrefix) { + candidateFeatures.put(getAttributeByName(attributePrefix + "_index_in_sent"), (double) helper.getMentionIndexInSent(mention)); + candidateFeatures.put(getAttributeByName(attributePrefix + "_token_count"), (double) mention.getChildIdsSize()); + candidateFeatures.put(getAttributeByName(attributePrefix + "_is_zero"), toBinary(mention.isZeroSubject())); + candidateFeatures.put(getAttributeByName(attributePrefix + "_is_pronoun"), toBinary(helper.getMentionHeadToken(mention).getChosenInterpretation().getCtag().matches("ppron.*"))); + candidateFeatures.put(getAttributeByName(attributePrefix + "_is_named"), toBinary(helper.isMentionNamedEntity(mention))); + } + +} diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java b/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java index ca4f915..329f31a 100644 --- a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java +++ b/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java @@ -1,5 +1,51 @@ package pl.waw.ipipan.zil.summ.nicolas.zero; +import com.google.common.collect.Sets; +import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; +import pl.waw.ipipan.zil.summ.nicolas.common.Utils; +import pl.waw.ipipan.zil.summ.nicolas.zero.train.TrainingDataExtractor; +import weka.classifiers.Classifier; +import weka.core.Instance; +import weka.core.Instances; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; public class ZeroSubjectInjector { + + private final ZeroFeatureExtractor featureExtractor; + private final Classifier classifier; + private final Instances instances; + + public ZeroSubjectInjector() throws IOException, ClassNotFoundException { + classifier = Utils.loadClassifier(Constants.ZERO_MODEL_PATH); + featureExtractor = new ZeroFeatureExtractor(); + instances = Utils.createNewInstances(featureExtractor.getAttributesList()); + } + + public Set<String> findZeroSubjectTokenIds(TText text, List<TSentence> selectedSentences) throws Exception { + Set<String> summarySentenceIds = selectedSentences.stream().map(TSentence::getId).collect(Collectors.toSet()); + List<ZeroSubjectCandidate> zeroSubjectCandidates = CandidateFinder.findZeroSubjectCandidates(text, summarySentenceIds); + Map<ZeroSubjectCandidate, Instance> candidate2instance = + TrainingDataExtractor.extractInstancesFromZeroCandidates(zeroSubjectCandidates, text, featureExtractor); + + Set<String> result = Sets.newHashSet(); + for (Map.Entry<ZeroSubjectCandidate, Instance> entry : candidate2instance.entrySet()) { + ZeroSubjectCandidate candidate = entry.getKey(); + Instance instance = entry.getValue(); + instance.setDataset(instances); + instance.setClassMissing(); + boolean good = classifier.classifyInstance(instance) > 0.5; + if (good) { + result.addAll(candidate.getZeroCandidateMention().getChildIds()); + } + } + return result; + } + } diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainModel.java b/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainModel.java new file mode 100644 index 0000000..34df6cf --- /dev/null +++ b/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainModel.java @@ -0,0 +1,51 @@ +package pl.waw.ipipan.zil.summ.nicolas.zero.train; + +import org.apache.commons.lang3.time.StopWatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; +import weka.classifiers.Classifier; +import weka.core.Instances; +import weka.core.converters.ArffLoader; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.ObjectOutputStream; + + +public class TrainModel { + + private static final Logger LOG = LoggerFactory.getLogger(TrainModel.class); + + private TrainModel() { + } + + public static void main(String[] args) throws Exception { + + ArffLoader loader = new ArffLoader(); + loader.setFile(new File(Constants.ZERO_DATASET_PATH)); + Instances instances = loader.getDataSet(); + instances.setClassIndex(0); + LOG.info(instances.size() + " instances loaded."); + LOG.info(instances.numAttributes() + " attributes for each instance."); + + StopWatch watch = new StopWatch(); + watch.start(); + + Classifier classifier = Constants.getZerosClassifier(); + + LOG.info("Building classifier..."); + classifier.buildClassifier(instances); + LOG.info("...done."); + + try (ObjectOutputStream oos = new ObjectOutputStream( + new FileOutputStream(Constants.ZERO_MODEL_PATH))) { + oos.writeObject(classifier); + } + + watch.stop(); + LOG.info("Elapsed time: " + watch); + + LOG.info(classifier.toString()); + } +} diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainingDataExtractor.java b/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainingDataExtractor.java new file mode 100644 index 0000000..fcdc68d --- /dev/null +++ b/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainingDataExtractor.java @@ -0,0 +1,97 @@ +package pl.waw.ipipan.zil.summ.nicolas.zero.train; + +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import org.apache.commons.io.IOUtils; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; +import pl.waw.ipipan.zil.summ.nicolas.common.Utils; +import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; +import pl.waw.ipipan.zil.summ.nicolas.zero.CandidateFinder; +import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroFeatureExtractor; +import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectCandidate; +import weka.core.Attribute; +import weka.core.DenseInstance; +import weka.core.Instance; +import weka.core.Instances; +import weka.core.converters.ArffSaver; + +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class TrainingDataExtractor { + + private static final String IDS_PATH = "corpora/summaries_dev"; + private static final String THRIFTED_PATH = "corpora/preprocessed_full_texts/dev/"; + private static final String GOLD_ZEROS_PATH = "/zeros.tsv"; + + private TrainingDataExtractor() { + } + + public static void main(String[] args) throws IOException { + + Map<String, TText> id2preprocessedText = Utils.loadPreprocessedTexts(THRIFTED_PATH); + Map<String, Set<String>> id2sentIds = loadSentenceIds(IDS_PATH); + + ZeroScorer zeroScorer = new ZeroScorer(GOLD_ZEROS_PATH); + ZeroFeatureExtractor featureExtractor = new ZeroFeatureExtractor(); + + Instances instances = Utils.createNewInstances(featureExtractor.getAttributesList()); + + for (Map.Entry<String, TText> entry : id2preprocessedText.entrySet()) { + String textId = entry.getKey(); + + TText text = entry.getValue(); + Set<String> sentenceIds = id2sentIds.get(textId); + FeatureHelper featureHelper = new FeatureHelper(text); + + List<ZeroSubjectCandidate> zeroSubjectCandidates = CandidateFinder.findZeroSubjectCandidates(text, sentenceIds); + Map<ZeroSubjectCandidate, Instance> candidate2instance = extractInstancesFromZeroCandidates(zeroSubjectCandidates, text, featureExtractor); + + for (Map.Entry<ZeroSubjectCandidate, Instance> entry2 : candidate2instance.entrySet()) { + boolean good = zeroScorer.isValidCandidate(entry2.getKey(), featureHelper); + Instance instance = entry2.getValue(); + instance.setDataset(instances); + instance.setClassValue(good ? 1 : 0); + instances.add(instance); + } + } + + saveInstancesToFile(instances); + } + + public static Map<ZeroSubjectCandidate, Instance> extractInstancesFromZeroCandidates(List<ZeroSubjectCandidate> candidates, TText text, ZeroFeatureExtractor featureExtractor) { + Map<ZeroSubjectCandidate, Map<Attribute, Double>> candidate2features = featureExtractor.calculateFeatures(candidates, text); + Map<ZeroSubjectCandidate, Instance> candidate2instance = Maps.newHashMap(); + for (Map.Entry<ZeroSubjectCandidate, Map<Attribute, Double>> entry : candidate2features.entrySet()) { + Instance instance = new DenseInstance(featureExtractor.getAttributesList().size()); + Map<Attribute, Double> sentenceFeatures = entry.getValue(); + for (Attribute attribute : featureExtractor.getAttributesList()) { + instance.setValue(attribute, sentenceFeatures.get(attribute)); + } + candidate2instance.put(entry.getKey(), instance); + } + return candidate2instance; + } + + private static void saveInstancesToFile(Instances instances) throws IOException { + ArffSaver saver = new ArffSaver(); + saver.setInstances(instances); + saver.setFile(new File(Constants.ZERO_DATASET_PATH)); + saver.writeBatch(); + } + + private static Map<String, Set<String>> loadSentenceIds(String idsPath) throws IOException { + Map<String, Set<String>> result = Maps.newHashMap(); + for (File f : new File(idsPath).listFiles()) { + String id = f.getName().split("_")[0]; + List<String> sentenceIds = IOUtils.readLines(new FileReader(f)); + result.put(id, Sets.newHashSet(sentenceIds)); + } + return result; + } +} diff --git a/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/ZeroScorer.java b/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/ZeroScorer.java new file mode 100644 index 0000000..3eef010 --- /dev/null +++ b/nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/ZeroScorer.java @@ -0,0 +1,50 @@ +package pl.waw.ipipan.zil.summ.nicolas.zero.train; + +import com.google.common.collect.Maps; +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVParser; +import org.apache.commons.csv.CSVRecord; +import org.apache.commons.csv.QuoteMode; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; +import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; +import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectCandidate; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.List; +import java.util.Map; + +public class ZeroScorer { + + private static final char DELIMITER = '\t'; + + private final Map<String, Boolean> candidateEncoding2Decision = Maps.newHashMap(); + + public ZeroScorer(String goldZerosPath) throws IOException { + try (InputStream stream = ZeroScorer.class.getResourceAsStream(goldZerosPath); + InputStreamReader reader = new InputStreamReader(stream, Constants.ENCODING); + CSVParser parser = new CSVParser(reader, CSVFormat.DEFAULT.withDelimiter(DELIMITER).withEscape('|').withQuoteMode(QuoteMode.NONE).withQuote('~'))) { + List<CSVRecord> records = parser.getRecords(); + for (CSVRecord record : records) { + candidateEncoding2Decision.put(encode(record.get(2), record.get(3), record.get(4)), record.get(0).equalsIgnoreCase("C")); + } + } + } + + private String encode(String mentionOrth, String firstSentenceOrth, String secondSentenceOrth) { + return mentionOrth + DELIMITER + firstSentenceOrth + DELIMITER + secondSentenceOrth; + } + + private String encode(ZeroSubjectCandidate candidate, FeatureHelper helper) { + String mentionOrth = helper.getMentionOrth(candidate.getZeroCandidateMention()); + String firstSentenceOrth = helper.getSentenceOrth(candidate.getPreviousSentence()); + String secondSentenceOrth = helper.getSentenceOrth(candidate.getSentence()); + return encode(mentionOrth, firstSentenceOrth, secondSentenceOrth); + } + + public boolean isValidCandidate(ZeroSubjectCandidate candidate, FeatureHelper helper) { + return candidateEncoding2Decision.get(encode(candidate, helper)); + } + +} diff --git a/nicolas-zero/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java b/nicolas-zero/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java index 7948faa..4ab4ee2 100644 --- a/nicolas-zero/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java +++ b/nicolas-zero/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java @@ -2,12 +2,11 @@ package pl.waw.ipipan.zil.summ.nicolas.zero; import com.google.common.collect.Sets; import org.apache.commons.io.IOUtils; -import org.junit.BeforeClass; import org.junit.Test; import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; -import pl.waw.ipipan.zil.summ.nicolas.common.ThriftTextHelper; import pl.waw.ipipan.zil.summ.nicolas.common.Utils; +import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; import java.io.IOException; import java.io.InputStream; @@ -22,18 +21,11 @@ public class CandidateFinderTest { private static final String SAMPLE_TEXT_PATH = "/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin"; private static final String SAMPLE_TEXT_SUMMARY_IDS_PATH = "/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt"; - private static CandidateFinder candidateFinder; - - @BeforeClass - public static void init() { - candidateFinder = new CandidateFinder(); - } - @Test public void shouldFindZeroSubjectCandidateInSampleText() throws Exception { - ThriftTextHelper sampleTextHelper = loadSampleTextHelper(); + FeatureHelper sampleTextHelper = loadSampleTextHelper(); Set<String> summarySentenceIds = loadSampleTextSummarySentenceIds(); - List<ZeroSubjectCandidate> candidates = candidateFinder.findZeroSubjectCandidates(sampleTextHelper.getText(), summarySentenceIds); + List<ZeroSubjectCandidate> candidates = CandidateFinder.findZeroSubjectCandidates(sampleTextHelper.getText(), summarySentenceIds); assertEquals(1, candidates.size()); ZeroSubjectCandidate zeroSubjectCandidate = candidates.get(0); @@ -41,9 +33,9 @@ public class CandidateFinderTest { TSentence secondSentence = zeroSubjectCandidate.getSentence(); TMention zeroCandidate = zeroSubjectCandidate.getZeroCandidateMention(); - assertEquals("Ala ma kota.", sampleTextHelper.getSentenceText(firstSentence)); - assertEquals("Ala ma też psa.", sampleTextHelper.getSentenceText(secondSentence)); - assertEquals("Ala", sampleTextHelper.getMentionText(zeroCandidate)); + assertEquals("Ala ma kota.", sampleTextHelper.getSentenceOrth(firstSentence)); + assertEquals("Ala ma też psa.", sampleTextHelper.getSentenceOrth(secondSentence)); + assertEquals("Ala", sampleTextHelper.getMentionOrth(zeroCandidate)); } private Set<String> loadSampleTextSummarySentenceIds() throws IOException { @@ -53,9 +45,9 @@ public class CandidateFinderTest { } } - private ThriftTextHelper loadSampleTextHelper() throws IOException { + private FeatureHelper loadSampleTextHelper() throws IOException { try (InputStream stream = CandidateFinderTest.class.getResourceAsStream(SAMPLE_TEXT_PATH)) { - return new ThriftTextHelper(Utils.loadThrifted(stream)); + return new FeatureHelper(Utils.loadThrifted(stream)); } } } \ No newline at end of file diff --git a/nicolas-zero/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjectorTest.java b/nicolas-zero/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjectorTest.java deleted file mode 100644 index e98bc27..0000000 --- a/nicolas-zero/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjectorTest.java +++ /dev/null @@ -1,11 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.zero; - -import org.junit.Test; - -public class ZeroSubjectInjectorTest { - - @Test - public void shouldInit() throws Exception { - ZeroSubjectInjector injector = new ZeroSubjectInjector(); - } -} \ No newline at end of file diff --git a/nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin b/nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin new file mode 100644 index 0000000..e30b245 Binary files /dev/null and b/nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin differ diff --git a/nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt b/nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt new file mode 100644 index 0000000..10ac642 --- /dev/null +++ b/nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt @@ -0,0 +1,2 @@ +s-2.1 +s-2.2 diff --git a/pom.xml b/pom.xml index 3a2ba87..bbdbd9b 100644 --- a/pom.xml +++ b/pom.xml @@ -61,6 +61,11 @@ <artifactId>nicolas-common</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <artifactId>nicolas-zero</artifactId> + <version>${project.version}</version> + </dependency> <!-- internal --> <dependency> -- libgit2 0.22.2