Commit 7e387f1cdc557ac810c9c4118ddff9e36c78a776
1 parent
2169abf8
training code
Showing
25 changed files
with
397 additions
and
696 deletions
.gitignore
nicolas-cli/pom.xml
... | ... | @@ -53,28 +53,31 @@ |
53 | 53 | <build> |
54 | 54 | <plugins> |
55 | 55 | <plugin> |
56 | + <groupId>org.apache.maven.plugins</groupId> | |
56 | 57 | <artifactId>maven-assembly-plugin</artifactId> |
57 | - <configuration> | |
58 | - <appendAssemblyId>false</appendAssemblyId> | |
59 | - <archive> | |
60 | - <manifest> | |
61 | - <mainClass>pl.waw.ipipan.zil.summ.nicolas.cli.Main</mainClass> | |
62 | - </manifest> | |
63 | - </archive> | |
64 | - <descriptorRefs> | |
65 | - <descriptorRef>jar-with-dependencies</descriptorRef> | |
66 | - </descriptorRefs> | |
67 | - </configuration> | |
68 | 58 | <executions> |
69 | 59 | <execution> |
70 | - <id>make-assembly</id> | |
60 | + <id>jar-with-dependencies</id> | |
71 | 61 | <phase>package</phase> |
72 | 62 | <goals> |
73 | 63 | <goal>single</goal> |
74 | 64 | </goals> |
65 | + <configuration> | |
66 | + <descriptorRefs> | |
67 | + <descriptorRef>jar-with-dependencies</descriptorRef> | |
68 | + </descriptorRefs> | |
69 | + <appendAssemblyId>false</appendAssemblyId> | |
70 | + <finalName>nicolas-cli</finalName> | |
71 | + <archive> | |
72 | + <manifest> | |
73 | + <mainClass>pl.waw.ipipan.zil.summ.nicolas.cli.Main</mainClass> | |
74 | + </manifest> | |
75 | + </archive> | |
76 | + </configuration> | |
75 | 77 | </execution> |
76 | 78 | </executions> |
77 | 79 | </plugin> |
80 | + | |
78 | 81 | </plugins> |
79 | 82 | </build> |
80 | 83 | </project> |
81 | 84 | \ No newline at end of file |
... | ... |
nicolas-cli/src/main/java/pl/waw/ipipan/zil/summ/nicolas/cli/Cli.java
... | ... | @@ -10,18 +10,19 @@ import org.slf4j.Logger; |
10 | 10 | import org.slf4j.LoggerFactory; |
11 | 11 | |
12 | 12 | import java.io.File; |
13 | +import java.io.IOException; | |
13 | 14 | |
14 | 15 | class Cli { |
15 | 16 | |
16 | 17 | private static final Logger LOG = LoggerFactory.getLogger(Cli.class); |
17 | 18 | |
18 | - @Parameter(names = {"-help", "-h"}, description = "Print help") | |
19 | + @Parameter(names = {"-help", "-h"}, description = "Print help", help = true) | |
19 | 20 | private boolean help = false; |
20 | 21 | |
21 | 22 | @Parameter(names = {"-input", "-i"}, description = "Input text file to summarize", required = true, validateWith = FileValidator.class, converter = FileConverter.class) |
22 | 23 | private File inputFile; |
23 | 24 | |
24 | - @Parameter(names = {"-output", "-o"}, description = "Output file path for summary", required = true, validateWith = FileValidator.class, converter = FileConverter.class) | |
25 | + @Parameter(names = {"-output", "-o"}, description = "Output file path for summary", required = true, validateWith = OutputFileValidator.class, converter = FileConverter.class) | |
25 | 26 | private File outputFile; |
26 | 27 | |
27 | 28 | @Parameter(names = {"-target", "-t"}, description = "Target summary token count", required = true, validateWith = PositiveInteger.class) |
... | ... | @@ -84,4 +85,19 @@ class Cli { |
84 | 85 | } |
85 | 86 | |
86 | 87 | } |
88 | + | |
89 | + public static class OutputFileValidator implements IParameterValidator { | |
90 | + | |
91 | + @Override | |
92 | + public void validate(String name, String value) { | |
93 | + File file = new File(value); | |
94 | + try { | |
95 | + file.createNewFile(); | |
96 | + } catch (IOException ex) { | |
97 | + throw new ParameterException("Parameter " + name | |
98 | + + " should be a valid file path (found " + value + ")", ex); | |
99 | + } | |
100 | + } | |
101 | + | |
102 | + } | |
87 | 103 | } |
... | ... |
nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java
nicolas-train/pom.xml
... | ... | @@ -40,6 +40,10 @@ |
40 | 40 | <groupId>pl.waw.ipipan.zil.multiservice</groupId> |
41 | 41 | <artifactId>utils</artifactId> |
42 | 42 | </dependency> |
43 | + <dependency> | |
44 | + <groupId>pl.waw.ipipan.zil.summ</groupId> | |
45 | + <artifactId>eval</artifactId> | |
46 | + </dependency> | |
43 | 47 | |
44 | 48 | <!-- third party --> |
45 | 49 | <dependency> |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Main.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.nicolas.train; | |
2 | + | |
3 | +import pl.waw.ipipan.zil.summ.nicolas.train.pipeline.*; | |
4 | + | |
5 | +public class Main { | |
6 | + | |
7 | + private Main() { | |
8 | + } | |
9 | + | |
10 | + public static void main(String[] args) throws Exception { | |
11 | + DownloadCorpus.main(args); | |
12 | + DownloadTrainingResources.main(args); | |
13 | + ExtractGoldSummaries.main(args); | |
14 | + CreateOptimalSummaries.main(args); | |
15 | + PrepareTrainingData.main(args); | |
16 | + TrainAllModels.main(args); | |
17 | + } | |
18 | +} | |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/PathConstants.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.nicolas.train; | |
2 | + | |
3 | +import net.lingala.zip4j.core.ZipFile; | |
4 | +import net.lingala.zip4j.exception.ZipException; | |
5 | +import org.apache.commons.io.FileUtils; | |
6 | +import org.slf4j.Logger; | |
7 | +import org.slf4j.LoggerFactory; | |
8 | + | |
9 | +import java.io.File; | |
10 | +import java.io.IOException; | |
11 | +import java.net.URL; | |
12 | + | |
13 | +public class PathConstants { | |
14 | + | |
15 | + private static final Logger LOG = LoggerFactory.getLogger(PathConstants.class); | |
16 | + | |
17 | + public static final String CORPUS_DOWNLOAD_URL = "http://zil.ipipan.waw.pl/PolishSummariesCorpus?action=AttachFile&do=get&target=PSC_1.0.zip"; | |
18 | + public static final String PREPROCESSED_CORPUS_DOWNLOAD_URL = "http://zil.ipipan.waw.pl/Nicolas?action=AttachFile&do=get&target=all-preprocessed.zip"; | |
19 | + public static final String SUMMARY_SENTENCE_IDS_DOWNLOAD_URL = "http://zil.ipipan.waw.pl/Nicolas?action=AttachFile&do=get&target=train-zero-sentence-ids.zip"; | |
20 | + public static final String ZERO_TRAINING_CORPUS_URL = "http://zil.ipipan.waw.pl/Nicolas?action=AttachFile&do=get&target=train-zero.tsv"; | |
21 | + | |
22 | + public static final File WORKING_DIR = new File("data"); | |
23 | + | |
24 | + public static final File ZIPPED_CORPUS_FILE = new File(WORKING_DIR, "PSC_1.0.zip"); | |
25 | + public static final File ZIPPED_PREPROCESSED_CORPUS_FILE = new File(WORKING_DIR, "all-preprocessed.zip"); | |
26 | + public static final File ZIPPED_SUMMARY_SENTENCE_IDS_FILE = new File(WORKING_DIR, "train-zero-sentence-ids.zip"); | |
27 | + | |
28 | + public static final File EXTRACTED_CORPUS_DIR = new File(WORKING_DIR, "corpus"); | |
29 | + public static final File EXTRACTED_CORPUS_DATA_DIR = new File(new File(EXTRACTED_CORPUS_DIR, "PSC_1.0"), "data"); | |
30 | + public static final File SUMMARY_SENTENCE_IDS_DIR = new File(WORKING_DIR, "train-zero-sentence-ids"); | |
31 | + public static final File PREPROCESSED_CORPUS_DIR = new File(WORKING_DIR, "all-preprocessed"); | |
32 | + public static final File GOLD_TEST_SUMMARIES_DIR = new File(WORKING_DIR, "test-gold"); | |
33 | + public static final File GOLD_TRAIN_SUMMARIES_DIR = new File(WORKING_DIR, "train-gold"); | |
34 | + public static final File OPTIMAL_SUMMARIES_DIR = new File(WORKING_DIR, "train-optimal"); | |
35 | + public static final File ZERO_TRAINING_CORPUS = new File(WORKING_DIR, "train-zero.tsv"); | |
36 | + | |
37 | + public static final File ARFF_DIR = new File(WORKING_DIR, "train-arff"); | |
38 | + public static final File MENTION_ARFF = new File(ARFF_DIR, "mentions.arff"); | |
39 | + public static final File SENTENCE_ARFF = new File(ARFF_DIR, "sentences.arff"); | |
40 | + public static final File ZERO_ARFF = new File(ARFF_DIR, "zeros.arff"); | |
41 | + | |
42 | + private PathConstants() { | |
43 | + } | |
44 | + | |
45 | + public static File createFolder(File folder) { | |
46 | + if (folder.mkdir()) { | |
47 | + LOG.info("Created directory at: {}.", folder.getPath()); | |
48 | + } else { | |
49 | + LOG.info("Directory already present at: {}.", folder.getPath()); | |
50 | + } | |
51 | + return folder; | |
52 | + } | |
53 | + | |
54 | + public static void downloadFile(String fileUrl, File targetFile) throws IOException { | |
55 | + if (!targetFile.exists()) { | |
56 | + LOG.info("Downloading file from url {} to file {} ...", fileUrl, targetFile); | |
57 | + FileUtils.copyURLToFile(new URL(fileUrl), targetFile); | |
58 | + LOG.info("done."); | |
59 | + } else { | |
60 | + LOG.info("File {} already downloaded.", targetFile); | |
61 | + } | |
62 | + } | |
63 | + | |
64 | + public static void downloadFileAndExtract(String url, File targetZipFile, File targetDir) throws IOException, ZipException { | |
65 | + downloadFile(url, targetZipFile); | |
66 | + extractZipFile(targetZipFile, targetDir); | |
67 | + } | |
68 | + | |
69 | + private static void extractZipFile(File targetZipFile, File targetDir) throws ZipException { | |
70 | + if (targetDir.exists()) { | |
71 | + LOG.info("Zip file {} already extracted to dir {}.", targetZipFile, targetDir); | |
72 | + } else { | |
73 | + createFolder(targetDir); | |
74 | + ZipFile zipFile = new ZipFile(targetZipFile); | |
75 | + zipFile.extractAll(targetDir.getPath()); | |
76 | + LOG.info("Extracted zip file: {} to dir: {}.", targetZipFile, targetDir); | |
77 | + } | |
78 | + } | |
79 | +} | |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/MentionScorer.java renamed to nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/MentionScorer.java
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/SentenceScorer.java renamed to nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/SentenceScorer.java
1 | -package pl.waw.ipipan.zil.summ.nicolas.train.model.sentence; | |
1 | +package pl.waw.ipipan.zil.summ.nicolas.train.model; | |
2 | 2 | |
3 | 3 | import com.google.common.collect.HashMultiset; |
4 | 4 | import com.google.common.collect.Maps; |
5 | 5 | import com.google.common.collect.Multiset; |
6 | -import com.google.common.collect.Sets; | |
7 | 6 | import pl.waw.ipipan.zil.multiservice.thrift.types.TParagraph; |
8 | 7 | import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; |
9 | 8 | import pl.waw.ipipan.zil.multiservice.thrift.types.TText; |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/ModelConstants.java renamed to nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/Settings.java
1 | -package pl.waw.ipipan.zil.summ.nicolas.train.model.common; | |
1 | +package pl.waw.ipipan.zil.summ.nicolas.train.model; | |
2 | 2 | |
3 | 3 | import weka.classifiers.Classifier; |
4 | 4 | import weka.classifiers.trees.RandomForest; |
5 | 5 | |
6 | -public class ModelConstants { | |
6 | +public class Settings { | |
7 | 7 | |
8 | - public static final String MENTION_DATASET_PATH = "data/arff/mentions_train.arff"; | |
9 | - public static final String SENTENCE_DATASET_PATH = "data/arff/sentences_train.arff"; | |
10 | - public static final String ZERO_DATASET_PATH = "data/arff/zeros_train.arff"; | |
11 | - | |
12 | - private static final int NUM_ITERATIONS = 250; | |
8 | + private static final int NUM_ITERATIONS = 20; | |
13 | 9 | private static final int NUM_EXECUTION_SLOTS = 8; |
14 | 10 | private static final int SEED = 0; |
15 | 11 | |
16 | - private ModelConstants() { | |
12 | + private Settings() { | |
17 | 13 | } |
18 | 14 | |
19 | 15 | public static Classifier getMentionClassifier() { |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/ZeroScorer.java renamed to nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/ZeroScorer.java
1 | -package pl.waw.ipipan.zil.summ.nicolas.train.model.zero; | |
1 | +package pl.waw.ipipan.zil.summ.nicolas.train.model; | |
2 | 2 | |
3 | 3 | import com.google.common.collect.Maps; |
4 | 4 | import org.apache.commons.csv.CSVFormat; |
... | ... | @@ -9,9 +9,7 @@ import pl.waw.ipipan.zil.summ.nicolas.common.Constants; |
9 | 9 | import pl.waw.ipipan.zil.summ.nicolas.features.FeatureHelper; |
10 | 10 | import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectCandidate; |
11 | 11 | |
12 | -import java.io.IOException; | |
13 | -import java.io.InputStream; | |
14 | -import java.io.InputStreamReader; | |
12 | +import java.io.*; | |
15 | 13 | import java.util.List; |
16 | 14 | import java.util.Map; |
17 | 15 | |
... | ... | @@ -21,8 +19,8 @@ public class ZeroScorer { |
21 | 19 | |
22 | 20 | private final Map<String, Boolean> candidateEncoding2Decision = Maps.newHashMap(); |
23 | 21 | |
24 | - public ZeroScorer(String goldZerosResourcePath) throws IOException { | |
25 | - try (InputStream stream = ZeroScorer.class.getResourceAsStream(goldZerosResourcePath); | |
22 | + public ZeroScorer(File zeroTrainingCorpusFile) throws IOException { | |
23 | + try (InputStream stream = new FileInputStream(zeroTrainingCorpusFile); | |
26 | 24 | InputStreamReader reader = new InputStreamReader(stream, Constants.ENCODING); |
27 | 25 | CSVParser parser = new CSVParser(reader, CSVFormat.DEFAULT.withDelimiter(DELIMITER).withEscape('|').withQuoteMode(QuoteMode.NONE).withQuote('~'))) { |
28 | 26 | List<CSVRecord> records = parser.getRecords(); |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/TrainModelCommon.java deleted
1 | -package pl.waw.ipipan.zil.summ.nicolas.train.model.common; | |
2 | - | |
3 | -import org.apache.commons.lang3.time.StopWatch; | |
4 | -import org.slf4j.Logger; | |
5 | -import org.slf4j.LoggerFactory; | |
6 | -import weka.classifiers.Classifier; | |
7 | -import weka.core.Instances; | |
8 | -import weka.core.converters.ArffLoader; | |
9 | - | |
10 | -import java.io.File; | |
11 | -import java.io.FileOutputStream; | |
12 | -import java.io.ObjectOutputStream; | |
13 | -import java.util.logging.LogManager; | |
14 | - | |
15 | -@SuppressWarnings("squid:S2118") | |
16 | -public class TrainModelCommon { | |
17 | - | |
18 | - private static final Logger LOG = LoggerFactory.getLogger(TrainModelCommon.class); | |
19 | - | |
20 | - private static final String TARGET_MODEL_DIR = "nicolas-model/src/main/resources"; | |
21 | - | |
22 | - private TrainModelCommon() { | |
23 | - } | |
24 | - | |
25 | - public static void trainAndSaveModel(String datasetPath, Classifier classifier, String targetPath) throws Exception { | |
26 | - LogManager.getLogManager().reset(); // disable WEKA logging | |
27 | - | |
28 | - ArffLoader loader = new ArffLoader(); | |
29 | - loader.setFile(new File(datasetPath)); | |
30 | - Instances instances = loader.getDataSet(); | |
31 | - instances.setClassIndex(0); | |
32 | - LOG.info("{} instances loaded.", instances.size()); | |
33 | - LOG.info("{} attributes for each instance.", instances.numAttributes()); | |
34 | - | |
35 | - StopWatch watch = new StopWatch(); | |
36 | - watch.start(); | |
37 | - | |
38 | - LOG.info("Building classifier..."); | |
39 | - classifier.buildClassifier(instances); | |
40 | - LOG.info("...done. Build classifier: {}", classifier); | |
41 | - | |
42 | - String target = TARGET_MODEL_DIR + targetPath; | |
43 | - LOG.info("Saving classifier at: {}", target); | |
44 | - try (ObjectOutputStream oos = new ObjectOutputStream( | |
45 | - new FileOutputStream(target))) { | |
46 | - oos.writeObject(classifier); | |
47 | - } | |
48 | - | |
49 | - watch.stop(); | |
50 | - LOG.info("Elapsed time: {}", watch); | |
51 | - } | |
52 | -} |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/TrainMentionModel.java deleted
1 | -package pl.waw.ipipan.zil.summ.nicolas.train.model.mention; | |
2 | - | |
3 | -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; | |
4 | -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; | |
5 | -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.TrainModelCommon; | |
6 | -import weka.classifiers.Classifier; | |
7 | - | |
8 | -public class TrainMentionModel { | |
9 | - | |
10 | - private TrainMentionModel() { | |
11 | - } | |
12 | - | |
13 | - public static void main(String[] args) throws Exception { | |
14 | - Classifier classifier = ModelConstants.getMentionClassifier(); | |
15 | - String datasetPath = ModelConstants.MENTION_DATASET_PATH; | |
16 | - String targetPath = Constants.MENTION_MODEL_RESOURCE_PATH; | |
17 | - TrainModelCommon.trainAndSaveModel(datasetPath, classifier, targetPath); | |
18 | - } | |
19 | - | |
20 | -} |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/TrainSentenceModel.java deleted
1 | -package pl.waw.ipipan.zil.summ.nicolas.train.model.sentence; | |
2 | - | |
3 | -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; | |
4 | -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; | |
5 | -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.TrainModelCommon; | |
6 | -import weka.classifiers.Classifier; | |
7 | - | |
8 | -public class TrainSentenceModel { | |
9 | - | |
10 | - private TrainSentenceModel() { | |
11 | - } | |
12 | - | |
13 | - public static void main(String[] args) throws Exception { | |
14 | - Classifier classifier = ModelConstants.getSentenceClassifier(); | |
15 | - String datasetPath = ModelConstants.SENTENCE_DATASET_PATH; | |
16 | - String targetPath = Constants.SENTENCE_MODEL_RESOURCE_PATH; | |
17 | - TrainModelCommon.trainAndSaveModel(datasetPath, classifier, targetPath); | |
18 | - } | |
19 | - | |
20 | -} |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/TrainZeroModel.java deleted
1 | -package pl.waw.ipipan.zil.summ.nicolas.train.model.zero; | |
2 | - | |
3 | -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; | |
4 | -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; | |
5 | -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.TrainModelCommon; | |
6 | -import weka.classifiers.Classifier; | |
7 | - | |
8 | -public class TrainZeroModel { | |
9 | - | |
10 | - private TrainZeroModel() { | |
11 | - } | |
12 | - | |
13 | - public static void main(String[] args) throws Exception { | |
14 | - Classifier classifier = ModelConstants.getZeroClassifier(); | |
15 | - String datasetPath = ModelConstants.ZERO_DATASET_PATH; | |
16 | - String targetPath = Constants.ZERO_MODEL_RESOURCE_PATH; | |
17 | - TrainModelCommon.trainAndSaveModel(datasetPath, classifier, targetPath); | |
18 | - } | |
19 | - | |
20 | -} |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/CreateOptimalSummaries.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; | |
2 | + | |
3 | + | |
4 | +import com.google.common.collect.HashMultiset; | |
5 | +import com.google.common.collect.Lists; | |
6 | +import com.google.common.collect.Maps; | |
7 | +import com.google.common.collect.Multiset; | |
8 | +import org.apache.commons.io.FileUtils; | |
9 | +import pl.waw.ipipan.zil.summ.eval.Main; | |
10 | +import pl.waw.ipipan.zil.summ.eval.rouge.RougeN; | |
11 | +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; | |
12 | + | |
13 | +import java.io.File; | |
14 | +import java.io.IOException; | |
15 | +import java.util.*; | |
16 | +import java.util.stream.Collectors; | |
17 | + | |
18 | +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; | |
19 | + | |
20 | +public class CreateOptimalSummaries { | |
21 | + | |
22 | + private static final int ROUGE_N = 1; | |
23 | + private static final String SUMMARY_SUFFIX = "_optimal.txt"; | |
24 | + | |
25 | + private CreateOptimalSummaries() { | |
26 | + } | |
27 | + | |
28 | + public static void main(String[] args) throws IOException { | |
29 | + createFolder(OPTIMAL_SUMMARIES_DIR); | |
30 | + Map<String, Set<File>> id2goldSummaryFiles = Main.loadFiles(GOLD_TRAIN_SUMMARIES_DIR); | |
31 | + createUpperBoundSummaries(id2goldSummaryFiles, OPTIMAL_SUMMARIES_DIR); | |
32 | + } | |
33 | + | |
34 | + private static void createUpperBoundSummaries(Map<String, Set<File>> id2goldSummaryFiles, File targetDir) | |
35 | + throws IOException { | |
36 | + for (Map.Entry<String, Set<File>> entry : id2goldSummaryFiles.entrySet()) { | |
37 | + String id = entry.getKey(); | |
38 | + Set<File> goldSummaryFiles = entry.getValue(); | |
39 | + Set<String> goldSummaries = goldSummaryFiles.stream().map(Main::loadSummaryFromFile) | |
40 | + .collect(Collectors.toSet()); | |
41 | + int averageGoldWordCount = (int) goldSummaries.stream().mapToDouble(s -> RougeN.tokenize(s).size()) | |
42 | + .average().orElse(0); | |
43 | + String optimalSummary = createOptimalSummary(goldSummaries, ROUGE_N, averageGoldWordCount); | |
44 | + File targetFile = new File(targetDir, id + SUMMARY_SUFFIX); | |
45 | + FileUtils.writeStringToFile(targetFile, optimalSummary, Constants.ENCODING); | |
46 | + } | |
47 | + } | |
48 | + | |
49 | + private static String createOptimalSummary(Set<String> goldSummaries, int i, int averageGoldWordCount) { | |
50 | + Map<List<String>, List<Integer>> ngram2counts = Maps.newHashMap(); | |
51 | + for (String goldSummary : goldSummaries) { | |
52 | + Multiset<List<String>> goldNgrams = HashMultiset.create(); | |
53 | + RougeN.countNgrams(goldNgrams, i, goldSummary); | |
54 | + for (List<String> ngram : goldNgrams.elementSet()) { | |
55 | + ngram2counts.putIfAbsent(ngram, Lists.newArrayList()); | |
56 | + ngram2counts.get(ngram).add(goldNgrams.count(ngram)); | |
57 | + } | |
58 | + } | |
59 | + | |
60 | + int summaryWordCount = 0; | |
61 | + StringBuilder summary = new StringBuilder(); | |
62 | + while (averageGoldWordCount >= summaryWordCount) { | |
63 | + List<String> ngram = pickBestNgram(ngram2counts); | |
64 | + summary.append(" ").append(String.join(" ", ngram)); | |
65 | + summaryWordCount += ngram.size(); | |
66 | + } | |
67 | + return summary.toString().trim(); | |
68 | + } | |
69 | + | |
70 | + private static List<String> pickBestNgram(Map<List<String>, List<Integer>> ngram2counts) { | |
71 | + Optional<List<String>> optional = ngram2counts.keySet().stream() | |
72 | + .sorted(Comparator.comparing((List<String> ngram) -> ngram2counts.get(ngram).size()).reversed()).findFirst(); | |
73 | + if (!optional.isPresent()) { | |
74 | + throw new IllegalArgumentException("No more ngrams to pick!"); | |
75 | + } | |
76 | + List<String> optimalNgram = optional.get(); | |
77 | + List<Integer> counts = ngram2counts.get(optimalNgram); | |
78 | + List<Integer> newCounts = Lists.newArrayList(); | |
79 | + for (Integer c : counts) | |
80 | + if (c > 1) | |
81 | + newCounts.add(c - 1); | |
82 | + if (newCounts.isEmpty()) | |
83 | + ngram2counts.remove(optimalNgram); | |
84 | + else | |
85 | + ngram2counts.put(optimalNgram, newCounts); | |
86 | + return optimalNgram; | |
87 | + } | |
88 | +} | |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadAndPreprocessCorpus.java deleted
1 | -package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; | |
2 | - | |
3 | -import net.lingala.zip4j.core.ZipFile; | |
4 | -import org.apache.commons.io.FileUtils; | |
5 | -import org.slf4j.Logger; | |
6 | -import org.slf4j.LoggerFactory; | |
7 | - | |
8 | -import java.io.File; | |
9 | -import java.net.URL; | |
10 | - | |
11 | -public class DownloadAndPreprocessCorpus { | |
12 | - | |
13 | - private static final Logger LOG = LoggerFactory.getLogger(DownloadAndPreprocessCorpus.class); | |
14 | - | |
15 | - private static final String WORKING_DIR = "data"; | |
16 | - private static final String CORPUS_DOWNLOAD_URL = "http://zil.ipipan.waw.pl/PolishSummariesCorpus?action=AttachFile&do=get&target=PSC_1.0.zip"; | |
17 | - | |
18 | - private DownloadAndPreprocessCorpus() { | |
19 | - } | |
20 | - | |
21 | - public static void main(String[] args) throws Exception { | |
22 | - File workDir = createFolder(WORKING_DIR); | |
23 | - | |
24 | - File corpusFile = new File(workDir, "corpus.zip"); | |
25 | - if (!corpusFile.exists()) { | |
26 | - LOG.info("Downloading corpus file..."); | |
27 | - FileUtils.copyURLToFile(new URL(CORPUS_DOWNLOAD_URL), corpusFile); | |
28 | - LOG.info("done."); | |
29 | - } else { | |
30 | - LOG.info("Corpus file already downloaded."); | |
31 | - } | |
32 | - | |
33 | - File extractedCorpusDir = new File(workDir, "corpus"); | |
34 | - if (extractedCorpusDir.exists()) { | |
35 | - LOG.info("Corpus file already extracted."); | |
36 | - } else { | |
37 | - ZipFile zipFile = new ZipFile(corpusFile); | |
38 | - zipFile.extractAll(extractedCorpusDir.getPath()); | |
39 | - LOG.info("Extracted corpus file."); | |
40 | - } | |
41 | - | |
42 | - File pscDir = new File(extractedCorpusDir, "PSC_1.0"); | |
43 | - File dataDir = new File(pscDir, "data"); | |
44 | - | |
45 | - File preprocessed = new File(WORKING_DIR, "preprocessed"); | |
46 | - createFolder(preprocessed.getPath()); | |
47 | - Preprocess.main(new String[]{dataDir.getPath(), preprocessed.getPath()}); | |
48 | - } | |
49 | - | |
50 | - private static File createFolder(String path) { | |
51 | - File folder = new File(path); | |
52 | - if (folder.mkdir()) { | |
53 | - LOG.info("Created directory at: {}.", path); | |
54 | - } else { | |
55 | - LOG.info("Directory already present at: {}.", path); | |
56 | - } | |
57 | - return folder; | |
58 | - } | |
59 | -} |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadCorpus.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; | |
2 | + | |
3 | +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; | |
4 | + | |
5 | +public class DownloadCorpus { | |
6 | + | |
7 | + private DownloadCorpus() { | |
8 | + } | |
9 | + | |
10 | + public static void main(String[] args) throws Exception { | |
11 | + createFolder(WORKING_DIR); | |
12 | + downloadFileAndExtract(CORPUS_DOWNLOAD_URL, ZIPPED_CORPUS_FILE, EXTRACTED_CORPUS_DIR); | |
13 | + } | |
14 | + | |
15 | +} | |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadTrainingResources.java
0 → 100644
1 | +package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; | |
2 | + | |
3 | +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; | |
4 | + | |
5 | +public class DownloadTrainingResources { | |
6 | + | |
7 | + private DownloadTrainingResources() { | |
8 | + } | |
9 | + | |
10 | + public static void main(String[] args) throws Exception { | |
11 | + createFolder(WORKING_DIR); | |
12 | + downloadFileAndExtract(PREPROCESSED_CORPUS_DOWNLOAD_URL, ZIPPED_PREPROCESSED_CORPUS_FILE, PREPROCESSED_CORPUS_DIR); | |
13 | + downloadFileAndExtract(SUMMARY_SENTENCE_IDS_DOWNLOAD_URL, ZIPPED_SUMMARY_SENTENCE_IDS_FILE, SUMMARY_SENTENCE_IDS_DIR); | |
14 | + downloadFile(ZERO_TRAINING_CORPUS_URL, ZERO_TRAINING_CORPUS); | |
15 | + } | |
16 | + | |
17 | + | |
18 | +} | |
... | ... |
nicolas-eval/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/ExtractGoldSummaries.java renamed to nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/ExtractGoldSummaries.java
1 | -package pl.waw.ipipan.zil.summ.nicolas.eval; | |
1 | +package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; | |
2 | 2 | |
3 | 3 | import pl.waw.ipipan.zil.summ.nicolas.common.Utils; |
4 | 4 | import pl.waw.ipipan.zil.summ.pscapi.io.PSC_IO; |
... | ... | @@ -9,32 +9,43 @@ import javax.xml.bind.JAXBException; |
9 | 9 | import java.io.File; |
10 | 10 | import java.io.IOException; |
11 | 11 | import java.util.List; |
12 | -import java.util.Set; | |
12 | +import java.util.function.Predicate; | |
13 | 13 | import java.util.stream.Collectors; |
14 | +import java.util.stream.Stream; | |
14 | 15 | |
15 | -import static pl.waw.ipipan.zil.summ.nicolas.eval.Constants.loadTestTextIds; | |
16 | +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; | |
16 | 17 | |
17 | 18 | public class ExtractGoldSummaries { |
18 | 19 | |
20 | + private static final String ABSTRACT_SUMMARY_TYPE = "abstract"; | |
21 | + private static final int SUMMARY_RATIO = 20; | |
22 | + | |
23 | + private static final Predicate<Text> IS_TEST = text -> text.getSummaries().getSummary().stream().anyMatch(summary -> summary.getType().equals(ABSTRACT_SUMMARY_TYPE)); | |
24 | + | |
25 | + | |
19 | 26 | private ExtractGoldSummaries() { |
20 | 27 | } |
21 | 28 | |
22 | 29 | public static void main(String[] args) throws IOException, JAXBException { |
23 | - File corpusDir = new File("data/corpus/PSC_1.0/data"); | |
24 | - File targetDir = new File("data/summaries-gold"); | |
25 | - targetDir.mkdir(); | |
30 | + createFolder(GOLD_TEST_SUMMARIES_DIR); | |
31 | + createFolder(GOLD_TRAIN_SUMMARIES_DIR); | |
26 | 32 | |
27 | - Set<String> testTextIds = loadTestTextIds(); | |
28 | - File[] files = corpusDir.listFiles(); | |
33 | + File[] files = EXTRACTED_CORPUS_DATA_DIR.listFiles(); | |
29 | 34 | if (files != null) { |
30 | 35 | for (File file : files) { |
31 | 36 | Text text = PSC_IO.readText(file); |
32 | - if (!testTextIds.contains(text.getId())) | |
33 | - continue; | |
34 | 37 | |
35 | - List<Summary> goldSummaries = text.getSummaries().getSummary().stream().filter(summary -> summary.getType().equals("abstract") && summary.getRatio().equals(20)).collect(Collectors.toList()); | |
38 | + List<Summary> goldSummaries; | |
39 | + Stream<Summary> stream = text.getSummaries().getSummary().stream(); | |
40 | + boolean isTest = IS_TEST.test(text); | |
41 | + if (isTest) { | |
42 | + goldSummaries = stream.filter(summary -> summary.getType().equals(ABSTRACT_SUMMARY_TYPE) && summary.getRatio().equals(SUMMARY_RATIO)).collect(Collectors.toList()); | |
43 | + } else { | |
44 | + goldSummaries = stream.filter(summary -> !summary.getType().equals(ABSTRACT_SUMMARY_TYPE) && summary.getRatio().equals(SUMMARY_RATIO)).collect(Collectors.toList()); | |
45 | + } | |
36 | 46 | |
37 | 47 | for (Summary summary : goldSummaries) { |
48 | + File targetDir = isTest ? GOLD_TEST_SUMMARIES_DIR : GOLD_TRAIN_SUMMARIES_DIR; | |
38 | 49 | File targetFile = new File(targetDir, text.getId() + "_" + summary.getAuthor() + ".txt"); |
39 | 50 | Utils.writeStringToFile(summary.getBody(), targetFile); |
40 | 51 | } |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/PrepareTrainingData.java
... | ... | @@ -11,22 +11,18 @@ import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; |
11 | 11 | import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; |
12 | 12 | import pl.waw.ipipan.zil.multiservice.thrift.types.TText; |
13 | 13 | import pl.waw.ipipan.zil.summ.nicolas.InstanceUtils; |
14 | -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; | |
15 | 14 | import pl.waw.ipipan.zil.summ.nicolas.common.ThriftUtils; |
16 | 15 | import pl.waw.ipipan.zil.summ.nicolas.common.Utils; |
17 | 16 | import pl.waw.ipipan.zil.summ.nicolas.features.FeatureHelper; |
18 | 17 | import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor; |
19 | -import pl.waw.ipipan.zil.summ.nicolas.mention.MentionModel; | |
20 | 18 | import pl.waw.ipipan.zil.summ.nicolas.sentence.SentenceFeatureExtractor; |
21 | -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; | |
22 | -import pl.waw.ipipan.zil.summ.nicolas.train.model.mention.MentionScorer; | |
23 | -import pl.waw.ipipan.zil.summ.nicolas.train.model.sentence.SentenceScorer; | |
24 | -import pl.waw.ipipan.zil.summ.nicolas.train.model.zero.ZeroScorer; | |
19 | +import pl.waw.ipipan.zil.summ.nicolas.train.model.MentionScorer; | |
20 | +import pl.waw.ipipan.zil.summ.nicolas.train.model.SentenceScorer; | |
21 | +import pl.waw.ipipan.zil.summ.nicolas.train.model.ZeroScorer; | |
25 | 22 | import pl.waw.ipipan.zil.summ.nicolas.zero.CandidateFinder; |
26 | 23 | import pl.waw.ipipan.zil.summ.nicolas.zero.InstanceCreator; |
27 | 24 | import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroFeatureExtractor; |
28 | 25 | import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectCandidate; |
29 | -import weka.classifiers.Classifier; | |
30 | 26 | import weka.core.Instance; |
31 | 27 | import weka.core.Instances; |
32 | 28 | import weka.core.converters.ArffSaver; |
... | ... | @@ -34,31 +30,26 @@ import weka.core.converters.ArffSaver; |
34 | 30 | import java.io.File; |
35 | 31 | import java.io.FileReader; |
36 | 32 | import java.io.IOException; |
37 | -import java.io.InputStream; | |
33 | +import java.util.Arrays; | |
38 | 34 | import java.util.List; |
39 | 35 | import java.util.Map; |
40 | 36 | import java.util.Set; |
41 | 37 | import java.util.function.Predicate; |
42 | 38 | import java.util.stream.Collectors; |
43 | 39 | |
40 | +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; | |
41 | + | |
44 | 42 | public class PrepareTrainingData { |
45 | 43 | |
46 | 44 | private static final Logger LOG = LoggerFactory.getLogger(PrepareTrainingData.class); |
47 | 45 | |
48 | - private static final String THRIFT_TEXTS_PATH = "data/preprocessed"; | |
49 | - private static final String OPTIMAL_SUMMARIES_DIR_PATH = "data/summaries-optimal"; | |
50 | - private static final String SUMMARY_SENTENCE_IDS = "data/summaries-sentence-ids"; | |
51 | - | |
52 | - private static final String ZERO_TRAINING_DATA_RESOURCE_PATH = "/pl/waw/ipipan/zil/summ/nicolas/train/train_zero.tsv"; | |
53 | - private static final String TRAIN_TEXT_IDS_RESOURCE_PATH = "/pl/waw/ipipan/zil/summ/nicolas/train/train_text_ids.txt"; | |
54 | - | |
55 | 46 | private PrepareTrainingData() { |
56 | 47 | } |
57 | 48 | |
58 | 49 | public static void main(String[] args) throws Exception { |
59 | 50 | Set<String> trainTextIds = loadTrainTextIds(); |
60 | 51 | |
61 | - Map<String, TText> id2preprocessedText = ThriftUtils.loadThriftTextsFromFolder(new File(THRIFT_TEXTS_PATH), trainTextIds::contains); | |
52 | + Map<String, TText> id2preprocessedText = ThriftUtils.loadThriftTextsFromFolder(PREPROCESSED_CORPUS_DIR, trainTextIds::contains); | |
62 | 53 | Map<String, String> id2optimalSummary = loadOptimalSummaries(trainTextIds::contains); |
63 | 54 | |
64 | 55 | prepareMentionsDataset(id2preprocessedText, id2optimalSummary); |
... | ... | @@ -66,7 +57,7 @@ public class PrepareTrainingData { |
66 | 57 | prepareZerosDataset(id2preprocessedText); |
67 | 58 | } |
68 | 59 | |
69 | - public static void prepareMentionsDataset(Map<String, TText> id2preprocessedText, Map<String, String> id2optimalSummary) throws IOException { | |
60 | + private static void prepareMentionsDataset(Map<String, TText> id2preprocessedText, Map<String, String> id2optimalSummary) throws IOException { | |
70 | 61 | MentionScorer mentionScorer = new MentionScorer(); |
71 | 62 | MentionFeatureExtractor featureExtractor = new MentionFeatureExtractor(); |
72 | 63 | |
... | ... | @@ -74,7 +65,7 @@ public class PrepareTrainingData { |
74 | 65 | |
75 | 66 | int i = 1; |
76 | 67 | for (Map.Entry<String, TText> entry : id2preprocessedText.entrySet()) { |
77 | - LOG.info("{}/{}", i++, id2preprocessedText.size()); | |
68 | + logProgress(id2preprocessedText, i++); | |
78 | 69 | |
79 | 70 | String id = entry.getKey(); |
80 | 71 | TText preprocessedText = entry.getValue(); |
... | ... | @@ -92,29 +83,33 @@ public class PrepareTrainingData { |
92 | 83 | instances.add(instance); |
93 | 84 | } |
94 | 85 | } |
95 | - saveInstancesToFile(instances, new File(ModelConstants.MENTION_DATASET_PATH)); | |
86 | + saveInstancesToFile(instances, MENTION_ARFF); | |
96 | 87 | } |
97 | 88 | |
98 | - private static Set<String> loadTrainTextIds() throws IOException { | |
99 | - try (InputStream inputStream = PrepareTrainingData.class.getResourceAsStream(TRAIN_TEXT_IDS_RESOURCE_PATH)) { | |
100 | - List<String> testTextIds = IOUtils.readLines(inputStream, Constants.ENCODING); | |
101 | - return testTextIds.stream().map(String::trim).collect(Collectors.toSet()); | |
89 | + private static void logProgress(Map<String, TText> id2preprocessedText, int i) { | |
90 | + if (i % 10 == 0) { | |
91 | + LOG.info("{}/{}", i, id2preprocessedText.size()); | |
102 | 92 | } |
103 | 93 | } |
104 | 94 | |
105 | - public static void prepareSentencesDataset(Map<String, TText> id2preprocessedText, Map<String, String> id2optimalSummary) throws Exception { | |
95 | + private static Set<String> loadTrainTextIds() throws IOException { | |
96 | + File[] optimalSummaries = OPTIMAL_SUMMARIES_DIR.listFiles(); | |
97 | + if (optimalSummaries == null) | |
98 | + throw new IOException("No optimal summaries at " + OPTIMAL_SUMMARIES_DIR); | |
99 | + | |
100 | + return Arrays.stream(optimalSummaries).map(file -> file.getName().split("_")[0]).collect(Collectors.toSet()); | |
101 | + } | |
102 | + | |
103 | + private static void prepareSentencesDataset(Map<String, TText> id2preprocessedText, Map<String, String> id2optimalSummary) throws Exception { | |
106 | 104 | |
107 | 105 | SentenceScorer sentenceScorer = new SentenceScorer(); |
108 | 106 | SentenceFeatureExtractor featureExtractor = new SentenceFeatureExtractor(); |
109 | 107 | |
110 | 108 | Instances instances = Utils.createNewInstances(featureExtractor.getAttributesList()); |
111 | 109 | |
112 | - Classifier classifier = Utils.loadClassifierFromResource(Constants.MENTION_MODEL_RESOURCE_PATH); | |
113 | - MentionFeatureExtractor mentionFeatureExtractor = new MentionFeatureExtractor(); | |
114 | - | |
115 | 110 | int i = 1; |
116 | 111 | for (String textId : id2preprocessedText.keySet()) { |
117 | - LOG.info("{}/{}", i++, id2preprocessedText.size()); | |
112 | + logProgress(id2preprocessedText, i++); | |
118 | 113 | |
119 | 114 | TText preprocessedText = id2preprocessedText.get(textId); |
120 | 115 | String optimalSummary = id2optimalSummary.get(textId); |
... | ... | @@ -123,9 +118,7 @@ public class PrepareTrainingData { |
123 | 118 | Map<TSentence, Double> sentence2score = sentenceScorer.calculateSentenceScores(optimalSummary, preprocessedText); |
124 | 119 | |
125 | 120 | Set<TMention> goodMentions |
126 | - = MentionModel.detectGoodMentions(classifier, mentionFeatureExtractor, preprocessedText); | |
127 | -// Set<TMention> goodMentions | |
128 | -// = Utils.loadGoldGoodMentions(textId, preprocessedText, true); | |
121 | + = loadGoldGoodMentions(textId, preprocessedText, id2optimalSummary); | |
129 | 122 | |
130 | 123 | Map<TSentence, Instance> sentence2instance = InstanceUtils.extractInstancesFromSentences(preprocessedText, featureExtractor, goodMentions); |
131 | 124 | for (Map.Entry<TSentence, Instance> entry : sentence2instance.entrySet()) { |
... | ... | @@ -136,21 +129,31 @@ public class PrepareTrainingData { |
136 | 129 | instances.add(instance); |
137 | 130 | } |
138 | 131 | } |
139 | - saveInstancesToFile(instances, new File(ModelConstants.SENTENCE_DATASET_PATH)); | |
132 | + saveInstancesToFile(instances, SENTENCE_ARFF); | |
133 | + } | |
134 | + | |
135 | + private static Set<TMention> loadGoldGoodMentions(String id, TText text, Map<String, String> id2optimalSummary) throws IOException { | |
136 | + String optimalSummary = id2optimalSummary.get(id); | |
137 | + | |
138 | + MentionScorer scorer = new MentionScorer(); | |
139 | + Map<TMention, Double> mention2score = scorer.calculateMentionScores(optimalSummary, text); | |
140 | + | |
141 | + mention2score.keySet().removeIf(tMention -> mention2score.get(tMention) < 1); | |
142 | + return mention2score.keySet(); | |
140 | 143 | } |
141 | 144 | |
142 | - public static void prepareZerosDataset(Map<String, TText> id2preprocessedText) throws IOException { | |
145 | + private static void prepareZerosDataset(Map<String, TText> id2preprocessedText) throws IOException { | |
143 | 146 | |
144 | - Map<String, Set<String>> id2sentIds = loadSentenceIds(SUMMARY_SENTENCE_IDS); | |
147 | + Map<String, Set<String>> id2sentIds = loadSentenceIds(SUMMARY_SENTENCE_IDS_DIR); | |
145 | 148 | |
146 | - ZeroScorer zeroScorer = new ZeroScorer(ZERO_TRAINING_DATA_RESOURCE_PATH); | |
149 | + ZeroScorer zeroScorer = new ZeroScorer(ZERO_TRAINING_CORPUS); | |
147 | 150 | ZeroFeatureExtractor featureExtractor = new ZeroFeatureExtractor(); |
148 | 151 | |
149 | 152 | Instances instances = Utils.createNewInstances(featureExtractor.getAttributesList()); |
150 | 153 | |
151 | 154 | int i = 1; |
152 | 155 | for (Map.Entry<String, TText> entry : id2preprocessedText.entrySet()) { |
153 | - LOG.info(i++ + "/" + id2preprocessedText.size()); | |
156 | + logProgress(id2preprocessedText, i++); | |
154 | 157 | |
155 | 158 | String textId = entry.getKey(); |
156 | 159 | |
... | ... | @@ -170,12 +173,12 @@ public class PrepareTrainingData { |
170 | 173 | } |
171 | 174 | } |
172 | 175 | |
173 | - saveInstancesToFile(instances, new File(ModelConstants.ZERO_DATASET_PATH)); | |
176 | + saveInstancesToFile(instances, ZERO_ARFF); | |
174 | 177 | } |
175 | 178 | |
176 | - private static Map<String, Set<String>> loadSentenceIds(String idsPath) throws IOException { | |
179 | + private static Map<String, Set<String>> loadSentenceIds(File idsFolder) throws IOException { | |
177 | 180 | Map<String, Set<String>> result = Maps.newHashMap(); |
178 | - File[] files = new File(idsPath).listFiles(); | |
181 | + File[] files = idsFolder.listFiles(); | |
179 | 182 | if (files != null) |
180 | 183 | for (File f : files) { |
181 | 184 | String id = f.getName().split("_")[0]; |
... | ... | @@ -194,7 +197,7 @@ public class PrepareTrainingData { |
194 | 197 | |
195 | 198 | private static Map<String, String> loadOptimalSummaries(Predicate<String> idFilter) throws IOException { |
196 | 199 | Map<String, String> id2optimalSummary = Maps.newHashMap(); |
197 | - File[] files = new File(OPTIMAL_SUMMARIES_DIR_PATH).listFiles(); | |
200 | + File[] files = OPTIMAL_SUMMARIES_DIR.listFiles(); | |
198 | 201 | if (files != null) |
199 | 202 | for (File optimalSummaryFile : files) { |
200 | 203 | String textId = optimalSummaryFile.getName().split("_")[0]; |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/Preprocess.java renamed to nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/PreprocessCorpus.java
... | ... | @@ -9,37 +9,27 @@ import pl.waw.ipipan.zil.summ.pscapi.xml.Text; |
9 | 9 | import java.io.File; |
10 | 10 | import java.util.Arrays; |
11 | 11 | |
12 | -public class Preprocess { | |
12 | +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; | |
13 | 13 | |
14 | - private static final Logger LOG = LoggerFactory.getLogger(Preprocess.class); | |
14 | +public class PreprocessCorpus { | |
15 | + | |
16 | + private static final Logger LOG = LoggerFactory.getLogger(PreprocessCorpus.class); | |
15 | 17 | |
16 | 18 | private static final String CORPUS_FILE_SUFFIX = ".xml"; |
17 | 19 | private static final String OUTPUT_FILE_SUFFIX = ".thrift"; |
18 | 20 | |
19 | - private Preprocess() { | |
21 | + private PreprocessCorpus() { | |
20 | 22 | } |
21 | 23 | |
22 | 24 | public static void main(String[] args) { |
23 | - if (args.length != 2) { | |
24 | - LOG.error("Wrong usage! Try " + Preprocess.class.getSimpleName() + " dirWithCorpusFiles targetDir"); | |
25 | - return; | |
26 | - } | |
27 | - File corpusDir = new File(args[0]); | |
28 | - if (!corpusDir.isDirectory()) { | |
29 | - LOG.error("Corpus directory does not exist: {}", corpusDir); | |
30 | - return; | |
31 | - } | |
32 | - File targetDir = new File(args[1]); | |
33 | - if (!targetDir.isDirectory()) { | |
34 | - LOG.error("Target directory does not exist: {}", targetDir); | |
35 | - return; | |
36 | - } | |
25 | + | |
26 | + createFolder(PREPROCESSED_CORPUS_DIR); | |
37 | 27 | |
38 | 28 | int ok = 0; |
39 | 29 | int err = 0; |
40 | - File[] files = corpusDir.listFiles(f -> f.getName().endsWith(CORPUS_FILE_SUFFIX)); | |
30 | + File[] files = EXTRACTED_CORPUS_DATA_DIR.listFiles(f -> f.getName().endsWith(CORPUS_FILE_SUFFIX)); | |
41 | 31 | if (files == null || files.length == 0) { |
42 | - LOG.error("No corpus files found at: {}", corpusDir); | |
32 | + LOG.error("No corpus files found at: {}", EXTRACTED_CORPUS_DATA_DIR); | |
43 | 33 | return; |
44 | 34 | } |
45 | 35 | Arrays.sort(files); |
... | ... | @@ -49,7 +39,7 @@ public class Preprocess { |
49 | 39 | for (File file : files) { |
50 | 40 | try { |
51 | 41 | Text text = PSC_IO.readText(file); |
52 | - File targetFile = new File(targetDir, file.getName().replaceFirst(CORPUS_FILE_SUFFIX + "$", OUTPUT_FILE_SUFFIX)); | |
42 | + File targetFile = new File(PREPROCESSED_CORPUS_DIR, file.getName().replaceFirst(CORPUS_FILE_SUFFIX + "$", OUTPUT_FILE_SUFFIX)); | |
53 | 43 | processor.preprocessToFile(text.getBody(), targetFile); |
54 | 44 | ok++; |
55 | 45 | } catch (Exception e) { |
... | ... |
nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/TrainAllModels.java
1 | 1 | package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; |
2 | 2 | |
3 | -import pl.waw.ipipan.zil.summ.nicolas.train.model.mention.TrainMentionModel; | |
4 | -import pl.waw.ipipan.zil.summ.nicolas.train.model.sentence.TrainSentenceModel; | |
5 | -import pl.waw.ipipan.zil.summ.nicolas.train.model.zero.TrainZeroModel; | |
3 | +import org.apache.commons.lang3.time.StopWatch; | |
4 | +import org.slf4j.Logger; | |
5 | +import org.slf4j.LoggerFactory; | |
6 | +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; | |
7 | +import pl.waw.ipipan.zil.summ.nicolas.train.model.Settings; | |
8 | +import weka.classifiers.Classifier; | |
9 | +import weka.core.Instances; | |
10 | +import weka.core.converters.ArffLoader; | |
11 | + | |
12 | +import java.io.File; | |
13 | +import java.io.FileOutputStream; | |
14 | +import java.io.ObjectOutputStream; | |
15 | +import java.util.logging.LogManager; | |
16 | + | |
17 | +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; | |
6 | 18 | |
7 | 19 | public class TrainAllModels { |
8 | 20 | |
21 | + private static final Logger LOG = LoggerFactory.getLogger(TrainAllModels.class); | |
22 | + | |
23 | + private static final String TARGET_MODEL_DIR = "nicolas-model/src/main/resources"; | |
24 | + | |
9 | 25 | private TrainAllModels() { |
10 | 26 | } |
11 | 27 | |
12 | 28 | public static void main(String[] args) throws Exception { |
13 | - TrainMentionModel.main(args); | |
14 | - TrainSentenceModel.main(args); | |
15 | - TrainZeroModel.main(args); | |
29 | + trainAndSaveModel(MENTION_ARFF, Settings.getMentionClassifier(), Constants.MENTION_MODEL_RESOURCE_PATH); | |
30 | + trainAndSaveModel(SENTENCE_ARFF, Settings.getSentenceClassifier(), Constants.SENTENCE_MODEL_RESOURCE_PATH); | |
31 | + trainAndSaveModel(ZERO_ARFF, Settings.getZeroClassifier(), Constants.ZERO_MODEL_RESOURCE_PATH); | |
32 | + } | |
33 | + | |
34 | + private static void trainAndSaveModel(File dataset, Classifier classifier, String targetPath) throws Exception { | |
35 | + LogManager.getLogManager().reset(); // disable WEKA logging | |
36 | + | |
37 | + ArffLoader loader = new ArffLoader(); | |
38 | + loader.setFile(dataset); | |
39 | + Instances instances = loader.getDataSet(); | |
40 | + instances.setClassIndex(0); | |
41 | + LOG.info("{} instances loaded.", instances.size()); | |
42 | + LOG.info("{} attributes for each instance.", instances.numAttributes()); | |
43 | + | |
44 | + StopWatch watch = new StopWatch(); | |
45 | + watch.start(); | |
46 | + | |
47 | + LOG.info("Building classifier..."); | |
48 | + classifier.buildClassifier(instances); | |
49 | + LOG.info("...done. Build classifier: {}", classifier); | |
50 | + | |
51 | + String target = TARGET_MODEL_DIR + targetPath; | |
52 | + LOG.info("Saving classifier at: {}", target); | |
53 | + try (ObjectOutputStream oos = new ObjectOutputStream( | |
54 | + new FileOutputStream(target))) { | |
55 | + oos.writeObject(classifier); | |
56 | + } | |
57 | + | |
58 | + watch.stop(); | |
59 | + LOG.info("Elapsed time: {}", watch); | |
16 | 60 | } |
17 | 61 | } |
... | ... |
nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/train_text_ids.txt deleted
1 | -199704210011 | |
2 | -199704210013 | |
3 | -199704250031 | |
4 | -199704260017 | |
5 | -199801030156 | |
6 | -199801100009 | |
7 | -199801150038 | |
8 | -199801150133 | |
9 | -199801170001 | |
10 | -199801170129 | |
11 | -199801170130 | |
12 | -199801200002 | |
13 | -199801200132 | |
14 | -199801210007 | |
15 | -199801220030 | |
16 | -199801220127 | |
17 | -199801230001 | |
18 | -199801230095 | |
19 | -199801240116 | |
20 | -199801240123 | |
21 | -199801260113 | |
22 | -199801270108 | |
23 | -199801280128 | |
24 | -199801290020 | |
25 | -199801310032 | |
26 | -199802040201 | |
27 | -199901180149 | |
28 | -199901190049 | |
29 | -199901230088 | |
30 | -199901250006 | |
31 | -199901250008 | |
32 | -199901250111 | |
33 | -199901250113 | |
34 | -199901300064 | |
35 | -199901300098 | |
36 | -199902240123 | |
37 | -199906220027 | |
38 | -199906220037 | |
39 | -199906220038 | |
40 | -199906220056 | |
41 | -199906220065 | |
42 | -199906230040 | |
43 | -199906230052 | |
44 | -199906240040 | |
45 | -199906240088 | |
46 | -199906250007 | |
47 | -199906250091 | |
48 | -199906260015 | |
49 | -199906260018 | |
50 | -199906260038 | |
51 | -199907030016 | |
52 | -199907030018 | |
53 | -199907030042 | |
54 | -199907030059 | |
55 | -199907050032 | |
56 | -199907050040 | |
57 | -199907050047 | |
58 | -199907050071 | |
59 | -199907270095 | |
60 | -199907270137 | |
61 | -199907270145 | |
62 | -199909210045 | |
63 | -199909250054 | |
64 | -199909300064 | |
65 | -199909300065 | |
66 | -199909300066 | |
67 | -199910020049 | |
68 | -199910020050 | |
69 | -199910090047 | |
70 | -199910090049 | |
71 | -199910090051 | |
72 | -199910110055 | |
73 | -199910110057 | |
74 | -199910210058 | |
75 | -199910210059 | |
76 | -199910270041 | |
77 | -199910280054 | |
78 | -199910280055 | |
79 | -199910280057 | |
80 | -199910300026 | |
81 | -199911030039 | |
82 | -199911030040 | |
83 | -199911030041 | |
84 | -199911060031 | |
85 | -199911060042 | |
86 | -199911060043 | |
87 | -199911080054 | |
88 | -199911080055 | |
89 | -199911080056 | |
90 | -199911100061 | |
91 | -199911100062 | |
92 | -199911100063 | |
93 | -199911130036 | |
94 | -199911130037 | |
95 | -199911130038 | |
96 | -199911180042 | |
97 | -199911180043 | |
98 | -199911180044 | |
99 | -199911220059 | |
100 | -199911220061 | |
101 | -199911220066 | |
102 | -199911230041 | |
103 | -199911240035 | |
104 | -199911240037 | |
105 | -199911240038 | |
106 | -199911250055 | |
107 | -199911250057 | |
108 | -199912020059 | |
109 | -199912090045 | |
110 | -199912090047 | |
111 | -199912090061 | |
112 | -199912110041 | |
113 | -199912110042 | |
114 | -199912130055 | |
115 | -199912130057 | |
116 | -199912170065 | |
117 | -199912180052 | |
118 | -199912210018 | |
119 | -199912210037 | |
120 | -199912210040 | |
121 | -199912220045 | |
122 | -199912220046 | |
123 | -199912220047 | |
124 | -199912230058 | |
125 | -199912230059 | |
126 | -199912230097 | |
127 | -199912280028 | |
128 | -199912280044 | |
129 | -199912280045 | |
130 | -199912310085 | |
131 | -199912310087 | |
132 | -200001030047 | |
133 | -200001030106 | |
134 | -200001040030 | |
135 | -200001040031 | |
136 | -200001060052 | |
137 | -200001060053 | |
138 | -200001060055 | |
139 | -200001070062 | |
140 | -200001070066 | |
141 | -200001080040 | |
142 | -200001080041 | |
143 | -200001140061 | |
144 | -200001140064 | |
145 | -200001170049 | |
146 | -200001170051 | |
147 | -200001170052 | |
148 | -200001170053 | |
149 | -200001180040 | |
150 | -200001200056 | |
151 | -200001220023 | |
152 | -200001220118 | |
153 | -200001240016 | |
154 | -200001290042 | |
155 | -200001310048 | |
156 | -200001310049 | |
157 | -200001310050 | |
158 | -200001310054 | |
159 | -200002090042 | |
160 | -200002090043 | |
161 | -200002120045 | |
162 | -200002120046 | |
163 | -200002160046 | |
164 | -200002160047 | |
165 | -200002250063 | |
166 | -200002250065 | |
167 | -200002250066 | |
168 | -200002290044 | |
169 | -200002290045 | |
170 | -200002290046 | |
171 | -200002290047 | |
172 | -200002290048 | |
173 | -200003010058 | |
174 | -200003010059 | |
175 | -200003060054 | |
176 | -200003060055 | |
177 | -200003060057 | |
178 | -200003110047 | |
179 | -200003110048 | |
180 | -200003110049 | |
181 | -200003210044 | |
182 | -200003210045 | |
183 | -200004120021 | |
184 | -200004120022 | |
185 | -200004120023 | |
186 | -200004150048 | |
187 | -200004150049 | |
188 | -200004150050 | |
189 | -200004170026 | |
190 | -200004170065 | |
191 | -200004220044 | |
192 | -200004220045 | |
193 | -200004220046 | |
194 | -200004220047 | |
195 | -200004220048 | |
196 | -200005060030 | |
197 | -200005150055 | |
198 | -200005150059 | |
199 | -200005300045 | |
200 | -200005300047 | |
201 | -200005300048 | |
202 | -200006010065 | |
203 | -200006010066 | |
204 | -200006010067 | |
205 | -200006050056 | |
206 | -200006050057 | |
207 | -200006050058 | |
208 | -200006050059 | |
209 | -200006050061 | |
210 | -200006050068 | |
211 | -200006070056 | |
212 | -200006080033 | |
213 | -200006120031 | |
214 | -200006130055 | |
215 | -200006130057 | |
216 | -200006130059 | |
217 | -200006260069 | |
218 | -200006260071 | |
219 | -200006270059 | |
220 | -200007120068 | |
221 | -200007120070 | |
222 | -200007120072 | |
223 | -200007170026 | |
224 | -200007180051 | |
225 | -200007240034 | |
226 | -200007270050 | |
227 | -200007280033 | |
228 | -200008040071 | |
229 | -200008040073 | |
230 | -200008250077 | |
231 | -200008250079 | |
232 | -200008260055 | |
233 | -200008310046 | |
234 | -200010120066 | |
235 | -200010120074 | |
236 | -200010130063 | |
237 | -200010140048 | |
238 | -200010140049 | |
239 | -200010160039 | |
240 | -200010160048 | |
241 | -200010160049 | |
242 | -200010180059 | |
243 | -200010180063 | |
244 | -200010190066 | |
245 | -200010190068 | |
246 | -200011210063 | |
247 | -200011210064 | |
248 | -200011210066 | |
249 | -200012050066 | |
250 | -200012050067 | |
251 | -200012050068 | |
252 | -200012050069 | |
253 | -200012050070 | |
254 | -200012050071 | |
255 | -200012080134 | |
256 | -200012080137 | |
257 | -200012110069 | |
258 | -200012110070 | |
259 | -200012110071 | |
260 | -200012110075 | |
261 | -200012120028 | |
262 | -200012120068 | |
263 | -200012120072 | |
264 | -200012130056 | |
265 | -200012130100 | |
266 | -200012130102 | |
267 | -200012130103 | |
268 | -200012140095 | |
269 | -200012140096 | |
270 | -200012140097 | |
271 | -200012140098 | |
272 | -200012140099 | |
273 | -200012140100 | |
274 | -200012150076 | |
275 | -200012160048 | |
276 | -200012160049 | |
277 | -200012180083 | |
278 | -200012180084 | |
279 | -200012180088 | |
280 | -200012230028 | |
281 | -200012230045 | |
282 | -200012230046 | |
283 | -200012230047 | |
284 | -200012230048 | |
285 | -200012230050 | |
286 | -200012270055 | |
287 | -200012270056 | |
288 | -200101020059 | |
289 | -200101020062 | |
290 | -200101020063 | |
291 | -200101020075 | |
292 | -200101130048 | |
293 | -200101130050 | |
294 | -200101130051 | |
295 | -200101130055 | |
296 | -200101150043 | |
297 | -200101150045 | |
298 | -200101180050 | |
299 | -200101180051 | |
300 | -200101180052 | |
301 | -200101200048 | |
302 | -200101220047 | |
303 | -200101220053 | |
304 | -200102070011 | |
305 | -200102070016 | |
306 | -200102120034 | |
307 | -200102120057 | |
308 | -200102130014 | |
309 | -200102150001 | |
310 | -200102150014 | |
311 | -200102160011 | |
312 | -200102190016 | |
313 | -200102220001 | |
314 | -200102220013 | |
315 | -200102270041 | |
316 | -200102270062 | |
317 | -200102280169 | |
318 | -200103010049 | |
319 | -200103060022 | |
320 | -200103060032 | |
321 | -200103060057 | |
322 | -200103080026 | |
323 | -200103080030 | |
324 | -200103080036 | |
325 | -200103100019 | |
326 | -200103100021 | |
327 | -200103100058 | |
328 | -200103100062 | |
329 | -200103130008 | |
330 | -200103130023 | |
331 | -200103130069 | |
332 | -200103200066 | |
333 | -200103200080 | |
334 | -200103270069 | |
335 | -200103310092 | |
336 | -200104020007 | |
337 | -200104050011 | |
338 | -200104100021 | |
339 | -200104100023 | |
340 | -200104170015 | |
341 | -200104170040 | |
342 | -200104170055 | |
343 | -200104170057 | |
344 | -200104190039 | |
345 | -200104190066 | |
346 | -200104230031 | |
347 | -200104230069 | |
348 | -200104260051 | |
349 | -200104260053 | |
350 | -200104300213 | |
351 | -200104300215 | |
352 | -200104300217 | |
353 | -200105020092 | |
354 | -200105050042 | |
355 | -200105050043 | |
356 | -200105050046 | |
357 | -200105050048 | |
358 | -200105070017 | |
359 | -200105140050 | |
360 | -200105140052 | |
361 | -200105220096 | |
362 | -200105290074 | |
363 | -200105290075 | |
364 | -200106120068 | |
365 | -200106120069 | |
366 | -200106180051 | |
367 | -200106180053 | |
368 | -200106200064 | |
369 | -200106220086 | |
370 | -200106220087 | |
371 | -200106220088 | |
372 | -200106220090 | |
373 | -200106250050 | |
374 | -200107120071 | |
375 | -200107120073 | |
376 | -200107210129 | |
377 | -200107240070 | |
378 | -200107250080 | |
379 | -200108060051 | |
380 | -200108060155 | |
381 | -200108060156 | |
382 | -200108060157 | |
383 | -200108070038 | |
384 | -200108160040 | |
385 | -200108180123 | |
386 | -200108200033 | |
387 | -200108210066 | |
388 | -200108210074 | |
389 | -200108270077 | |
390 | -200108280064 | |
391 | -200109060061 | |
392 | -200109130091 | |
393 | -200109250092 | |
394 | -200109260097 | |
395 | -200109270116 | |
396 | -200110020075 | |
397 | -200110150056 | |
398 | -200110150062 | |
399 | -200110200070 | |
400 | -200110200071 | |
401 | -200110220068 | |
402 | -200111080086 | |
403 | -200111140055 | |
404 | -200111210078 | |
405 | -200111240060 | |
406 | -200112040031 | |
407 | -200112040077 | |
408 | -200112050063 | |
409 | -200112100041 | |
410 | -200112190067 | |
411 | -200201280011 | |
412 | -200201290029 | |
413 | -200202280078 | |
414 | -200203280057 | |
415 | -200203290107 |