Skip to content

Commit

Permalink
Added evaluation method for better model performance
Browse files Browse the repository at this point in the history
  • Loading branch information
xEcho1337 committed Oct 20, 2024
1 parent ef47e77 commit e7b7e1a
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 15 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ DataSet training = new DataSet(first, second, third, fourth);

We have everything setup, we can call the fit method inside a loop and wait for the network to finish.

Tip: When training always split the data in two, one set for training and one for testing/evaluation.

```java
double error;

do {
error = network.fit(training);
network.fit(training);
error = network.evaluate(training);
} while (error > 0.01);
```

Expand Down
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ plugins {
}

group = 'net.echo'
version = '2.0.12'
version = '2.0.2'

repositories {
mavenCentral()
Expand Down
11 changes: 9 additions & 2 deletions src/main/java/net/echo/brain4j/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,16 @@ public interface Model {
* Trains the model for one epoch.
*
* @param set dataset for training
* @return model error
*/
double fit(DataSet set);
void fit(DataSet set);

/**
* Evaluates the model on the given dataset.
*
* @param set dataset for testing
* @return the error of the model
*/
double evaluate(DataSet set);

/**
* Predicts output for given input.
Expand Down
21 changes: 19 additions & 2 deletions src/main/java/net/echo/brain4j/model/impl/FeedForwardModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import net.echo.brain4j.structure.Neuron;
import net.echo.brain4j.structure.Synapse;
import net.echo.brain4j.training.BackPropagation;
import net.echo.brain4j.training.data.DataRow;
import net.echo.brain4j.training.data.DataSet;
import net.echo.brain4j.training.optimizers.Optimizer;
import net.echo.brain4j.training.optimizers.impl.Adam;
Expand Down Expand Up @@ -84,8 +85,24 @@ public void compile(InitializationType type, LossFunctions function, Optimizer o
}

@Override
public double fit(DataSet set) {
return propagation.iterate(set, optimizer.getLearningRate());
public void fit(DataSet set) {
propagation.iterate(set, optimizer.getLearningRate());
}

@Override
public double evaluate(DataSet set) {
double totalError = 0.0;

for (DataRow row : set.getDataRows()) {
double[] inputs = row.inputs();
double[] targets = row.outputs();

double[] outputs = predict(inputs);

totalError += function.getFunction().calculate(targets, outputs);
}

return totalError;
}

@Override
Expand Down
10 changes: 1 addition & 9 deletions src/main/java/net/echo/brain4j/training/BackPropagation.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,24 @@
public class BackPropagation {

private final Model model;
private final LossFunction lossFunction;
private final Optimizer optimizer;

private int timestep = 0;

public BackPropagation(Model model, Optimizer optimizer) {
this.model = model;
this.lossFunction = model.getLossFunction();
this.optimizer = optimizer;
}

public double iterate(DataSet dataSet, double learningRate) {
double totalError = 0.0;

public void iterate(DataSet dataSet, double learningRate) {
for (DataRow row : dataSet.getDataRows()) {
double[] inputs = row.inputs();
double[] targets = row.outputs();

double[] outputs = model.predict(inputs);

totalError += lossFunction.calculate(targets, outputs);

backpropagate(targets, outputs, learningRate);
}

return totalError / dataSet.getDataRows().size();
}

private void backpropagate(double[] targets, double[] outputs, double learningRate) {
Expand Down

0 comments on commit e7b7e1a

Please sign in to comment.