From 7e387f1cdc557ac810c9c4118ddff9e36c78a776 Mon Sep 17 00:00:00 2001 From: Mateusz Kopeć <m.kopec@ipipan.waw.pl> Date: Thu, 9 Mar 2017 15:33:49 +0100 Subject: [PATCH] training code --- .gitignore | 3 ++- nicolas-cli/pom.xml | 27 +++++++++++++++------------ nicolas-cli/src/main/java/pl/waw/ipipan/zil/summ/nicolas/cli/Cli.java | 20 ++++++++++++++++++-- nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java | 1 - nicolas-eval/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/ExtractGoldSummaries.java | 45 --------------------------------------------- nicolas-train/pom.xml | 4 ++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Main.java | 18 ++++++++++++++++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/PathConstants.java | 79 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/MentionScorer.java | 41 +++++++++++++++++++++++++++++++++++++++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/SentenceScorer.java | 32 ++++++++++++++++++++++++++++++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/Settings.java | 39 +++++++++++++++++++++++++++++++++++++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/ZeroScorer.java | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/ModelConstants.java | 43 ------------------------------------------- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/TrainModelCommon.java | 52 ---------------------------------------------------- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/MentionScorer.java | 41 ----------------------------------------- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/TrainMentionModel.java | 20 -------------------- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/SentenceScorer.java | 33 --------------------------------- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/TrainSentenceModel.java | 20 -------------------- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/TrainZeroModel.java | 20 -------------------- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/ZeroScorer.java | 52 ---------------------------------------------------- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/CreateOptimalSummaries.java | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadAndPreprocessCorpus.java | 59 ----------------------------------------------------------- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadCorpus.java | 15 +++++++++++++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadTrainingResources.java | 18 ++++++++++++++++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/ExtractGoldSummaries.java | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/PrepareTrainingData.java | 83 +++++++++++++++++++++++++++++++++++++++++++---------------------------------------- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/Preprocess.java | 63 --------------------------------------------------------------- nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/PreprocessCorpus.java | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/TrainAllModels.java | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++------ nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/train_text_ids.txt | 415 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- train.sh | 5 +++++ 31 files changed, 626 insertions(+), 925 deletions(-) delete mode 100644 nicolas-eval/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/ExtractGoldSummaries.java create mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Main.java create mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/PathConstants.java create mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/MentionScorer.java create mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/SentenceScorer.java create mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/Settings.java create mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/ZeroScorer.java delete mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/ModelConstants.java delete mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/TrainModelCommon.java delete mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/MentionScorer.java delete mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/TrainMentionModel.java delete mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/SentenceScorer.java delete mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/TrainSentenceModel.java delete mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/TrainZeroModel.java delete mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/ZeroScorer.java create mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/CreateOptimalSummaries.java delete mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadAndPreprocessCorpus.java create mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadCorpus.java create mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadTrainingResources.java create mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/ExtractGoldSummaries.java delete mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/Preprocess.java create mode 100644 nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/PreprocessCorpus.java delete mode 100644 nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/train_text_ids.txt create mode 100755 train.sh diff --git a/.gitignore b/.gitignore index d70732d..c67355e 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,5 @@ hs_err_pid* .idea *.iml -/data \ No newline at end of file +/data +/summaries diff --git a/nicolas-cli/pom.xml b/nicolas-cli/pom.xml index 422e8f1..5062880 100644 --- a/nicolas-cli/pom.xml +++ b/nicolas-cli/pom.xml @@ -53,28 +53,31 @@ <build> <plugins> <plugin> + <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-assembly-plugin</artifactId> - <configuration> - <appendAssemblyId>false</appendAssemblyId> - <archive> - <manifest> - <mainClass>pl.waw.ipipan.zil.summ.nicolas.cli.Main</mainClass> - </manifest> - </archive> - <descriptorRefs> - <descriptorRef>jar-with-dependencies</descriptorRef> - </descriptorRefs> - </configuration> <executions> <execution> - <id>make-assembly</id> + <id>jar-with-dependencies</id> <phase>package</phase> <goals> <goal>single</goal> </goals> + <configuration> + <descriptorRefs> + <descriptorRef>jar-with-dependencies</descriptorRef> + </descriptorRefs> + <appendAssemblyId>false</appendAssemblyId> + <finalName>nicolas-cli</finalName> + <archive> + <manifest> + <mainClass>pl.waw.ipipan.zil.summ.nicolas.cli.Main</mainClass> + </manifest> + </archive> + </configuration> </execution> </executions> </plugin> + </plugins> </build> </project> \ No newline at end of file diff --git a/nicolas-cli/src/main/java/pl/waw/ipipan/zil/summ/nicolas/cli/Cli.java b/nicolas-cli/src/main/java/pl/waw/ipipan/zil/summ/nicolas/cli/Cli.java index ace95d1..d3257ea 100644 --- a/nicolas-cli/src/main/java/pl/waw/ipipan/zil/summ/nicolas/cli/Cli.java +++ b/nicolas-cli/src/main/java/pl/waw/ipipan/zil/summ/nicolas/cli/Cli.java @@ -10,18 +10,19 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; +import java.io.IOException; class Cli { private static final Logger LOG = LoggerFactory.getLogger(Cli.class); - @Parameter(names = {"-help", "-h"}, description = "Print help") + @Parameter(names = {"-help", "-h"}, description = "Print help", help = true) private boolean help = false; @Parameter(names = {"-input", "-i"}, description = "Input text file to summarize", required = true, validateWith = FileValidator.class, converter = FileConverter.class) private File inputFile; - @Parameter(names = {"-output", "-o"}, description = "Output file path for summary", required = true, validateWith = FileValidator.class, converter = FileConverter.class) + @Parameter(names = {"-output", "-o"}, description = "Output file path for summary", required = true, validateWith = OutputFileValidator.class, converter = FileConverter.class) private File outputFile; @Parameter(names = {"-target", "-t"}, description = "Target summary token count", required = true, validateWith = PositiveInteger.class) @@ -84,4 +85,19 @@ class Cli { } } + + public static class OutputFileValidator implements IParameterValidator { + + @Override + public void validate(String name, String value) { + File file = new File(value); + try { + file.createNewFile(); + } catch (IOException ex) { + throw new ParameterException("Parameter " + name + + " should be a valid file path (found " + value + ")", ex); + } + } + + } } diff --git a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java index fb3c1f4..ad7cbb0 100644 --- a/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java +++ b/nicolas-common/src/main/java/pl/waw/ipipan/zil/summ/nicolas/common/Utils.java @@ -182,5 +182,4 @@ public class Utils { return sb.toString().trim(); } - } \ No newline at end of file diff --git a/nicolas-eval/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/ExtractGoldSummaries.java b/nicolas-eval/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/ExtractGoldSummaries.java deleted file mode 100644 index 6f193c6..0000000 --- a/nicolas-eval/src/main/java/pl/waw/ipipan/zil/summ/nicolas/eval/ExtractGoldSummaries.java +++ /dev/null @@ -1,45 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.eval; - -import pl.waw.ipipan.zil.summ.nicolas.common.Utils; -import pl.waw.ipipan.zil.summ.pscapi.io.PSC_IO; -import pl.waw.ipipan.zil.summ.pscapi.xml.Summary; -import pl.waw.ipipan.zil.summ.pscapi.xml.Text; - -import javax.xml.bind.JAXBException; -import java.io.File; -import java.io.IOException; -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; - -import static pl.waw.ipipan.zil.summ.nicolas.eval.Constants.loadTestTextIds; - -public class ExtractGoldSummaries { - - private ExtractGoldSummaries() { - } - - public static void main(String[] args) throws IOException, JAXBException { - File corpusDir = new File("data/corpus/PSC_1.0/data"); - File targetDir = new File("data/summaries-gold"); - targetDir.mkdir(); - - Set<String> testTextIds = loadTestTextIds(); - File[] files = corpusDir.listFiles(); - if (files != null) { - for (File file : files) { - Text text = PSC_IO.readText(file); - if (!testTextIds.contains(text.getId())) - continue; - - List<Summary> goldSummaries = text.getSummaries().getSummary().stream().filter(summary -> summary.getType().equals("abstract") && summary.getRatio().equals(20)).collect(Collectors.toList()); - - for (Summary summary : goldSummaries) { - File targetFile = new File(targetDir, text.getId() + "_" + summary.getAuthor() + ".txt"); - Utils.writeStringToFile(summary.getBody(), targetFile); - } - } - } - } - -} diff --git a/nicolas-train/pom.xml b/nicolas-train/pom.xml index 3fe8055..0124f0f 100644 --- a/nicolas-train/pom.xml +++ b/nicolas-train/pom.xml @@ -40,6 +40,10 @@ <groupId>pl.waw.ipipan.zil.multiservice</groupId> <artifactId>utils</artifactId> </dependency> + <dependency> + <groupId>pl.waw.ipipan.zil.summ</groupId> + <artifactId>eval</artifactId> + </dependency> <!-- third party --> <dependency> diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Main.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Main.java new file mode 100644 index 0000000..07e556a --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/Main.java @@ -0,0 +1,18 @@ +package pl.waw.ipipan.zil.summ.nicolas.train; + +import pl.waw.ipipan.zil.summ.nicolas.train.pipeline.*; + +public class Main { + + private Main() { + } + + public static void main(String[] args) throws Exception { + DownloadCorpus.main(args); + DownloadTrainingResources.main(args); + ExtractGoldSummaries.main(args); + CreateOptimalSummaries.main(args); + PrepareTrainingData.main(args); + TrainAllModels.main(args); + } +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/PathConstants.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/PathConstants.java new file mode 100644 index 0000000..44e67c3 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/PathConstants.java @@ -0,0 +1,79 @@ +package pl.waw.ipipan.zil.summ.nicolas.train; + +import net.lingala.zip4j.core.ZipFile; +import net.lingala.zip4j.exception.ZipException; +import org.apache.commons.io.FileUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.net.URL; + +public class PathConstants { + + private static final Logger LOG = LoggerFactory.getLogger(PathConstants.class); + + public static final String CORPUS_DOWNLOAD_URL = "http://zil.ipipan.waw.pl/PolishSummariesCorpus?action=AttachFile&do=get&target=PSC_1.0.zip"; + public static final String PREPROCESSED_CORPUS_DOWNLOAD_URL = "http://zil.ipipan.waw.pl/Nicolas?action=AttachFile&do=get&target=all-preprocessed.zip"; + public static final String SUMMARY_SENTENCE_IDS_DOWNLOAD_URL = "http://zil.ipipan.waw.pl/Nicolas?action=AttachFile&do=get&target=train-zero-sentence-ids.zip"; + public static final String ZERO_TRAINING_CORPUS_URL = "http://zil.ipipan.waw.pl/Nicolas?action=AttachFile&do=get&target=train-zero.tsv"; + + public static final File WORKING_DIR = new File("data"); + + public static final File ZIPPED_CORPUS_FILE = new File(WORKING_DIR, "PSC_1.0.zip"); + public static final File ZIPPED_PREPROCESSED_CORPUS_FILE = new File(WORKING_DIR, "all-preprocessed.zip"); + public static final File ZIPPED_SUMMARY_SENTENCE_IDS_FILE = new File(WORKING_DIR, "train-zero-sentence-ids.zip"); + + public static final File EXTRACTED_CORPUS_DIR = new File(WORKING_DIR, "corpus"); + public static final File EXTRACTED_CORPUS_DATA_DIR = new File(new File(EXTRACTED_CORPUS_DIR, "PSC_1.0"), "data"); + public static final File SUMMARY_SENTENCE_IDS_DIR = new File(WORKING_DIR, "train-zero-sentence-ids"); + public static final File PREPROCESSED_CORPUS_DIR = new File(WORKING_DIR, "all-preprocessed"); + public static final File GOLD_TEST_SUMMARIES_DIR = new File(WORKING_DIR, "test-gold"); + public static final File GOLD_TRAIN_SUMMARIES_DIR = new File(WORKING_DIR, "train-gold"); + public static final File OPTIMAL_SUMMARIES_DIR = new File(WORKING_DIR, "train-optimal"); + public static final File ZERO_TRAINING_CORPUS = new File(WORKING_DIR, "train-zero.tsv"); + + public static final File ARFF_DIR = new File(WORKING_DIR, "train-arff"); + public static final File MENTION_ARFF = new File(ARFF_DIR, "mentions.arff"); + public static final File SENTENCE_ARFF = new File(ARFF_DIR, "sentences.arff"); + public static final File ZERO_ARFF = new File(ARFF_DIR, "zeros.arff"); + + private PathConstants() { + } + + public static File createFolder(File folder) { + if (folder.mkdir()) { + LOG.info("Created directory at: {}.", folder.getPath()); + } else { + LOG.info("Directory already present at: {}.", folder.getPath()); + } + return folder; + } + + public static void downloadFile(String fileUrl, File targetFile) throws IOException { + if (!targetFile.exists()) { + LOG.info("Downloading file from url {} to file {} ...", fileUrl, targetFile); + FileUtils.copyURLToFile(new URL(fileUrl), targetFile); + LOG.info("done."); + } else { + LOG.info("File {} already downloaded.", targetFile); + } + } + + public static void downloadFileAndExtract(String url, File targetZipFile, File targetDir) throws IOException, ZipException { + downloadFile(url, targetZipFile); + extractZipFile(targetZipFile, targetDir); + } + + private static void extractZipFile(File targetZipFile, File targetDir) throws ZipException { + if (targetDir.exists()) { + LOG.info("Zip file {} already extracted to dir {}.", targetZipFile, targetDir); + } else { + createFolder(targetDir); + ZipFile zipFile = new ZipFile(targetZipFile); + zipFile.extractAll(targetDir.getPath()); + LOG.info("Extracted zip file: {} to dir: {}.", targetZipFile, targetDir); + } + } +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/MentionScorer.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/MentionScorer.java new file mode 100644 index 0000000..60e679a --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/MentionScorer.java @@ -0,0 +1,41 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.model; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Maps; +import com.google.common.collect.Multiset; +import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; +import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.common.Utils; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class MentionScorer { + + public Map<TMention, Double> calculateMentionScores(String optimalSummary, TText text) { + Multiset<String> tokenCounts = HashMultiset.create(Utils.tokenize(optimalSummary.toLowerCase())); + + List<TSentence> sentences = text.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(Collectors.toList()); + Map<TMention, String> mention2Orth = Utils.loadMention2Orth(sentences, true); + + return booleanTokenIntersection(mention2Orth, tokenCounts); + } + + private static Map<TMention, Double> booleanTokenIntersection(Map<TMention, String> mention2Orth, Multiset<String> tokenCounts) { + Map<TMention, Double> mention2score = Maps.newHashMap(); + for (Map.Entry<TMention, String> entry : mention2Orth.entrySet()) { + TMention mention = entry.getKey(); + String mentionOrth = mention2Orth.get(mention); + for (String token : Utils.tokenize(mentionOrth)) { + if (tokenCounts.contains(token.toLowerCase())) { + mention2score.put(mention, 1.0); + break; + } + } + mention2score.putIfAbsent(mention, 0.0); + } + return mention2score; + } +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/SentenceScorer.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/SentenceScorer.java new file mode 100644 index 0000000..61d01f0 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/SentenceScorer.java @@ -0,0 +1,32 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.model; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Maps; +import com.google.common.collect.Multiset; +import pl.waw.ipipan.zil.multiservice.thrift.types.TParagraph; +import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; +import pl.waw.ipipan.zil.multiservice.thrift.types.TText; +import pl.waw.ipipan.zil.summ.nicolas.common.Utils; + +import java.util.List; +import java.util.Map; + +public class SentenceScorer { + public Map<TSentence, Double> calculateSentenceScores(String optimalSummary, TText preprocessedText) { + Multiset<String> tokenCounts = HashMultiset.create(Utils.tokenize(optimalSummary.toLowerCase())); + + Map<TSentence, Double> sentence2score = Maps.newHashMap(); + for (TParagraph paragraph : preprocessedText.getParagraphs()) + for (TSentence sentence : paragraph.getSentences()) { + double score = 0.0; + + String orth = Utils.loadSentence2Orth(sentence); + List<String> tokens = Utils.tokenize(orth); + for (String token : tokens) { + score += tokenCounts.contains(token.toLowerCase()) ? 1.0 : 0.0; + } + sentence2score.put(sentence, score / tokens.size()); + } + return sentence2score; + } +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/Settings.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/Settings.java new file mode 100644 index 0000000..d73c945 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/Settings.java @@ -0,0 +1,39 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.model; + +import weka.classifiers.Classifier; +import weka.classifiers.trees.RandomForest; + +public class Settings { + + private static final int NUM_ITERATIONS = 20; + private static final int NUM_EXECUTION_SLOTS = 8; + private static final int SEED = 0; + + private Settings() { + } + + public static Classifier getMentionClassifier() { + RandomForest classifier = new RandomForest(); + classifier.setNumIterations(NUM_ITERATIONS); + classifier.setSeed(SEED); + classifier.setNumExecutionSlots(NUM_EXECUTION_SLOTS); + return classifier; + } + + public static Classifier getSentenceClassifier() { + RandomForest classifier = new RandomForest(); + classifier.setNumIterations(NUM_ITERATIONS); + classifier.setSeed(SEED); + classifier.setNumExecutionSlots(NUM_EXECUTION_SLOTS); + return classifier; + } + + public static Classifier getZeroClassifier() { + RandomForest classifier = new RandomForest(); + classifier.setNumIterations(NUM_ITERATIONS); + classifier.setSeed(SEED); + classifier.setNumExecutionSlots(NUM_EXECUTION_SLOTS); + return classifier; + } + +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/ZeroScorer.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/ZeroScorer.java new file mode 100644 index 0000000..241874e --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/ZeroScorer.java @@ -0,0 +1,50 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.model; + +import com.google.common.collect.Maps; +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVParser; +import org.apache.commons.csv.CSVRecord; +import org.apache.commons.csv.QuoteMode; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; +import pl.waw.ipipan.zil.summ.nicolas.features.FeatureHelper; +import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectCandidate; + +import java.io.*; +import java.util.List; +import java.util.Map; + +public class ZeroScorer { + + private static final char DELIMITER = '\t'; + + private final Map<String, Boolean> candidateEncoding2Decision = Maps.newHashMap(); + + public ZeroScorer(File zeroTrainingCorpusFile) throws IOException { + try (InputStream stream = new FileInputStream(zeroTrainingCorpusFile); + InputStreamReader reader = new InputStreamReader(stream, Constants.ENCODING); + CSVParser parser = new CSVParser(reader, CSVFormat.DEFAULT.withDelimiter(DELIMITER).withEscape('|').withQuoteMode(QuoteMode.NONE).withQuote('~'))) { + List<CSVRecord> records = parser.getRecords(); + for (CSVRecord record : records) { + String key = encode(record.get(2), record.get(3), record.get(4)); + boolean isValid = "C".equalsIgnoreCase(record.get(0)); + candidateEncoding2Decision.put(key, isValid); + } + } + } + + private String encode(String mentionOrth, String firstSentenceOrth, String secondSentenceOrth) { + return mentionOrth + DELIMITER + firstSentenceOrth + DELIMITER + secondSentenceOrth; + } + + private String encode(ZeroSubjectCandidate candidate, FeatureHelper helper) { + String mentionOrth = helper.getMentionOrth(candidate.getZeroCandidateMention()); + String firstSentenceOrth = helper.getSentenceOrth(candidate.getPreviousSentence()); + String secondSentenceOrth = helper.getSentenceOrth(candidate.getSentence()); + return encode(mentionOrth, firstSentenceOrth, secondSentenceOrth); + } + + public boolean isValidCandidate(ZeroSubjectCandidate candidate, FeatureHelper helper) { + return candidateEncoding2Decision.get(encode(candidate, helper)); + } + +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/ModelConstants.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/ModelConstants.java deleted file mode 100644 index dae10ae..0000000 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/ModelConstants.java +++ /dev/null @@ -1,43 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.train.model.common; - -import weka.classifiers.Classifier; -import weka.classifiers.trees.RandomForest; - -public class ModelConstants { - - public static final String MENTION_DATASET_PATH = "data/arff/mentions_train.arff"; - public static final String SENTENCE_DATASET_PATH = "data/arff/sentences_train.arff"; - public static final String ZERO_DATASET_PATH = "data/arff/zeros_train.arff"; - - private static final int NUM_ITERATIONS = 250; - private static final int NUM_EXECUTION_SLOTS = 8; - private static final int SEED = 0; - - private ModelConstants() { - } - - public static Classifier getMentionClassifier() { - RandomForest classifier = new RandomForest(); - classifier.setNumIterations(NUM_ITERATIONS); - classifier.setSeed(SEED); - classifier.setNumExecutionSlots(NUM_EXECUTION_SLOTS); - return classifier; - } - - public static Classifier getSentenceClassifier() { - RandomForest classifier = new RandomForest(); - classifier.setNumIterations(NUM_ITERATIONS); - classifier.setSeed(SEED); - classifier.setNumExecutionSlots(NUM_EXECUTION_SLOTS); - return classifier; - } - - public static Classifier getZeroClassifier() { - RandomForest classifier = new RandomForest(); - classifier.setNumIterations(NUM_ITERATIONS); - classifier.setSeed(SEED); - classifier.setNumExecutionSlots(NUM_EXECUTION_SLOTS); - return classifier; - } - -} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/TrainModelCommon.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/TrainModelCommon.java deleted file mode 100644 index d8c1a87..0000000 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/common/TrainModelCommon.java +++ /dev/null @@ -1,52 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.train.model.common; - -import org.apache.commons.lang3.time.StopWatch; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import weka.classifiers.Classifier; -import weka.core.Instances; -import weka.core.converters.ArffLoader; - -import java.io.File; -import java.io.FileOutputStream; -import java.io.ObjectOutputStream; -import java.util.logging.LogManager; - -@SuppressWarnings("squid:S2118") -public class TrainModelCommon { - - private static final Logger LOG = LoggerFactory.getLogger(TrainModelCommon.class); - - private static final String TARGET_MODEL_DIR = "nicolas-model/src/main/resources"; - - private TrainModelCommon() { - } - - public static void trainAndSaveModel(String datasetPath, Classifier classifier, String targetPath) throws Exception { - LogManager.getLogManager().reset(); // disable WEKA logging - - ArffLoader loader = new ArffLoader(); - loader.setFile(new File(datasetPath)); - Instances instances = loader.getDataSet(); - instances.setClassIndex(0); - LOG.info("{} instances loaded.", instances.size()); - LOG.info("{} attributes for each instance.", instances.numAttributes()); - - StopWatch watch = new StopWatch(); - watch.start(); - - LOG.info("Building classifier..."); - classifier.buildClassifier(instances); - LOG.info("...done. Build classifier: {}", classifier); - - String target = TARGET_MODEL_DIR + targetPath; - LOG.info("Saving classifier at: {}", target); - try (ObjectOutputStream oos = new ObjectOutputStream( - new FileOutputStream(target))) { - oos.writeObject(classifier); - } - - watch.stop(); - LOG.info("Elapsed time: {}", watch); - } -} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/MentionScorer.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/MentionScorer.java deleted file mode 100644 index 29eaa5f..0000000 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/MentionScorer.java +++ /dev/null @@ -1,41 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.train.model.mention; - -import com.google.common.collect.HashMultiset; -import com.google.common.collect.Maps; -import com.google.common.collect.Multiset; -import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; -import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; -import pl.waw.ipipan.zil.multiservice.thrift.types.TText; -import pl.waw.ipipan.zil.summ.nicolas.common.Utils; - -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -public class MentionScorer { - - public Map<TMention, Double> calculateMentionScores(String optimalSummary, TText text) { - Multiset<String> tokenCounts = HashMultiset.create(Utils.tokenize(optimalSummary.toLowerCase())); - - List<TSentence> sentences = text.getParagraphs().stream().flatMap(p -> p.getSentences().stream()).collect(Collectors.toList()); - Map<TMention, String> mention2Orth = Utils.loadMention2Orth(sentences, true); - - return booleanTokenIntersection(mention2Orth, tokenCounts); - } - - private static Map<TMention, Double> booleanTokenIntersection(Map<TMention, String> mention2Orth, Multiset<String> tokenCounts) { - Map<TMention, Double> mention2score = Maps.newHashMap(); - for (Map.Entry<TMention, String> entry : mention2Orth.entrySet()) { - TMention mention = entry.getKey(); - String mentionOrth = mention2Orth.get(mention); - for (String token : Utils.tokenize(mentionOrth)) { - if (tokenCounts.contains(token.toLowerCase())) { - mention2score.put(mention, 1.0); - break; - } - } - mention2score.putIfAbsent(mention, 0.0); - } - return mention2score; - } -} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/TrainMentionModel.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/TrainMentionModel.java deleted file mode 100644 index c9104c5..0000000 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/mention/TrainMentionModel.java +++ /dev/null @@ -1,20 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.train.model.mention; - -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.TrainModelCommon; -import weka.classifiers.Classifier; - -public class TrainMentionModel { - - private TrainMentionModel() { - } - - public static void main(String[] args) throws Exception { - Classifier classifier = ModelConstants.getMentionClassifier(); - String datasetPath = ModelConstants.MENTION_DATASET_PATH; - String targetPath = Constants.MENTION_MODEL_RESOURCE_PATH; - TrainModelCommon.trainAndSaveModel(datasetPath, classifier, targetPath); - } - -} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/SentenceScorer.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/SentenceScorer.java deleted file mode 100644 index ef985a5..0000000 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/SentenceScorer.java +++ /dev/null @@ -1,33 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.train.model.sentence; - -import com.google.common.collect.HashMultiset; -import com.google.common.collect.Maps; -import com.google.common.collect.Multiset; -import com.google.common.collect.Sets; -import pl.waw.ipipan.zil.multiservice.thrift.types.TParagraph; -import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; -import pl.waw.ipipan.zil.multiservice.thrift.types.TText; -import pl.waw.ipipan.zil.summ.nicolas.common.Utils; - -import java.util.List; -import java.util.Map; - -public class SentenceScorer { - public Map<TSentence, Double> calculateSentenceScores(String optimalSummary, TText preprocessedText) { - Multiset<String> tokenCounts = HashMultiset.create(Utils.tokenize(optimalSummary.toLowerCase())); - - Map<TSentence, Double> sentence2score = Maps.newHashMap(); - for (TParagraph paragraph : preprocessedText.getParagraphs()) - for (TSentence sentence : paragraph.getSentences()) { - double score = 0.0; - - String orth = Utils.loadSentence2Orth(sentence); - List<String> tokens = Utils.tokenize(orth); - for (String token : tokens) { - score += tokenCounts.contains(token.toLowerCase()) ? 1.0 : 0.0; - } - sentence2score.put(sentence, score / tokens.size()); - } - return sentence2score; - } -} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/TrainSentenceModel.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/TrainSentenceModel.java deleted file mode 100644 index b79e6ca..0000000 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/sentence/TrainSentenceModel.java +++ /dev/null @@ -1,20 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.train.model.sentence; - -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.TrainModelCommon; -import weka.classifiers.Classifier; - -public class TrainSentenceModel { - - private TrainSentenceModel() { - } - - public static void main(String[] args) throws Exception { - Classifier classifier = ModelConstants.getSentenceClassifier(); - String datasetPath = ModelConstants.SENTENCE_DATASET_PATH; - String targetPath = Constants.SENTENCE_MODEL_RESOURCE_PATH; - TrainModelCommon.trainAndSaveModel(datasetPath, classifier, targetPath); - } - -} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/TrainZeroModel.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/TrainZeroModel.java deleted file mode 100644 index 7770a29..0000000 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/TrainZeroModel.java +++ /dev/null @@ -1,20 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.train.model.zero; - -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.TrainModelCommon; -import weka.classifiers.Classifier; - -public class TrainZeroModel { - - private TrainZeroModel() { - } - - public static void main(String[] args) throws Exception { - Classifier classifier = ModelConstants.getZeroClassifier(); - String datasetPath = ModelConstants.ZERO_DATASET_PATH; - String targetPath = Constants.ZERO_MODEL_RESOURCE_PATH; - TrainModelCommon.trainAndSaveModel(datasetPath, classifier, targetPath); - } - -} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/ZeroScorer.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/ZeroScorer.java deleted file mode 100644 index c88eb31..0000000 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/model/zero/ZeroScorer.java +++ /dev/null @@ -1,52 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.train.model.zero; - -import com.google.common.collect.Maps; -import org.apache.commons.csv.CSVFormat; -import org.apache.commons.csv.CSVParser; -import org.apache.commons.csv.CSVRecord; -import org.apache.commons.csv.QuoteMode; -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; -import pl.waw.ipipan.zil.summ.nicolas.features.FeatureHelper; -import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectCandidate; - -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.util.List; -import java.util.Map; - -public class ZeroScorer { - - private static final char DELIMITER = '\t'; - - private final Map<String, Boolean> candidateEncoding2Decision = Maps.newHashMap(); - - public ZeroScorer(String goldZerosResourcePath) throws IOException { - try (InputStream stream = ZeroScorer.class.getResourceAsStream(goldZerosResourcePath); - InputStreamReader reader = new InputStreamReader(stream, Constants.ENCODING); - CSVParser parser = new CSVParser(reader, CSVFormat.DEFAULT.withDelimiter(DELIMITER).withEscape('|').withQuoteMode(QuoteMode.NONE).withQuote('~'))) { - List<CSVRecord> records = parser.getRecords(); - for (CSVRecord record : records) { - String key = encode(record.get(2), record.get(3), record.get(4)); - boolean isValid = "C".equalsIgnoreCase(record.get(0)); - candidateEncoding2Decision.put(key, isValid); - } - } - } - - private String encode(String mentionOrth, String firstSentenceOrth, String secondSentenceOrth) { - return mentionOrth + DELIMITER + firstSentenceOrth + DELIMITER + secondSentenceOrth; - } - - private String encode(ZeroSubjectCandidate candidate, FeatureHelper helper) { - String mentionOrth = helper.getMentionOrth(candidate.getZeroCandidateMention()); - String firstSentenceOrth = helper.getSentenceOrth(candidate.getPreviousSentence()); - String secondSentenceOrth = helper.getSentenceOrth(candidate.getSentence()); - return encode(mentionOrth, firstSentenceOrth, secondSentenceOrth); - } - - public boolean isValidCandidate(ZeroSubjectCandidate candidate, FeatureHelper helper) { - return candidateEncoding2Decision.get(encode(candidate, helper)); - } - -} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/CreateOptimalSummaries.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/CreateOptimalSummaries.java new file mode 100644 index 0000000..aeba701 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/CreateOptimalSummaries.java @@ -0,0 +1,88 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; + + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.collect.Multiset; +import org.apache.commons.io.FileUtils; +import pl.waw.ipipan.zil.summ.eval.Main; +import pl.waw.ipipan.zil.summ.eval.rouge.RougeN; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; + +import java.io.File; +import java.io.IOException; +import java.util.*; +import java.util.stream.Collectors; + +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; + +public class CreateOptimalSummaries { + + private static final int ROUGE_N = 1; + private static final String SUMMARY_SUFFIX = "_optimal.txt"; + + private CreateOptimalSummaries() { + } + + public static void main(String[] args) throws IOException { + createFolder(OPTIMAL_SUMMARIES_DIR); + Map<String, Set<File>> id2goldSummaryFiles = Main.loadFiles(GOLD_TRAIN_SUMMARIES_DIR); + createUpperBoundSummaries(id2goldSummaryFiles, OPTIMAL_SUMMARIES_DIR); + } + + private static void createUpperBoundSummaries(Map<String, Set<File>> id2goldSummaryFiles, File targetDir) + throws IOException { + for (Map.Entry<String, Set<File>> entry : id2goldSummaryFiles.entrySet()) { + String id = entry.getKey(); + Set<File> goldSummaryFiles = entry.getValue(); + Set<String> goldSummaries = goldSummaryFiles.stream().map(Main::loadSummaryFromFile) + .collect(Collectors.toSet()); + int averageGoldWordCount = (int) goldSummaries.stream().mapToDouble(s -> RougeN.tokenize(s).size()) + .average().orElse(0); + String optimalSummary = createOptimalSummary(goldSummaries, ROUGE_N, averageGoldWordCount); + File targetFile = new File(targetDir, id + SUMMARY_SUFFIX); + FileUtils.writeStringToFile(targetFile, optimalSummary, Constants.ENCODING); + } + } + + private static String createOptimalSummary(Set<String> goldSummaries, int i, int averageGoldWordCount) { + Map<List<String>, List<Integer>> ngram2counts = Maps.newHashMap(); + for (String goldSummary : goldSummaries) { + Multiset<List<String>> goldNgrams = HashMultiset.create(); + RougeN.countNgrams(goldNgrams, i, goldSummary); + for (List<String> ngram : goldNgrams.elementSet()) { + ngram2counts.putIfAbsent(ngram, Lists.newArrayList()); + ngram2counts.get(ngram).add(goldNgrams.count(ngram)); + } + } + + int summaryWordCount = 0; + StringBuilder summary = new StringBuilder(); + while (averageGoldWordCount >= summaryWordCount) { + List<String> ngram = pickBestNgram(ngram2counts); + summary.append(" ").append(String.join(" ", ngram)); + summaryWordCount += ngram.size(); + } + return summary.toString().trim(); + } + + private static List<String> pickBestNgram(Map<List<String>, List<Integer>> ngram2counts) { + Optional<List<String>> optional = ngram2counts.keySet().stream() + .sorted(Comparator.comparing((List<String> ngram) -> ngram2counts.get(ngram).size()).reversed()).findFirst(); + if (!optional.isPresent()) { + throw new IllegalArgumentException("No more ngrams to pick!"); + } + List<String> optimalNgram = optional.get(); + List<Integer> counts = ngram2counts.get(optimalNgram); + List<Integer> newCounts = Lists.newArrayList(); + for (Integer c : counts) + if (c > 1) + newCounts.add(c - 1); + if (newCounts.isEmpty()) + ngram2counts.remove(optimalNgram); + else + ngram2counts.put(optimalNgram, newCounts); + return optimalNgram; + } +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadAndPreprocessCorpus.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadAndPreprocessCorpus.java deleted file mode 100644 index 5ca1991..0000000 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadAndPreprocessCorpus.java +++ /dev/null @@ -1,59 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; - -import net.lingala.zip4j.core.ZipFile; -import org.apache.commons.io.FileUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.net.URL; - -public class DownloadAndPreprocessCorpus { - - private static final Logger LOG = LoggerFactory.getLogger(DownloadAndPreprocessCorpus.class); - - private static final String WORKING_DIR = "data"; - private static final String CORPUS_DOWNLOAD_URL = "http://zil.ipipan.waw.pl/PolishSummariesCorpus?action=AttachFile&do=get&target=PSC_1.0.zip"; - - private DownloadAndPreprocessCorpus() { - } - - public static void main(String[] args) throws Exception { - File workDir = createFolder(WORKING_DIR); - - File corpusFile = new File(workDir, "corpus.zip"); - if (!corpusFile.exists()) { - LOG.info("Downloading corpus file..."); - FileUtils.copyURLToFile(new URL(CORPUS_DOWNLOAD_URL), corpusFile); - LOG.info("done."); - } else { - LOG.info("Corpus file already downloaded."); - } - - File extractedCorpusDir = new File(workDir, "corpus"); - if (extractedCorpusDir.exists()) { - LOG.info("Corpus file already extracted."); - } else { - ZipFile zipFile = new ZipFile(corpusFile); - zipFile.extractAll(extractedCorpusDir.getPath()); - LOG.info("Extracted corpus file."); - } - - File pscDir = new File(extractedCorpusDir, "PSC_1.0"); - File dataDir = new File(pscDir, "data"); - - File preprocessed = new File(WORKING_DIR, "preprocessed"); - createFolder(preprocessed.getPath()); - Preprocess.main(new String[]{dataDir.getPath(), preprocessed.getPath()}); - } - - private static File createFolder(String path) { - File folder = new File(path); - if (folder.mkdir()) { - LOG.info("Created directory at: {}.", path); - } else { - LOG.info("Directory already present at: {}.", path); - } - return folder; - } -} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadCorpus.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadCorpus.java new file mode 100644 index 0000000..7e4b548 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadCorpus.java @@ -0,0 +1,15 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; + +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; + +public class DownloadCorpus { + + private DownloadCorpus() { + } + + public static void main(String[] args) throws Exception { + createFolder(WORKING_DIR); + downloadFileAndExtract(CORPUS_DOWNLOAD_URL, ZIPPED_CORPUS_FILE, EXTRACTED_CORPUS_DIR); + } + +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadTrainingResources.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadTrainingResources.java new file mode 100644 index 0000000..980fa5c --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/DownloadTrainingResources.java @@ -0,0 +1,18 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; + +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; + +public class DownloadTrainingResources { + + private DownloadTrainingResources() { + } + + public static void main(String[] args) throws Exception { + createFolder(WORKING_DIR); + downloadFileAndExtract(PREPROCESSED_CORPUS_DOWNLOAD_URL, ZIPPED_PREPROCESSED_CORPUS_FILE, PREPROCESSED_CORPUS_DIR); + downloadFileAndExtract(SUMMARY_SENTENCE_IDS_DOWNLOAD_URL, ZIPPED_SUMMARY_SENTENCE_IDS_FILE, SUMMARY_SENTENCE_IDS_DIR); + downloadFile(ZERO_TRAINING_CORPUS_URL, ZERO_TRAINING_CORPUS); + } + + +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/ExtractGoldSummaries.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/ExtractGoldSummaries.java new file mode 100644 index 0000000..19d5171 --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/ExtractGoldSummaries.java @@ -0,0 +1,56 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; + +import pl.waw.ipipan.zil.summ.nicolas.common.Utils; +import pl.waw.ipipan.zil.summ.pscapi.io.PSC_IO; +import pl.waw.ipipan.zil.summ.pscapi.xml.Summary; +import pl.waw.ipipan.zil.summ.pscapi.xml.Text; + +import javax.xml.bind.JAXBException; +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; + +public class ExtractGoldSummaries { + + private static final String ABSTRACT_SUMMARY_TYPE = "abstract"; + private static final int SUMMARY_RATIO = 20; + + private static final Predicate<Text> IS_TEST = text -> text.getSummaries().getSummary().stream().anyMatch(summary -> summary.getType().equals(ABSTRACT_SUMMARY_TYPE)); + + + private ExtractGoldSummaries() { + } + + public static void main(String[] args) throws IOException, JAXBException { + createFolder(GOLD_TEST_SUMMARIES_DIR); + createFolder(GOLD_TRAIN_SUMMARIES_DIR); + + File[] files = EXTRACTED_CORPUS_DATA_DIR.listFiles(); + if (files != null) { + for (File file : files) { + Text text = PSC_IO.readText(file); + + List<Summary> goldSummaries; + Stream<Summary> stream = text.getSummaries().getSummary().stream(); + boolean isTest = IS_TEST.test(text); + if (isTest) { + goldSummaries = stream.filter(summary -> summary.getType().equals(ABSTRACT_SUMMARY_TYPE) && summary.getRatio().equals(SUMMARY_RATIO)).collect(Collectors.toList()); + } else { + goldSummaries = stream.filter(summary -> !summary.getType().equals(ABSTRACT_SUMMARY_TYPE) && summary.getRatio().equals(SUMMARY_RATIO)).collect(Collectors.toList()); + } + + for (Summary summary : goldSummaries) { + File targetDir = isTest ? GOLD_TEST_SUMMARIES_DIR : GOLD_TRAIN_SUMMARIES_DIR; + File targetFile = new File(targetDir, text.getId() + "_" + summary.getAuthor() + ".txt"); + Utils.writeStringToFile(summary.getBody(), targetFile); + } + } + } + } + +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/PrepareTrainingData.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/PrepareTrainingData.java index 33865d5..b15a291 100644 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/PrepareTrainingData.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/PrepareTrainingData.java @@ -11,22 +11,18 @@ import pl.waw.ipipan.zil.multiservice.thrift.types.TMention; import pl.waw.ipipan.zil.multiservice.thrift.types.TSentence; import pl.waw.ipipan.zil.multiservice.thrift.types.TText; import pl.waw.ipipan.zil.summ.nicolas.InstanceUtils; -import pl.waw.ipipan.zil.summ.nicolas.common.Constants; import pl.waw.ipipan.zil.summ.nicolas.common.ThriftUtils; import pl.waw.ipipan.zil.summ.nicolas.common.Utils; import pl.waw.ipipan.zil.summ.nicolas.features.FeatureHelper; import pl.waw.ipipan.zil.summ.nicolas.mention.MentionFeatureExtractor; -import pl.waw.ipipan.zil.summ.nicolas.mention.MentionModel; import pl.waw.ipipan.zil.summ.nicolas.sentence.SentenceFeatureExtractor; -import pl.waw.ipipan.zil.summ.nicolas.train.model.common.ModelConstants; -import pl.waw.ipipan.zil.summ.nicolas.train.model.mention.MentionScorer; -import pl.waw.ipipan.zil.summ.nicolas.train.model.sentence.SentenceScorer; -import pl.waw.ipipan.zil.summ.nicolas.train.model.zero.ZeroScorer; +import pl.waw.ipipan.zil.summ.nicolas.train.model.MentionScorer; +import pl.waw.ipipan.zil.summ.nicolas.train.model.SentenceScorer; +import pl.waw.ipipan.zil.summ.nicolas.train.model.ZeroScorer; import pl.waw.ipipan.zil.summ.nicolas.zero.CandidateFinder; import pl.waw.ipipan.zil.summ.nicolas.zero.InstanceCreator; import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroFeatureExtractor; import pl.waw.ipipan.zil.summ.nicolas.zero.ZeroSubjectCandidate; -import weka.classifiers.Classifier; import weka.core.Instance; import weka.core.Instances; import weka.core.converters.ArffSaver; @@ -34,31 +30,26 @@ import weka.core.converters.ArffSaver; import java.io.File; import java.io.FileReader; import java.io.IOException; -import java.io.InputStream; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.Predicate; import java.util.stream.Collectors; +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; + public class PrepareTrainingData { private static final Logger LOG = LoggerFactory.getLogger(PrepareTrainingData.class); - private static final String THRIFT_TEXTS_PATH = "data/preprocessed"; - private static final String OPTIMAL_SUMMARIES_DIR_PATH = "data/summaries-optimal"; - private static final String SUMMARY_SENTENCE_IDS = "data/summaries-sentence-ids"; - - private static final String ZERO_TRAINING_DATA_RESOURCE_PATH = "/pl/waw/ipipan/zil/summ/nicolas/train/train_zero.tsv"; - private static final String TRAIN_TEXT_IDS_RESOURCE_PATH = "/pl/waw/ipipan/zil/summ/nicolas/train/train_text_ids.txt"; - private PrepareTrainingData() { } public static void main(String[] args) throws Exception { Set<String> trainTextIds = loadTrainTextIds(); - Map<String, TText> id2preprocessedText = ThriftUtils.loadThriftTextsFromFolder(new File(THRIFT_TEXTS_PATH), trainTextIds::contains); + Map<String, TText> id2preprocessedText = ThriftUtils.loadThriftTextsFromFolder(PREPROCESSED_CORPUS_DIR, trainTextIds::contains); Map<String, String> id2optimalSummary = loadOptimalSummaries(trainTextIds::contains); prepareMentionsDataset(id2preprocessedText, id2optimalSummary); @@ -66,7 +57,7 @@ public class PrepareTrainingData { prepareZerosDataset(id2preprocessedText); } - public static void prepareMentionsDataset(Map<String, TText> id2preprocessedText, Map<String, String> id2optimalSummary) throws IOException { + private static void prepareMentionsDataset(Map<String, TText> id2preprocessedText, Map<String, String> id2optimalSummary) throws IOException { MentionScorer mentionScorer = new MentionScorer(); MentionFeatureExtractor featureExtractor = new MentionFeatureExtractor(); @@ -74,7 +65,7 @@ public class PrepareTrainingData { int i = 1; for (Map.Entry<String, TText> entry : id2preprocessedText.entrySet()) { - LOG.info("{}/{}", i++, id2preprocessedText.size()); + logProgress(id2preprocessedText, i++); String id = entry.getKey(); TText preprocessedText = entry.getValue(); @@ -92,29 +83,33 @@ public class PrepareTrainingData { instances.add(instance); } } - saveInstancesToFile(instances, new File(ModelConstants.MENTION_DATASET_PATH)); + saveInstancesToFile(instances, MENTION_ARFF); } - private static Set<String> loadTrainTextIds() throws IOException { - try (InputStream inputStream = PrepareTrainingData.class.getResourceAsStream(TRAIN_TEXT_IDS_RESOURCE_PATH)) { - List<String> testTextIds = IOUtils.readLines(inputStream, Constants.ENCODING); - return testTextIds.stream().map(String::trim).collect(Collectors.toSet()); + private static void logProgress(Map<String, TText> id2preprocessedText, int i) { + if (i % 10 == 0) { + LOG.info("{}/{}", i, id2preprocessedText.size()); } } - public static void prepareSentencesDataset(Map<String, TText> id2preprocessedText, Map<String, String> id2optimalSummary) throws Exception { + private static Set<String> loadTrainTextIds() throws IOException { + File[] optimalSummaries = OPTIMAL_SUMMARIES_DIR.listFiles(); + if (optimalSummaries == null) + throw new IOException("No optimal summaries at " + OPTIMAL_SUMMARIES_DIR); + + return Arrays.stream(optimalSummaries).map(file -> file.getName().split("_")[0]).collect(Collectors.toSet()); + } + + private static void prepareSentencesDataset(Map<String, TText> id2preprocessedText, Map<String, String> id2optimalSummary) throws Exception { SentenceScorer sentenceScorer = new SentenceScorer(); SentenceFeatureExtractor featureExtractor = new SentenceFeatureExtractor(); Instances instances = Utils.createNewInstances(featureExtractor.getAttributesList()); - Classifier classifier = Utils.loadClassifierFromResource(Constants.MENTION_MODEL_RESOURCE_PATH); - MentionFeatureExtractor mentionFeatureExtractor = new MentionFeatureExtractor(); - int i = 1; for (String textId : id2preprocessedText.keySet()) { - LOG.info("{}/{}", i++, id2preprocessedText.size()); + logProgress(id2preprocessedText, i++); TText preprocessedText = id2preprocessedText.get(textId); String optimalSummary = id2optimalSummary.get(textId); @@ -123,9 +118,7 @@ public class PrepareTrainingData { Map<TSentence, Double> sentence2score = sentenceScorer.calculateSentenceScores(optimalSummary, preprocessedText); Set<TMention> goodMentions - = MentionModel.detectGoodMentions(classifier, mentionFeatureExtractor, preprocessedText); -// Set<TMention> goodMentions -// = Utils.loadGoldGoodMentions(textId, preprocessedText, true); + = loadGoldGoodMentions(textId, preprocessedText, id2optimalSummary); Map<TSentence, Instance> sentence2instance = InstanceUtils.extractInstancesFromSentences(preprocessedText, featureExtractor, goodMentions); for (Map.Entry<TSentence, Instance> entry : sentence2instance.entrySet()) { @@ -136,21 +129,31 @@ public class PrepareTrainingData { instances.add(instance); } } - saveInstancesToFile(instances, new File(ModelConstants.SENTENCE_DATASET_PATH)); + saveInstancesToFile(instances, SENTENCE_ARFF); + } + + private static Set<TMention> loadGoldGoodMentions(String id, TText text, Map<String, String> id2optimalSummary) throws IOException { + String optimalSummary = id2optimalSummary.get(id); + + MentionScorer scorer = new MentionScorer(); + Map<TMention, Double> mention2score = scorer.calculateMentionScores(optimalSummary, text); + + mention2score.keySet().removeIf(tMention -> mention2score.get(tMention) < 1); + return mention2score.keySet(); } - public static void prepareZerosDataset(Map<String, TText> id2preprocessedText) throws IOException { + private static void prepareZerosDataset(Map<String, TText> id2preprocessedText) throws IOException { - Map<String, Set<String>> id2sentIds = loadSentenceIds(SUMMARY_SENTENCE_IDS); + Map<String, Set<String>> id2sentIds = loadSentenceIds(SUMMARY_SENTENCE_IDS_DIR); - ZeroScorer zeroScorer = new ZeroScorer(ZERO_TRAINING_DATA_RESOURCE_PATH); + ZeroScorer zeroScorer = new ZeroScorer(ZERO_TRAINING_CORPUS); ZeroFeatureExtractor featureExtractor = new ZeroFeatureExtractor(); Instances instances = Utils.createNewInstances(featureExtractor.getAttributesList()); int i = 1; for (Map.Entry<String, TText> entry : id2preprocessedText.entrySet()) { - LOG.info(i++ + "/" + id2preprocessedText.size()); + logProgress(id2preprocessedText, i++); String textId = entry.getKey(); @@ -170,12 +173,12 @@ public class PrepareTrainingData { } } - saveInstancesToFile(instances, new File(ModelConstants.ZERO_DATASET_PATH)); + saveInstancesToFile(instances, ZERO_ARFF); } - private static Map<String, Set<String>> loadSentenceIds(String idsPath) throws IOException { + private static Map<String, Set<String>> loadSentenceIds(File idsFolder) throws IOException { Map<String, Set<String>> result = Maps.newHashMap(); - File[] files = new File(idsPath).listFiles(); + File[] files = idsFolder.listFiles(); if (files != null) for (File f : files) { String id = f.getName().split("_")[0]; @@ -194,7 +197,7 @@ public class PrepareTrainingData { private static Map<String, String> loadOptimalSummaries(Predicate<String> idFilter) throws IOException { Map<String, String> id2optimalSummary = Maps.newHashMap(); - File[] files = new File(OPTIMAL_SUMMARIES_DIR_PATH).listFiles(); + File[] files = OPTIMAL_SUMMARIES_DIR.listFiles(); if (files != null) for (File optimalSummaryFile : files) { String textId = optimalSummaryFile.getName().split("_")[0]; diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/Preprocess.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/Preprocess.java deleted file mode 100644 index 753753e..0000000 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/Preprocess.java +++ /dev/null @@ -1,63 +0,0 @@ -package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import pl.waw.ipipan.zil.summ.nicolas.multiservice.Preprocessor; -import pl.waw.ipipan.zil.summ.pscapi.io.PSC_IO; -import pl.waw.ipipan.zil.summ.pscapi.xml.Text; - -import java.io.File; -import java.util.Arrays; - -public class Preprocess { - - private static final Logger LOG = LoggerFactory.getLogger(Preprocess.class); - - private static final String CORPUS_FILE_SUFFIX = ".xml"; - private static final String OUTPUT_FILE_SUFFIX = ".thrift"; - - private Preprocess() { - } - - public static void main(String[] args) { - if (args.length != 2) { - LOG.error("Wrong usage! Try " + Preprocess.class.getSimpleName() + " dirWithCorpusFiles targetDir"); - return; - } - File corpusDir = new File(args[0]); - if (!corpusDir.isDirectory()) { - LOG.error("Corpus directory does not exist: {}", corpusDir); - return; - } - File targetDir = new File(args[1]); - if (!targetDir.isDirectory()) { - LOG.error("Target directory does not exist: {}", targetDir); - return; - } - - int ok = 0; - int err = 0; - File[] files = corpusDir.listFiles(f -> f.getName().endsWith(CORPUS_FILE_SUFFIX)); - if (files == null || files.length == 0) { - LOG.error("No corpus files found at: {}", corpusDir); - return; - } - Arrays.sort(files); - - Preprocessor processor = new Preprocessor(); - - for (File file : files) { - try { - Text text = PSC_IO.readText(file); - File targetFile = new File(targetDir, file.getName().replaceFirst(CORPUS_FILE_SUFFIX + "$", OUTPUT_FILE_SUFFIX)); - processor.preprocessToFile(text.getBody(), targetFile); - ok++; - } catch (Exception e) { - err++; - LOG.error("Problem with text in " + file + ", " + e); - } - } - LOG.info("{} texts processed successfully.", ok); - LOG.info("{} texts with errors.", err); - } -} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/PreprocessCorpus.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/PreprocessCorpus.java new file mode 100644 index 0000000..449454b --- /dev/null +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/PreprocessCorpus.java @@ -0,0 +1,53 @@ +package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pl.waw.ipipan.zil.summ.nicolas.multiservice.Preprocessor; +import pl.waw.ipipan.zil.summ.pscapi.io.PSC_IO; +import pl.waw.ipipan.zil.summ.pscapi.xml.Text; + +import java.io.File; +import java.util.Arrays; + +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; + +public class PreprocessCorpus { + + private static final Logger LOG = LoggerFactory.getLogger(PreprocessCorpus.class); + + private static final String CORPUS_FILE_SUFFIX = ".xml"; + private static final String OUTPUT_FILE_SUFFIX = ".thrift"; + + private PreprocessCorpus() { + } + + public static void main(String[] args) { + + createFolder(PREPROCESSED_CORPUS_DIR); + + int ok = 0; + int err = 0; + File[] files = EXTRACTED_CORPUS_DATA_DIR.listFiles(f -> f.getName().endsWith(CORPUS_FILE_SUFFIX)); + if (files == null || files.length == 0) { + LOG.error("No corpus files found at: {}", EXTRACTED_CORPUS_DATA_DIR); + return; + } + Arrays.sort(files); + + Preprocessor processor = new Preprocessor(); + + for (File file : files) { + try { + Text text = PSC_IO.readText(file); + File targetFile = new File(PREPROCESSED_CORPUS_DIR, file.getName().replaceFirst(CORPUS_FILE_SUFFIX + "$", OUTPUT_FILE_SUFFIX)); + processor.preprocessToFile(text.getBody(), targetFile); + ok++; + } catch (Exception e) { + err++; + LOG.error("Problem with text in " + file + ", " + e); + } + } + LOG.info("{} texts processed successfully.", ok); + LOG.info("{} texts with errors.", err); + } +} diff --git a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/TrainAllModels.java b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/TrainAllModels.java index ea1158a..10dfa40 100644 --- a/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/TrainAllModels.java +++ b/nicolas-train/src/main/java/pl/waw/ipipan/zil/summ/nicolas/train/pipeline/TrainAllModels.java @@ -1,17 +1,61 @@ package pl.waw.ipipan.zil.summ.nicolas.train.pipeline; -import pl.waw.ipipan.zil.summ.nicolas.train.model.mention.TrainMentionModel; -import pl.waw.ipipan.zil.summ.nicolas.train.model.sentence.TrainSentenceModel; -import pl.waw.ipipan.zil.summ.nicolas.train.model.zero.TrainZeroModel; +import org.apache.commons.lang3.time.StopWatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import pl.waw.ipipan.zil.summ.nicolas.common.Constants; +import pl.waw.ipipan.zil.summ.nicolas.train.model.Settings; +import weka.classifiers.Classifier; +import weka.core.Instances; +import weka.core.converters.ArffLoader; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.ObjectOutputStream; +import java.util.logging.LogManager; + +import static pl.waw.ipipan.zil.summ.nicolas.train.PathConstants.*; public class TrainAllModels { + private static final Logger LOG = LoggerFactory.getLogger(TrainAllModels.class); + + private static final String TARGET_MODEL_DIR = "nicolas-model/src/main/resources"; + private TrainAllModels() { } public static void main(String[] args) throws Exception { - TrainMentionModel.main(args); - TrainSentenceModel.main(args); - TrainZeroModel.main(args); + trainAndSaveModel(MENTION_ARFF, Settings.getMentionClassifier(), Constants.MENTION_MODEL_RESOURCE_PATH); + trainAndSaveModel(SENTENCE_ARFF, Settings.getSentenceClassifier(), Constants.SENTENCE_MODEL_RESOURCE_PATH); + trainAndSaveModel(ZERO_ARFF, Settings.getZeroClassifier(), Constants.ZERO_MODEL_RESOURCE_PATH); + } + + private static void trainAndSaveModel(File dataset, Classifier classifier, String targetPath) throws Exception { + LogManager.getLogManager().reset(); // disable WEKA logging + + ArffLoader loader = new ArffLoader(); + loader.setFile(dataset); + Instances instances = loader.getDataSet(); + instances.setClassIndex(0); + LOG.info("{} instances loaded.", instances.size()); + LOG.info("{} attributes for each instance.", instances.numAttributes()); + + StopWatch watch = new StopWatch(); + watch.start(); + + LOG.info("Building classifier..."); + classifier.buildClassifier(instances); + LOG.info("...done. Build classifier: {}", classifier); + + String target = TARGET_MODEL_DIR + targetPath; + LOG.info("Saving classifier at: {}", target); + try (ObjectOutputStream oos = new ObjectOutputStream( + new FileOutputStream(target))) { + oos.writeObject(classifier); + } + + watch.stop(); + LOG.info("Elapsed time: {}", watch); } } diff --git a/nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/train_text_ids.txt b/nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/train_text_ids.txt deleted file mode 100644 index 6b0ff86..0000000 --- a/nicolas-train/src/main/resources/pl/waw/ipipan/zil/summ/nicolas/train/train_text_ids.txt +++ /dev/null @@ -1,415 +0,0 @@ -199704210011 -199704210013 -199704250031 -199704260017 -199801030156 -199801100009 -199801150038 -199801150133 -199801170001 -199801170129 -199801170130 -199801200002 -199801200132 -199801210007 -199801220030 -199801220127 -199801230001 -199801230095 -199801240116 -199801240123 -199801260113 -199801270108 -199801280128 -199801290020 -199801310032 -199802040201 -199901180149 -199901190049 -199901230088 -199901250006 -199901250008 -199901250111 -199901250113 -199901300064 -199901300098 -199902240123 -199906220027 -199906220037 -199906220038 -199906220056 -199906220065 -199906230040 -199906230052 -199906240040 -199906240088 -199906250007 -199906250091 -199906260015 -199906260018 -199906260038 -199907030016 -199907030018 -199907030042 -199907030059 -199907050032 -199907050040 -199907050047 -199907050071 -199907270095 -199907270137 -199907270145 -199909210045 -199909250054 -199909300064 -199909300065 -199909300066 -199910020049 -199910020050 -199910090047 -199910090049 -199910090051 -199910110055 -199910110057 -199910210058 -199910210059 -199910270041 -199910280054 -199910280055 -199910280057 -199910300026 -199911030039 -199911030040 -199911030041 -199911060031 -199911060042 -199911060043 -199911080054 -199911080055 -199911080056 -199911100061 -199911100062 -199911100063 -199911130036 -199911130037 -199911130038 -199911180042 -199911180043 -199911180044 -199911220059 -199911220061 -199911220066 -199911230041 -199911240035 -199911240037 -199911240038 -199911250055 -199911250057 -199912020059 -199912090045 -199912090047 -199912090061 -199912110041 -199912110042 -199912130055 -199912130057 -199912170065 -199912180052 -199912210018 -199912210037 -199912210040 -199912220045 -199912220046 -199912220047 -199912230058 -199912230059 -199912230097 -199912280028 -199912280044 -199912280045 -199912310085 -199912310087 -200001030047 -200001030106 -200001040030 -200001040031 -200001060052 -200001060053 -200001060055 -200001070062 -200001070066 -200001080040 -200001080041 -200001140061 -200001140064 -200001170049 -200001170051 -200001170052 -200001170053 -200001180040 -200001200056 -200001220023 -200001220118 -200001240016 -200001290042 -200001310048 -200001310049 -200001310050 -200001310054 -200002090042 -200002090043 -200002120045 -200002120046 -200002160046 -200002160047 -200002250063 -200002250065 -200002250066 -200002290044 -200002290045 -200002290046 -200002290047 -200002290048 -200003010058 -200003010059 -200003060054 -200003060055 -200003060057 -200003110047 -200003110048 -200003110049 -200003210044 -200003210045 -200004120021 -200004120022 -200004120023 -200004150048 -200004150049 -200004150050 -200004170026 -200004170065 -200004220044 -200004220045 -200004220046 -200004220047 -200004220048 -200005060030 -200005150055 -200005150059 -200005300045 -200005300047 -200005300048 -200006010065 -200006010066 -200006010067 -200006050056 -200006050057 -200006050058 -200006050059 -200006050061 -200006050068 -200006070056 -200006080033 -200006120031 -200006130055 -200006130057 -200006130059 -200006260069 -200006260071 -200006270059 -200007120068 -200007120070 -200007120072 -200007170026 -200007180051 -200007240034 -200007270050 -200007280033 -200008040071 -200008040073 -200008250077 -200008250079 -200008260055 -200008310046 -200010120066 -200010120074 -200010130063 -200010140048 -200010140049 -200010160039 -200010160048 -200010160049 -200010180059 -200010180063 -200010190066 -200010190068 -200011210063 -200011210064 -200011210066 -200012050066 -200012050067 -200012050068 -200012050069 -200012050070 -200012050071 -200012080134 -200012080137 -200012110069 -200012110070 -200012110071 -200012110075 -200012120028 -200012120068 -200012120072 -200012130056 -200012130100 -200012130102 -200012130103 -200012140095 -200012140096 -200012140097 -200012140098 -200012140099 -200012140100 -200012150076 -200012160048 -200012160049 -200012180083 -200012180084 -200012180088 -200012230028 -200012230045 -200012230046 -200012230047 -200012230048 -200012230050 -200012270055 -200012270056 -200101020059 -200101020062 -200101020063 -200101020075 -200101130048 -200101130050 -200101130051 -200101130055 -200101150043 -200101150045 -200101180050 -200101180051 -200101180052 -200101200048 -200101220047 -200101220053 -200102070011 -200102070016 -200102120034 -200102120057 -200102130014 -200102150001 -200102150014 -200102160011 -200102190016 -200102220001 -200102220013 -200102270041 -200102270062 -200102280169 -200103010049 -200103060022 -200103060032 -200103060057 -200103080026 -200103080030 -200103080036 -200103100019 -200103100021 -200103100058 -200103100062 -200103130008 -200103130023 -200103130069 -200103200066 -200103200080 -200103270069 -200103310092 -200104020007 -200104050011 -200104100021 -200104100023 -200104170015 -200104170040 -200104170055 -200104170057 -200104190039 -200104190066 -200104230031 -200104230069 -200104260051 -200104260053 -200104300213 -200104300215 -200104300217 -200105020092 -200105050042 -200105050043 -200105050046 -200105050048 -200105070017 -200105140050 -200105140052 -200105220096 -200105290074 -200105290075 -200106120068 -200106120069 -200106180051 -200106180053 -200106200064 -200106220086 -200106220087 -200106220088 -200106220090 -200106250050 -200107120071 -200107120073 -200107210129 -200107240070 -200107250080 -200108060051 -200108060155 -200108060156 -200108060157 -200108070038 -200108160040 -200108180123 -200108200033 -200108210066 -200108210074 -200108270077 -200108280064 -200109060061 -200109130091 -200109250092 -200109260097 -200109270116 -200110020075 -200110150056 -200110150062 -200110200070 -200110200071 -200110220068 -200111080086 -200111140055 -200111210078 -200111240060 -200112040031 -200112040077 -200112050063 -200112100041 -200112190067 -200201280011 -200201290029 -200202280078 -200203280057 -200203290107 diff --git a/train.sh b/train.sh new file mode 100755 index 0000000..46b5624 --- /dev/null +++ b/train.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +mvn clean install -Dmaven.test.skip=true +mvn -pl nicolas-train exec:java -Dexec.mainClass="pl.waw.ipipan.zil.summ.nicolas.train.Main" +mvn install -Dmaven.test.skip=true -- libgit2 0.22.2