Skip to content

Commit

Permalink
Async and parallel weights and biases changes
Browse files Browse the repository at this point in the history
  • Loading branch information
xEcho1337 committed Oct 22, 2024
1 parent b420cfb commit 6626b97
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
3 changes: 1 addition & 2 deletions src/main/java/net/echo/brain4j/structure/Neuron.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ public class Neuron {
private Synapse synapse;
private double delta;
private double value;
@Expose
private double bias = 2 * Math.random() - 1;
@Expose private double bias = 2 * Math.random() - 1;

public Synapse getSynapse() {
return synapse;
Expand Down
10 changes: 4 additions & 6 deletions src/main/java/net/echo/brain4j/training/BackPropagation.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

import net.echo.brain4j.layer.Layer;
import net.echo.brain4j.layer.impl.DropoutLayer;
import net.echo.brain4j.loss.LossFunction;
import net.echo.brain4j.model.Model;
import net.echo.brain4j.structure.Neuron;
import net.echo.brain4j.structure.Synapse;
import net.echo.brain4j.training.data.DataRow;
import net.echo.brain4j.training.data.DataSet;
import net.echo.brain4j.training.optimizers.Optimizer;

import java.util.List;
import java.util.*;

public class BackPropagation {

Expand Down Expand Up @@ -69,6 +68,7 @@ private void initialDelta(List<Layer> layers, double[] targets, double[] outputs

for (int i = 0; i < outputLayer.getNeurons().size(); i++) {
Neuron neuron = outputLayer.getNeuronAt(i);

double output = outputs[i];
double error = targets[i] - output;

Expand All @@ -80,9 +80,7 @@ private void initialDelta(List<Layer> layers, double[] targets, double[] outputs
private void updateWeightsAndBiases(List<Layer> layers, double learningRate) {
timestep++;

for (int l = 0; l < layers.size() - 1; l++) {
Layer nextLayer = layers.get(l + 1);

layers.parallelStream().forEach(nextLayer -> {
for (Synapse synapse : nextLayer.getSynapses()) {
optimizer.update(synapse, timestep);
}
Expand All @@ -91,6 +89,6 @@ private void updateWeightsAndBiases(List<Layer> layers, double learningRate) {
double deltaBias = learningRate * neuron.getDelta();
neuron.setBias(neuron.getBias() + deltaBias);
}
}
});
}
}
21 changes: 13 additions & 8 deletions src/test/java/XorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
public class XorTest {

public static void main(String[] args) {
Model network = new FeedForwardModel(
new DenseLayer(2, Activations.LINEAR),
new DenseLayer(32, Activations.RELU),
new DenseLayer(32, Activations.RELU),
new DenseLayer(32, Activations.RELU),
new DenseLayer(1, Activations.SIGMOID)
);
Model network = new FeedForwardModel();

network.add(new DenseLayer(2, Activations.LINEAR));
network.add(new DenseLayer(32, Activations.RELU));
network.add(new DenseLayer(32, Activations.RELU));
network.add(new DenseLayer(32, Activations.RELU));
network.add(new DenseLayer(1, Activations.SIGMOID));

network.compile(InitializationType.XAVIER, LossFunctions.BINARY_CROSS_ENTROPY, new Adam(0.001));

Expand All @@ -41,12 +41,17 @@ public static void main(String[] args) {
epoches++;

network.fit(training);

double evalStart = System.nanoTime();
error = network.evaluate(training);
double evalTook = System.nanoTime() - evalStart;

if (epoches % 100 == 0) {

System.out.println("Epoch #" + epoches + " has error " + error);
System.out.println("Eval took " + (evalTook / 1e6) + "ms");
}
} while (error > 0.01);
} while (error > 1.0E-4);

double took = (System.nanoTime() - start) / 1e6;

Expand Down

0 comments on commit 6626b97

Please sign in to comment.