Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Horsmann committed Feb 2, 2018
1 parent 6af7036 commit df93dbe
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/
package org.dkpro.tc.examples.regression;

import static org.junit.Assert.assertTrue;
import static org.junit.Assert.*;

import org.dkpro.lab.task.ParameterSpace;
import org.dkpro.tc.examples.TestCaseSuperClass;
Expand Down Expand Up @@ -58,6 +58,6 @@ public void testTrainTest() throws Exception{

EvaluationData<Double> data = Tc2LtlabEvalConverter.convertRegressionModeId2Outcome(ContextMemoryReport.id2outcome);
MeanSquaredError mse = new MeanSquaredError(data);
assertTrue(mse.getResult() > 1.0);
assertEquals(3.37, mse.getResult(), 0.01);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class LiblinearModelSerializationDescription extends ModelSerializationTa

@Discriminator(name = DIM_CLASSIFICATION_ARGS)
private List<String> classificationArguments;

boolean trainModel = true;

@Override
Expand Down Expand Up @@ -71,12 +71,20 @@ private void trainAndStoreModel(TaskContext aContext) throws Exception {
}

private void copyOutcomeMappingToThisFolder(TaskContext aContext) throws IOException {
if(isRegression()){
return;
}

File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY);
String mapping = LiblinearAdapter.getOutcomeMappingFilename();

FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping));
}

private boolean isRegression() {
return learningMode.equals(Constants.LM_REGRESSION);
}

private void copyFeatureNameMappingToThisFolder(TaskContext aContext) throws IOException {
File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY);
String mapping = LiblinearAdapter.getFeatureNameMappingFilename();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public class LoadModelConnectorLiblinear extends ModelSerialization_ImplBase {
private String learningMode;

private Model liblinearModel;

private Map<Integer, String> outcomeMapping;

private Map<String, Integer> featureMapping;
Expand Down Expand Up @@ -98,6 +99,11 @@ private Map<String, Integer> loadFeature2IntegerMapping(File tcModelLocation) th
}

private Map<Integer, String> loadOutcome2IntegerMapping(File tcModelLocation) throws IOException {

if (isRegression()){
return new HashMap<>();
}

Map<Integer, String> map = new HashMap<>();
List<String> readLines = FileUtils
.readLines(new File(tcModelLocation, LiblinearAdapter.getOutcomeMappingFilename()), "utf-8");
Expand All @@ -108,7 +114,12 @@ private Map<Integer, String> loadOutcome2IntegerMapping(File tcModelLocation) th
return map;
}

private Double toValue(Object value)
private boolean isRegression() {
return learningMode.equals(Constants.LM_REGRESSION);
}


private Double toValue(Object value)
{
double v;
if (value instanceof Number) {
Expand Down Expand Up @@ -173,7 +184,7 @@ public void process(JCas jcas) throws AnalysisEngineProcessException {
Feature[] instance = testInstances[i];
Double prediction = Linear.predict(liblinearModel, instance);

if (learningMode.equals(Constants.LM_REGRESSION)) {
if (isRegression()) {
outcomes.get(i).setOutcome(prediction.toString());
} else {
String predictedLabel = outcomeMapping.get(prediction.intValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ public void writeClassifierFormat(Collection<Instance> in) throws Exception {
List<Integer> keys = new ArrayList<Integer>(entry.keySet());
Collections.sort(keys);

bw.append(outcomeMap.get(inst.getOutcome()) + "\t");

if (isRegression()) {
bw.append(inst.getOutcome() + "\t");
} else {
bw.append(outcomeMap.get(inst.getOutcome()) + "\t");
}
for (int i = 0; i < keys.size(); i++) {
Integer key = keys.get(i);
Double value = entry.get(key);
Expand All @@ -161,6 +166,11 @@ public void writeClassifierFormat(Collection<Instance> in) throws Exception {
}

private void writeOutcomeMapping(File outputDirectory, String file, Map<String, Integer> map) throws IOException {

if(isRegression()){
return;
}

StringBuilder sb = new StringBuilder();
for (String k : map.keySet()) {
sb.append(k + "\t" + map.get(k) + "\n");
Expand Down Expand Up @@ -248,6 +258,9 @@ public void init(File outputDirectory, boolean useSparse, String learningMode, b
* @param outcomes
*/
private void buildOutcomeMap(String[] outcomes) {
if(isRegression()){
return;
}
outcomeMap = new HashMap<>();
Integer i = 0;
List<String> outcomesSorted = new ArrayList<>(Arrays.asList(outcomes));
Expand Down Expand Up @@ -288,6 +301,10 @@ private void recordInstanceId(Instance instance, int i, Map<String, String> inde
return;
}
}

private boolean isRegression(){
return learningMode.equals(Constants.LM_REGRESSION);
}

@Override
public void close() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,83 +33,76 @@
import org.dkpro.tc.ml.libsvm.LibsvmAdapter;
import org.dkpro.tc.ml.libsvm.api.LibsvmTrainModel;

public class LibsvmModelSerializationDescription
extends ModelSerializationTask
implements Constants
{

@Discriminator(name = DIM_CLASSIFICATION_ARGS)
private List<String> classificationArguments;

boolean trainModel = true;

@Override
public void execute(TaskContext aContext)
throws Exception
{
trainAndStoreModel(aContext);

writeModelConfiguration(aContext, LibsvmAdapter.class.getName());
}

private void trainAndStoreModel(TaskContext aContext)
throws Exception
{
boolean multiLabel = learningMode.equals(Constants.LM_MULTI_LABEL);
if (multiLabel) {
throw new TextClassificationException("Multi-label is not yet implemented");
}

File fileTrain = getTrainFile(aContext);

File model = new File(outputFolder, Constants.MODEL_CLASSIFIER);

LibsvmTrainModel ltm = new LibsvmTrainModel();
ltm.run(buildParameters(fileTrain, model));
copyOutcomeMappingToThisFolder(aContext);
copyFeatureNameMappingToThisFolder(aContext);
}

private void copyOutcomeMappingToThisFolder(TaskContext aContext)
throws IOException
{
File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA,
AccessMode.READONLY);
String mapping = LibsvmAdapter.getOutcomeMappingFilename();

FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping));
}

private void copyFeatureNameMappingToThisFolder(TaskContext aContext)
throws IOException
{
File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA,
AccessMode.READONLY);
String mapping = LibsvmAdapter.getFeatureNameMappingFilename();

FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping));
}

private File getTrainFile(TaskContext aContext)
{
File trainFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA,
AccessMode.READONLY);
File fileTrain = new File(trainFolder, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT);

return fileTrain;
}

private String[] buildParameters(File fileTrain, File model)
{
List<String> parameters = new ArrayList<>();
if (classificationArguments != null) {
for (String a : classificationArguments) {
parameters.add(a);
}
}
parameters.add(fileTrain.getAbsolutePath());
parameters.add(model.getAbsolutePath());
return parameters.toArray(new String[0]);
}
public class LibsvmModelSerializationDescription extends ModelSerializationTask implements Constants {

@Discriminator(name = DIM_CLASSIFICATION_ARGS)
private List<String> classificationArguments;

boolean trainModel = true;

@Override
public void execute(TaskContext aContext) throws Exception {
trainAndStoreModel(aContext);

writeModelConfiguration(aContext, LibsvmAdapter.class.getName());
}

private void trainAndStoreModel(TaskContext aContext) throws Exception {
boolean multiLabel = learningMode.equals(Constants.LM_MULTI_LABEL);
if (multiLabel) {
throw new TextClassificationException("Multi-label is not yet implemented");
}

File fileTrain = getTrainFile(aContext);

File model = new File(outputFolder, Constants.MODEL_CLASSIFIER);

LibsvmTrainModel ltm = new LibsvmTrainModel();
ltm.run(buildParameters(fileTrain, model));
copyOutcomeMappingToThisFolder(aContext);
copyFeatureNameMappingToThisFolder(aContext);
}

private void copyOutcomeMappingToThisFolder(TaskContext aContext) throws IOException {

if(isRegression()){
return;
}

File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY);
String mapping = LibsvmAdapter.getOutcomeMappingFilename();

FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping));
}

private boolean isRegression() {
return learningMode.equals(Constants.LM_REGRESSION);
}

private void copyFeatureNameMappingToThisFolder(TaskContext aContext) throws IOException {
File trainDataFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY);
String mapping = LibsvmAdapter.getFeatureNameMappingFilename();

FileUtils.copyFile(new File(trainDataFolder, mapping), new File(outputFolder, mapping));
}

private File getTrainFile(TaskContext aContext) {
File trainFolder = aContext.getFolder(TEST_TASK_INPUT_KEY_TRAINING_DATA, AccessMode.READONLY);
File fileTrain = new File(trainFolder, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT);

return fileTrain;
}

private String[] buildParameters(File fileTrain, File model) {
List<String> parameters = new ArrayList<>();
if (classificationArguments != null) {
for (String a : classificationArguments) {
parameters.add(a);
}
}
parameters.add(fileTrain.getAbsolutePath());
parameters.add(model.getAbsolutePath());
return parameters.toArray(new String[0]);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ private Map<String, Integer> loadFeature2IntegerMapping(File tcModelLocation) th
private Map<String, String> loadInteger2OutcomeMapping(File tcModelLocation)
throws IOException
{
if(isRegression()){
return new HashMap<>();
}

Map<String, String> map = new HashMap<>();
List<String> readLines = FileUtils
.readLines(new File(tcModelLocation, LibsvmAdapter.getOutcomeMappingFilename()), "utf-8");
Expand All @@ -123,6 +127,10 @@ private Map<String, String> loadInteger2OutcomeMapping(File tcModelLocation)
}
return map;
}

private boolean isRegression(){
return learningMode.equals(Constants.LM_REGRESSION);
}

@Override
public void process(JCas jcas)
Expand All @@ -141,7 +149,7 @@ public void process(JCas jcas)

for (int i = 0; i < outcomes.size(); i++) {

if (learningMode.equals(Constants.LM_REGRESSION)) {
if (isRegression()) {
String val = writtenPredictions.get(i);
outcomes.get(i).setOutcome(val);
}
Expand Down
Loading

0 comments on commit df93dbe

Please sign in to comment.