diff --git a/core/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java b/core/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java new file mode 100644 index 000000000..1946398c5 --- /dev/null +++ b/core/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Collection; +import java.util.HashSet; +import java.util.Random; + +/** + * Online gradient machine learner that tries to minimize the label ranking hinge loss. + * Implements a gradient machine with one sigmpod hidden layer. + * It tries to minimize the ranking loss of some given set of labels, + * so this can be used for multi-class, multi-label + * or auto-encoding of sparse data (e.g. text). + */ +public class GradientMachine extends AbstractVectorClassifier implements OnlineLearner, Writable { + + public static final int WRITABLE_VERSION = 1; + + // the learning rate of the algorithm + private double learningRate = 0.1; + + // the regularization term, a positive number that controls the size of the weight vector + private double regularization = 0.1; + + // the sparsity term, a positive number that controls the sparsity of the hidden layer. (0 - 1) + private double sparsity = 0.1; + + // the sparsity learning rate. + private double sparsityLearningRate = 0.1; + + // the number of features + private int numFeatures = 10; + // the number of hidden nodes + private int numHidden = 100; + // the number of output nodes + private int numOutput = 2; + + // coefficients for the input to hidden layer. + // There are numHidden Vectors of dimension numFeatures. + private Vector[] hiddenWeights; + + // coefficients for the hidden to output layer. + // There are numOuput Vectors of dimension numHidden. + private Vector[] outputWeights; + + // hidden unit bias + private Vector hiddenBias; + + // output unit bias + private Vector outputBias; + + private final Random rnd; + + public GradientMachine(int numFeatures, int numHidden, int numOutput) { + this.numFeatures = numFeatures; + this.numHidden = numHidden; + this.numOutput = numOutput; + hiddenWeights = new DenseVector[numHidden]; + for (int i = 0; i < numHidden; i++) { + hiddenWeights[i] = new DenseVector(numFeatures); + hiddenWeights[i].assign(0); + } + hiddenBias = new DenseVector(numHidden); + hiddenBias.assign(0); + outputWeights = new DenseVector[numOutput]; + for (int i = 0; i < numOutput; i++) { + outputWeights[i] = new DenseVector(numHidden); + outputWeights[i].assign(0); + } + outputBias = new DenseVector(numOutput); + outputBias.assign(0); + rnd = RandomUtils.getRandom(); + } + + /** + * Initialize weights. + * + * @param gen random number generator. + */ + public void initWeights(Random gen) { + double hiddenFanIn = 1.0f / Math.sqrt(numFeatures); + for (int i = 0; i < numHidden; i++) { + for (int j = 0; j < numFeatures; j++) { + double val = (2.0 * gen.nextDouble() - 1.0) * hiddenFanIn; + hiddenWeights[i].setQuick(j, val); + } + } + double outputFanIn = 1.0f / Math.sqrt(numHidden); + for (int i = 0; i < numOutput; i++) { + for (int j = 0; j < numHidden; j++) { + double val = (2.0 * gen.nextDouble() - 1.0) * outputFanIn; + outputWeights[i].setQuick(j, val); + } + } + } + + /** + * Chainable configuration option. + * + * @param learningRate New value of initial learning rate. + * @return This, so other configurations can be chained. + */ + public GradientMachine learningRate(double learningRate) { + this.learningRate = learningRate; + return this; + } + + /** + * Chainable configuration option. + * + * @param regularization A positive value that controls the weight vector size. + * @return This, so other configurations can be chained. + */ + public GradientMachine regularization(double regularization) { + this.regularization = regularization; + return this; + } + + /** + * Chainable configuration option. + * + * @param sparsity A value between zero and one that controls the fraction of hidden units + * that are activated on average. + * @return This, so other configurations can be chained. + */ + public GradientMachine sparsity(double sparsity) { + this.sparsity = sparsity; + return this; + } + + /** + * Chainable configuration option. + * + * @param sparsityLearningRate New value of initial learning rate for sparsity. + * @return This, so other configurations can be chained. + */ + public GradientMachine sparsityLearningRate(double sparsityLearningRate) { + this.sparsityLearningRate = sparsityLearningRate; + return this; + } + + public void copyFrom(GradientMachine other) { + numFeatures = other.numFeatures; + numHidden = other.numHidden; + numOutput = other.numOutput; + learningRate = other.learningRate; + regularization = other.regularization; + sparsity = other.sparsity; + sparsityLearningRate = other.sparsityLearningRate; + hiddenWeights = new DenseVector[numHidden]; + for (int i = 0; i < numHidden; i++) { + hiddenWeights[i] = other.hiddenWeights[i].clone(); + } + hiddenBias = other.hiddenBias.clone(); + outputWeights = new DenseVector[numOutput]; + for (int i = 0; i < numOutput; i++) { + outputWeights[i] = other.outputWeights[i].clone(); + } + outputBias = other.outputBias.clone(); + } + + @Override + public int numCategories() { + return numOutput; + } + + public int numFeatures() { + return numFeatures; + } + + public int numHidden() { + return numHidden; + } + + /** + * Feeds forward from input to hidden unit.. + * + * @return Hidden unit activations. + */ + public DenseVector inputToHidden(Vector input) { + DenseVector activations = new DenseVector(numHidden); + for (int i = 0; i < numHidden; i++) { + activations.setQuick(i, hiddenWeights[i].dot(input)); + } + hiddenBias.addTo(activations); + activations.assign(Functions.min(40.0)).assign(Functions.max(-40)); + activations.assign(Functions.SIGMOID); + return activations; + } + + /** + * Feeds forward from hidden to output + * + * @return Output unit activations. + */ + public DenseVector hiddenToOutput(Vector hiddenActivation) { + DenseVector activations = new DenseVector(numOutput); + for (int i = 0; i < numOutput; i++) { + activations.setQuick(i, outputWeights[i].dot(hiddenActivation)); + } + outputBias.addTo(activations); + return activations; + } + + /** + * Updates using ranking loss. + * + * @param hiddenActivation the hidden unit's activation + * @param goodLabels the labels you want ranked above others. + * @param numTrials how many times you want to search for the highest scoring bad label. + * @param gen Random number generator. + */ + public void updateRanking(Vector hiddenActivation, + Collection goodLabels, + int numTrials, + Random gen) { + // All the labels are good, do nothing. + if (goodLabels.size() >= numOutput) { + return; + } + for (Integer good : goodLabels) { + double goodScore = outputWeights[good].dot(hiddenActivation); + int highestBad = -1; + double highestBadScore = Double.NEGATIVE_INFINITY; + for (int i = 0; i < numTrials; i++) { + int bad = gen.nextInt(numOutput); + while (goodLabels.contains(bad)) { + bad = gen.nextInt(numOutput); + } + double badScore = outputWeights[bad].dot(hiddenActivation); + if (badScore > highestBadScore) { + highestBadScore = badScore; + highestBad = bad; + } + } + int bad = highestBad; + double loss = 1.0 - goodScore + highestBadScore; + if (loss < 0.0) { + continue; + } + // Note from the loss above the gradient dloss/dy , y being the label is -1 for good + // and +1 for bad. + // dy / dw is just w since y = x' * w + b. + // Hence by the chain rule, dloss / dw = dloss / dy * dy / dw = -w. + // For the regularization part, 0.5 * lambda * w' w, the gradient is lambda * w. + // dy / db = 1. + Vector gradGood = outputWeights[good].clone(); + gradGood.assign(Functions.NEGATE); + Vector propHidden = gradGood.clone(); + Vector gradBad = outputWeights[bad].clone(); + gradBad.addTo(propHidden); + gradGood.assign(Functions.mult(-learningRate * (1.0 - regularization))); + gradGood.addTo(outputWeights[good]); + gradBad.assign(Functions.mult(-learningRate * (1.0 + regularization))); + gradBad.addTo(outputWeights[bad]); + outputBias.setQuick(good, outputBias.get(good) + learningRate); + outputBias.setQuick(bad, outputBias.get(bad) - learningRate); + // Gradient of sigmoid is s * (1 -s). + Vector gradSig = hiddenActivation.clone(); + gradSig.assign(Functions.SIGMOIDGRADIENT); + // Multiply by the change caused by the ranking loss. + for (int i = 0; i < numHidden; i++) { + gradSig.setQuick(i, gradSig.get(i) * propHidden.get(i)); + } + for (int i = 0; i < numHidden; i++) { + for (int j = 0; j < numFeatures; j++) { + double v = hiddenWeights[i].get(j); + v -= learningRate * (gradSig.get(i) + regularization * v); + hiddenWeights[i].setQuick(j, v); + } + } + } + } + + @Override + public Vector classify(Vector instance) { + Vector result = classifyNoLink(instance); + // Find the max value's index. + int max = result.maxValueIndex(); + result.assign(0); + result.setQuick(max, 1.0); + return result.viewPart(1, result.size() - 1); + } + + @Override + public Vector classifyNoLink(Vector instance) { + DenseVector hidden = inputToHidden(instance); + return hiddenToOutput(hidden); + } + + @Override + public double classifyScalar(Vector instance) { + Vector output = classifyNoLink(instance); + if (output.get(0) > output.get(1)) { + return 0; + } + return 1; + } + + public GradientMachine copy() { + close(); + GradientMachine r = new GradientMachine(numFeatures(), numHidden(), numCategories()); + r.copyFrom(this); + return r; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(WRITABLE_VERSION); + out.writeDouble(learningRate); + out.writeDouble(regularization); + out.writeDouble(sparsity); + out.writeDouble(sparsityLearningRate); + out.writeInt(numFeatures); + out.writeInt(numHidden); + out.writeInt(numOutput); + VectorWritable.writeVector(out, hiddenBias); + for (int i = 0; i < numHidden; i++) { + VectorWritable.writeVector(out, hiddenWeights[i]); + } + VectorWritable.writeVector(out, outputBias); + for (int i = 0; i < numOutput; i++) { + VectorWritable.writeVector(out, outputWeights[i]); + } + } + + @Override + public void readFields(DataInput in) throws IOException { + int version = in.readInt(); + if (version == WRITABLE_VERSION) { + learningRate = in.readDouble(); + regularization = in.readDouble(); + sparsity = in.readDouble(); + sparsityLearningRate = in.readDouble(); + numFeatures = in.readInt(); + numHidden = in.readInt(); + numOutput = in.readInt(); + hiddenWeights = new DenseVector[numHidden]; + hiddenBias = VectorWritable.readVector(in); + for (int i = 0; i < numHidden; i++) { + hiddenWeights[i] = VectorWritable.readVector(in); + } + outputWeights = new DenseVector[numOutput]; + outputBias = VectorWritable.readVector(in); + for (int i = 0; i < numOutput; i++) { + outputWeights[i] = VectorWritable.readVector(in); + } + } else { + throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version); + } + } + + @Override + public void close() { + // This is an online classifier, nothing to do. + } + + @Override + public void train(long trackingKey, String groupKey, int actual, Vector instance) { + Vector hiddenActivation = inputToHidden(instance); + Vector outputActivation = hiddenToOutput(hiddenActivation); + Collection goodLabels = new HashSet(); + goodLabels.add(actual); + updateRanking(hiddenActivation, goodLabels, 2, rnd); + } + + @Override + public void train(long trackingKey, int actual, Vector instance) { + train(trackingKey, null, actual, instance); + } + + @Override + public void train(int actual, Vector instance) { + train(0, null, actual, instance); + } + +} diff --git a/core/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java b/core/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java new file mode 100644 index 000000000..2d632540e --- /dev/null +++ b/core/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +import java.io.IOException; +import java.util.Random; + +public final class GradientMachineTest extends OnlineBaseTest { + + @Test + public void testGradientmachine() throws IOException { + Vector target = readStandardData(); + GradientMachine grad = new GradientMachine(8, 4, 2).learningRate(0.1).regularization(0.01); + RandomUtils.useTestSeed(); + Random gen = RandomUtils.getRandom(); + grad.initWeights(gen); + train(getInput(), target, grad); + test(getInput(), target, grad, 0.05, 1); + } + +} diff --git a/math/src/main/java/org/apache/mahout/math/function/Functions.java b/math/src/main/java/org/apache/mahout/math/function/Functions.java index fded29d2c..62b0b7d55 100644 --- a/math/src/main/java/org/apache/mahout/math/function/Functions.java +++ b/math/src/main/java/org/apache/mahout/math/function/Functions.java @@ -273,6 +273,22 @@ public double apply(double a) { } }; + /** Function that returns 1 / (1 + exp(-a) */ + public static final DoubleFunction SIGMOID = new DoubleFunction() { + @Override + public double apply(double a) { + return 1.0 / (1.0 + Math.exp(-a)); + } + }; + + /** Function that returns a * (1-a) */ + public static final DoubleFunction SIGMOIDGRADIENT = new DoubleFunction() { + @Override + public double apply(double a) { + return a * (1.0 - a); + } + }; + /** Function that returns Math.tan(a). */ public static final DoubleFunction TAN = new DoubleFunction() {