Skip to content

Commit

Permalink
Issue #230 make TrainEA.getMethod() return the best member of populat…
Browse files Browse the repository at this point in the history
…ion.
  • Loading branch information
jeffheaton committed Sep 3, 2017
1 parent 0266b18 commit 16836ad
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 42 deletions.
97 changes: 58 additions & 39 deletions src/main/java/org/encog/Test.java
Original file line number Diff line number Diff line change
@@ -1,58 +1,77 @@
/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2017 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog;

import org.encog.Encog;
import org.encog.engine.network.activation.ActivationLinear;
import org.encog.engine.network.activation.ActivationReLU;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.importance.PerturbationFeatureImportanceCalc;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.layers.BasicLayer;
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation;
import org.encog.neural.networks.training.propagation.sgd.StochasticGradientDescent;
import org.encog.neural.networks.training.propagation.sgd.update.AdaGradUpdate;
import org.encog.neural.networks.training.propagation.sgd.update.NesterovUpdate;
import org.encog.neural.networks.training.propagation.sgd.update.RMSPropUpdate;
import org.encog.neural.pattern.ElmanPattern;


/**
* XOR: This example is essentially the "Hello World" of neural network
* programming. This example shows how to construct an Encog neural
* network to predict the output from the XOR operator. This example
* uses backpropagation to train the neural network.
*
* This example attempts to use a minimum of Encog features to create and
* train the neural network. This allows you to see exactly what is going
* on. For a more advanced example, that uses Encog factories, refer to
* the XORFactory example.
*
*/
public class Test {

/**
* The input necessary for XOR.
*/
public static double XOR_INPUT[][] = { { 0.0, 0.0 }, { 1.0, 0.0 },
{ 0.0, 1.0 }, { 1.0, 1.0 } };

/**
* The ideal data necessary for XOR.
*/
public static double XOR_IDEAL[][] = { { 0.0 }, { 1.0 }, { 1.0 }, { 0.0 } };

/**
* The main method.
* @param args No arguments are used.
*/
public static void main(final String args[]) {

ElmanPattern elmanPat = new ElmanPattern();
elmanPat.setInputNeurons(5);
elmanPat.addHiddenLayer(5);
elmanPat.setOutputNeurons(1);
BasicNetwork network = (BasicNetwork) elmanPat.generate();
System.out.println(network.toString());
// create a neural network, without using a factory
BasicNetwork network = new BasicNetwork();
network.addLayer(new BasicLayer(new ActivationReLU(),true,2));
network.addLayer(new BasicLayer(new ActivationSigmoid(),true,3));
network.addLayer(new BasicLayer(new ActivationLinear(),false,1));
network.getStructure().finalizeStructure();
network.reset();

// create training data
MLDataSet trainingSet = new BasicMLDataSet(XOR_INPUT, XOR_IDEAL);

// train the neural network
final ResilientPropagation train = new ResilientPropagation(network, trainingSet);

int epoch = 1;

do {
train.iteration();
System.out.println("Epoch #" + epoch + " Error:" + train.getError());
epoch++;
} while(train.getError() > 0.01);
train.finishTraining();

// test the neural network
System.out.println("Neural Network Results:");
for(MLDataPair pair: trainingSet ) {
final MLData output = network.compute(pair.getInput());
System.out.println(pair.getInput().getData(0) + "," + pair.getInput().getData(1)
+ ", actual=" + output.getData(0) + ",ideal=" + pair.getIdeal().getData(0));
}

Encog.getInstance().shutdown();
}
}
}
7 changes: 6 additions & 1 deletion src/main/java/org/encog/ml/ea/train/basic/TrainEA.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.ea.genome.Genome;
import org.encog.ml.ea.population.Population;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
Expand Down Expand Up @@ -170,7 +171,11 @@ public void finishTraining() {
*/
@Override
public MLMethod getMethod() {
return this.getPopulation();
Genome g = this.getPopulation().getBestGenome();
if(g==null || getCODEC()==null) {
return null;
}
return getCODEC().decode(g);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ public void process(final MLDataPair pair) {
// Calculate error for the output layer.
this.errorFunction.calculateError(
this.network.getActivationFunctions()[0], this.layerSums,this.layerOutput,
pair.getIdeal().getData(), this.actual, this.layerDelta, this.flatSpot[0],
pair.getIdeal().getData(), this.actual, this.layerDelta, this.flatSpot[0],
pair.getSignificance());

// Apply regularization, if requested.
Expand Down Expand Up @@ -255,7 +255,7 @@ private void processLevel(final int currentLevel) {

final int index = this.weightIndex[currentLevel];
final ActivationFunction activation = this.network
.getActivationFunctions()[currentLevel];
.getActivationFunctions()[currentLevel + 1];
final double currentFlatSpot = this.flatSpot[currentLevel + 1];

// handle weights
Expand Down

0 comments on commit 16836ad

Please sign in to comment.