Skip to content

Commit

Permalink
MAHOUT-702 add passive-aggressive learner
Browse files Browse the repository at this point in the history
git-svn-id: https://svn.apache.org/repos/asf/mahout/trunk@1131476 13f79535-47bb-0310-9956-ffa450edef68
  • Loading branch information
srowen committed Jun 4, 2011
1 parent 0338e6a commit d79e97a
Show file tree
Hide file tree
Showing 4 changed files with 403 additions and 131 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.classifier.sgd;

import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;

/**
* Online passive aggressive learner that tries to minimize the label ranking hinge loss.
* Implements a multi-class linear classifier minimizing rank loss.
* based on "Online passive aggressive algorithms" by Cramer et al, 2006.
* Note: Its better to use classifyNoLink because the loss function is based
* on ensuring that the score of the good label is larger than the next
* highest label by some margin. The conversion to probability is just done
* by exponentiating and dividing by the sum and is empirical at best.
* Your features should be pre-normalized in some sensible range, for example,
* by subtracting the mean and standard deviation, if they are very
* different in magnitude from each other.
*/
public class PassiveAggressive extends AbstractVectorClassifier implements OnlineLearner, Writable {

private static final Logger log = LoggerFactory.getLogger(PassiveAggressive.class);

public static final int WRITABLE_VERSION = 1;

// the learning rate of the algorithm
private double learningRate = 0.1;

// loss statistics.
private int lossCount = 0;
private double lossSum = 0;

// coefficients for the classification. This is a dense matrix
// that is (numCategories ) x numFeatures
private Matrix weights;

// number of categories we are classifying.
private int numCategories;

public PassiveAggressive(int numCategories, int numFeatures) {
this.numCategories = numCategories;
weights = new DenseMatrix(numCategories, numFeatures);
weights.assign(0.0);
}

/**
* Chainable configuration option.
*
* @param learningRate New value of initial learning rate.
* @return This, so other configurations can be chained.
*/
public PassiveAggressive learningRate(double learningRate) {
this.learningRate = learningRate;
return this;
}

public void copyFrom(PassiveAggressive other) {
learningRate = other.learningRate;
numCategories = other.numCategories;
weights = other.weights;
}

@Override
public int numCategories() {
return numCategories;
}

@Override
public Vector classify(Vector instance) {
Vector result = (DenseVector) classifyNoLink(instance);
// Convert to probabilities by exponentiation.
double max = result.maxValue();
result.assign(Functions.minus(max)).assign(Functions.EXP);
result = result.divide(result.norm(1));

return result.viewPart(1, result.size() - 1);
}

@Override
public Vector classifyNoLink(Vector instance) {
Vector result = new DenseVector(weights.numRows());
result.assign(0);
for (int i = 0; i < weights.numRows(); i++) {
result.setQuick(i, weights.viewRow(i).dot(instance));
}
return result;
}

@Override
public double classifyScalar(Vector instance) {
double v1 = weights.viewRow(0).dot(instance);
double v2 = weights.viewRow(1).dot(instance);
v1 = Math.exp(v1);
v2 = Math.exp(v2);
return v2 / (v1 + v2);
}

public int numFeatures() {
return weights.numCols();
}

public PassiveAggressive copy() {
close();
PassiveAggressive r = new PassiveAggressive(numCategories(), numFeatures());
r.copyFrom(this);
return r;
}

@Override
public void write(DataOutput out) throws IOException {
out.writeInt(WRITABLE_VERSION);
out.writeDouble(learningRate);
out.writeInt(numCategories);
MatrixWritable.writeMatrix(out, weights);
}

@Override
public void readFields(DataInput in) throws IOException {
int version = in.readInt();
if (version == WRITABLE_VERSION) {
learningRate = in.readDouble();
numCategories = in.readInt();
weights = MatrixWritable.readMatrix(in);
} else {
throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version);
}
}

@Override
public void close() {
// This is an online classifier, nothing to do.
}

@Override
public void train(long trackingKey, String groupKey, int actual, Vector instance) {
if (lossCount > 1000) {
log.info("Avg. Loss = {}", lossSum / lossCount);
lossCount = 0;
lossSum = 0;
}
Vector result = classifyNoLink(instance);
double my_score = result.get(actual);
// Find the highest score that is not actual.
int other_idx = result.maxValueIndex();
double other_value = result.get(other_idx);
if (other_idx == actual) {
result.setQuick(other_idx, Double.NEGATIVE_INFINITY);
other_idx = result.maxValueIndex();
other_value = result.get(other_idx);
}
double loss = 1.0 - my_score + other_value;
lossCount += 1;
if (loss >= 0) {
lossSum += loss;
double tau = loss / (instance.dot(instance) + 0.5 / learningRate);
Vector delta = instance.clone();
delta.assign(Functions.mult(tau));
delta.addTo(weights.getRow(actual));
delta.assign(Functions.mult(-1));
delta.addTo(weights.getRow(other_idx));
}
}

@Override
public void train(long trackingKey, int actual, Vector instance) {
train(trackingKey, null, actual, instance);
}

@Override
public void train(int actual, Vector instance) {
train(0, null, actual, instance);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.mahout.classifier.sgd;

import com.google.common.base.CharMatcher;
import com.google.common.base.Charsets;
import com.google.common.base.Splitter;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.CharStreams;
import com.google.common.io.Resources;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;

import java.io.IOException;
import java.io.InputStreamReader;
import java.util.List;
import java.util.Map;
import java.util.Random;

public abstract class OnlineBaseTest extends MahoutTestCase {

private Matrix input;

protected Matrix getInput() {
return input;
}

protected Vector readStandardData() throws IOException {
// 60 test samples. First column is constant. Second and third are normally distributed from
// either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59). The first 30 rows have a
// target variable of 0, the last 30 a target of 1. The remaining columns are are random noise.
input = readCsv("sgd.csv");

// regenerate the target variable
Vector target = new DenseVector(60);
target.assign(0);
target.viewPart(30, 30).assign(1);
return target;
}

protected static void train(Matrix input, Vector target, OnlineLearner lr) {
RandomUtils.useTestSeed();
Random gen = RandomUtils.getRandom();

// train on samples in random order (but only one pass)
for (int row : permute(gen, 60)) {
lr.train((int) target.get(row), input.getRow(row));
}
lr.close();
}

protected static void test(Matrix input, Vector target, AbstractVectorClassifier lr,
double expected_mean_error, double expected_absolute_error) {
// now test the accuracy
Matrix tmp = lr.classify(input);
// mean(abs(tmp - target))
double meanAbsoluteError = tmp.getColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60;

// max(abs(tmp - target)
double maxAbsoluteError = tmp.getColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);

System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError);
assertEquals(0, meanAbsoluteError , expected_mean_error);
assertEquals(0, maxAbsoluteError, expected_absolute_error);

// convenience methods should give the same results
Vector v = lr.classifyScalar(input);
assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-5);
v = lr.classifyFull(input).getColumn(1);
assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-4);
}

/**
* Permute the integers from 0 ... max-1
*
* @param gen The random number generator to use.
* @param max The number of integers to permute
* @return An array of jumbled integer values
*/
protected static int[] permute(Random gen, int max) {
int[] permutation = new int[max];
permutation[0] = 0;
for (int i = 1; i < max; i++) {
int n = gen.nextInt(i + 1);
if (n == i) {
permutation[i] = i;
} else {
permutation[i] = permutation[n];
permutation[n] = i;
}
}
return permutation;
}


/**
* Reads a file containing CSV data. This isn't implemented quite the way you might like for a
* real program, but does the job for reading test data. Most notably, it will only read numbers,
* not quoted strings.
*
* @param resourceName Where to get the data.
* @return A matrix of the results.
* @throws IOException If there is an error reading the data
*/
protected static Matrix readCsv(String resourceName) throws IOException {
Splitter onCommas = Splitter.on(",").trimResults(CharMatcher.anyOf(" \""));

Readable isr = new InputStreamReader(Resources.getResource(resourceName).openStream(), Charsets.UTF_8);
List<String> data = CharStreams.readLines(isr);
String first = data.get(0);
data = data.subList(1, data.size());

List<String> values = Lists.newArrayList(onCommas.split(first));
Matrix r = new DenseMatrix(data.size(), values.size());

int column = 0;
Map<String, Integer> labels = Maps.newHashMap();
for (String value : values) {
labels.put(value, column);
column++;
}
r.setColumnLabelBindings(labels);

int row = 0;
for (String line : data) {
column = 0;
values = Lists.newArrayList(onCommas.split(line));
for (String value : values) {
r.set(row, column, Double.parseDouble(value));
column++;
}
row++;
}

return r;
}
}
Loading

0 comments on commit d79e97a

Please sign in to comment.