Skip to content

Commit

Permalink
#434 - test tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
Horsmann committed Feb 3, 2018
1 parent 90a55bc commit 154b0aa
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 110 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*******************************************************************************
* Copyright 2018
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
******************************************************************************/
package org.dkpro.tc.io.libsvm;

import java.io.File;
import java.util.List;

import org.dkpro.lab.engine.TaskContext;
import org.dkpro.lab.storage.StorageService.AccessMode;
import org.dkpro.lab.task.Discriminator;
import org.dkpro.lab.task.impl.ExecutableTaskBase;
import org.dkpro.tc.api.exception.TextClassificationException;
import org.dkpro.tc.core.Constants;

public abstract class LibsvmDataFormatTestTask extends ExecutableTaskBase implements Constants {

@Discriminator(name = DIM_CLASSIFICATION_ARGS)
protected List<String> classificationArguments;

@Discriminator(name = DIM_LEARNING_MODE)
protected String learningMode;

@Discriminator(name = DIM_FEATURE_MODE)
protected String featureMode;

protected abstract void runPrediction(TaskContext aContext) throws Exception;

protected File getPredictionFile(TaskContext aContext) {
File folder = aContext.getFolder("", AccessMode.READWRITE);
return new File(folder, Constants.FILENAME_PREDICTIONS);
}

protected File getTestFile(TaskContext aContext) {
File testFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TEST_DATA, AccessMode.READONLY);
File fileTest = new File(testFolder, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT);
return fileTest;
}

protected File getTrainFile(TaskContext aContext) {
File trainFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY);
File fileTrain = new File(trainFolder, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT);

return fileTrain;
}

protected void throwExceptionIfMultiLabelMode() throws TextClassificationException {
boolean multiLabel = learningMode.equals(Constants.LM_MULTI_LABEL);
if (multiLabel) {
throw new TextClassificationException("Multi-label is not supported");
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
import org.dkpro.lab.engine.TaskContext;
import org.dkpro.lab.storage.StorageService.AccessMode;
import org.dkpro.lab.task.Discriminator;
import org.dkpro.lab.task.impl.ExecutableTaskBase;
import org.dkpro.tc.api.exception.TextClassificationException;
import org.dkpro.tc.core.Constants;
import org.dkpro.tc.io.libsvm.LibsvmDataFormatTestTask;
import org.dkpro.tc.ml.liblinear.util.LiblinearUtils;

import de.bwaldvogel.liblinear.Feature;
Expand All @@ -38,14 +37,14 @@
import de.bwaldvogel.liblinear.Problem;
import de.bwaldvogel.liblinear.SolverType;

public class LiblinearTestTask extends ExecutableTaskBase implements Constants {
public class LiblinearTestTask extends LibsvmDataFormatTestTask implements Constants {

@Discriminator(name = DIM_CLASSIFICATION_ARGS)
private List<String> classificationArguments;

@Discriminator(name = DIM_FEATURE_MODE)
private String featureMode;

@Discriminator(name = DIM_LEARNING_MODE)
private String learningMode;

Expand All @@ -55,54 +54,35 @@ public class LiblinearTestTask extends ExecutableTaskBase implements Constants {

@Override
public void execute(TaskContext aContext) throws Exception {
boolean multiLabel = learningMode.equals(Constants.LM_MULTI_LABEL);
if (multiLabel) {
throw new TextClassificationException(
"Multi-label requested, but LIBLINEAR only supports single label setups.");
}
throwExceptionIfMultiLabelMode();

runPrediction(aContext);
}

@Override
protected void runPrediction(TaskContext aContext) throws Exception {

File fileTrain = getTrainFile(aContext);
File fileTest = getTestFile(aContext);

File predFolder = aContext.getFolder("", AccessMode.READWRITE);
File predictionsFile = new File(predFolder, Constants.FILENAME_PREDICTIONS);

// default for bias is -1, documentation says to set it to 1 in order to
// get results closer
// to libsvm
// writer adds bias, so if we de-activate that here for some reason, we
// need to also
// deactivate it there
Problem train = Problem.readFromFile(fileTrain, 1.0);

SolverType solver = LiblinearUtils.getSolver(classificationArguments);
double C = LiblinearUtils.getParameterC(classificationArguments);
double eps = LiblinearUtils.getParameterEpsilon(classificationArguments);

Linear.setDebugOutput(null);

Parameter parameter = new Parameter(solver, C, eps);
Model model = Linear.train(train, parameter);

Problem test = Problem.readFromFile(fileTest, 1.0);

predict(aContext, model, test);
}

private File getTestFile(TaskContext aContext) {
File testFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TEST_DATA, AccessMode.READONLY);
File fileTest = new File(testFolder, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT);
return fileTest;
}

private File getTrainFile(TaskContext aContext) {
File trainFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY);
File fileTrain = new File(trainFolder, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT);

return fileTrain;
}

private void predict(TaskContext aContext, Model model, Problem test) throws Exception {
File predFolder = aContext.getFolder("", AccessMode.READWRITE);
File predictionsFile = new File(predFolder, Constants.FILENAME_PREDICTIONS);

BufferedWriter writer = new BufferedWriter(
new OutputStreamWriter(new FileOutputStream(predictionsFile), "utf-8"));
writer.append("#PREDICTION;GOLD" + "\n");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,87 +37,51 @@
import org.apache.uima.pear.util.FileUtil;
import org.dkpro.lab.engine.TaskContext;
import org.dkpro.lab.storage.StorageService.AccessMode;
import org.dkpro.lab.task.Discriminator;
import org.dkpro.lab.task.impl.ExecutableTaskBase;
import org.dkpro.tc.api.exception.TextClassificationException;
import org.dkpro.tc.core.Constants;
import org.dkpro.tc.io.libsvm.LibsvmDataFormatTestTask;
import org.dkpro.tc.ml.libsvm.api.LibsvmPredict;
import org.dkpro.tc.ml.libsvm.api.LibsvmTrainModel;

import libsvm.svm;
import libsvm.svm_model;

public class LibsvmTestTask
extends ExecutableTaskBase
extends LibsvmDataFormatTestTask
implements Constants
{
@Discriminator(name = DIM_CLASSIFICATION_ARGS)
private List<String> classificationArguments;
@Discriminator(name = DIM_LEARNING_MODE)
private String learningMode;

@Override
public void execute(TaskContext aContext)
throws Exception
{
exceptMultiLabelMode();
throwExceptionIfMultiLabelMode();

File fileTrain = getTrainFile(aContext);
runPrediction(aContext);
}

@Override
protected void runPrediction(TaskContext aContext) throws Exception {
File fileTrain = getTrainFile(aContext);
File fileTest = getTestFile(aContext);

BufferedReader r = new BufferedReader(
new InputStreamReader(new FileInputStream(fileTest), "utf-8"));

File model = new File(aContext.getFolder("", AccessMode.READWRITE),
Constants.MODEL_CLASSIFIER);

LibsvmTrainModel ltm = new LibsvmTrainModel();
ltm.run(buildParameters(fileTrain, model));
prediction(model, aContext);
}

private void exceptMultiLabelMode()
throws TextClassificationException
{
boolean multiLabel = learningMode.equals(Constants.LM_MULTI_LABEL);
if (multiLabel) {
throw new TextClassificationException("Multi-label is not supported");
}
}

private String[] buildParameters(File fileTrain, File model)
{
List<String> parameters = new ArrayList<>();
if (classificationArguments != null) {
for (String a : classificationArguments) {
parameters.add(a);
}
}
parameters.add(fileTrain.getAbsolutePath());
parameters.add(model.getAbsolutePath());
return parameters.toArray(new String[0]);
}

private void prediction(File model, TaskContext aContext)
throws Exception
{
File fileTest = getTestFile(aContext);

LibsvmPredict predictor = new LibsvmPredict();

BufferedReader r = new BufferedReader(
new InputStreamReader(new FileInputStream(fileTest), "utf-8"));
File prediction = getPredictionFile(aContext);
File predTmp = createTemporaryPredictionFile();

DataOutputStream output = new DataOutputStream(new FileOutputStream(predTmp));
svm_model trainedModel = svm.svm_load_model(model.getAbsolutePath());
predictor.predict(r, output, trainedModel, 0);
svm_model svmModel = svm.svm_load_model(model.getAbsolutePath());
predictor.predict(r, output, svmModel, 0);
output.close();

mergePredictedValuesWithExpected(fileTest, predTmp, prediction);
}

// We only get the predicted values but we loose the information which value was expected - we
// thus use the test file and restore the expected values from there
private void mergePredictedValuesWithExpected(File fileTest, File predTmp, File prediction)
throws IOException
{

File prediction = getPredictionFile(aContext);
BufferedWriter bw = new BufferedWriter(
new OutputStreamWriter(new FileOutputStream(prediction), "utf-8"));

Expand All @@ -132,7 +96,22 @@ private void mergePredictedValuesWithExpected(File fileTest, File predTmp, File
}
bw.close();
}


private String[] buildParameters(File fileTrain, File model)
{
List<String> parameters = new ArrayList<>();
if (classificationArguments != null) {
for (String a : classificationArguments) {
parameters.add(a);
}
}
parameters.add(fileTrain.getAbsolutePath());
parameters.add(model.getAbsolutePath());
return parameters.toArray(new String[0]);
}


private List<String> pickGold(List<String> readLines)
{
List<String> gold = new ArrayList<>();
Expand All @@ -159,26 +138,5 @@ private File createTemporaryPredictionFile()
return createTempFile;
}

private File getPredictionFile(TaskContext aContext)
{
File folder = aContext.getFolder("", AccessMode.READWRITE);
return new File(folder, Constants.FILENAME_PREDICTIONS);
}

private File getTestFile(TaskContext aContext)
{
File testFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TEST_DATA, AccessMode.READONLY);
File fileTest = new File(testFolder, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT);
return fileTest;
}

private File getTrainFile(TaskContext aContext)
{
File trainFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA,
AccessMode.READONLY);
File fileTrain = new File(trainFolder, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT);

return fileTrain;
}

}

0 comments on commit 154b0aa

Please sign in to comment.