-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MAHOUT-702 add passive-aggressive learner
git-svn-id: https://svn.apache.org/repos/asf/mahout/trunk@1131476 13f79535-47bb-0310-9956-ffa450edef68
- Loading branch information
Showing
4 changed files
with
403 additions
and
131 deletions.
There are no files selected for viewing
202 changes: 202 additions & 0 deletions
202
core/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.mahout.classifier.sgd; | ||
|
||
import org.apache.hadoop.io.Writable; | ||
import org.apache.mahout.classifier.AbstractVectorClassifier; | ||
import org.apache.mahout.classifier.OnlineLearner; | ||
import org.apache.mahout.math.DenseMatrix; | ||
import org.apache.mahout.math.DenseVector; | ||
import org.apache.mahout.math.Matrix; | ||
import org.apache.mahout.math.MatrixWritable; | ||
import org.apache.mahout.math.Vector; | ||
import org.apache.mahout.math.function.Functions; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.io.DataInput; | ||
import java.io.DataOutput; | ||
import java.io.IOException; | ||
|
||
/** | ||
* Online passive aggressive learner that tries to minimize the label ranking hinge loss. | ||
* Implements a multi-class linear classifier minimizing rank loss. | ||
* based on "Online passive aggressive algorithms" by Cramer et al, 2006. | ||
* Note: Its better to use classifyNoLink because the loss function is based | ||
* on ensuring that the score of the good label is larger than the next | ||
* highest label by some margin. The conversion to probability is just done | ||
* by exponentiating and dividing by the sum and is empirical at best. | ||
* Your features should be pre-normalized in some sensible range, for example, | ||
* by subtracting the mean and standard deviation, if they are very | ||
* different in magnitude from each other. | ||
*/ | ||
public class PassiveAggressive extends AbstractVectorClassifier implements OnlineLearner, Writable { | ||
|
||
private static final Logger log = LoggerFactory.getLogger(PassiveAggressive.class); | ||
|
||
public static final int WRITABLE_VERSION = 1; | ||
|
||
// the learning rate of the algorithm | ||
private double learningRate = 0.1; | ||
|
||
// loss statistics. | ||
private int lossCount = 0; | ||
private double lossSum = 0; | ||
|
||
// coefficients for the classification. This is a dense matrix | ||
// that is (numCategories ) x numFeatures | ||
private Matrix weights; | ||
|
||
// number of categories we are classifying. | ||
private int numCategories; | ||
|
||
public PassiveAggressive(int numCategories, int numFeatures) { | ||
this.numCategories = numCategories; | ||
weights = new DenseMatrix(numCategories, numFeatures); | ||
weights.assign(0.0); | ||
} | ||
|
||
/** | ||
* Chainable configuration option. | ||
* | ||
* @param learningRate New value of initial learning rate. | ||
* @return This, so other configurations can be chained. | ||
*/ | ||
public PassiveAggressive learningRate(double learningRate) { | ||
this.learningRate = learningRate; | ||
return this; | ||
} | ||
|
||
public void copyFrom(PassiveAggressive other) { | ||
learningRate = other.learningRate; | ||
numCategories = other.numCategories; | ||
weights = other.weights; | ||
} | ||
|
||
@Override | ||
public int numCategories() { | ||
return numCategories; | ||
} | ||
|
||
@Override | ||
public Vector classify(Vector instance) { | ||
Vector result = (DenseVector) classifyNoLink(instance); | ||
// Convert to probabilities by exponentiation. | ||
double max = result.maxValue(); | ||
result.assign(Functions.minus(max)).assign(Functions.EXP); | ||
result = result.divide(result.norm(1)); | ||
|
||
return result.viewPart(1, result.size() - 1); | ||
} | ||
|
||
@Override | ||
public Vector classifyNoLink(Vector instance) { | ||
Vector result = new DenseVector(weights.numRows()); | ||
result.assign(0); | ||
for (int i = 0; i < weights.numRows(); i++) { | ||
result.setQuick(i, weights.viewRow(i).dot(instance)); | ||
} | ||
return result; | ||
} | ||
|
||
@Override | ||
public double classifyScalar(Vector instance) { | ||
double v1 = weights.viewRow(0).dot(instance); | ||
double v2 = weights.viewRow(1).dot(instance); | ||
v1 = Math.exp(v1); | ||
v2 = Math.exp(v2); | ||
return v2 / (v1 + v2); | ||
} | ||
|
||
public int numFeatures() { | ||
return weights.numCols(); | ||
} | ||
|
||
public PassiveAggressive copy() { | ||
close(); | ||
PassiveAggressive r = new PassiveAggressive(numCategories(), numFeatures()); | ||
r.copyFrom(this); | ||
return r; | ||
} | ||
|
||
@Override | ||
public void write(DataOutput out) throws IOException { | ||
out.writeInt(WRITABLE_VERSION); | ||
out.writeDouble(learningRate); | ||
out.writeInt(numCategories); | ||
MatrixWritable.writeMatrix(out, weights); | ||
} | ||
|
||
@Override | ||
public void readFields(DataInput in) throws IOException { | ||
int version = in.readInt(); | ||
if (version == WRITABLE_VERSION) { | ||
learningRate = in.readDouble(); | ||
numCategories = in.readInt(); | ||
weights = MatrixWritable.readMatrix(in); | ||
} else { | ||
throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version); | ||
} | ||
} | ||
|
||
@Override | ||
public void close() { | ||
// This is an online classifier, nothing to do. | ||
} | ||
|
||
@Override | ||
public void train(long trackingKey, String groupKey, int actual, Vector instance) { | ||
if (lossCount > 1000) { | ||
log.info("Avg. Loss = {}", lossSum / lossCount); | ||
lossCount = 0; | ||
lossSum = 0; | ||
} | ||
Vector result = classifyNoLink(instance); | ||
double my_score = result.get(actual); | ||
// Find the highest score that is not actual. | ||
int other_idx = result.maxValueIndex(); | ||
double other_value = result.get(other_idx); | ||
if (other_idx == actual) { | ||
result.setQuick(other_idx, Double.NEGATIVE_INFINITY); | ||
other_idx = result.maxValueIndex(); | ||
other_value = result.get(other_idx); | ||
} | ||
double loss = 1.0 - my_score + other_value; | ||
lossCount += 1; | ||
if (loss >= 0) { | ||
lossSum += loss; | ||
double tau = loss / (instance.dot(instance) + 0.5 / learningRate); | ||
Vector delta = instance.clone(); | ||
delta.assign(Functions.mult(tau)); | ||
delta.addTo(weights.getRow(actual)); | ||
delta.assign(Functions.mult(-1)); | ||
delta.addTo(weights.getRow(other_idx)); | ||
} | ||
} | ||
|
||
@Override | ||
public void train(long trackingKey, int actual, Vector instance) { | ||
train(trackingKey, null, actual, instance); | ||
} | ||
|
||
@Override | ||
public void train(int actual, Vector instance) { | ||
train(0, null, actual, instance); | ||
} | ||
|
||
} |
160 changes: 160 additions & 0 deletions
160
core/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.mahout.classifier.sgd; | ||
|
||
import com.google.common.base.CharMatcher; | ||
import com.google.common.base.Charsets; | ||
import com.google.common.base.Splitter; | ||
import com.google.common.collect.Lists; | ||
import com.google.common.collect.Maps; | ||
import com.google.common.io.CharStreams; | ||
import com.google.common.io.Resources; | ||
import org.apache.mahout.classifier.AbstractVectorClassifier; | ||
import org.apache.mahout.classifier.OnlineLearner; | ||
import org.apache.mahout.common.MahoutTestCase; | ||
import org.apache.mahout.common.RandomUtils; | ||
import org.apache.mahout.math.DenseMatrix; | ||
import org.apache.mahout.math.DenseVector; | ||
import org.apache.mahout.math.Matrix; | ||
import org.apache.mahout.math.Vector; | ||
import org.apache.mahout.math.function.Functions; | ||
|
||
import java.io.IOException; | ||
import java.io.InputStreamReader; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Random; | ||
|
||
public abstract class OnlineBaseTest extends MahoutTestCase { | ||
|
||
private Matrix input; | ||
|
||
protected Matrix getInput() { | ||
return input; | ||
} | ||
|
||
protected Vector readStandardData() throws IOException { | ||
// 60 test samples. First column is constant. Second and third are normally distributed from | ||
// either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59). The first 30 rows have a | ||
// target variable of 0, the last 30 a target of 1. The remaining columns are are random noise. | ||
input = readCsv("sgd.csv"); | ||
|
||
// regenerate the target variable | ||
Vector target = new DenseVector(60); | ||
target.assign(0); | ||
target.viewPart(30, 30).assign(1); | ||
return target; | ||
} | ||
|
||
protected static void train(Matrix input, Vector target, OnlineLearner lr) { | ||
RandomUtils.useTestSeed(); | ||
Random gen = RandomUtils.getRandom(); | ||
|
||
// train on samples in random order (but only one pass) | ||
for (int row : permute(gen, 60)) { | ||
lr.train((int) target.get(row), input.getRow(row)); | ||
} | ||
lr.close(); | ||
} | ||
|
||
protected static void test(Matrix input, Vector target, AbstractVectorClassifier lr, | ||
double expected_mean_error, double expected_absolute_error) { | ||
// now test the accuracy | ||
Matrix tmp = lr.classify(input); | ||
// mean(abs(tmp - target)) | ||
double meanAbsoluteError = tmp.getColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60; | ||
|
||
// max(abs(tmp - target) | ||
double maxAbsoluteError = tmp.getColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS); | ||
|
||
System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError); | ||
assertEquals(0, meanAbsoluteError , expected_mean_error); | ||
assertEquals(0, maxAbsoluteError, expected_absolute_error); | ||
|
||
// convenience methods should give the same results | ||
Vector v = lr.classifyScalar(input); | ||
assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-5); | ||
v = lr.classifyFull(input).getColumn(1); | ||
assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-4); | ||
} | ||
|
||
/** | ||
* Permute the integers from 0 ... max-1 | ||
* | ||
* @param gen The random number generator to use. | ||
* @param max The number of integers to permute | ||
* @return An array of jumbled integer values | ||
*/ | ||
protected static int[] permute(Random gen, int max) { | ||
int[] permutation = new int[max]; | ||
permutation[0] = 0; | ||
for (int i = 1; i < max; i++) { | ||
int n = gen.nextInt(i + 1); | ||
if (n == i) { | ||
permutation[i] = i; | ||
} else { | ||
permutation[i] = permutation[n]; | ||
permutation[n] = i; | ||
} | ||
} | ||
return permutation; | ||
} | ||
|
||
|
||
/** | ||
* Reads a file containing CSV data. This isn't implemented quite the way you might like for a | ||
* real program, but does the job for reading test data. Most notably, it will only read numbers, | ||
* not quoted strings. | ||
* | ||
* @param resourceName Where to get the data. | ||
* @return A matrix of the results. | ||
* @throws IOException If there is an error reading the data | ||
*/ | ||
protected static Matrix readCsv(String resourceName) throws IOException { | ||
Splitter onCommas = Splitter.on(",").trimResults(CharMatcher.anyOf(" \"")); | ||
|
||
Readable isr = new InputStreamReader(Resources.getResource(resourceName).openStream(), Charsets.UTF_8); | ||
List<String> data = CharStreams.readLines(isr); | ||
String first = data.get(0); | ||
data = data.subList(1, data.size()); | ||
|
||
List<String> values = Lists.newArrayList(onCommas.split(first)); | ||
Matrix r = new DenseMatrix(data.size(), values.size()); | ||
|
||
int column = 0; | ||
Map<String, Integer> labels = Maps.newHashMap(); | ||
for (String value : values) { | ||
labels.put(value, column); | ||
column++; | ||
} | ||
r.setColumnLabelBindings(labels); | ||
|
||
int row = 0; | ||
for (String line : data) { | ||
column = 0; | ||
values = Lists.newArrayList(onCommas.split(line)); | ||
for (String value : values) { | ||
r.set(row, column, Double.parseDouble(value)); | ||
column++; | ||
} | ||
row++; | ||
} | ||
|
||
return r; | ||
} | ||
} |
Oops, something went wrong.