From df93dbe78179d7a72d407af34116775253b24a6c Mon Sep 17 00:00:00 2001 From: Tobias Horsmann Date: Fri, 2 Feb 2018 21:38:34 +0100 Subject: [PATCH] #433 --- .../regression/LibsvmRegressionDemoTest.java | 4 +- ...iblinearModelSerializationDescription.java | 10 +- .../LoadModelConnectorLiblinear.java | 15 +- .../liblinear/writer/LiblinearDataWriter.java | 19 ++- .../LibsvmModelSerializationDescription.java | 149 +++++++++--------- .../LoadModelConnectorLibsvm.java | 10 +- .../tc/ml/libsvm/writer/LibsvmDataWriter.java | 37 +++-- 7 files changed, 149 insertions(+), 95 deletions(-) diff --git a/dkpro-tc-examples/src/test/java/org/dkpro/tc/examples/regression/LibsvmRegressionDemoTest.java b/dkpro-tc-examples/src/test/java/org/dkpro/tc/examples/regression/LibsvmRegressionDemoTest.java index 71f59b345..2f6a365e3 100644 --- a/dkpro-tc-examples/src/test/java/org/dkpro/tc/examples/regression/LibsvmRegressionDemoTest.java +++ b/dkpro-tc-examples/src/test/java/org/dkpro/tc/examples/regression/LibsvmRegressionDemoTest.java @@ -18,7 +18,7 @@ */ package org.dkpro.tc.examples.regression; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; import org.dkpro.lab.task.ParameterSpace; import org.dkpro.tc.examples.TestCaseSuperClass; @@ -58,6 +58,6 @@ public void testTrainTest() throws Exception{ EvaluationData data = Tc2LtlabEvalConverter.convertRegressionModeId2Outcome(ContextMemoryReport.id2outcome); MeanSquaredError mse = new MeanSquaredError(data); - assertTrue(mse.getResult() > 1.0); + assertEquals(3.37, mse.getResult(), 0.01); } } diff --git a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LiblinearModelSerializationDescription.java b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LiblinearModelSerializationDescription.java index 55930709d..ea535787e 100644 --- a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LiblinearModelSerializationDescription.java +++ b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LiblinearModelSerializationDescription.java @@ -41,7 +41,7 @@ public class LiblinearModelSerializationDescription extends ModelSerializationTa @Discriminator(name = DIM_CLASSIFICATION_ARGS) private List classificationArguments; - + boolean trainModel = true; @Override @@ -71,12 +71,20 @@ private void trainAndStoreModel(TaskContext aContext) throws Exception { } private void copyOutcomeMappingToThisFolder(TaskContext aContext) throws IOException { + if(isRegression()){ + return; + } + File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); String mapping = LiblinearAdapter.getOutcomeMappingFilename(); FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping)); } + private boolean isRegression() { + return learningMode.equals(Constants.LM_REGRESSION); + } + private void copyFeatureNameMappingToThisFolder(TaskContext aContext) throws IOException { File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); String mapping = LiblinearAdapter.getFeatureNameMappingFilename(); diff --git a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LoadModelConnectorLiblinear.java b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LoadModelConnectorLiblinear.java index 7200919e6..a388b2c8d 100644 --- a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LoadModelConnectorLiblinear.java +++ b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LoadModelConnectorLiblinear.java @@ -66,6 +66,7 @@ public class LoadModelConnectorLiblinear extends ModelSerialization_ImplBase { private String learningMode; private Model liblinearModel; + private Map outcomeMapping; private Map featureMapping; @@ -98,6 +99,11 @@ private Map loadFeature2IntegerMapping(File tcModelLocation) th } private Map loadOutcome2IntegerMapping(File tcModelLocation) throws IOException { + + if (isRegression()){ + return new HashMap<>(); + } + Map map = new HashMap<>(); List readLines = FileUtils .readLines(new File(tcModelLocation, LiblinearAdapter.getOutcomeMappingFilename()), "utf-8"); @@ -108,7 +114,12 @@ private Map loadOutcome2IntegerMapping(File tcModelLocation) th return map; } - private Double toValue(Object value) + private boolean isRegression() { + return learningMode.equals(Constants.LM_REGRESSION); + } + + + private Double toValue(Object value) { double v; if (value instanceof Number) { @@ -173,7 +184,7 @@ public void process(JCas jcas) throws AnalysisEngineProcessException { Feature[] instance = testInstances[i]; Double prediction = Linear.predict(liblinearModel, instance); - if (learningMode.equals(Constants.LM_REGRESSION)) { + if (isRegression()) { outcomes.get(i).setOutcome(prediction.toString()); } else { String predictedLabel = outcomeMapping.get(prediction.intValue()); diff --git a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/writer/LiblinearDataWriter.java b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/writer/LiblinearDataWriter.java index 3cbfc2d44..a97afbcdf 100644 --- a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/writer/LiblinearDataWriter.java +++ b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/writer/LiblinearDataWriter.java @@ -140,7 +140,12 @@ public void writeClassifierFormat(Collection in) throws Exception { List keys = new ArrayList(entry.keySet()); Collections.sort(keys); - bw.append(outcomeMap.get(inst.getOutcome()) + "\t"); + + if (isRegression()) { + bw.append(inst.getOutcome() + "\t"); + } else { + bw.append(outcomeMap.get(inst.getOutcome()) + "\t"); + } for (int i = 0; i < keys.size(); i++) { Integer key = keys.get(i); Double value = entry.get(key); @@ -161,6 +166,11 @@ public void writeClassifierFormat(Collection in) throws Exception { } private void writeOutcomeMapping(File outputDirectory, String file, Map map) throws IOException { + + if(isRegression()){ + return; + } + StringBuilder sb = new StringBuilder(); for (String k : map.keySet()) { sb.append(k + "\t" + map.get(k) + "\n"); @@ -248,6 +258,9 @@ public void init(File outputDirectory, boolean useSparse, String learningMode, b * @param outcomes */ private void buildOutcomeMap(String[] outcomes) { + if(isRegression()){ + return; + } outcomeMap = new HashMap<>(); Integer i = 0; List outcomesSorted = new ArrayList<>(Arrays.asList(outcomes)); @@ -288,6 +301,10 @@ private void recordInstanceId(Instance instance, int i, Map inde return; } } + + private boolean isRegression(){ + return learningMode.equals(Constants.LM_REGRESSION); + } @Override public void close() throws Exception { diff --git a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LibsvmModelSerializationDescription.java b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LibsvmModelSerializationDescription.java index 3f7f9ab35..47a1f5deb 100644 --- a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LibsvmModelSerializationDescription.java +++ b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LibsvmModelSerializationDescription.java @@ -33,83 +33,76 @@ import org.dkpro.tc.ml.libsvm.LibsvmAdapter; import org.dkpro.tc.ml.libsvm.api.LibsvmTrainModel; -public class LibsvmModelSerializationDescription - extends ModelSerializationTask - implements Constants -{ - - @Discriminator(name = DIM_CLASSIFICATION_ARGS) - private List classificationArguments; - - boolean trainModel = true; - - @Override - public void execute(TaskContext aContext) - throws Exception - { - trainAndStoreModel(aContext); - - writeModelConfiguration(aContext, LibsvmAdapter.class.getName()); - } - - private void trainAndStoreModel(TaskContext aContext) - throws Exception - { - boolean multiLabel = learningMode.equals(Constants.LM_MULTI_LABEL); - if (multiLabel) { - throw new TextClassificationException("Multi-label is not yet implemented"); - } - - File fileTrain = getTrainFile(aContext); - - File model = new File(outputFolder, Constants.MODEL_CLASSIFIER); - - LibsvmTrainModel ltm = new LibsvmTrainModel(); - ltm.run(buildParameters(fileTrain, model)); - copyOutcomeMappingToThisFolder(aContext); - copyFeatureNameMappingToThisFolder(aContext); - } - - private void copyOutcomeMappingToThisFolder(TaskContext aContext) - throws IOException - { - File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, - AccessMode.READONLY); - String mapping = LibsvmAdapter.getOutcomeMappingFilename(); - - FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping)); - } - - private void copyFeatureNameMappingToThisFolder(TaskContext aContext) - throws IOException - { - File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, - AccessMode.READONLY); - String mapping = LibsvmAdapter.getFeatureNameMappingFilename(); - - FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping)); - } - - private File getTrainFile(TaskContext aContext) - { - File trainFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, - AccessMode.READONLY); - File fileTrain = new File(trainFolder, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT); - - return fileTrain; - } - - private String[] buildParameters(File fileTrain, File model) - { - List parameters = new ArrayList<>(); - if (classificationArguments != null) { - for (String a : classificationArguments) { - parameters.add(a); - } - } - parameters.add(fileTrain.getAbsolutePath()); - parameters.add(model.getAbsolutePath()); - return parameters.toArray(new String[0]); - } +public class LibsvmModelSerializationDescription extends ModelSerializationTask implements Constants { + + @Discriminator(name = DIM_CLASSIFICATION_ARGS) + private List classificationArguments; + + boolean trainModel = true; + + @Override + public void execute(TaskContext aContext) throws Exception { + trainAndStoreModel(aContext); + + writeModelConfiguration(aContext, LibsvmAdapter.class.getName()); + } + + private void trainAndStoreModel(TaskContext aContext) throws Exception { + boolean multiLabel = learningMode.equals(Constants.LM_MULTI_LABEL); + if (multiLabel) { + throw new TextClassificationException("Multi-label is not yet implemented"); + } + + File fileTrain = getTrainFile(aContext); + + File model = new File(outputFolder, Constants.MODEL_CLASSIFIER); + + LibsvmTrainModel ltm = new LibsvmTrainModel(); + ltm.run(buildParameters(fileTrain, model)); + copyOutcomeMappingToThisFolder(aContext); + copyFeatureNameMappingToThisFolder(aContext); + } + + private void copyOutcomeMappingToThisFolder(TaskContext aContext) throws IOException { + + if(isRegression()){ + return; + } + + File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); + String mapping = LibsvmAdapter.getOutcomeMappingFilename(); + + FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping)); + } + + private boolean isRegression() { + return learningMode.equals(Constants.LM_REGRESSION); + } + + private void copyFeatureNameMappingToThisFolder(TaskContext aContext) throws IOException { + File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); + String mapping = LibsvmAdapter.getFeatureNameMappingFilename(); + + FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping)); + } + + private File getTrainFile(TaskContext aContext) { + File trainFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); + File fileTrain = new File(trainFolder, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT); + + return fileTrain; + } + + private String[] buildParameters(File fileTrain, File model) { + List parameters = new ArrayList<>(); + if (classificationArguments != null) { + for (String a : classificationArguments) { + parameters.add(a); + } + } + parameters.add(fileTrain.getAbsolutePath()); + parameters.add(model.getAbsolutePath()); + return parameters.toArray(new String[0]); + } } \ No newline at end of file diff --git a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LoadModelConnectorLibsvm.java b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LoadModelConnectorLibsvm.java index 27eb8d69e..a5c7ab176 100644 --- a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LoadModelConnectorLibsvm.java +++ b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LoadModelConnectorLibsvm.java @@ -114,6 +114,10 @@ private Map loadFeature2IntegerMapping(File tcModelLocation) th private Map loadInteger2OutcomeMapping(File tcModelLocation) throws IOException { + if(isRegression()){ + return new HashMap<>(); + } + Map map = new HashMap<>(); List readLines = FileUtils .readLines(new File(tcModelLocation, LibsvmAdapter.getOutcomeMappingFilename()), "utf-8"); @@ -123,6 +127,10 @@ private Map loadInteger2OutcomeMapping(File tcModelLocation) } return map; } + + private boolean isRegression(){ + return learningMode.equals(Constants.LM_REGRESSION); + } @Override public void process(JCas jcas) @@ -141,7 +149,7 @@ public void process(JCas jcas) for (int i = 0; i < outcomes.size(); i++) { - if (learningMode.equals(Constants.LM_REGRESSION)) { + if (isRegression()) { String val = writtenPredictions.get(i); outcomes.get(i).setOutcome(val); } diff --git a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/writer/LibsvmDataWriter.java b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/writer/LibsvmDataWriter.java index e9b9f1f77..fc19c4c8b 100644 --- a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/writer/LibsvmDataWriter.java +++ b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/writer/LibsvmDataWriter.java @@ -50,25 +50,25 @@ * For example: 1 1:1 3:1 4:1 6:1 2 2:1 3:1 5:1 7:1 1 3:1 5:1 */ public class LibsvmDataWriter implements DataWriter { - + public static final String INDEX2INSTANCEID = "index2Instanceid.txt"; - + File outputDirectory; - + boolean useSparse; - + String learningMode; - + boolean applyWeighting; - + File classifierFormatOutputFile; - + BufferedWriter bw = null; - + Map index2instanceId; Gson gson = new Gson(); - + private int maxId = 0; Map featureNames2id; @@ -140,7 +140,12 @@ public void writeClassifierFormat(Collection in) throws Exception { List keys = new ArrayList(entry.keySet()); Collections.sort(keys); - bw.append(outcomeMap.get(inst.getOutcome()) + "\t"); + + if (isRegression()) { + bw.append(inst.getOutcome() + "\t"); + } else { + bw.append(outcomeMap.get(inst.getOutcome()) + "\t"); + } for (int i = 0; i < keys.size(); i++) { Integer key = keys.get(i); Double value = entry.get(key); @@ -161,6 +166,11 @@ public void writeClassifierFormat(Collection in) throws Exception { } private void writeOutcomeMapping(File outputDirectory, String file, Map map) throws IOException { + + if(isRegression()){ + return; + } + StringBuilder sb = new StringBuilder(); for (String k : map.keySet()) { sb.append(k + "\t" + map.get(k) + "\n"); @@ -248,6 +258,9 @@ public void init(File outputDirectory, boolean useSparse, String learningMode, b * @param outcomes */ private void buildOutcomeMap(String[] outcomes) { + if(isRegression()){ + return; + } outcomeMap = new HashMap<>(); Integer i = 0; List outcomesSorted = new ArrayList<>(Arrays.asList(outcomes)); @@ -288,6 +301,10 @@ private void recordInstanceId(Instance instance, int i, Map inde return; } } + + private boolean isRegression(){ + return learningMode.equals(Constants.LM_REGRESSION); + } @Override public void close() throws Exception {