From 16836adb421f1953d5e53294a858d6987a1e7a67 Mon Sep 17 00:00:00 2001 From: Jeff Heaton Date: Sun, 3 Sep 2017 13:33:57 -0500 Subject: [PATCH] Issue #230 make TrainEA.getMethod() return the best member of population. --- src/main/java/org/encog/Test.java | 97 +++++++++++-------- .../org/encog/ml/ea/train/basic/TrainEA.java | 7 +- .../training/propagation/GradientWorker.java | 4 +- 3 files changed, 66 insertions(+), 42 deletions(-) diff --git a/src/main/java/org/encog/Test.java b/src/main/java/org/encog/Test.java index e08d800a2..c7902e481 100644 --- a/src/main/java/org/encog/Test.java +++ b/src/main/java/org/encog/Test.java @@ -1,58 +1,77 @@ -/* - * Encog(tm) Core v3.4 - Java Version - * http://www.heatonresearch.com/encog/ - * https://github.com/encog/encog-java-core - - * Copyright 2008-2017 Heaton Research, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * For more information on Heaton Research copyrights, licenses - * and trademarks visit: - * http://www.heatonresearch.com/copyright - */ package org.encog; - -import org.encog.Encog; +import org.encog.engine.network.activation.ActivationLinear; +import org.encog.engine.network.activation.ActivationReLU; import org.encog.engine.network.activation.ActivationSigmoid; import org.encog.ml.data.MLData; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; import org.encog.ml.data.basic.BasicMLDataSet; -import org.encog.ml.importance.PerturbationFeatureImportanceCalc; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.layers.BasicLayer; import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation; -import org.encog.neural.networks.training.propagation.sgd.StochasticGradientDescent; -import org.encog.neural.networks.training.propagation.sgd.update.AdaGradUpdate; -import org.encog.neural.networks.training.propagation.sgd.update.NesterovUpdate; -import org.encog.neural.networks.training.propagation.sgd.update.RMSPropUpdate; -import org.encog.neural.pattern.ElmanPattern; - +/** + * XOR: This example is essentially the "Hello World" of neural network + * programming. This example shows how to construct an Encog neural + * network to predict the output from the XOR operator. This example + * uses backpropagation to train the neural network. + * + * This example attempts to use a minimum of Encog features to create and + * train the neural network. This allows you to see exactly what is going + * on. For a more advanced example, that uses Encog factories, refer to + * the XORFactory example. + * + */ public class Test { + /** + * The input necessary for XOR. + */ + public static double XOR_INPUT[][] = { { 0.0, 0.0 }, { 1.0, 0.0 }, + { 0.0, 1.0 }, { 1.0, 1.0 } }; + + /** + * The ideal data necessary for XOR. + */ + public static double XOR_IDEAL[][] = { { 0.0 }, { 1.0 }, { 1.0 }, { 0.0 } }; + /** * The main method. * @param args No arguments are used. */ public static void main(final String args[]) { - ElmanPattern elmanPat = new ElmanPattern(); - elmanPat.setInputNeurons(5); - elmanPat.addHiddenLayer(5); - elmanPat.setOutputNeurons(1); - BasicNetwork network = (BasicNetwork) elmanPat.generate(); - System.out.println(network.toString()); + // create a neural network, without using a factory + BasicNetwork network = new BasicNetwork(); + network.addLayer(new BasicLayer(new ActivationReLU(),true,2)); + network.addLayer(new BasicLayer(new ActivationSigmoid(),true,3)); + network.addLayer(new BasicLayer(new ActivationLinear(),false,1)); + network.getStructure().finalizeStructure(); + network.reset(); + + // create training data + MLDataSet trainingSet = new BasicMLDataSet(XOR_INPUT, XOR_IDEAL); + + // train the neural network + final ResilientPropagation train = new ResilientPropagation(network, trainingSet); + + int epoch = 1; + + do { + train.iteration(); + System.out.println("Epoch #" + epoch + " Error:" + train.getError()); + epoch++; + } while(train.getError() > 0.01); + train.finishTraining(); + + // test the neural network + System.out.println("Neural Network Results:"); + for(MLDataPair pair: trainingSet ) { + final MLData output = network.compute(pair.getInput()); + System.out.println(pair.getInput().getData(0) + "," + pair.getInput().getData(1) + + ", actual=" + output.getData(0) + ",ideal=" + pair.getIdeal().getData(0)); + } + + Encog.getInstance().shutdown(); } -} +} \ No newline at end of file diff --git a/src/main/java/org/encog/ml/ea/train/basic/TrainEA.java b/src/main/java/org/encog/ml/ea/train/basic/TrainEA.java index 2230d8de7..c1e49b970 100644 --- a/src/main/java/org/encog/ml/ea/train/basic/TrainEA.java +++ b/src/main/java/org/encog/ml/ea/train/basic/TrainEA.java @@ -30,6 +30,7 @@ import org.encog.ml.MLMethod; import org.encog.ml.TrainingImplementationType; import org.encog.ml.data.MLDataSet; +import org.encog.ml.ea.genome.Genome; import org.encog.ml.ea.population.Population; import org.encog.ml.train.MLTrain; import org.encog.ml.train.strategy.Strategy; @@ -170,7 +171,11 @@ public void finishTraining() { */ @Override public MLMethod getMethod() { - return this.getPopulation(); + Genome g = this.getPopulation().getBestGenome(); + if(g==null || getCODEC()==null) { + return null; + } + return getCODEC().decode(g); } /** diff --git a/src/main/java/org/encog/neural/networks/training/propagation/GradientWorker.java b/src/main/java/org/encog/neural/networks/training/propagation/GradientWorker.java index 69c5364c4..49894c3c3 100644 --- a/src/main/java/org/encog/neural/networks/training/propagation/GradientWorker.java +++ b/src/main/java/org/encog/neural/networks/training/propagation/GradientWorker.java @@ -216,7 +216,7 @@ public void process(final MLDataPair pair) { // Calculate error for the output layer. this.errorFunction.calculateError( this.network.getActivationFunctions()[0], this.layerSums,this.layerOutput, - pair.getIdeal().getData(), this.actual, this.layerDelta, this.flatSpot[0], + pair.getIdeal().getData(), this.actual, this.layerDelta, this.flatSpot[0], pair.getSignificance()); // Apply regularization, if requested. @@ -255,7 +255,7 @@ private void processLevel(final int currentLevel) { final int index = this.weightIndex[currentLevel]; final ActivationFunction activation = this.network - .getActivationFunctions()[currentLevel]; + .getActivationFunctions()[currentLevel + 1]; final double currentFlatSpot = this.flatSpot[currentLevel + 1]; // handle weights