From d366ee32d2b1ebaa571f01cdccffacf45fa44a7b Mon Sep 17 00:00:00 2001 From: Tobias Horsmann Date: Sat, 3 Feb 2018 13:52:18 +0100 Subject: [PATCH] #434 - drafted libsvm/liblinear with super module --- .../tc/core/task/ModelSerializationTask.java | 6 +- dkpro-tc-examples/pom.xml | 4 + ...veAndLoadModelDocumentSingleLabelTest.java | 7 +- dkpro-tc-io-libsvm/pom.xml | 19 ++ .../org/dkpro/tc/io/libsvm/AdapterFormat.java | 34 ++ .../LibsvmDataFormatOutcomeIdReport.java | 189 +++++++++++ .../tc/io/libsvm/LibsvmDataFormatWriter.java | 9 +- .../io/libsvm/LibsvmModelLoaderConnector.java | 223 +++++++++++++ .../io/libsvm/LibsvmModelSerialization.java | 96 ++++++ ...CRFSuiteModelSerializationDescription.java | 124 ++++--- dkpro-tc-ml-liblinear/pom.xml | 38 ++- .../tc/ml/liblinear/LiblinearAdapter.java | 16 +- .../report/LiblinearOutcomeIdReport.java | 196 ----------- ...iblinearModelSerializationDescription.java | 64 +--- .../LoadModelConnectorLiblinear.java | 170 ++-------- .../liblinear/writer/LiblinearDataWriter.java | 314 ------------------ .../ml/liblinear/LiblinearDataWriterTest.java | 4 +- dkpro-tc-ml-libsvm/pom.xml | 36 +- .../org/dkpro/tc/ml/libsvm/LibsvmAdapter.java | 8 +- .../org/dkpro/tc/ml/libsvm/LibsvmUtils.java | 84 ----- .../libsvm/report/LibsvmOutcomeIdReport.java | 195 ----------- .../LibsvmModelSerializationDescription.java | 71 +--- .../LoadModelConnectorLibsvm.java | 223 ++----------- .../SvmhmmModelSerializationDescription.java | 8 +- .../WekaModelSerializationDescription.java | 8 +- pom.xml | 7 +- 26 files changed, 769 insertions(+), 1384 deletions(-) create mode 100644 dkpro-tc-io-libsvm/pom.xml create mode 100644 dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/AdapterFormat.java create mode 100644 dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmDataFormatOutcomeIdReport.java rename dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/writer/LibsvmDataWriter.java => dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmDataFormatWriter.java (97%) create mode 100644 dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmModelLoaderConnector.java create mode 100644 dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmModelSerialization.java delete mode 100644 dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/report/LiblinearOutcomeIdReport.java delete mode 100644 dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/writer/LiblinearDataWriter.java delete mode 100644 dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/LibsvmUtils.java delete mode 100644 dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/report/LibsvmOutcomeIdReport.java diff --git a/dkpro-tc-core/src/main/java/org/dkpro/tc/core/task/ModelSerializationTask.java b/dkpro-tc-core/src/main/java/org/dkpro/tc/core/task/ModelSerializationTask.java index ab037f9a7..d6136a694 100644 --- a/dkpro-tc-core/src/main/java/org/dkpro/tc/core/task/ModelSerializationTask.java +++ b/dkpro-tc-core/src/main/java/org/dkpro/tc/core/task/ModelSerializationTask.java @@ -48,14 +48,16 @@ public void setOutputFolder(File outputFolder) } - public void writeModelConfiguration(TaskContext aContext, String mlAdapter) throws Exception{ + protected void writeModelConfiguration(TaskContext aContext) throws Exception{ SaveModelUtils.writeModelParameters(aContext, outputFolder, featureSet); SaveModelUtils.writeFeatureMode(outputFolder, featureMode); SaveModelUtils.writeLearningMode(outputFolder, learningMode); - SaveModelUtils.writeModelAdapterInformation(outputFolder, mlAdapter); SaveModelUtils.writeCurrentVersionOfDKProTC(outputFolder); + writeAdapter(); } + + protected abstract void writeAdapter() throws Exception; } \ No newline at end of file diff --git a/dkpro-tc-examples/pom.xml b/dkpro-tc-examples/pom.xml index afccb66cf..dcd0e069a 100644 --- a/dkpro-tc-examples/pom.xml +++ b/dkpro-tc-examples/pom.xml @@ -63,6 +63,10 @@ org.dkpro.tc dkpro-tc-features-pair-similarity + + org.dkpro.tc + dkpro-tc-io-libsvm + org.dkpro.tc dkpro-tc-ml diff --git a/dkpro-tc-examples/src/test/java/org/dkpro/tc/examples/model/liblinear/LiblinearSaveAndLoadModelDocumentSingleLabelTest.java b/dkpro-tc-examples/src/test/java/org/dkpro/tc/examples/model/liblinear/LiblinearSaveAndLoadModelDocumentSingleLabelTest.java index 39adf8a03..cf3e1f067 100644 --- a/dkpro-tc-examples/src/test/java/org/dkpro/tc/examples/model/liblinear/LiblinearSaveAndLoadModelDocumentSingleLabelTest.java +++ b/dkpro-tc-examples/src/test/java/org/dkpro/tc/examples/model/liblinear/LiblinearSaveAndLoadModelDocumentSingleLabelTest.java @@ -56,6 +56,7 @@ import org.dkpro.tc.features.length.NrOfTokens; import org.dkpro.tc.features.ngram.LuceneCharacterNGram; import org.dkpro.tc.features.ngram.LuceneNGram; +import org.dkpro.tc.io.libsvm.AdapterFormat; import org.dkpro.tc.ml.ExperimentSaveModel; import org.dkpro.tc.ml.liblinear.LiblinearAdapter; import org.dkpro.tc.ml.uima.TcAnnotator; @@ -165,7 +166,7 @@ private void documentVerifyCreatedModelFiles(File modelFolder) assertTrue(learningMode.exists()); File id2outcomeMapping = new File( - modelFolder.getAbsolutePath() + "/" + LiblinearAdapter.getOutcomeMappingFilename()); + modelFolder.getAbsolutePath() + "/" + AdapterFormat.getOutcomeMappingFilename()); assertTrue(id2outcomeMapping.exists()); } @@ -294,6 +295,7 @@ private static void unitLoadAndUseModel(File modelFolder) possibleOutcomes.add("JJ"); possibleOutcomes.add("VBD"); possibleOutcomes.add("NNS"); + possibleOutcomes.add("TO"); possibleOutcomes.add("VBN"); possibleOutcomes.add("IN"); possibleOutcomes.add("CC"); @@ -304,7 +306,6 @@ private static void unitLoadAndUseModel(File modelFolder) assertEquals(31, outcomes.size()); for(TextClassificationOutcome o : outcomes){ - System.out.println(o.getOutcome()); assertTrue(possibleOutcomes.contains(o.getOutcome())); } @@ -343,7 +344,7 @@ private void unitVerifyCreatedModelFiles(File modelFolder) assertTrue(learningMode.exists()); File id2outcomeMapping = new File( - modelFolder.getAbsolutePath() + "/" + LiblinearAdapter.getOutcomeMappingFilename()); + modelFolder.getAbsolutePath() + "/" + AdapterFormat.getOutcomeMappingFilename()); assertTrue(id2outcomeMapping.exists()); } } \ No newline at end of file diff --git a/dkpro-tc-io-libsvm/pom.xml b/dkpro-tc-io-libsvm/pom.xml new file mode 100644 index 000000000..76d072869 --- /dev/null +++ b/dkpro-tc-io-libsvm/pom.xml @@ -0,0 +1,19 @@ + + 4.0.0 + dkpro-tc-io-libsvm + + org.dkpro.tc + dkpro-tc + 1.0.0-SNAPSHOT + + + + org.dkpro.tc + dkpro-tc-core + + + org.dkpro.tc + dkpro-tc-ml + + + \ No newline at end of file diff --git a/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/AdapterFormat.java b/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/AdapterFormat.java new file mode 100644 index 000000000..675e58152 --- /dev/null +++ b/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/AdapterFormat.java @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright 2018 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.dkpro.tc.io.libsvm; + +public class AdapterFormat { + + public static String getOutcomeMappingFilename() { + return "outcome-mapping.txt"; + } + + public static String getFeatureNameMappingFilename() { + return "feature-name-mapping.txt"; + } + + public static String getFeatureNames() { + return "featurenames.txt"; + } + +} diff --git a/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmDataFormatOutcomeIdReport.java b/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmDataFormatOutcomeIdReport.java new file mode 100644 index 000000000..2eff5e125 --- /dev/null +++ b/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmDataFormatOutcomeIdReport.java @@ -0,0 +1,189 @@ +/******************************************************************************* + * Copyright 2018 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ +package org.dkpro.tc.io.libsvm; + +import java.io.File; +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import org.apache.commons.io.FileUtils; +import org.apache.commons.io.output.FileWriterWithEncoding; +import org.dkpro.lab.reporting.ReportBase; +import org.dkpro.lab.storage.StorageService.AccessMode; +import org.dkpro.tc.core.Constants; +import org.dkpro.tc.core.task.InitTask; +import org.dkpro.tc.ml.report.util.SortedKeyProperties; + +/** + * Creates id 2 outcome report + */ +public class LibsvmDataFormatOutcomeIdReport extends ReportBase implements Constants { + // constant dummy value for setting as threshold which is an expected field + // in the evaluation + // module but is not needed/provided by liblinear + private static final String THRESHOLD_CONSTANT = "-1"; + + public LibsvmDataFormatOutcomeIdReport() { + // required by groovy + } + + @Override + public void execute() throws Exception { + boolean isRegression = isRegression(); + + boolean isUnit = getDiscriminators().get(InitTask.class.getName() + "|" + Constants.DIM_FEATURE_MODE) + .equals(Constants.FM_UNIT); + + Map id2label = getId2LabelMapping(isRegression); + String header = buildHeader(id2label, isRegression); + + List predictions = readPredictions(); + Map index2instanceIdMap = getMapping(isUnit); + + Properties prop = new SortedKeyProperties(); + int lineCounter = 0; + for (String line : predictions) { + if (line.startsWith("#")) { + continue; + } + String[] split = line.split(";"); + String key = index2instanceIdMap.get(lineCounter + ""); + + if (isRegression) { + prop.setProperty(key, split[0] + ";" + split[1] + ";" + THRESHOLD_CONSTANT); + } else { + int pred = Double.valueOf(split[0]).intValue(); + int gold = Double.valueOf(split[1]).intValue(); + prop.setProperty(key, pred + ";" + gold + ";" + THRESHOLD_CONSTANT); + } + lineCounter++; + } + + File targetFile = getId2OutcomeFileLocation(); + + FileWriterWithEncoding fw = new FileWriterWithEncoding(targetFile, "utf-8"); + prop.store(fw, header); + fw.close(); + + } + + private boolean isRegression() { + + Collection keys = getDiscriminators().keySet(); + for (String k : keys) { + if (k.endsWith("|" + Constants.DIM_LEARNING_MODE)) { + return getDiscriminators().get(k).equals(Constants.LM_REGRESSION); + } + } + return false; + } + + private Map getMapping(boolean isUnit) throws IOException { + + File f; + if (isUnit) { + f = new File(getContext().getFolder(TEST_TASK_INPUT_KEY_TEST_DATA, AccessMode.READONLY), + LibsvmDataFormatWriter.INDEX2INSTANCEID); + } else { + f = new File(getContext().getFolder(TEST_TASK_INPUT_KEY_TEST_DATA, AccessMode.READONLY), + Constants.FILENAME_DOCUMENT_META_DATA_LOG); + } + + Map m = new HashMap<>(); + + int idx = 0; + for (String l : FileUtils.readLines(f, "utf-8")) { + if (l.startsWith("#")) { + continue; + } + if (l.trim().isEmpty()) { + continue; + } + String[] split = l.split("\t"); + + // if (isUnit) { + m.put(idx + "", split[0]); + idx++; + // } else { + // m.put(split[0], split[1]); + // } + + } + return m; + } + + private File getId2OutcomeFileLocation() { + File evaluationFolder = getContext().getFolder("", AccessMode.READWRITE); + return new File(evaluationFolder, ID_OUTCOME_KEY); + } + + private List readPredictions() throws IOException { + File predFolder = getContext().getFolder("", AccessMode.READWRITE); + return FileUtils.readLines(new File(predFolder, Constants.FILENAME_PREDICTIONS), "utf-8"); + } + + private String buildHeader(Map id2label, boolean isRegression) + throws UnsupportedEncodingException { + StringBuilder header = new StringBuilder(); + header.append("ID=PREDICTION;GOLDSTANDARD;THRESHOLD" + "\n" + "labels" + " "); + + if (isRegression) { + // no label mapping for regression so that is all we have to do + return header.toString(); + } + + int numKeys = id2label.keySet().size(); + List keys = new ArrayList(id2label.keySet()); + for (int i = 0; i < numKeys; i++) { + Integer key = keys.get(i); + header.append(key + "=" + URLEncoder.encode(id2label.get(key), "UTF-8")); + if (i + 1 < numKeys) { + header.append(" "); + } + } + return header.toString(); + } + + private Map getId2LabelMapping(boolean isRegression) throws Exception { + if (isRegression) { + // no map for regression; + return new HashMap<>(); + } + + File folder = getContext().getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); + String fileName = AdapterFormat.getOutcomeMappingFilename(); + File file = new File(folder, fileName); + Map map = new HashMap(); + + List lines = FileUtils.readLines(file, "utf-8"); + for (String line : lines) { + String[] split = line.split("\t"); + map.put(Integer.valueOf(split[1]), split[0]); + } + + return map; + } + +} \ No newline at end of file diff --git a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/writer/LibsvmDataWriter.java b/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmDataFormatWriter.java similarity index 97% rename from dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/writer/LibsvmDataWriter.java rename to dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmDataFormatWriter.java index fc19c4c8b..cf3c4fde3 100644 --- a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/writer/LibsvmDataWriter.java +++ b/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmDataFormatWriter.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. ******************************************************************************/ -package org.dkpro.tc.ml.libsvm.writer; +package org.dkpro.tc.io.libsvm; import java.io.BufferedReader; import java.io.BufferedWriter; @@ -38,7 +38,6 @@ import org.dkpro.tc.api.features.Instance; import org.dkpro.tc.core.Constants; import org.dkpro.tc.core.io.DataWriter; -import org.dkpro.tc.ml.libsvm.LibsvmAdapter; import com.google.gson.Gson; @@ -49,7 +48,7 @@ * * For example: 1 1:1 3:1 4:1 6:1 2 2:1 3:1 5:1 7:1 1 3:1 5:1 */ -public class LibsvmDataWriter implements DataWriter { +public class LibsvmDataFormatWriter implements DataWriter { public static final String INDEX2INSTANCEID = "index2Instanceid.txt"; @@ -161,8 +160,8 @@ public void writeClassifierFormat(Collection in) throws Exception { bw = null; writeMapping(outputDirectory, INDEX2INSTANCEID, index2instanceId); - writeFeatureName2idMapping(outputDirectory, LibsvmAdapter.getFeatureNameMappingFilename(), featureNames2id); - writeOutcomeMapping(outputDirectory, LibsvmAdapter.getOutcomeMappingFilename(), outcomeMap); + writeFeatureName2idMapping(outputDirectory, AdapterFormat.getFeatureNameMappingFilename(), featureNames2id); + writeOutcomeMapping(outputDirectory, AdapterFormat.getOutcomeMappingFilename(), outcomeMap); } private void writeOutcomeMapping(File outputDirectory, String file, Map map) throws IOException { diff --git a/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmModelLoaderConnector.java b/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmModelLoaderConnector.java new file mode 100644 index 000000000..8792aa147 --- /dev/null +++ b/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmModelLoaderConnector.java @@ -0,0 +1,223 @@ +/******************************************************************************* + * Copyright 2018 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +package org.dkpro.tc.io.libsvm; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.io.FileUtils; +import org.apache.uima.UimaContext; +import org.apache.uima.analysis_engine.AnalysisEngineProcessException; +import org.apache.uima.fit.descriptor.ConfigurationParameter; +import org.apache.uima.fit.descriptor.ExternalResource; +import org.apache.uima.fit.util.JCasUtil; +import org.apache.uima.jcas.JCas; +import org.apache.uima.pear.util.FileUtil; +import org.apache.uima.resource.ResourceInitializationException; +import org.dkpro.tc.api.features.Feature; +import org.dkpro.tc.api.features.FeatureExtractorResource_ImplBase; +import org.dkpro.tc.api.features.Instance; +import org.dkpro.tc.api.type.TextClassificationOutcome; +import org.dkpro.tc.core.Constants; +import org.dkpro.tc.core.ml.ModelSerialization_ImplBase; +import org.dkpro.tc.core.util.SaveModelUtils; +import org.dkpro.tc.core.util.TaskUtils; +import org.dkpro.tc.ml.uima.TcAnnotator; + +public abstract class LibsvmModelLoaderConnector + extends ModelSerialization_ImplBase +{ + + protected static final String OUTCOME_PLACEHOLDER = "-1"; + + @ConfigurationParameter(name = TcAnnotator.PARAM_TC_MODEL_LOCATION, mandatory = true) + protected File tcModelLocation; + + @ExternalResource(key = PARAM_FEATURE_EXTRACTORS, mandatory = true) + protected FeatureExtractorResource_ImplBase[] featureExtractors; + + @ConfigurationParameter(name = PARAM_FEATURE_MODE, mandatory = true) + protected String featureMode; + + @ConfigurationParameter(name = PARAM_LEARNING_MODE, mandatory = true) + protected String learningMode; + +// private svm_model model; + + protected Map integer2OutcomeMapping; + protected Map featureMapping; + + @Override + public void initialize(UimaContext context) + throws ResourceInitializationException + { + super.initialize(context); + + try { +// model = svm +// .svm_load_model(new File(tcModelLocation, MODEL_CLASSIFIER).getAbsolutePath()); + integer2OutcomeMapping = loadInteger2OutcomeMapping(tcModelLocation); + featureMapping = loadFeature2IntegerMapping(tcModelLocation); + SaveModelUtils.verifyTcVersion(tcModelLocation, getClass()); + } + catch (Exception e) { + throw new ResourceInitializationException(e); + } + + } + + private Map loadFeature2IntegerMapping(File tcModelLocation) throws IOException { + Map map = new HashMap<>(); + List readLines = FileUtils + .readLines(new File(tcModelLocation, AdapterFormat.getFeatureNameMappingFilename()), "utf-8"); + for (String l : readLines) { + String[] split = l.split("\t"); + map.put(split[0],Integer.valueOf(split[1])); + } + return map; + } + + private Map loadInteger2OutcomeMapping(File tcModelLocation) + throws IOException + { + if(isRegression()){ + return new HashMap<>(); + } + + Map map = new HashMap<>(); + List readLines = FileUtils + .readLines(new File(tcModelLocation, AdapterFormat.getOutcomeMappingFilename()), "utf-8"); + for (String l : readLines) { + String[] split = l.split("\t"); + map.put(split[1], split[0]); + } + return map; + } + + private boolean isRegression(){ + return learningMode.equals(Constants.LM_REGRESSION); + } + + @Override + public void process(JCas jcas) + throws AnalysisEngineProcessException + { + try { + File tempFile = createInputFile(jcas); + + File prediction = runPrediction(tempFile); + + List outcomes = getOutcomeAnnotations(jcas); + List writtenPredictions = FileUtils.readLines(prediction, "utf-8"); + + checkErrorConditionNumberOfOutcomesEqualsNumberOfPredictions(outcomes, + writtenPredictions); + + for (int i = 0; i < outcomes.size(); i++) { + + if (isRegression()) { + String val = writtenPredictions.get(i); + outcomes.get(i).setOutcome(val); + } + else { + String val = writtenPredictions.get(i).replaceAll("\\.0", ""); + String pred = integer2OutcomeMapping.get(val); + outcomes.get(i).setOutcome(pred); + } + + } + + } + catch (Exception e) { + throw new AnalysisEngineProcessException(e); + } + + } + + private List getOutcomeAnnotations(JCas jcas) + { + return new ArrayList<>(JCasUtil.select(jcas, TextClassificationOutcome.class)); + } + + private void checkErrorConditionNumberOfOutcomesEqualsNumberOfPredictions( + List outcomes, List readLines) + { + if (outcomes.size() != readLines.size()) { + throw new IllegalStateException("Expected [" + outcomes.size() + + "] predictions but were [" + readLines.size() + "]"); + } + } + + protected abstract File runPrediction(File tempFile) + throws Exception; + + private File createInputFile(JCas jcas) + throws Exception + { + File tempFile = FileUtil.createTempFile("libsvm", ".txmt"); + BufferedWriter bw = new BufferedWriter( + new OutputStreamWriter(new FileOutputStream(tempFile), "utf-8")); + + List inst = TaskUtils.getMultipleInstancesUnitMode(featureExtractors, jcas, true, + true); + + for (Instance i : inst) { + bw.write(OUTCOME_PLACEHOLDER); + for (Feature f : i.getFeatures()) { + if (!sanityCheckValue(f)) { + continue; + } + bw.write("\t"); + bw.write(featureMapping.get(f.getName()) + ":" + f.getValue()); + } + bw.write("\n"); + } + bw.close(); + + return tempFile; + } + + private boolean sanityCheckValue(Feature f) + { + if (f.getValue() instanceof Number) { + return true; + } + if (f.getName().equals(Constants.ID_FEATURE_NAME)) { + return false; + } + + try { + Double.valueOf((String) f.getValue()); + } + catch (Exception e) { + throw new IllegalArgumentException( + "Feature [" + f.getName() + "] has a non-numeric value [" + f.getValue() + "]", + e); + } + return false; + } + +} \ No newline at end of file diff --git a/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmModelSerialization.java b/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmModelSerialization.java new file mode 100644 index 000000000..aec811aeb --- /dev/null +++ b/dkpro-tc-io-libsvm/src/main/java/org/dkpro/tc/io/libsvm/LibsvmModelSerialization.java @@ -0,0 +1,96 @@ +/******************************************************************************* + * Copyright 2018 + * Ubiquitous Knowledge Processing (UKP) Lab + * Technische Universität Darmstadt + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ******************************************************************************/ + +package org.dkpro.tc.io.libsvm; + +import java.io.File; +import java.io.IOException; +import java.util.List; + +import org.apache.commons.io.FileUtils; +import org.dkpro.lab.engine.TaskContext; +import org.dkpro.lab.storage.StorageService.AccessMode; +import org.dkpro.lab.task.Discriminator; +import org.dkpro.tc.api.exception.TextClassificationException; +import org.dkpro.tc.core.Constants; +import org.dkpro.tc.core.task.ModelSerializationTask; + +public abstract class LibsvmModelSerialization extends ModelSerializationTask implements Constants { + + @Discriminator(name = DIM_CLASSIFICATION_ARGS) + protected List classificationArguments; + + boolean trainModel = true; + + @Override + public void execute(TaskContext aContext) throws Exception { + trainAndStoreModel(aContext); + writeModelConfiguration(aContext); + } + + private void trainAndStoreModel(TaskContext aContext) throws Exception { + boolean multiLabel = learningMode.equals(Constants.LM_MULTI_LABEL); + + if (multiLabel) { + throw new TextClassificationException("Multi-label is not yet implemented"); + } + + File fileTrain = getTrainFile(aContext); + + trainModel(fileTrain); + + copyOutcomeMappingToThisFolder(aContext); + copyFeatureNameMappingToThisFolder(aContext); + } + + protected abstract void trainModel(File fileTrain) throws Exception; + + private void copyOutcomeMappingToThisFolder(TaskContext aContext) throws IOException { + + if (isRegression()) { + return; + } + + File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); + String mapping = AdapterFormat.getOutcomeMappingFilename(); + + FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping)); + } + + private boolean isRegression() { + return learningMode.equals(Constants.LM_REGRESSION); + } + + private void copyFeatureNameMappingToThisFolder(TaskContext aContext) throws IOException { + File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); + String mapping = AdapterFormat.getFeatureNameMappingFilename(); + + FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping)); + } + + private File getTrainFile(TaskContext aContext) { + File trainFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); + File fileTrain = new File(trainFolder, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT); + + return fileTrain; + } + + @Override + protected abstract void writeAdapter() throws Exception; + +} \ No newline at end of file diff --git a/dkpro-tc-ml-crfsuite/src/main/java/org/dkpro/tc/ml/crfsuite/task/serialization/CRFSuiteModelSerializationDescription.java b/dkpro-tc-ml-crfsuite/src/main/java/org/dkpro/tc/ml/crfsuite/task/serialization/CRFSuiteModelSerializationDescription.java index 2c794f1c3..4b668e73c 100644 --- a/dkpro-tc-ml-crfsuite/src/main/java/org/dkpro/tc/ml/crfsuite/task/serialization/CRFSuiteModelSerializationDescription.java +++ b/dkpro-tc-ml-crfsuite/src/main/java/org/dkpro/tc/ml/crfsuite/task/serialization/CRFSuiteModelSerializationDescription.java @@ -29,74 +29,66 @@ import org.dkpro.lab.task.Discriminator; import org.dkpro.tc.core.Constants; import org.dkpro.tc.core.task.ModelSerializationTask; +import org.dkpro.tc.core.util.SaveModelUtils; import org.dkpro.tc.ml.crfsuite.CRFSuiteAdapter; import org.dkpro.tc.ml.crfsuite.task.CRFSuiteTestTask; import org.dkpro.tc.ml.crfsuite.task.CrfUtil; -public class CRFSuiteModelSerializationDescription - extends ModelSerializationTask - implements Constants -{ - - @Discriminator(name = DIM_CLASSIFICATION_ARGS) - private List classificationArguments; - - boolean trainModel = true; - - private String algoName; - - private List algoParameters; - - @Override - public void execute(TaskContext aContext) - throws Exception - { - - if (trainModel) { - processParameters(classificationArguments); - trainAndStoreModel(aContext); - } - else { - copyAlreadyTrainedModel(aContext); - } - - writeModelConfiguration(aContext, CRFSuiteAdapter.class.getName()); - } - - private void copyAlreadyTrainedModel(TaskContext aContext) - throws Exception - { - File file = aContext.getFile(MODEL_CLASSIFIER, AccessMode.READONLY); - - FileInputStream fis = new FileInputStream(file); - FileOutputStream fos = new FileOutputStream(new File(outputFolder, MODEL_CLASSIFIER)); - IOUtils.copy(fis, fos); - } - - private void trainAndStoreModel(TaskContext aContext) - throws Exception - { - File trainFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, - AccessMode.READONLY); - - String classifierPath = outputFolder.getAbsolutePath() + "/" + MODEL_CLASSIFIER; - String trainingDataPath = trainFolder.getPath() + "/" + Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT; - List commandTrainModel = CRFSuiteTestTask.getTrainCommand(classifierPath, - trainingDataPath, algoName, algoParameters); - - Process process = new ProcessBuilder().inheritIO().command(commandTrainModel).start(); - process.waitFor(); - } - - private void processParameters(List classificationArguments) - throws Exception - { - algoName = CrfUtil.getAlgorithm(classificationArguments); - algoParameters = CrfUtil.getAlgorithmConfigurationParameter(classificationArguments); - } - - public void trainModel(boolean b) - { - trainModel = b; - } +public class CRFSuiteModelSerializationDescription extends ModelSerializationTask implements Constants { + + @Discriminator(name = DIM_CLASSIFICATION_ARGS) + private List classificationArguments; + + boolean trainModel = true; + + private String algoName; + + private List algoParameters; + + @Override + public void execute(TaskContext aContext) throws Exception { + + if (trainModel) { + processParameters(classificationArguments); + trainAndStoreModel(aContext); + } else { + copyAlreadyTrainedModel(aContext); + } + + writeModelConfiguration(aContext); + } + + private void copyAlreadyTrainedModel(TaskContext aContext) throws Exception { + File file = aContext.getFile(MODEL_CLASSIFIER, AccessMode.READONLY); + + FileInputStream fis = new FileInputStream(file); + FileOutputStream fos = new FileOutputStream(new File(outputFolder, MODEL_CLASSIFIER)); + IOUtils.copy(fis, fos); + } + + private void trainAndStoreModel(TaskContext aContext) throws Exception { + File trainFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); + + String classifierPath = outputFolder.getAbsolutePath() + "/" + MODEL_CLASSIFIER; + String trainingDataPath = trainFolder.getPath() + "/" + Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT; + List commandTrainModel = CRFSuiteTestTask.getTrainCommand(classifierPath, trainingDataPath, algoName, + algoParameters); + + Process process = new ProcessBuilder().inheritIO().command(commandTrainModel).start(); + process.waitFor(); + } + + private void processParameters(List classificationArguments) throws Exception { + algoName = CrfUtil.getAlgorithm(classificationArguments); + algoParameters = CrfUtil.getAlgorithmConfigurationParameter(classificationArguments); + } + + public void trainModel(boolean b) { + trainModel = b; + } + + @Override + protected void writeAdapter() throws Exception { + SaveModelUtils.writeModelAdapterInformation(outputFolder, CRFSuiteAdapter.class.getName()); + } } \ No newline at end of file diff --git a/dkpro-tc-ml-liblinear/pom.xml b/dkpro-tc-ml-liblinear/pom.xml index 14d26b184..0ea0a32bc 100644 --- a/dkpro-tc-ml-liblinear/pom.xml +++ b/dkpro-tc-ml-liblinear/pom.xml @@ -24,42 +24,46 @@ dkpro-tc-ml-liblinear - - commons-logging - commons-logging-api - - de.bwaldvogel - liblinear + org.dkpro.tc + dkpro-tc-core + + + org.dkpro.tc + dkpro-tc-api-features - junit - junit - test + org.dkpro.tc + dkpro-tc-api org.dkpro.tc - dkpro-tc-api-features + dkpro-tc-ml org.dkpro.tc - dkpro-tc-api + dkpro-tc-io-libsvm commons-io commons-io + + commons-logging + commons-logging-api + - org.dkpro.tc - dkpro-tc-core + de.bwaldvogel + liblinear - org.dkpro.lab - dkpro-lab-core + junit + junit + test - org.dkpro.tc - dkpro-tc-ml + org.dkpro.lab + dkpro-lab-core org.apache.uima diff --git a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/LiblinearAdapter.java b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/LiblinearAdapter.java index c5feba1e1..049704847 100644 --- a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/LiblinearAdapter.java +++ b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/LiblinearAdapter.java @@ -28,10 +28,10 @@ import org.dkpro.tc.core.ml.ModelSerialization_ImplBase; import org.dkpro.tc.core.ml.TcShallowLearningAdapter; import org.dkpro.tc.core.task.ModelSerializationTask; -import org.dkpro.tc.ml.liblinear.report.LiblinearOutcomeIdReport; +import org.dkpro.tc.io.libsvm.LibsvmDataFormatOutcomeIdReport; +import org.dkpro.tc.io.libsvm.LibsvmDataFormatWriter; import org.dkpro.tc.ml.liblinear.serialization.LiblinearModelSerializationDescription; import org.dkpro.tc.ml.liblinear.serialization.LoadModelConnectorLiblinear; -import org.dkpro.tc.ml.liblinear.writer.LiblinearDataWriter; import org.dkpro.tc.ml.report.InnerBatchReport; /** @@ -73,14 +73,6 @@ public static TcShallowLearningAdapter getInstance() { return new LiblinearAdapter(); } - public static String getOutcomeMappingFilename() { - return "outcome-mapping.txt"; - } - - public static String getFeatureNameMappingFilename() { - return "feature-name-mapping.txt"; - } - @Override public ExecutableTaskBase getTestTask() { return new LiblinearTestTask(); @@ -88,7 +80,7 @@ public ExecutableTaskBase getTestTask() { @Override public Class getOutcomeIdReportClass() { - return LiblinearOutcomeIdReport.class; + return LibsvmDataFormatOutcomeIdReport.class; } @Override @@ -105,7 +97,7 @@ public DimensionBundle> getFoldDimensionBundle( @Override public Class getDataWriterClass() { - return LiblinearDataWriter.class; + return LibsvmDataFormatWriter.class; } @Override diff --git a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/report/LiblinearOutcomeIdReport.java b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/report/LiblinearOutcomeIdReport.java deleted file mode 100644 index b0710c76a..000000000 --- a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/report/LiblinearOutcomeIdReport.java +++ /dev/null @@ -1,196 +0,0 @@ -/******************************************************************************* - * Copyright 2018 - * Ubiquitous Knowledge Processing (UKP) Lab - * Technische Universität Darmstadt - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ******************************************************************************/ -package org.dkpro.tc.ml.liblinear.report; - -import java.io.File; -import java.io.IOException; -import java.io.UnsupportedEncodingException; -import java.net.URLEncoder; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -import org.apache.commons.io.FileUtils; -import org.apache.commons.io.output.FileWriterWithEncoding; -import org.dkpro.lab.reporting.ReportBase; -import org.dkpro.lab.storage.StorageService; -import org.dkpro.lab.storage.StorageService.AccessMode; -import org.dkpro.tc.core.Constants; -import org.dkpro.tc.core.task.InitTask; -import org.dkpro.tc.ml.liblinear.LiblinearAdapter; -import org.dkpro.tc.ml.liblinear.LiblinearTestTask; -import org.dkpro.tc.ml.liblinear.writer.LiblinearDataWriter; -import org.dkpro.tc.ml.report.util.SortedKeyProperties; - -/** - * Creates id 2 outcome report - */ -public class LiblinearOutcomeIdReport - extends ReportBase - implements Constants -{ - // constant dummy value for setting as threshold which is an expected field in the evaluation - // module but is not needed/provided by liblinear - private static final String THRESHOLD_CONSTANT = "-1"; - - public LiblinearOutcomeIdReport(){ - //required by groovy - } - - @Override - public void execute() - throws Exception - { - boolean isRegression = getDiscriminators() - .get(LiblinearTestTask.class.getName() + "|" + Constants.DIM_LEARNING_MODE) - .equals(Constants.LM_REGRESSION); - - boolean isUnit = getDiscriminators() - .get(InitTask.class.getName() + "|" + Constants.DIM_FEATURE_MODE) - .equals(Constants.FM_UNIT); - - Map id2label = getId2LabelMapping(isRegression); - String header = buildHeader(id2label, isRegression); - - List predictions = readPredictions(); - Map index2instanceIdMap = getMapping(isUnit); - - Properties prop = new SortedKeyProperties(); - int lineCounter = 0; - for (String line : predictions) { - if (line.startsWith("#")) { - continue; - } - String[] split = line.split(";"); - String key = index2instanceIdMap.get(lineCounter+""); - - if (isRegression){ - prop.setProperty(key, - split[0] + ";" + split[1] + ";" + THRESHOLD_CONSTANT); - }else{ - int pred = Double.valueOf(split[0]).intValue(); - int gold = Double.valueOf(split[1]).intValue(); - prop.setProperty(key, - pred + ";" + gold + ";" + THRESHOLD_CONSTANT); - } - lineCounter++; - } - - File targetFile = getId2OutcomeFileLocation(); - - FileWriterWithEncoding fw = new FileWriterWithEncoding(targetFile, "utf-8"); - prop.store(fw, header); - fw.close(); - - } - -private Map getMapping(boolean isUnit) throws IOException { - - File f; - if (isUnit) { - f = new File(getContext().getFolder(TEST_TASK_INPUT_KEY_TEST_DATA, AccessMode.READONLY), - LiblinearDataWriter.INDEX2INSTANCEID); - } else { - f = new File(getContext().getFolder(TEST_TASK_INPUT_KEY_TEST_DATA, AccessMode.READONLY), - Constants.FILENAME_DOCUMENT_META_DATA_LOG); - } - - Map m = new HashMap<>(); - - int idx=0; - for (String l : FileUtils.readLines(f, "utf-8")) { - if (l.startsWith("#")) { - continue; - } - if (l.trim().isEmpty()) { - continue; - } - String[] split = l.split("\t"); - -// if (isUnit) { - m.put(idx + "", split[0]); - idx++; -// } else { -// m.put(split[0], split[1]); -// } - - } - return m; - } - - private File getId2OutcomeFileLocation() - { - File evaluationFolder = getContext().getFolder("", AccessMode.READWRITE); - return new File(evaluationFolder, ID_OUTCOME_KEY); - } - - private List readPredictions() - throws IOException - { - File predFolder = getContext().getFolder("", AccessMode.READWRITE); - return FileUtils.readLines(new File(predFolder, Constants.FILENAME_PREDICTIONS), "utf-8"); - } - - private String buildHeader(Map id2label, boolean isRegression) - throws UnsupportedEncodingException - { - StringBuilder header = new StringBuilder(); - header.append("ID=PREDICTION;GOLDSTANDARD;THRESHOLD" + "\n" + "labels" + " "); - - if (isRegression) { - // no label mapping for regression so that is all we have to do - return header.toString(); - } - - int numKeys = id2label.keySet().size(); - List keys = new ArrayList(id2label.keySet()); - for (int i = 0; i < numKeys; i++) { - Integer key = keys.get(i); - header.append(key + "=" + URLEncoder.encode(id2label.get(key), "UTF-8")); - if (i + 1 < numKeys) { - header.append(" "); - } - } - return header.toString(); - } - - private Map getId2LabelMapping(boolean isRegression) - throws Exception - { - if(isRegression){ - //no map for regression; - return new HashMap<>(); - } - - File mappingFolder = getContext().getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, StorageService.AccessMode.READONLY); - String fileName = LiblinearAdapter.getOutcomeMappingFilename(); - File file = new File(mappingFolder, fileName); - Map map = new HashMap(); - - List lines = FileUtils.readLines(file, "utf-8"); - for (String line : lines) { - String[] split = line.split("\t"); - map.put(Integer.valueOf(split[1]), split[0]); - } - - return map; - } - -} \ No newline at end of file diff --git a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LiblinearModelSerializationDescription.java b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LiblinearModelSerializationDescription.java index ea535787e..bbffbce2a 100644 --- a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LiblinearModelSerializationDescription.java +++ b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LiblinearModelSerializationDescription.java @@ -19,15 +19,10 @@ package org.dkpro.tc.ml.liblinear.serialization; import java.io.File; -import java.io.IOException; -import java.util.List; -import org.apache.commons.io.FileUtils; -import org.dkpro.lab.engine.TaskContext; -import org.dkpro.lab.storage.StorageService.AccessMode; -import org.dkpro.lab.task.Discriminator; import org.dkpro.tc.core.Constants; -import org.dkpro.tc.core.task.ModelSerializationTask; +import org.dkpro.tc.core.util.SaveModelUtils; +import org.dkpro.tc.io.libsvm.LibsvmModelSerialization; import org.dkpro.tc.ml.liblinear.LiblinearAdapter; import org.dkpro.tc.ml.liblinear.util.LiblinearUtils; @@ -37,27 +32,16 @@ import de.bwaldvogel.liblinear.Problem; import de.bwaldvogel.liblinear.SolverType; -public class LiblinearModelSerializationDescription extends ModelSerializationTask implements Constants { +public class LiblinearModelSerializationDescription extends LibsvmModelSerialization implements Constants { - @Discriminator(name = DIM_CLASSIFICATION_ARGS) - private List classificationArguments; - boolean trainModel = true; - @Override - public void execute(TaskContext aContext) throws Exception { - trainAndStoreModel(aContext); - - writeModelConfiguration(aContext, LiblinearAdapter.class.getName()); + public void trainModel(boolean b) { + trainModel = b; } - private void trainAndStoreModel(TaskContext aContext) throws Exception { - // create mapping and persist mapping - File fileTrain = getTrainFile(aContext); - - copyFeatureNameMappingToThisFolder(aContext); - copyOutcomeMappingToThisFolder(aContext); - + @Override + protected void trainModel(File fileTrain) throws Exception { SolverType solver = LiblinearUtils.getSolver(classificationArguments); double C = LiblinearUtils.getParameterC(classificationArguments); double eps = LiblinearUtils.getParameterEpsilon(classificationArguments); @@ -70,36 +54,8 @@ private void trainAndStoreModel(TaskContext aContext) throws Exception { model.save(new File(outputFolder, MODEL_CLASSIFIER)); } - private void copyOutcomeMappingToThisFolder(TaskContext aContext) throws IOException { - if(isRegression()){ - return; - } - - File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); - String mapping = LiblinearAdapter.getOutcomeMappingFilename(); - - FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping)); - } - - private boolean isRegression() { - return learningMode.equals(Constants.LM_REGRESSION); - } - - private void copyFeatureNameMappingToThisFolder(TaskContext aContext) throws IOException { - File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); - String mapping = LiblinearAdapter.getFeatureNameMappingFilename(); - - FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping)); - } - - private File getTrainFile(TaskContext aContext) { - File trainFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); - File fileTrain = new File(trainFolder, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT); - - return fileTrain; - } - - public void trainModel(boolean b) { - trainModel = b; + @Override + protected void writeAdapter() throws Exception { + SaveModelUtils.writeModelAdapterInformation(outputFolder, LiblinearAdapter.class.getName()); } } \ No newline at end of file diff --git a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LoadModelConnectorLiblinear.java b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LoadModelConnectorLiblinear.java index a388b2c8d..d8a4145f5 100644 --- a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LoadModelConnectorLiblinear.java +++ b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/serialization/LoadModelConnectorLiblinear.java @@ -20,182 +20,54 @@ import static org.dkpro.tc.core.Constants.MODEL_CLASSIFIER; +import java.io.BufferedWriter; import java.io.File; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.io.FileOutputStream; +import java.io.OutputStreamWriter; -import org.apache.commons.io.FileUtils; import org.apache.uima.UimaContext; -import org.apache.uima.analysis_engine.AnalysisEngineProcessException; -import org.apache.uima.fit.descriptor.ConfigurationParameter; -import org.apache.uima.fit.descriptor.ExternalResource; -import org.apache.uima.fit.util.JCasUtil; -import org.apache.uima.jcas.JCas; import org.apache.uima.resource.ResourceInitializationException; -import org.dkpro.tc.api.features.FeatureExtractorResource_ImplBase; -import org.dkpro.tc.api.features.Instance; -import org.dkpro.tc.api.type.TextClassificationOutcome; -import org.dkpro.tc.core.Constants; -import org.dkpro.tc.core.ml.ModelSerialization_ImplBase; -import org.dkpro.tc.core.util.SaveModelUtils; -import org.dkpro.tc.core.util.TaskUtils; -import org.dkpro.tc.ml.liblinear.LiblinearAdapter; -import org.dkpro.tc.ml.uima.TcAnnotator; +import org.dkpro.tc.io.libsvm.LibsvmModelLoaderConnector; import de.bwaldvogel.liblinear.Feature; import de.bwaldvogel.liblinear.Linear; import de.bwaldvogel.liblinear.Model; import de.bwaldvogel.liblinear.Problem; -public class LoadModelConnectorLiblinear extends ModelSerialization_ImplBase { - - @ConfigurationParameter(name = TcAnnotator.PARAM_TC_MODEL_LOCATION, mandatory = true) - private File tcModelLocation; - - @ExternalResource(key = PARAM_FEATURE_EXTRACTORS, mandatory = true) - protected FeatureExtractorResource_ImplBase[] featureExtractors; - - @ConfigurationParameter(name = PARAM_FEATURE_MODE, mandatory = true) - private String featureMode; - - @ConfigurationParameter(name = PARAM_LEARNING_MODE, mandatory = true) - private String learningMode; +public class LoadModelConnectorLiblinear extends LibsvmModelLoaderConnector { private Model liblinearModel; - private Map outcomeMapping; - - private Map featureMapping; - @Override public void initialize(UimaContext context) throws ResourceInitializationException { super.initialize(context); try { liblinearModel = Linear.loadModel(new File(tcModelLocation, MODEL_CLASSIFIER)); - outcomeMapping = loadOutcome2IntegerMapping(tcModelLocation); - featureMapping = loadFeature2IntegerMapping(tcModelLocation); - SaveModelUtils.verifyTcVersion(tcModelLocation, getClass()); } catch (Exception e) { throw new ResourceInitializationException(e); } - - } - - - private Map loadFeature2IntegerMapping(File tcModelLocation) throws IOException { - Map map = new HashMap<>(); - List readLines = FileUtils - .readLines(new File(tcModelLocation, LiblinearAdapter.getFeatureNameMappingFilename()), "utf-8"); - for (String l : readLines) { - String[] split = l.split("\t"); - map.put(split[0],Integer.valueOf(split[1])); - } - return map; } - private Map loadOutcome2IntegerMapping(File tcModelLocation) throws IOException { + @Override + protected File runPrediction(File infile) throws Exception { - if (isRegression()){ - return new HashMap<>(); - } + Problem predictionProblem = Problem.readFromFile(infile, 1.0); - Map map = new HashMap<>(); - List readLines = FileUtils - .readLines(new File(tcModelLocation, LiblinearAdapter.getOutcomeMappingFilename()), "utf-8"); - for (String l : readLines) { - String[] split = l.split("\t"); - map.put(Integer.valueOf(split[1]), split[0]); - } - return map; - } - - private boolean isRegression() { - return learningMode.equals(Constants.LM_REGRESSION); - } - - - private Double toValue(Object value) - { - double v; - if (value instanceof Number) { - v = ((Number) value).doubleValue(); - } - else { - v = 1.0; - } - - return v; - } - - @Override - public void process(JCas jcas) throws AnalysisEngineProcessException { - try { - List instances = TaskUtils.getMultipleInstancesUnitMode(featureExtractors, jcas, true, - new LiblinearAdapter().useSparseFeatures()); - - StringBuilder sb = new StringBuilder(); - for (Instance inst : instances) { - Map entry = new HashMap<>(); - for (org.dkpro.tc.api.features.Feature f : inst.getFeatures()) { - Integer id = featureMapping.get(f.getName()); - Double val = toValue(f.getValue()); - - if (Math.abs(val) < 0.00000001) { - // skip zero values - continue; - } - - entry.put(id, val); - } - List keys = new ArrayList(entry.keySet()); - Collections.sort(keys); - - sb.append("-1\t"); // dummy label - - sb.append("1:1.0\t"); //bias entry - - for (int i = 0; i < keys.size(); i++) { - Integer key = keys.get(i); - Double value = entry.get(key); - sb.append("" + key.toString() + ":" + value.toString()); - if (i + 1 < keys.size()) { - sb.append("\t"); - } - } - sb.append("\n"); - } - - File inputData = File.createTempFile("libLinearePrediction", - Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT); - inputData.deleteOnExit(); - FileUtils.writeStringToFile(inputData, sb.toString(), "utf-8"); - - Problem predictionProblem = Problem.readFromFile(inputData, 1.0); - - List outcomes = new ArrayList<>( - JCasUtil.select(jcas, TextClassificationOutcome.class)); - Feature[][] testInstances = predictionProblem.x; - for (int i = 0; i < testInstances.length; i++) { - Feature[] instance = testInstances[i]; - Double prediction = Linear.predict(liblinearModel, instance); - - if (isRegression()) { - outcomes.get(i).setOutcome(prediction.toString()); - } else { - String predictedLabel = outcomeMapping.get(prediction.intValue()); - outcomes.get(i).setOutcome(predictedLabel); - } - } - - } catch (Exception e) { - throw new AnalysisEngineProcessException(e); + File tmp = File.createTempFile("libLinearePrediction",".txt"); + + BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(tmp), "utf-8")); + Feature[][] testInstances = predictionProblem.x; + for (int i = 0; i < testInstances.length; i++) { + Feature[] instance = testInstances[i]; + Double prediction = Linear.predict(liblinearModel, instance); + writer.write(prediction.toString() + "\n"); } - + + writer.close(); + + tmp.deleteOnExit(); + return tmp; } } \ No newline at end of file diff --git a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/writer/LiblinearDataWriter.java b/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/writer/LiblinearDataWriter.java deleted file mode 100644 index a97afbcdf..000000000 --- a/dkpro-tc-ml-liblinear/src/main/java/org/dkpro/tc/ml/liblinear/writer/LiblinearDataWriter.java +++ /dev/null @@ -1,314 +0,0 @@ -/******************************************************************************* - * Copyright 2018 - * Ubiquitous Knowledge Processing (UKP) Lab - * Technische Universität Darmstadt - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ******************************************************************************/ -package org.dkpro.tc.ml.liblinear.writer; - -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStreamReader; -import java.io.OutputStreamWriter; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.apache.commons.io.FileUtils; -import org.dkpro.tc.api.features.Feature; -import org.dkpro.tc.api.features.Instance; -import org.dkpro.tc.core.Constants; -import org.dkpro.tc.core.io.DataWriter; -import org.dkpro.tc.ml.liblinear.LiblinearAdapter; - -import com.google.gson.Gson; - -/** - * Format is outcome TAB index:value TAB index:value TAB ... - * - * Zeros are omitted. Indexes need to be sorted. - * - * For example: 1 1:1 3:1 4:1 6:1 2 2:1 3:1 5:1 7:1 1 3:1 5:1 - */ -public class LiblinearDataWriter implements DataWriter { - - public static final String INDEX2INSTANCEID = "index2Instanceid.txt"; - - File outputDirectory; - - boolean useSparse; - - String learningMode; - - boolean applyWeighting; - - File classifierFormatOutputFile; - - BufferedWriter bw = null; - - Map index2instanceId; - - Gson gson = new Gson(); - - private int maxId = 0; - - Map featureNames2id; - - Map outcomeMap; - - @Override - public void writeGenericFormat(Collection instances) throws Exception { - initGeneric(); - - // bulk-write - in sequence mode this keeps the instances together that - // belong to the same sequence! - Instance[] array = instances.toArray(new Instance[0]); - bw.write(gson.toJson(array) + System.lineSeparator()); - - bw.close(); - bw = null; - } - - private void initGeneric() throws IOException { - if (bw != null) { - return; - } - bw = new BufferedWriter(new OutputStreamWriter( - new FileOutputStream(new File(outputDirectory, Constants.GENERIC_FEATURE_FILE), true), "utf-8")); - } - - @Override - public void transformFromGeneric() throws Exception { - BufferedReader reader = new BufferedReader(new InputStreamReader( - new FileInputStream(new File(outputDirectory, Constants.GENERIC_FEATURE_FILE)), "utf-8")); - - String line = null; - while ((line = reader.readLine()) != null) { - Instance[] instance = gson.fromJson(line, Instance[].class); - List ins = new ArrayList<>(Arrays.asList(instance)); - writeClassifierFormat(ins); - } - - reader.close(); - FileUtils.deleteQuietly(new File(outputDirectory, Constants.GENERIC_FEATURE_FILE)); - } - - @Override - public void writeClassifierFormat(Collection in) throws Exception { - - if (featureNames2id == null) { - createFeatureNameMap(); - } - - initClassifierFormat(); - - List instances = new ArrayList<>(in); - - for (Instance inst : instances) { - Map entry = new HashMap<>(); - recordInstanceId(inst, maxId++, index2instanceId); - for (Feature f : inst.getFeatures()) { - Integer id = featureNames2id.get(f.getName()); - Double val = toValue(f.getValue()); - - if (Math.abs(val) < 0.00000001) { - // skip zero values - continue; - } - - entry.put(id, val); - } - List keys = new ArrayList(entry.keySet()); - Collections.sort(keys); - - - if (isRegression()) { - bw.append(inst.getOutcome() + "\t"); - } else { - bw.append(outcomeMap.get(inst.getOutcome()) + "\t"); - } - for (int i = 0; i < keys.size(); i++) { - Integer key = keys.get(i); - Double value = entry.get(key); - bw.append("" + key.toString() + ":" + value.toString()); - if (i + 1 < keys.size()) { - bw.append("\t"); - } - } - bw.append("\n"); - } - - bw.close(); - bw = null; - - writeMapping(outputDirectory, INDEX2INSTANCEID, index2instanceId); - writeFeatureName2idMapping(outputDirectory, LiblinearAdapter.getFeatureNameMappingFilename(), featureNames2id); - writeOutcomeMapping(outputDirectory, LiblinearAdapter.getOutcomeMappingFilename(), outcomeMap); - } - - private void writeOutcomeMapping(File outputDirectory, String file, Map map) throws IOException { - - if(isRegression()){ - return; - } - - StringBuilder sb = new StringBuilder(); - for (String k : map.keySet()) { - sb.append(k + "\t" + map.get(k) + "\n"); - } - - FileUtils.writeStringToFile(new File(outputDirectory, file), sb.toString(), "utf-8"); - } - - private Double toValue(Object value) { - double v; - if (value instanceof Number) { - v = ((Number) value).doubleValue(); - } else { - v = 1.0; - } - - return v; - } - - private void createFeatureNameMap() throws IOException { - featureNames2id = new HashMap<>(); - List readLines = FileUtils.readLines(new File(outputDirectory, Constants.FILENAME_FEATURES), "utf-8"); - - // add a "bias" feature node; otherwise LIBLINEAR is unable to predict - // the majority class for - // instances consisting entirely of features never seen during training - featureNames2id.put("x.BIAS", 1); - - Integer i = 2; - for (String l : readLines) { - if (l.isEmpty()) { - continue; - } - featureNames2id.put(l, i++); - } - } - - private void writeFeatureName2idMapping(File outputDirectory2, String featurename2instanceid2, - Map stringToInt) throws IOException { - StringBuilder sb = new StringBuilder(); - for (String k : stringToInt.keySet()) { - sb.append(k + "\t" + stringToInt.get(k) + "\n"); - } - FileUtils.writeStringToFile(new File(outputDirectory, featurename2instanceid2), sb.toString(), "utf-8"); - } - - private void initClassifierFormat() throws Exception { - if (bw != null) { - return; - } - - bw = new BufferedWriter( - new OutputStreamWriter(new FileOutputStream(classifierFormatOutputFile, true), "utf-8")); - } - - @Override - public void init(File outputDirectory, boolean useSparse, String learningMode, boolean applyWeighting, - String[] outcomes) throws Exception { - this.outputDirectory = outputDirectory; - this.useSparse = useSparse; - this.learningMode = learningMode; - this.applyWeighting = applyWeighting; - classifierFormatOutputFile = new File(outputDirectory, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT); - - index2instanceId = new HashMap<>(); - - // Caution: DKPro Lab imports (aka copies!) the data of the train task - // as test task. We use - // appending mode for streaming. We might append the old training file - // with - // testing data! - // Force delete the old training file to make sure we start with a - // clean, empty file - if (classifierFormatOutputFile.exists()) { - FileUtils.forceDelete(classifierFormatOutputFile); - } - - buildOutcomeMap(outcomes); - } - - /** - * Creates a mapping from the label names to integer values to identify - * labels by integers - * - * @param outcomes - */ - private void buildOutcomeMap(String[] outcomes) { - if(isRegression()){ - return; - } - outcomeMap = new HashMap<>(); - Integer i = 0; - List outcomesSorted = new ArrayList<>(Arrays.asList(outcomes)); - Collections.sort(outcomesSorted); - for (String o : outcomesSorted) { - outcomeMap.put(o, i++); - } - } - - @Override - public boolean canStream() { - return true; - } - - @Override - public String getGenericFileName() { - return Constants.GENERIC_FEATURE_FILE; - } - - private void writeMapping(File outputDirectory, String fileName, Map index2instanceId) - throws IOException { - StringBuilder sb = new StringBuilder(); - sb.append("#Index\tDkProInstanceId\n"); - for (String k : index2instanceId.keySet()) { - sb.append(k + "\t" + index2instanceId.get(k) + "\n"); - } - FileUtils.writeStringToFile(new File(outputDirectory, fileName), sb.toString(), "utf-8"); - } - - // build a map between the dkpro instance id and the index in the file - private void recordInstanceId(Instance instance, int i, Map index2instanceId) { - Collection features = instance.getFeatures(); - for (Feature f : features) { - if (!f.getName().equals(Constants.ID_FEATURE_NAME)) { - continue; - } - index2instanceId.put(i + "", f.getValue() + ""); - return; - } - } - - private boolean isRegression(){ - return learningMode.equals(Constants.LM_REGRESSION); - } - - @Override - public void close() throws Exception { - - } - -} \ No newline at end of file diff --git a/dkpro-tc-ml-liblinear/src/test/java/org/dkpro/tc/ml/liblinear/LiblinearDataWriterTest.java b/dkpro-tc-ml-liblinear/src/test/java/org/dkpro/tc/ml/liblinear/LiblinearDataWriterTest.java index da9ac0734..1ca5887ae 100644 --- a/dkpro-tc-ml-liblinear/src/test/java/org/dkpro/tc/ml/liblinear/LiblinearDataWriterTest.java +++ b/dkpro-tc-ml-liblinear/src/test/java/org/dkpro/tc/ml/liblinear/LiblinearDataWriterTest.java @@ -27,7 +27,7 @@ import org.dkpro.tc.api.features.Feature; import org.dkpro.tc.api.features.Instance; import org.dkpro.tc.core.Constants; -import org.dkpro.tc.ml.liblinear.writer.LiblinearDataWriter; +import org.dkpro.tc.io.libsvm.LibsvmDataFormatWriter; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -64,7 +64,7 @@ public void dataWriterTest() throws Exception { sb.append("feature2\n"); FileUtils.writeStringToFile(new File(outputDirectory, Constants.FILENAME_FEATURES), sb.toString(), "utf-8"); File outputFile = new File(outputDirectory, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT); - LiblinearDataWriter writer = new LiblinearDataWriter(); + LibsvmDataFormatWriter writer = new LibsvmDataFormatWriter(); writer.init(outputDirectory, false, Constants.LM_SINGLE_LABEL, false, new String[]{"0", "1"}); writer.writeClassifierFormat(fs); diff --git a/dkpro-tc-ml-libsvm/pom.xml b/dkpro-tc-ml-libsvm/pom.xml index 939b64c07..f53289e0f 100644 --- a/dkpro-tc-ml-libsvm/pom.xml +++ b/dkpro-tc-ml-libsvm/pom.xml @@ -7,25 +7,13 @@ 1.0.0-SNAPSHOT - - commons-logging - commons-logging-api - - - com.datumbox - libsvm - - - commons-io - commons-io - org.dkpro.tc - dkpro-tc-ml + dkpro-tc-core - + org.dkpro.tc - dkpro-tc-core + dkpro-tc-api org.dkpro.tc @@ -33,8 +21,12 @@ org.dkpro.tc - dkpro-tc-api + dkpro-tc-io-libsvm + + org.dkpro.tc + dkpro-tc-ml + org.apache.uima uimafit-core @@ -43,6 +35,18 @@ org.dkpro.lab dkpro-lab-core + + commons-logging + commons-logging-api + + + com.datumbox + libsvm + + + commons-io + commons-io + org.apache.uima uimaj-core diff --git a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/LibsvmAdapter.java b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/LibsvmAdapter.java index e4335311f..63d73d308 100644 --- a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/LibsvmAdapter.java +++ b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/LibsvmAdapter.java @@ -28,10 +28,10 @@ import org.dkpro.tc.core.ml.ModelSerialization_ImplBase; import org.dkpro.tc.core.ml.TcShallowLearningAdapter; import org.dkpro.tc.core.task.ModelSerializationTask; -import org.dkpro.tc.ml.libsvm.report.LibsvmOutcomeIdReport; +import org.dkpro.tc.io.libsvm.LibsvmDataFormatWriter; +import org.dkpro.tc.io.libsvm.LibsvmDataFormatOutcomeIdReport; import org.dkpro.tc.ml.libsvm.serialization.LibsvmModelSerializationDescription; import org.dkpro.tc.ml.libsvm.serialization.LoadModelConnectorLibsvm; -import org.dkpro.tc.ml.libsvm.writer.LibsvmDataWriter; import org.dkpro.tc.ml.report.InnerBatchReport; /** @@ -86,7 +86,7 @@ public ExecutableTaskBase getTestTask() @Override public Class getOutcomeIdReportClass() { - return LibsvmOutcomeIdReport.class; + return LibsvmDataFormatOutcomeIdReport.class; } @Override @@ -105,7 +105,7 @@ public DimensionBundle> getFoldDimensionBundle(String[] files @Override public Class getDataWriterClass() { - return LibsvmDataWriter.class; + return LibsvmDataFormatWriter.class; } @Override diff --git a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/LibsvmUtils.java b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/LibsvmUtils.java deleted file mode 100644 index 86790ed6d..000000000 --- a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/LibsvmUtils.java +++ /dev/null @@ -1,84 +0,0 @@ -/******************************************************************************* - * Copyright 2018 - * Ubiquitous Knowledge Processing (UKP) Lab - * Technische Universität Darmstadt - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ******************************************************************************/ -package org.dkpro.tc.ml.libsvm; - -import java.io.BufferedReader; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStreamReader; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Set; - -public class LibsvmUtils -{ - - public static String outcomeMap2String(Map map) - { - StringBuilder sb = new StringBuilder(); - for (Entry entry : map.entrySet()) { - sb.append(entry.getKey()); - sb.append("\t"); - sb.append(entry.getValue()); - sb.append("\n"); - } - - return sb.toString(); - } - - public static Map createMapping(File... files) - throws IOException - { - Set uniqueOutcomes = new HashSet<>(); - for (File f : files) { - uniqueOutcomes.addAll(pickOutcomes(f)); - } - - Map mapping = new HashMap<>(); - int id = 0; - for (String o : uniqueOutcomes) { - mapping.put(o, id++); - } - - return mapping; - } - - private static Collection pickOutcomes(File file) - throws IOException - { - Set outcomes = new HashSet<>(); - - BufferedReader br = new BufferedReader( - new InputStreamReader(new FileInputStream(file), "utf-8")); - - String line = null; - while (((line = br.readLine()) != null)) { - if (line.isEmpty()) { - continue; - } - int firstTabIdx = line.indexOf("\t"); - outcomes.add(line.substring(0, firstTabIdx)); - } - br.close(); - return outcomes; - } -} diff --git a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/report/LibsvmOutcomeIdReport.java b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/report/LibsvmOutcomeIdReport.java deleted file mode 100644 index 172698a98..000000000 --- a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/report/LibsvmOutcomeIdReport.java +++ /dev/null @@ -1,195 +0,0 @@ -/******************************************************************************* - * Copyright 2018 - * Ubiquitous Knowledge Processing (UKP) Lab - * Technische Universität Darmstadt - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ******************************************************************************/ -package org.dkpro.tc.ml.libsvm.report; - -import java.io.File; -import java.io.IOException; -import java.io.UnsupportedEncodingException; -import java.net.URLEncoder; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -import org.apache.commons.io.FileUtils; -import org.apache.commons.io.output.FileWriterWithEncoding; -import org.dkpro.lab.reporting.ReportBase; -import org.dkpro.lab.storage.StorageService.AccessMode; -import org.dkpro.tc.core.Constants; -import org.dkpro.tc.core.task.InitTask; -import org.dkpro.tc.ml.libsvm.LibsvmAdapter; -import org.dkpro.tc.ml.libsvm.LibsvmTestTask; -import org.dkpro.tc.ml.libsvm.writer.LibsvmDataWriter; -import org.dkpro.tc.ml.report.util.SortedKeyProperties; - -/** - * Creates id 2 outcome report - */ -public class LibsvmOutcomeIdReport - extends ReportBase - implements Constants -{ - // constant dummy value for setting as threshold which is an expected field in the evaluation - // module but is not needed/provided by liblinear - private static final String THRESHOLD_CONSTANT = "-1"; - - public LibsvmOutcomeIdReport(){ - //required by groovy - } - - @Override - public void execute() - throws Exception - { - boolean isRegression = getDiscriminators() - .get(LibsvmTestTask.class.getName() + "|" + Constants.DIM_LEARNING_MODE) - .equals(Constants.LM_REGRESSION); - - boolean isUnit = getDiscriminators() - .get(InitTask.class.getName() + "|" + Constants.DIM_FEATURE_MODE) - .equals(Constants.FM_UNIT); - - Map id2label = getId2LabelMapping(isRegression); - String header = buildHeader(id2label, isRegression); - - List predictions = readPredictions(); - Map index2instanceIdMap = getMapping(isUnit); - - Properties prop = new SortedKeyProperties(); - int lineCounter = 0; - for (String line : predictions) { - if (line.startsWith("#")) { - continue; - } - String[] split = line.split(";"); - String key = index2instanceIdMap.get(lineCounter+""); - - if (isRegression){ - prop.setProperty(key, - split[0] + ";" + split[1] + ";" + THRESHOLD_CONSTANT); - }else{ - int pred = Double.valueOf(split[0]).intValue(); - int gold = Double.valueOf(split[1]).intValue(); - prop.setProperty(key, - pred + ";" + gold + ";" + THRESHOLD_CONSTANT); - } - lineCounter++; - } - - File targetFile = getId2OutcomeFileLocation(); - - FileWriterWithEncoding fw = new FileWriterWithEncoding(targetFile, "utf-8"); - prop.store(fw, header); - fw.close(); - - } - -private Map getMapping(boolean isUnit) throws IOException { - - File f; - if (isUnit) { - f = new File(getContext().getFolder(TEST_TASK_INPUT_KEY_TEST_DATA, AccessMode.READONLY), - LibsvmDataWriter.INDEX2INSTANCEID); - } else { - f = new File(getContext().getFolder(TEST_TASK_INPUT_KEY_TEST_DATA, AccessMode.READONLY), - Constants.FILENAME_DOCUMENT_META_DATA_LOG); - } - - Map m = new HashMap<>(); - - int idx=0; - for (String l : FileUtils.readLines(f, "utf-8")) { - if (l.startsWith("#")) { - continue; - } - if (l.trim().isEmpty()) { - continue; - } - String[] split = l.split("\t"); - -// if (isUnit) { - m.put(idx + "", split[0]); - idx++; -// } else { -// m.put(split[0], split[1]); -// } - - } - return m; - } - - private File getId2OutcomeFileLocation() - { - File evaluationFolder = getContext().getFolder("", AccessMode.READWRITE); - return new File(evaluationFolder, ID_OUTCOME_KEY); - } - - private List readPredictions() - throws IOException - { - File predFolder = getContext().getFolder("", AccessMode.READWRITE); - return FileUtils.readLines(new File(predFolder, Constants.FILENAME_PREDICTIONS), "utf-8"); - } - - private String buildHeader(Map id2label, boolean isRegression) - throws UnsupportedEncodingException - { - StringBuilder header = new StringBuilder(); - header.append("ID=PREDICTION;GOLDSTANDARD;THRESHOLD" + "\n" + "labels" + " "); - - if (isRegression) { - // no label mapping for regression so that is all we have to do - return header.toString(); - } - - int numKeys = id2label.keySet().size(); - List keys = new ArrayList(id2label.keySet()); - for (int i = 0; i < numKeys; i++) { - Integer key = keys.get(i); - header.append(key + "=" + URLEncoder.encode(id2label.get(key), "UTF-8")); - if (i + 1 < numKeys) { - header.append(" "); - } - } - return header.toString(); - } - - private Map getId2LabelMapping(boolean isRegression) - throws Exception - { - if(isRegression){ - //no map for regression; - return new HashMap<>(); - } - - File folder = getContext().getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); - String fileName = LibsvmAdapter.getOutcomeMappingFilename(); - File file = new File(folder, fileName); - Map map = new HashMap(); - - List lines = FileUtils.readLines(file, "utf-8"); - for (String line : lines) { - String[] split = line.split("\t"); - map.put(Integer.valueOf(split[1]), split[0]); - } - - return map; - } - -} \ No newline at end of file diff --git a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LibsvmModelSerializationDescription.java b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LibsvmModelSerializationDescription.java index 47a1f5deb..349da6cc4 100644 --- a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LibsvmModelSerializationDescription.java +++ b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LibsvmModelSerializationDescription.java @@ -19,78 +19,27 @@ package org.dkpro.tc.ml.libsvm.serialization; import java.io.File; -import java.io.IOException; import java.util.ArrayList; import java.util.List; -import org.apache.commons.io.FileUtils; -import org.dkpro.lab.engine.TaskContext; -import org.dkpro.lab.storage.StorageService.AccessMode; -import org.dkpro.lab.task.Discriminator; -import org.dkpro.tc.api.exception.TextClassificationException; import org.dkpro.tc.core.Constants; -import org.dkpro.tc.core.task.ModelSerializationTask; +import org.dkpro.tc.core.util.SaveModelUtils; +import org.dkpro.tc.io.libsvm.LibsvmModelSerialization; import org.dkpro.tc.ml.libsvm.LibsvmAdapter; import org.dkpro.tc.ml.libsvm.api.LibsvmTrainModel; -public class LibsvmModelSerializationDescription extends ModelSerializationTask implements Constants { - - @Discriminator(name = DIM_CLASSIFICATION_ARGS) - private List classificationArguments; - - boolean trainModel = true; +public class LibsvmModelSerializationDescription extends LibsvmModelSerialization implements Constants { @Override - public void execute(TaskContext aContext) throws Exception { - trainAndStoreModel(aContext); - - writeModelConfiguration(aContext, LibsvmAdapter.class.getName()); - } - - private void trainAndStoreModel(TaskContext aContext) throws Exception { - boolean multiLabel = learningMode.equals(Constants.LM_MULTI_LABEL); - if (multiLabel) { - throw new TextClassificationException("Multi-label is not yet implemented"); - } - - File fileTrain = getTrainFile(aContext); - - File model = new File(outputFolder, Constants.MODEL_CLASSIFIER); - - LibsvmTrainModel ltm = new LibsvmTrainModel(); - ltm.run(buildParameters(fileTrain, model)); - copyOutcomeMappingToThisFolder(aContext); - copyFeatureNameMappingToThisFolder(aContext); + protected void trainModel(File fileTrain) throws Exception { + LibsvmTrainModel ltm = new LibsvmTrainModel(); + File model = new File(outputFolder, Constants.MODEL_CLASSIFIER); + ltm.run(buildParameters(fileTrain, model)); } - private void copyOutcomeMappingToThisFolder(TaskContext aContext) throws IOException { - - if(isRegression()){ - return; - } - - File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); - String mapping = LibsvmAdapter.getOutcomeMappingFilename(); - - FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping)); - } - - private boolean isRegression() { - return learningMode.equals(Constants.LM_REGRESSION); - } - - private void copyFeatureNameMappingToThisFolder(TaskContext aContext) throws IOException { - File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); - String mapping = LibsvmAdapter.getFeatureNameMappingFilename(); - - FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping)); - } - - private File getTrainFile(TaskContext aContext) { - File trainFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY); - File fileTrain = new File(trainFolder, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT); - - return fileTrain; + @Override + protected void writeAdapter() throws Exception{ + SaveModelUtils.writeModelAdapterInformation(outputFolder, LibsvmAdapter.class.getName()); } private String[] buildParameters(File fileTrain, File model) { diff --git a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LoadModelConnectorLibsvm.java b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LoadModelConnectorLibsvm.java index a5c7ab176..e55134caa 100644 --- a/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LoadModelConnectorLibsvm.java +++ b/dkpro-tc-ml-libsvm/src/main/java/org/dkpro/tc/ml/libsvm/serialization/LoadModelConnectorLibsvm.java @@ -21,226 +21,47 @@ import static org.dkpro.tc.core.Constants.MODEL_CLASSIFIER; import java.io.BufferedReader; -import java.io.BufferedWriter; import java.io.DataOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; -import java.io.IOException; import java.io.InputStreamReader; -import java.io.OutputStreamWriter; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.apache.commons.io.FileUtils; import org.apache.uima.UimaContext; -import org.apache.uima.analysis_engine.AnalysisEngineProcessException; -import org.apache.uima.fit.descriptor.ConfigurationParameter; -import org.apache.uima.fit.descriptor.ExternalResource; -import org.apache.uima.fit.util.JCasUtil; -import org.apache.uima.jcas.JCas; import org.apache.uima.pear.util.FileUtil; import org.apache.uima.resource.ResourceInitializationException; -import org.dkpro.tc.api.features.Feature; -import org.dkpro.tc.api.features.FeatureExtractorResource_ImplBase; -import org.dkpro.tc.api.features.Instance; -import org.dkpro.tc.api.type.TextClassificationOutcome; -import org.dkpro.tc.core.Constants; -import org.dkpro.tc.core.ml.ModelSerialization_ImplBase; -import org.dkpro.tc.core.util.SaveModelUtils; -import org.dkpro.tc.core.util.TaskUtils; -import org.dkpro.tc.ml.libsvm.LibsvmAdapter; +import org.dkpro.tc.io.libsvm.LibsvmModelLoaderConnector; import org.dkpro.tc.ml.libsvm.api.LibsvmPredict; -import org.dkpro.tc.ml.uima.TcAnnotator; import libsvm.svm; import libsvm.svm_model; -public class LoadModelConnectorLibsvm - extends ModelSerialization_ImplBase -{ +public class LoadModelConnectorLibsvm extends LibsvmModelLoaderConnector { - private static final String OUTCOME_PLACEHOLDER = "-1"; + private svm_model model; - @ConfigurationParameter(name = TcAnnotator.PARAM_TC_MODEL_LOCATION, mandatory = true) - private File tcModelLocation; - - @ExternalResource(key = PARAM_FEATURE_EXTRACTORS, mandatory = true) - protected FeatureExtractorResource_ImplBase[] featureExtractors; - - @ConfigurationParameter(name = PARAM_FEATURE_MODE, mandatory = true) - private String featureMode; - - @ConfigurationParameter(name = PARAM_LEARNING_MODE, mandatory = true) - private String learningMode; - - private svm_model model; - - private Map integer2OutcomeMapping; - private Map featureMapping; - - @Override - public void initialize(UimaContext context) - throws ResourceInitializationException - { - super.initialize(context); - - try { - model = svm - .svm_load_model(new File(tcModelLocation, MODEL_CLASSIFIER).getAbsolutePath()); - integer2OutcomeMapping = loadInteger2OutcomeMapping(tcModelLocation); - featureMapping = loadFeature2IntegerMapping(tcModelLocation); - SaveModelUtils.verifyTcVersion(tcModelLocation, getClass()); - } - catch (Exception e) { - throw new ResourceInitializationException(e); - } - - } - - private Map loadFeature2IntegerMapping(File tcModelLocation) throws IOException { - Map map = new HashMap<>(); - List readLines = FileUtils - .readLines(new File(tcModelLocation, LibsvmAdapter.getFeatureNameMappingFilename()), "utf-8"); - for (String l : readLines) { - String[] split = l.split("\t"); - map.put(split[0],Integer.valueOf(split[1])); + @Override + public void initialize(UimaContext context) throws ResourceInitializationException { + super.initialize(context); + + try { + model = svm.svm_load_model(new File(tcModelLocation, MODEL_CLASSIFIER).getAbsolutePath()); + } catch (Exception e) { + throw new ResourceInitializationException(e); } - return map; - } - - private Map loadInteger2OutcomeMapping(File tcModelLocation) - throws IOException - { - if(isRegression()){ - return new HashMap<>(); - } - - Map map = new HashMap<>(); - List readLines = FileUtils - .readLines(new File(tcModelLocation, LibsvmAdapter.getOutcomeMappingFilename()), "utf-8"); - for (String l : readLines) { - String[] split = l.split("\t"); - map.put(split[1], split[0]); - } - return map; - } - - private boolean isRegression(){ - return learningMode.equals(Constants.LM_REGRESSION); - } - - @Override - public void process(JCas jcas) - throws AnalysisEngineProcessException - { - try { - File tempFile = createInputFile(jcas); - - File prediction = runPrediction(tempFile); - - List outcomes = getOutcomeAnnotations(jcas); - List writtenPredictions = FileUtils.readLines(prediction, "utf-8"); - - checkErrorConditionNumberOfOutcomesEqualsNumberOfPredictions(outcomes, - writtenPredictions); - - for (int i = 0; i < outcomes.size(); i++) { - if (isRegression()) { - String val = writtenPredictions.get(i); - outcomes.get(i).setOutcome(val); - } - else { - String val = writtenPredictions.get(i).replaceAll("\\.0", ""); - String pred = integer2OutcomeMapping.get(val); - outcomes.get(i).setOutcome(pred); - } - - } - - } - catch (Exception e) { - throw new AnalysisEngineProcessException(e); - } - - } - - private List getOutcomeAnnotations(JCas jcas) - { - return new ArrayList<>(JCasUtil.select(jcas, TextClassificationOutcome.class)); - } - - private void checkErrorConditionNumberOfOutcomesEqualsNumberOfPredictions( - List outcomes, List readLines) - { - if (outcomes.size() != readLines.size()) { - throw new IllegalStateException("Expected [" + outcomes.size() - + "] predictions but were [" + readLines.size() + "]"); - } - } - - private File runPrediction(File tempFile) - throws Exception - { - File prediction = FileUtil.createTempFile("libsvmPrediction", "libsvm"); - LibsvmPredict predictor = new LibsvmPredict(); - BufferedReader r = new BufferedReader( - new InputStreamReader(new FileInputStream(tempFile), "utf-8")); - - DataOutputStream output = new DataOutputStream(new FileOutputStream(prediction)); - predictor.predict(r, output, model, 0); - output.close(); - - return prediction; - } - - private File createInputFile(JCas jcas) - throws Exception - { - File tempFile = FileUtil.createTempFile("libsvm", ".txmt"); - BufferedWriter bw = new BufferedWriter( - new OutputStreamWriter(new FileOutputStream(tempFile), "utf-8")); - - List inst = TaskUtils.getMultipleInstancesUnitMode(featureExtractors, jcas, true, - new LibsvmAdapter().useSparseFeatures()); - - for (Instance i : inst) { - bw.write(OUTCOME_PLACEHOLDER); - for (Feature f : i.getFeatures()) { - if (!sanityCheckValue(f)) { - continue; - } - bw.write("\t"); - bw.write(featureMapping.get(f.getName()) + ":" + f.getValue()); - } - bw.write("\n"); - } - bw.close(); - - return tempFile; - } + } - private boolean sanityCheckValue(Feature f) - { - if (f.getValue() instanceof Number) { - return true; - } - if (f.getName().equals(Constants.ID_FEATURE_NAME)) { - return false; - } + @Override + protected File runPrediction(File tempFile) throws Exception { + File prediction = FileUtil.createTempFile("libsvmPrediction", "libsvm"); + LibsvmPredict predictor = new LibsvmPredict(); + BufferedReader r = new BufferedReader(new InputStreamReader(new FileInputStream(tempFile), "utf-8")); - try { - Double.valueOf((String) f.getValue()); - } - catch (Exception e) { - throw new IllegalArgumentException( - "Feature [" + f.getName() + "] has a non-numeric value [" + f.getValue() + "]", - e); - } - return false; - } + DataOutputStream output = new DataOutputStream(new FileOutputStream(prediction)); + predictor.predict(r, output, model, 0); + output.close(); + return prediction; + } } \ No newline at end of file diff --git a/dkpro-tc-ml-svmhmm/src/main/java/org/dkpro/tc/ml/svmhmm/task/serialization/SvmhmmModelSerializationDescription.java b/dkpro-tc-ml-svmhmm/src/main/java/org/dkpro/tc/ml/svmhmm/task/serialization/SvmhmmModelSerializationDescription.java index a95f75c91..a412e29e8 100644 --- a/dkpro-tc-ml-svmhmm/src/main/java/org/dkpro/tc/ml/svmhmm/task/serialization/SvmhmmModelSerializationDescription.java +++ b/dkpro-tc-ml-svmhmm/src/main/java/org/dkpro/tc/ml/svmhmm/task/serialization/SvmhmmModelSerializationDescription.java @@ -28,6 +28,7 @@ import org.dkpro.lab.task.Discriminator; import org.dkpro.tc.core.Constants; import org.dkpro.tc.core.task.ModelSerializationTask; +import org.dkpro.tc.core.util.SaveModelUtils; import org.dkpro.tc.ml.svmhmm.SVMHMMAdapter; import org.dkpro.tc.ml.svmhmm.task.SVMHMMTestTask; import org.dkpro.tc.ml.svmhmm.util.SVMHMMUtils; @@ -51,7 +52,7 @@ public void execute(TaskContext aContext) throws Exception { trainAndStoreModel(aContext); - writeModelConfiguration(aContext, SVMHMMAdapter.class.getName()); + writeModelConfiguration(aContext); } private void trainAndStoreModel(TaskContext aContext) @@ -105,4 +106,9 @@ private void processParameters(List classificationArguments) paramB = SVMHMMUtils.getParameterBeamWidth(classificationArguments); } + @Override + protected void writeAdapter() throws Exception { + SaveModelUtils.writeModelAdapterInformation(outputFolder, SVMHMMAdapter.class.getName()); + } + } diff --git a/dkpro-tc-ml-weka/src/main/java/org/dkpro/tc/ml/weka/task/serialization/WekaModelSerializationDescription.java b/dkpro-tc-ml-weka/src/main/java/org/dkpro/tc/ml/weka/task/serialization/WekaModelSerializationDescription.java index 88efe8d8e..d4d3ab80f 100644 --- a/dkpro-tc-ml-weka/src/main/java/org/dkpro/tc/ml/weka/task/serialization/WekaModelSerializationDescription.java +++ b/dkpro-tc-ml-weka/src/main/java/org/dkpro/tc/ml/weka/task/serialization/WekaModelSerializationDescription.java @@ -32,6 +32,7 @@ import org.dkpro.lab.task.Discriminator; import org.dkpro.tc.core.Constants; import org.dkpro.tc.core.task.ModelSerializationTask; +import org.dkpro.tc.core.util.SaveModelUtils; import org.dkpro.tc.ml.weka.WekaClassificationAdapter; import org.dkpro.tc.ml.weka.util.WekaUtils; @@ -83,7 +84,7 @@ public void execute(TaskContext aContext) throws Exception { writeWekaSpecificInformation(aContext); - writeModelConfiguration(aContext, WekaClassificationAdapter.class.getName()); + writeModelConfiguration(aContext); writeBipartitionThreshold(outputFolder, threshold); } @@ -165,4 +166,9 @@ private void writeWekaSpecificInformation(TaskContext aContext) } } + + @Override + protected void writeAdapter() throws Exception { + SaveModelUtils.writeModelAdapterInformation(outputFolder, WekaClassificationAdapter.class.getName()); + } } diff --git a/pom.xml b/pom.xml index 9aa79707d..3edf95f5b 100644 --- a/pom.xml +++ b/pom.xml @@ -114,7 +114,7 @@ dkpro-tc-features-pair dkpro-tc-features-pair-similarity dkpro-tc-features-spelling - + dkpro-tc-io-libsvm dkpro-tc-integrationtest dkpro-tc-ml dkpro-tc-ml-crfsuite @@ -185,6 +185,11 @@ dkpro-tc-integrationtest 1.0.0-SNAPSHOT + + org.dkpro.tc + dkpro-tc-io-libsvm + 1.0.0-SNAPSHOT + org.dkpro.tc dkpro-tc-ml