Commit 88415dbf2896c80d1c6362b9288378c637425ed0
1 parent
91b27b24
refactor, add zero features
Showing
25 changed files
with
303 additions
and
98 deletions
nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Constants.java
1 | 1 | package pl.waw.ipipan.zil.summ.nicolas.common; |
2 | 2 | |
3 | 3 | import com.google.common.base.Charsets; |
4 | +import com.google.common.collect.ImmutableList; | |
4 | 5 | import weka.classifiers.Classifier; |
5 | -import weka.classifiers.functions.Logistic; | |
6 | +import weka.classifiers.functions.SMO; | |
7 | +import weka.classifiers.meta.AdaBoostM1; | |
8 | +import weka.classifiers.meta.AttributeSelectedClassifier; | |
9 | +import weka.classifiers.rules.JRip; | |
10 | +import weka.classifiers.trees.J48; | |
6 | 11 | import weka.classifiers.trees.RandomForest; |
7 | 12 | |
8 | 13 | import java.nio.charset.Charset; |
... | ... | @@ -20,6 +25,8 @@ public class Constants { |
20 | 25 | |
21 | 26 | public static final Charset ENCODING = Charsets.UTF_8; |
22 | 27 | |
28 | + public static final ImmutableList<String> POS_TAGS = ImmutableList.of("end", "other", "null", "impt", "imps", "inf", "pred", "subst", "aglt", "ppron3", "ger", "praet", "fin", "num", "interp", "siebie", "brev", "interj", "ppron12", "adj", "burk", "pcon", "bedzie", "adv", "prep", "depr", "xxx", "winien", "conj", "qub", "adja", "ppas", "comp", "pact"); | |
29 | + | |
23 | 30 | private Constants() { |
24 | 31 | } |
25 | 32 | |
... | ... | @@ -33,14 +40,14 @@ public class Constants { |
33 | 40 | |
34 | 41 | public static Classifier getSentencesClassifier() { |
35 | 42 | RandomForest classifier = new RandomForest(); |
36 | - classifier.setNumIterations(250); | |
43 | + classifier.setNumIterations(10); | |
37 | 44 | classifier.setSeed(0); |
38 | 45 | classifier.setNumExecutionSlots(8); |
39 | 46 | return classifier; |
40 | 47 | } |
41 | 48 | |
42 | 49 | public static Classifier getZerosClassifier() { |
43 | - Logistic classifier = new Logistic(); | |
50 | + Classifier classifier = new J48(); | |
44 | 51 | return classifier; |
45 | 52 | } |
46 | 53 | } |
... | ... |
nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/features/FeatureHelper.java
... | ... | @@ -30,12 +30,13 @@ public class FeatureHelper { |
30 | 30 | private final Map<TMention, TToken> mention2head = Maps.newHashMap(); |
31 | 31 | private final Set<TMention> mentionsInNamedEntities = Sets.newHashSet(); |
32 | 32 | |
33 | - private final Map<TMention, Integer> mention2Index = Maps.newHashMap(); | |
33 | + private final Map<TMention, Integer> mention2index = Maps.newHashMap(); | |
34 | 34 | private final Map<TSentence, Integer> sent2Index = Maps.newHashMap(); |
35 | 35 | private final Map<TParagraph, Integer> par2Index = Maps.newHashMap(); |
36 | 36 | private final Map<TSentence, Integer> sent2IndexInPar = Maps.newHashMap(); |
37 | 37 | private final Map<TMention, Integer> mention2indexInPar = Maps.newHashMap(); |
38 | 38 | private final Map<TMention, Integer> mention2indexInSent = Maps.newHashMap(); |
39 | + private final Map<TMention, Integer> mention2firstTokenIndex = Maps.newHashMap(); | |
39 | 40 | |
40 | 41 | |
41 | 42 | public FeatureHelper(TText preprocessedText) { |
... | ... | @@ -82,7 +83,8 @@ public class FeatureHelper { |
82 | 83 | for (TMention mention : sent.getMentions()) { |
83 | 84 | mention2sent.put(mention, sent); |
84 | 85 | mention2par.put(mention, par); |
85 | - mention2Index.put(mention, mentionIdx++); | |
86 | + mention2index.put(mention, mentionIdx++); | |
87 | + mention2firstTokenIndex.put(mention, sent.getTokens().indexOf(tokenId2token.get(mention.getChildIds().iterator().next()))); | |
86 | 88 | mention2indexInSent.put(mention, mentionIdxInSent++); |
87 | 89 | mention2indexInPar.put(mention, mentionIdxInPar++); |
88 | 90 | |
... | ... | @@ -124,7 +126,11 @@ public class FeatureHelper { |
124 | 126 | } |
125 | 127 | |
126 | 128 | public int getMentionIndex(TMention mention) { |
127 | - return mention2Index.get(mention); | |
129 | + return mention2index.get(mention); | |
130 | + } | |
131 | + | |
132 | + public int getMentionFirstTokenIndex(TMention mention) { | |
133 | + return mention2firstTokenIndex.get(mention); | |
128 | 134 | } |
129 | 135 | |
130 | 136 | public int getMentionIndexInSent(TMention mention) { |
... | ... | @@ -200,4 +206,19 @@ public class FeatureHelper { |
200 | 206 | public TText getText() { |
201 | 207 | return text; |
202 | 208 | } |
209 | + | |
210 | + public TToken getTokenAfterMention(TMention mention) { | |
211 | + Integer idx = mention2firstTokenIndex.get(mention) + mention.getChildIds().size(); | |
212 | + List<TToken> sentenceTokens = mention2sent.get(mention).getTokens(); | |
213 | + if (idx >= sentenceTokens.size()) | |
214 | + return null; | |
215 | + return sentenceTokens.get(idx); | |
216 | + } | |
217 | + | |
218 | + public TToken getTokenBeforeMention(TMention mention) { | |
219 | + Integer idx = mention2firstTokenIndex.get(mention); | |
220 | + if (idx == 0) | |
221 | + return null; | |
222 | + return mention2sent.get(mention).getTokens().get(idx - 1); | |
223 | + } | |
203 | 224 | } |
... | ... |
nicolas-common/src/test/java/pl/waw/ipipan/zil/summ/nicolas/common/UtilsTest.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.nicolas.common; | |
2 | + | |
3 | +import org.junit.Test; | |
4 | +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; | |
5 | + | |
6 | +import java.io.InputStream; | |
7 | + | |
8 | +import static org.junit.Assert.assertEquals; | |
9 | + | |
10 | +public class UtilsTest { | |
11 | + | |
12 | + private static final String SAMPLE_TEXT_PATH = "/199704210011.bin"; | |
13 | + | |
14 | + @Test | |
15 | + public void shouldDeserializeTextIgnoringClassVersionId() throws Exception { | |
16 | + try (InputStream stream = UtilsTest.class.getResourceAsStream(SAMPLE_TEXT_PATH)) { | |
17 | + TText text = Utils.loadThrifted(stream); | |
18 | + assertEquals(26, text.getParagraphs().size()); | |
19 | + assertEquals(2, text.getParagraphs().get(4).getSentences().size()); | |
20 | + } | |
21 | + } | |
22 | +} | |
0 | 23 | \ No newline at end of file |
... | ... |
nicolas-common/src/test/resources/199704210011.bin
0 → 100644
No preview for this file type
nicolas-core/pom.xml
... | ... | @@ -21,11 +21,8 @@ |
21 | 21 | <groupId>pl.waw.ipipan.zil.summ</groupId> |
22 | 22 | <artifactId>nicolas-model</artifactId> |
23 | 23 | </dependency> |
24 | - <dependency> | |
25 | - <groupId>pl.waw.ipipan.zil.summ</groupId> | |
26 | - <artifactId>nicolas-zero</artifactId> | |
27 | - </dependency> | |
28 | 24 | |
25 | + <!-- internal --> | |
29 | 26 | <dependency> |
30 | 27 | <groupId>pl.waw.ipipan.zil.summ</groupId> |
31 | 28 | <artifactId>pscapi</artifactId> |
... | ... | @@ -35,6 +32,7 @@ |
35 | 32 | <artifactId>utils</artifactId> |
36 | 33 | </dependency> |
37 | 34 | |
35 | + <!-- third party --> | |
38 | 36 | <dependency> |
39 | 37 | <groupId>nz.ac.waikato.cms.weka</groupId> |
40 | 38 | <artifactId>weka-dev</artifactId> |
... | ... | @@ -51,5 +49,17 @@ |
51 | 49 | <groupId>org.apache.commons</groupId> |
52 | 50 | <artifactId>commons-lang3</artifactId> |
53 | 51 | </dependency> |
52 | + | |
53 | + <!-- logging --> | |
54 | + <dependency> | |
55 | + <groupId>org.slf4j</groupId> | |
56 | + <artifactId>slf4j-api</artifactId> | |
57 | + </dependency> | |
58 | + | |
59 | + <!-- test --> | |
60 | + <dependency> | |
61 | + <groupId>junit</groupId> | |
62 | + <artifactId>junit</artifactId> | |
63 | + </dependency> | |
54 | 64 | </dependencies> |
55 | 65 | </project> |
56 | 66 | \ No newline at end of file |
... | ... |
nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/EvalUtils.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.nicolas.eval; | |
2 | + | |
3 | +import org.apache.commons.lang3.time.StopWatch; | |
4 | +import org.apache.commons.lang3.tuple.Pair; | |
5 | +import org.slf4j.Logger; | |
6 | +import org.slf4j.LoggerFactory; | |
7 | +import weka.classifiers.Classifier; | |
8 | +import weka.classifiers.bayes.BayesNet; | |
9 | +import weka.classifiers.bayes.NaiveBayes; | |
10 | +import weka.classifiers.evaluation.Evaluation; | |
11 | +import weka.classifiers.functions.LinearRegression; | |
12 | +import weka.classifiers.functions.Logistic; | |
13 | +import weka.classifiers.functions.SMOreg; | |
14 | +import weka.classifiers.functions.SimpleLogistic; | |
15 | +import weka.classifiers.lazy.IBk; | |
16 | +import weka.classifiers.lazy.KStar; | |
17 | +import weka.classifiers.lazy.LWL; | |
18 | +import weka.classifiers.rules.DecisionTable; | |
19 | +import weka.classifiers.rules.JRip; | |
20 | +import weka.classifiers.rules.PART; | |
21 | +import weka.classifiers.trees.HoeffdingTree; | |
22 | +import weka.classifiers.trees.J48; | |
23 | +import weka.classifiers.trees.LMT; | |
24 | +import weka.classifiers.trees.RandomForest; | |
25 | +import weka.core.Instances; | |
26 | + | |
27 | +import java.util.Arrays; | |
28 | +import java.util.Comparator; | |
29 | +import java.util.Optional; | |
30 | +import java.util.Random; | |
31 | + | |
32 | +public class EvalUtils { | |
33 | + | |
34 | + private static final Logger LOG = LoggerFactory.getLogger(EvalUtils.class); | |
35 | + public static final int NUM_FOLDS = 10; | |
36 | + | |
37 | + private EvalUtils() { | |
38 | + } | |
39 | + | |
40 | + public static void crossvalidateClassification(Instances instances) throws Exception { | |
41 | + StopWatch watch = new StopWatch(); | |
42 | + watch.start(); | |
43 | + | |
44 | + Optional<Pair<Double, String>> max = Arrays.stream(new Classifier[]{new J48(), new RandomForest(), new HoeffdingTree(), new LMT(), | |
45 | + new Logistic(), | |
46 | + new SimpleLogistic(), new BayesNet(), new NaiveBayes(), | |
47 | + new KStar(), new IBk(), new LWL(), | |
48 | + new DecisionTable(), new JRip(), new PART()}).parallel().map(cls -> { | |
49 | + Evaluation eval = null; | |
50 | + try { | |
51 | + eval = new Evaluation(instances); | |
52 | + eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(1)); | |
53 | + } catch (Exception e) { | |
54 | + e.printStackTrace(); | |
55 | + } | |
56 | + double acc = eval.correct() / eval.numInstances(); | |
57 | + String name = cls.getClass().getSimpleName(); | |
58 | + LOG.info(name + " : " + acc); | |
59 | + | |
60 | + return Pair.of(acc, name); | |
61 | + }).max(Comparator.comparingDouble(Pair::getLeft)); | |
62 | + LOG.info("#########"); | |
63 | + LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); | |
64 | + | |
65 | + watch.stop(); | |
66 | + LOG.info("Elapsed time: " + watch); | |
67 | + } | |
68 | + | |
69 | + public static void crossvalidateRegression(Instances instances) { | |
70 | + StopWatch watch = new StopWatch(); | |
71 | + watch.start(); | |
72 | + | |
73 | + Optional<Pair<Double, String>> max = Arrays.stream(new Classifier[]{ | |
74 | + new RandomForest(), new LinearRegression(), new SMOreg()}).parallel().map(cls -> { | |
75 | + Evaluation eval = null; | |
76 | + double acc = 0; | |
77 | + try { | |
78 | + eval = new Evaluation(instances); | |
79 | + eval.crossValidateModel(cls, instances, NUM_FOLDS, new Random(1)); | |
80 | + acc = eval.correlationCoefficient(); | |
81 | + | |
82 | + } catch (Exception e) { | |
83 | + e.printStackTrace(); | |
84 | + } | |
85 | + String name = cls.getClass().getSimpleName(); | |
86 | + LOG.info(name + " : " + acc); | |
87 | + | |
88 | + return Pair.of(acc, name); | |
89 | + }).max(Comparator.comparingDouble(Pair::getLeft)); | |
90 | + LOG.info("#########"); | |
91 | + LOG.info("Best: " + max.get().getRight() + " : " + max.get().getLeft()); | |
92 | + | |
93 | + watch.stop(); | |
94 | + LOG.info("Elapsed time: " + watch); | |
95 | + } | |
96 | +} | |
0 | 97 | \ No newline at end of file |
... | ... |
nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/MentionFeatureExtractor.java
... | ... | @@ -2,6 +2,7 @@ package pl.waw.ipipan.zil.summ.nicolas.mention; |
2 | 2 | |
3 | 3 | import com.google.common.collect.*; |
4 | 4 | import pl.waw.ipipan.zil.multiservice.thrift.types.*; |
5 | +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; | |
5 | 6 | import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureExtractor; |
6 | 7 | import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; |
7 | 8 | import pl.waw.ipipan.zil.summ.nicolas.common.features.Interpretation; |
... | ... | @@ -45,7 +46,7 @@ public class MentionFeatureExtractor extends FeatureExtractor { |
45 | 46 | addBinaryAttribute(prefix + "_is_zero"); |
46 | 47 | addBinaryAttribute(prefix + "_is_named"); |
47 | 48 | addBinaryAttribute(prefix + "_is_pronoun"); |
48 | - addNominalAttribute(prefix + "_ctag", Lists.newArrayList("other", "null", "impt", "subst", "aglt", "ppron3", "ger", "praet", "fin", "num", "interp", "siebie", "brev", "interj", "ppron12", "adj", "burk", "pcon", "bedzie", "adv", "prep", "depr", "xxx", "winien", "conj", "qub", "adja", "ppas", "comp", "pact")); | |
49 | + addNominalAttribute(prefix + "_ctag", Constants.POS_TAGS); | |
49 | 50 | addNominalAttribute(prefix + "_person", Lists.newArrayList("other", "null", "pri", "sec", "ter")); |
50 | 51 | addNominalAttribute(prefix + "_case", Lists.newArrayList("other", "null", "nom", "acc", "dat", "gen", "loc", "inst", "voc")); |
51 | 52 | addNominalAttribute(prefix + "_number", Lists.newArrayList("other", "null", "sg", "pl")); |
... | ... |
nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Crossvalidate.java
... | ... | @@ -17,6 +17,9 @@ public class Crossvalidate { |
17 | 17 | |
18 | 18 | private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); |
19 | 19 | |
20 | + private Crossvalidate() { | |
21 | + } | |
22 | + | |
20 | 23 | public static void main(String[] args) throws Exception { |
21 | 24 | |
22 | 25 | ArffLoader loader = new ArffLoader(); |
... | ... | @@ -26,9 +29,6 @@ public class Crossvalidate { |
26 | 29 | LOG.info(instances.size() + " instances loaded."); |
27 | 30 | LOG.info(instances.numAttributes() + " attributes for each instance."); |
28 | 31 | |
29 | -// while (instances.size() > 10000) | |
30 | -// instances.remove(instances.size() - 1); | |
31 | - | |
32 | 32 | StopWatch watch = new StopWatch(); |
33 | 33 | watch.start(); |
34 | 34 | |
... | ... |
nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/mention/test/Validate.java
... | ... | @@ -14,9 +14,7 @@ import java.io.FileInputStream; |
14 | 14 | import java.io.IOException; |
15 | 15 | import java.io.ObjectInputStream; |
16 | 16 | |
17 | -/** | |
18 | - * Created by me2 on 05.04.16. | |
19 | - */ | |
17 | + | |
20 | 18 | public class Validate { |
21 | 19 | private static final Logger LOG = LoggerFactory.getLogger(Validate.class); |
22 | 20 | |
... | ... |
nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/sentence/test/Crossvalidate.java
1 | 1 | package pl.waw.ipipan.zil.summ.nicolas.sentence.test; |
2 | 2 | |
3 | -import org.apache.commons.lang3.time.StopWatch; | |
4 | 3 | import org.slf4j.Logger; |
5 | 4 | import org.slf4j.LoggerFactory; |
6 | 5 | import pl.waw.ipipan.zil.summ.nicolas.common.Constants; |
7 | -import weka.classifiers.Classifier; | |
8 | -import weka.classifiers.evaluation.Evaluation; | |
6 | +import pl.waw.ipipan.zil.summ.nicolas.eval.EvalUtils; | |
9 | 7 | import weka.core.Instances; |
10 | 8 | import weka.core.converters.ArffLoader; |
11 | 9 | |
12 | 10 | import java.io.File; |
13 | -import java.util.Random; | |
14 | 11 | |
15 | 12 | |
16 | 13 | public class Crossvalidate { |
17 | 14 | |
18 | 15 | private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); |
19 | 16 | |
17 | + private Crossvalidate() { | |
18 | + } | |
19 | + | |
20 | 20 | public static void main(String[] args) throws Exception { |
21 | 21 | |
22 | 22 | ArffLoader loader = new ArffLoader(); |
... | ... | @@ -26,16 +26,6 @@ public class Crossvalidate { |
26 | 26 | LOG.info(instances.size() + " instances loaded."); |
27 | 27 | LOG.info(instances.numAttributes() + " attributes for each instance."); |
28 | 28 | |
29 | - StopWatch watch = new StopWatch(); | |
30 | - watch.start(); | |
31 | - | |
32 | - Classifier tree = Constants.getSentencesClassifier(); | |
33 | - | |
34 | - Evaluation eval = new Evaluation(instances); | |
35 | - eval.crossValidateModel(tree, instances, 10, new Random(1)); | |
36 | - LOG.info(eval.toSummaryString()); | |
37 | - | |
38 | - watch.stop(); | |
39 | - LOG.info("Elapsed time: " + watch); | |
29 | + EvalUtils.crossvalidateRegression(instances); | |
40 | 30 | } |
41 | 31 | } |
... | ... |
nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java renamed to nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinder.java
... | ... | @@ -61,6 +61,8 @@ public class CandidateFinder { |
61 | 61 | } |
62 | 62 | |
63 | 63 | private static boolean isInNominative(TInterpretation interp) { |
64 | - return interp.getCtag().equals("subst") && Arrays.stream(interp.getMsd().split(":")).anyMatch(t -> t.equals("nom")); | |
64 | + boolean isNominative = Arrays.stream(interp.getMsd().split(":")).anyMatch(t -> t.equals("nom")); | |
65 | + boolean isSubst = interp.getCtag().equals("subst"); | |
66 | + return isSubst && isNominative; | |
65 | 67 | } |
66 | 68 | } |
... | ... |
nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainingDataExtractor.java renamed to nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/PrepareTrainingData.java
1 | -package pl.waw.ipipan.zil.summ.nicolas.zero.train; | |
1 | +package pl.waw.ipipan.zil.summ.nicolas.zero; | |
2 | 2 | |
3 | 3 | import com.google.common.collect.Maps; |
4 | 4 | import com.google.common.collect.Sets; |
5 | 5 | import org.apache.commons.io.IOUtils; |
6 | +import org.slf4j.Logger; | |
7 | +import org.slf4j.LoggerFactory; | |
6 | 8 | import pl.waw.ipipan.zil.multiservice.thrift.types.TText; |
7 | 9 | import pl.waw.ipipan.zil.summ.nicolas.common.Constants; |
8 | 10 | import pl.waw.ipipan.zil.summ.nicolas.common.Utils; |
9 | 11 | import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; |
10 | -import pl.waw.ipipan.zil.summ.nicolas.zero.CandidateFinder; | |
11 | -import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroFeatureExtractor; | |
12 | -import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectCandidate; | |
13 | 12 | import weka.core.Attribute; |
14 | 13 | import weka.core.DenseInstance; |
15 | 14 | import weka.core.Instance; |
... | ... | @@ -23,13 +22,15 @@ import java.util.List; |
23 | 22 | import java.util.Map; |
24 | 23 | import java.util.Set; |
25 | 24 | |
26 | -public class TrainingDataExtractor { | |
25 | +public class PrepareTrainingData { | |
26 | + | |
27 | + private static final Logger LOG = LoggerFactory.getLogger(PrepareTrainingData.class); | |
27 | 28 | |
28 | 29 | private static final String IDS_PATH = "corpora/summaries_dev"; |
29 | 30 | private static final String THRIFTED_PATH = "corpora/preprocessed_full_texts/dev/"; |
30 | 31 | private static final String GOLD_ZEROS_PATH = "/zeros.tsv"; |
31 | 32 | |
32 | - private TrainingDataExtractor() { | |
33 | + private PrepareTrainingData() { | |
33 | 34 | } |
34 | 35 | |
35 | 36 | public static void main(String[] args) throws IOException { |
... | ... | @@ -42,7 +43,10 @@ public class TrainingDataExtractor { |
42 | 43 | |
43 | 44 | Instances instances = Utils.createNewInstances(featureExtractor.getAttributesList()); |
44 | 45 | |
46 | + int i = 1; | |
45 | 47 | for (Map.Entry<String, TText> entry : id2preprocessedText.entrySet()) { |
48 | + LOG.info(i++ + "/" + id2preprocessedText.size()); | |
49 | + | |
46 | 50 | String textId = entry.getKey(); |
47 | 51 | |
48 | 52 | TText text = entry.getValue(); |
... | ... |
nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/TrainModel.java renamed to nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/TrainModel.java
nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java renamed to nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroFeatureExtractor.java
... | ... | @@ -2,14 +2,18 @@ package pl.waw.ipipan.zil.summ.nicolas.zero; |
2 | 2 | |
3 | 3 | import com.google.common.collect.Lists; |
4 | 4 | import com.google.common.collect.Maps; |
5 | +import com.google.common.collect.Sets; | |
5 | 6 | import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; |
6 | 7 | import pl.waw.ipipan.zil.multiservice.thrift.types.TText; |
8 | +import pl.waw.ipipan.zil.multiservice.thrift.types.TToken; | |
9 | +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; | |
7 | 10 | import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureExtractor; |
8 | 11 | import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; |
9 | 12 | import weka.core.Attribute; |
10 | 13 | |
11 | 14 | import java.util.List; |
12 | 15 | import java.util.Map; |
16 | +import java.util.Set; | |
13 | 17 | |
14 | 18 | |
15 | 19 | public class ZeroFeatureExtractor extends FeatureExtractor { |
... | ... | @@ -18,13 +22,26 @@ public class ZeroFeatureExtractor extends FeatureExtractor { |
18 | 22 | |
19 | 23 | for (String prefix : new String[]{"antecedent", "candidate"}) { |
20 | 24 | addNumericAttribute(prefix + "_index_in_sent"); |
25 | + addNumericAttribute(prefix + "_first_token_index_in_sent"); | |
21 | 26 | addNumericAttribute(prefix + "_token_count"); |
22 | - addBinaryAttribute(prefix + "_is_zero"); | |
23 | - addBinaryAttribute(prefix + "_is_pronoun"); | |
24 | 27 | addBinaryAttribute(prefix + "_is_named"); |
28 | + addNumericAttribute(prefix + "_sentence_mention_count"); | |
29 | + addNominalAttribute(prefix + "_next_token_pos", Constants.POS_TAGS); | |
30 | + addNominalAttribute(prefix + "_prev_token_pos", Constants.POS_TAGS); | |
31 | + addBinaryAttribute(prefix + "_is_nested"); | |
32 | + addBinaryAttribute(prefix + "_is_nesting"); | |
25 | 33 | } |
26 | 34 | |
35 | + addNumericAttribute("chain_length"); | |
36 | + | |
27 | 37 | addBinaryAttribute("pair_equal_orth"); |
38 | + addBinaryAttribute("pair_equal_ignore_case_orth"); | |
39 | + addBinaryAttribute("pair_equal_base"); | |
40 | + addBinaryAttribute("pair_equal_number"); | |
41 | + addBinaryAttribute("pair_equal_head_base"); | |
42 | + | |
43 | + addNumericAttribute("pair_sent_distance"); | |
44 | + addNumericAttribute("pair_par_distance"); | |
28 | 45 | |
29 | 46 | addNominalAttribute("score", Lists.newArrayList("bad", "good")); |
30 | 47 | fillSortedAttributes("score"); |
... | ... | @@ -53,17 +70,57 @@ public class ZeroFeatureExtractor extends FeatureExtractor { |
53 | 70 | addMentionFeatures(helper, candidateFeatures, mention, "candidate"); |
54 | 71 | addMentionFeatures(helper, candidateFeatures, antecedent, "antecedent"); |
55 | 72 | |
56 | - candidateFeatures.put(getAttributeByName("pair_equal_orth"), toBinary(helper.getMentionOrth(mention).equalsIgnoreCase(helper.getMentionOrth(antecedent)))); | |
73 | + candidateFeatures.put(getAttributeByName("pair_equal_orth"), toBinary(helper.getMentionOrth(mention).equals(helper.getMentionOrth(antecedent)))); | |
74 | + candidateFeatures.put(getAttributeByName("pair_equal_base"), toBinary(helper.getMentionBase(mention).equalsIgnoreCase(helper.getMentionBase(antecedent)))); | |
75 | + candidateFeatures.put(getAttributeByName("pair_equal_ignore_case_orth"), toBinary(helper.getMentionOrth(mention).equalsIgnoreCase(helper.getMentionOrth(antecedent)))); | |
76 | + candidateFeatures.put(getAttributeByName("pair_equal_head_base"), toBinary(helper.getMentionHeadToken(mention).getChosenInterpretation().getBase().equalsIgnoreCase(helper.getMentionHeadToken(antecedent).getChosenInterpretation().getBase()))); | |
77 | + | |
78 | + candidateFeatures.put(getAttributeByName("pair_sent_distance"), (double) Math.abs(helper.getSentIndex(helper.getMentionSentence(mention)) - helper.getSentIndex(helper.getMentionSentence(antecedent)))); | |
79 | + candidateFeatures.put(getAttributeByName("pair_par_distance"), (double) Math.abs(helper.getParIndex(helper.getMentionParagraph(mention)) - helper.getParIndex(helper.getMentionParagraph(antecedent)))); | |
80 | + | |
81 | + String mentionNumber = getNumber(helper.getMentionHeadToken(mention)); | |
82 | + String antecedentNumber = getNumber(helper.getMentionHeadToken(antecedent)); | |
83 | + candidateFeatures.put(getAttributeByName("pair_equal_number"), toBinary(mentionNumber != null && mentionNumber.equals(antecedentNumber))); | |
84 | + | |
85 | + candidateFeatures.put(getAttributeByName("chain_length"), (double) helper.getChainLength(mention)); | |
57 | 86 | |
58 | 87 | return candidateFeatures; |
59 | 88 | } |
60 | 89 | |
90 | + private String getNumber(TToken token) { | |
91 | + Set<String> msd = Sets.newHashSet(token.getChosenInterpretation().getMsd().split(":")); | |
92 | + if (msd.contains("sg")) | |
93 | + return "sg"; | |
94 | + else if (msd.contains("pl")) | |
95 | + return "pl"; | |
96 | + else | |
97 | + return null; | |
98 | + } | |
99 | + | |
61 | 100 | private void addMentionFeatures(FeatureHelper helper, Map<Attribute, Double> candidateFeatures, TMention mention, String attributePrefix) { |
62 | 101 | candidateFeatures.put(getAttributeByName(attributePrefix + "_index_in_sent"), (double) helper.getMentionIndexInSent(mention)); |
102 | + candidateFeatures.put(getAttributeByName(attributePrefix + "_first_token_index_in_sent"), (double) helper.getMentionFirstTokenIndex(mention)); | |
103 | + | |
63 | 104 | candidateFeatures.put(getAttributeByName(attributePrefix + "_token_count"), (double) mention.getChildIdsSize()); |
64 | - candidateFeatures.put(getAttributeByName(attributePrefix + "_is_zero"), toBinary(mention.isZeroSubject())); | |
65 | - candidateFeatures.put(getAttributeByName(attributePrefix + "_is_pronoun"), toBinary(helper.getMentionHeadToken(mention).getChosenInterpretation().getCtag().matches("ppron.*"))); | |
66 | 105 | candidateFeatures.put(getAttributeByName(attributePrefix + "_is_named"), toBinary(helper.isMentionNamedEntity(mention))); |
106 | + candidateFeatures.put(getAttributeByName(attributePrefix + "_sentence_mention_count"), (double) helper.getMentionSentence(mention).getMentions().size()); | |
107 | + | |
108 | + TToken nextToken = helper.getTokenAfterMention(mention); | |
109 | + addNominalAttributeValue(nextToken == null ? "end" : nextToken.getChosenInterpretation().getCtag(), candidateFeatures, attributePrefix + "_next_token_pos"); | |
110 | + TToken prevToken = helper.getTokenBeforeMention(mention); | |
111 | + addNominalAttributeValue(prevToken == null ? "end" : prevToken.getChosenInterpretation().getCtag(), candidateFeatures, attributePrefix + "_prev_token_pos"); | |
112 | + | |
113 | + candidateFeatures.put(getAttributeByName(attributePrefix + "_is_nested"), toBinary(helper.isNested(mention))); | |
114 | + candidateFeatures.put(getAttributeByName(attributePrefix + "_is_nesting"), toBinary(helper.isNesting(mention))); | |
115 | + | |
67 | 116 | } |
68 | 117 | |
118 | + private void addNominalAttributeValue(String value, Map<Attribute, Double> attribute2value, String attributeName) { | |
119 | + Attribute att = getAttributeByName(attributeName); | |
120 | + int index = att.indexOfValue(value); | |
121 | + if (index == -1) | |
122 | + LOG.warn(value + " not found for attribute " + attributeName); | |
123 | + attribute2value.put(att, (double) (index == -1 ? att.indexOfValue("other") : index)); | |
124 | + } | |
69 | 125 | } |
126 | + | |
... | ... |
nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/train/ZeroScorer.java renamed to nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroScorer.java
1 | -package pl.waw.ipipan.zil.summ.nicolas.zero.train; | |
1 | +package pl.waw.ipipan.zil.summ.nicolas.zero; | |
2 | 2 | |
3 | 3 | import com.google.common.collect.Maps; |
4 | 4 | import org.apache.commons.csv.CSVFormat; |
... | ... | @@ -7,7 +7,6 @@ import org.apache.commons.csv.CSVRecord; |
7 | 7 | import org.apache.commons.csv.QuoteMode; |
8 | 8 | import pl.waw.ipipan.zil.summ.nicolas.common.Constants; |
9 | 9 | import pl.waw.ipipan.zil.summ.nicolas.common.features.FeatureHelper; |
10 | -import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectCandidate; | |
11 | 10 | |
12 | 11 | import java.io.IOException; |
13 | 12 | import java.io.InputStream; |
... | ... |
nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectCandidate.java renamed to nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectCandidate.java
nicolas-zero/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java renamed to nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/ZeroSubjectInjector.java
... | ... | @@ -5,7 +5,6 @@ import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; |
5 | 5 | import pl.waw.ipipan.zil.multiservice.thrift.types.TText; |
6 | 6 | import pl.waw.ipipan.zil.summ.nicolas.common.Constants; |
7 | 7 | import pl.waw.ipipan.zil.summ.nicolas.common.Utils; |
8 | -import pl.waw.ipipan.zil.summ.nicolas.zero.train.TrainingDataExtractor; | |
9 | 8 | import weka.classifiers.Classifier; |
10 | 9 | import weka.core.Instance; |
11 | 10 | import weka.core.Instances; |
... | ... | @@ -32,7 +31,7 @@ public class ZeroSubjectInjector { |
32 | 31 | Set<String> summarySentenceIds = selectedSentences.stream().map(TSentence::getId).collect(Collectors.toSet()); |
33 | 32 | List<ZeroSubjectCandidate> zeroSubjectCandidates = CandidateFinder.findZeroSubjectCandidates(text, summarySentenceIds); |
34 | 33 | Map<ZeroSubjectCandidate, Instance> candidate2instance = |
35 | - TrainingDataExtractor.extractInstancesFromZeroCandidates(zeroSubjectCandidates, text, featureExtractor); | |
34 | + PrepareTrainingData.extractInstancesFromZeroCandidates(zeroSubjectCandidates, text, featureExtractor); | |
36 | 35 | |
37 | 36 | Set<String> result = Sets.newHashSet(); |
38 | 37 | for (Map.Entry<ZeroSubjectCandidate, Instance> entry : candidate2instance.entrySet()) { |
... | ... |
nicolas-core/src/main/java/pl/waw/ipipan/zil/summ/nicolas/zero/test/Crossvalidate.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.nicolas.zero.test; | |
2 | + | |
3 | +import org.slf4j.Logger; | |
4 | +import org.slf4j.LoggerFactory; | |
5 | +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; | |
6 | +import pl.waw.ipipan.zil.summ.nicolas.eval.EvalUtils; | |
7 | +import weka.core.Instances; | |
8 | +import weka.core.converters.ArffLoader; | |
9 | + | |
10 | +import java.io.File; | |
11 | + | |
12 | + | |
13 | +public class Crossvalidate { | |
14 | + | |
15 | + private static final Logger LOG = LoggerFactory.getLogger(Crossvalidate.class); | |
16 | + | |
17 | + private Crossvalidate() { | |
18 | + } | |
19 | + | |
20 | + public static void main(String[] args) throws Exception { | |
21 | + | |
22 | + ArffLoader loader = new ArffLoader(); | |
23 | + loader.setFile(new File(Constants.ZERO_DATASET_PATH)); | |
24 | + Instances instances = loader.getDataSet(); | |
25 | + instances.setClassIndex(0); | |
26 | + LOG.info(instances.size() + " instances loaded."); | |
27 | + LOG.info(instances.numAttributes() + " attributes for each instance."); | |
28 | + | |
29 | + EvalUtils.crossvalidateClassification(instances); | |
30 | + } | |
31 | +} | |
... | ... |
nicolas-zero/src/main/resources/zeros.tsv renamed to nicolas-core/src/main/resources/zeros.tsv
nicolas-zero/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java renamed to nicolas-core/src/test/java/pl/waw/ipipan/zil/summ/nicolas/zero/CandidateFinderTest.java
nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin renamed to nicolas-core/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_serialized_text.bin
No preview for this file type
nicolas-zero/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt renamed to nicolas-core/src/test/resources/pl/waw/ipipan/zil/summ/nicolas/zero/sample_summary_sentence_ids.txt
nicolas-train/src/test/java/pl/waw/ipipan/zil/summ/nicolas/train/multiservice/NLPProcessTest.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.nicolas.train.multiservice; | |
2 | + | |
3 | +import org.junit.Test; | |
4 | +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; | |
5 | + | |
6 | +import java.io.File; | |
7 | + | |
8 | +public class NLPProcessTest { | |
9 | + @Test | |
10 | + public void shouldProcessSampleText() throws Exception { | |
11 | + String text = "Ala ma kota. Ala ma też psa."; | |
12 | + TText processed = NLPProcess.annotate(text); | |
13 | + processed.getParagraphs().stream().flatMap(p->p.getSentences().stream()).forEach(s->System.out.println(s.getId())); | |
14 | + File targetFile = new File("sample_serialized_text.bin"); | |
15 | + NLPProcess.serialize(processed, targetFile); | |
16 | + } | |
17 | +} | |
0 | 18 | \ No newline at end of file |
... | ... |
nicolas-zero/pom.xml deleted
1 | -<?xml version="1.0" encoding="UTF-8"?> | |
2 | -<project xmlns="http://maven.apache.org/POM/4.0.0" | |
3 | - xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | |
4 | - xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | |
5 | - <modelVersion>4.0.0</modelVersion> | |
6 | - <parent> | |
7 | - <artifactId>nicolas-container</artifactId> | |
8 | - <groupId>pl.waw.ipipan.zil.summ</groupId> | |
9 | - <version>1.0-SNAPSHOT</version> | |
10 | - </parent> | |
11 | - | |
12 | - <artifactId>nicolas-zero</artifactId> | |
13 | - | |
14 | - <dependencies> | |
15 | - <!-- project --> | |
16 | - <dependency> | |
17 | - <groupId>pl.waw.ipipan.zil.summ</groupId> | |
18 | - <artifactId>nicolas-common</artifactId> | |
19 | - </dependency> | |
20 | - | |
21 | - <!-- third party --> | |
22 | - <dependency> | |
23 | - <groupId>org.apache.commons</groupId> | |
24 | - <artifactId>commons-csv</artifactId> | |
25 | - </dependency> | |
26 | - <dependency> | |
27 | - <groupId>commons-io</groupId> | |
28 | - <artifactId>commons-io</artifactId> | |
29 | - </dependency> | |
30 | - <dependency> | |
31 | - <groupId>org.apache.commons</groupId> | |
32 | - <artifactId>commons-lang3</artifactId> | |
33 | - </dependency> | |
34 | - | |
35 | - <!-- logging --> | |
36 | - <dependency> | |
37 | - <groupId>org.slf4j</groupId> | |
38 | - <artifactId>slf4j-api</artifactId> | |
39 | - </dependency> | |
40 | - | |
41 | - <!-- test --> | |
42 | - <dependency> | |
43 | - <groupId>junit</groupId> | |
44 | - <artifactId>junit</artifactId> | |
45 | - </dependency> | |
46 | - </dependencies> | |
47 | - | |
48 | -</project> | |
49 | 0 | \ No newline at end of file |