Skip to content

Commit

Permalink
Added model saving and loading with GSON
Browse files Browse the repository at this point in the history
  • Loading branch information
xEcho1337 committed Oct 15, 2024
1 parent ad75b62 commit 1755b9c
Show file tree
Hide file tree
Showing 11 changed files with 427 additions and 14 deletions.
5 changes: 5 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
plugins {
id 'java'
id 'com.github.johnrengelman.shadow' version '8.1.1'
}

group = 'net.echo'
version = '2.0.0'

repositories {
mavenCentral()
}

dependencies {
implementation 'com.google.code.gson:gson:2.10.1'
}
61 changes: 61 additions & 0 deletions src/main/java/net/echo/brain4j/adapters/LayerAdapter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package net.echo.brain4j.adapters;

import com.google.gson.*;
import net.echo.brain4j.activation.Activations;
import net.echo.brain4j.layer.Layer;
import net.echo.brain4j.layer.impl.DenseLayer;
import net.echo.brain4j.layer.impl.DropoutLayer;

import java.lang.reflect.Type;

public class LayerAdapter implements JsonSerializer<Layer>, JsonDeserializer<Layer> {

@Override
public JsonElement serialize(Layer layer, Type type, JsonSerializationContext context) {
JsonObject object = new JsonObject();

object.addProperty("type", layer.getClass().getSimpleName());
object.addProperty("activation", layer.getActivation().name());

if (layer instanceof DenseLayer) {
double[] biases = new double[layer.getNeurons().size()];

for (int i = 0; i < biases.length; i++) {
biases[i] = layer.getNeurons().get(i).getBias();
}

object.add("biases", context.serialize(biases));
} else if (layer instanceof DropoutLayer dropoutLayer) {
object.addProperty("rate", dropoutLayer.getDropout());
}

return object;
}

@Override
public Layer deserialize(JsonElement element, Type type, JsonDeserializationContext context) throws JsonParseException {
String layerType = element.getAsJsonObject().get("type").getAsString();
String activationType = element.getAsJsonObject().get("activation").getAsString();

Activations activations = Activations.valueOf(activationType);

return switch (layerType) {
case "DenseLayer" -> {
double[] biases = context.deserialize(element.getAsJsonObject().get("biases"), double[].class);

DenseLayer layer = new DenseLayer(biases.length, activations);

for (int i = 0; i < layer.getNeurons().size(); i++) {
layer.getNeuronAt(i).setBias(biases[i]);
}

yield layer;
}
case "DropoutLayer" -> {
double dropout = element.getAsJsonObject().get("rate").getAsDouble();
yield new DropoutLayer(dropout);
}
default -> throw new IllegalArgumentException("Unknown layer type: " + layerType);
};
}
}
53 changes: 53 additions & 0 deletions src/main/java/net/echo/brain4j/adapters/OptimizerAdapter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package net.echo.brain4j.adapters;

import com.google.gson.*;
import net.echo.brain4j.training.optimizers.Optimizer;
import net.echo.brain4j.training.optimizers.impl.Adam;
import net.echo.brain4j.training.optimizers.impl.SGD;

import java.lang.reflect.Type;

public class OptimizerAdapter implements JsonSerializer<Optimizer>, JsonDeserializer<Optimizer> {

@Override
public JsonElement serialize(Optimizer optimizer, Type type, JsonSerializationContext context) {
JsonObject object = new JsonObject();

object.addProperty("type", optimizer.getClass().getSimpleName());

JsonObject data = new JsonObject();

data.addProperty("learningRate", optimizer.getLearningRate());

if (optimizer instanceof Adam adam) {
data.addProperty("beta1", adam.getBeta1());
data.addProperty("beta2", adam.getBeta2());
data.addProperty("epsilon", adam.getEpsilon());
}

object.add("data", data);
return object;
}

@Override
public Optimizer deserialize(JsonElement jsonElement, Type type, JsonDeserializationContext context) throws JsonParseException {
JsonObject object = jsonElement.getAsJsonObject();
String optimizerType = object.get("type").getAsString();

JsonObject data = object.get("data").getAsJsonObject();
double learningRate = data.get("learningRate").getAsDouble();

return switch (optimizerType) {
case "SGD" -> new SGD(learningRate);
case "Adam" -> {
Adam adam = new Adam(learningRate);

adam.setBeta1(data.get("beta1").getAsDouble());
adam.setBeta2(data.get("beta2").getAsDouble());
adam.setEpsilon(data.get("epsilon").getAsDouble());
yield adam;
}
default -> null;
};
}
}
3 changes: 3 additions & 0 deletions src/main/java/net/echo/brain4j/layer/Layer.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package net.echo.brain4j.layer;

import com.google.gson.annotations.JsonAdapter;
import net.echo.brain4j.activation.Activation;
import net.echo.brain4j.activation.Activations;
import net.echo.brain4j.adapters.LayerAdapter;
import net.echo.brain4j.structure.Neuron;
import net.echo.brain4j.structure.Synapse;

import java.util.ArrayList;
import java.util.List;

@JsonAdapter(LayerAdapter.class)
public class Layer {

private final List<Neuron> neurons = new ArrayList<>();
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/net/echo/brain4j/model/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,8 @@ public interface Model {
List<Layer> getLayers();

String getStats();

void load(String path);

void save(String path);
}
125 changes: 114 additions & 11 deletions src/main/java/net/echo/brain4j/model/impl/FeedForwardModel.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
package net.echo.brain4j.model.impl;

import com.google.gson.Gson;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.reflect.TypeToken;
import net.echo.brain4j.adapters.LayerAdapter;
import net.echo.brain4j.adapters.OptimizerAdapter;
import net.echo.brain4j.layer.Layer;
import net.echo.brain4j.layer.impl.DenseLayer;
import net.echo.brain4j.layer.impl.DropoutLayer;
import net.echo.brain4j.loss.LossFunction;
import net.echo.brain4j.loss.LossFunctions;
Expand All @@ -11,14 +18,33 @@
import net.echo.brain4j.training.BackPropagation;
import net.echo.brain4j.training.data.DataSet;
import net.echo.brain4j.training.optimizers.Optimizer;

import net.echo.brain4j.training.optimizers.impl.Adam;
import net.echo.brain4j.training.optimizers.impl.SGD;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class FeedForwardModel implements Model {

private final List<Layer> layers;
private static final OptimizerAdapter OPTIMIZER_ADAPTER = new OptimizerAdapter();
private static final LayerAdapter LAYER_ADAPTER = new LayerAdapter();
private static final Gson GSON = new Gson()
.newBuilder()
.setPrettyPrinting()
.excludeFieldsWithoutExposeAnnotation()
.registerTypeAdapter(DenseLayer.class, LAYER_ADAPTER)
.registerTypeAdapter(DropoutLayer.class, LAYER_ADAPTER)
.registerTypeAdapter(Adam.class, OPTIMIZER_ADAPTER)
.registerTypeAdapter(SGD.class, OPTIMIZER_ADAPTER)
.create();

private List<Layer> layers;
private LossFunctions function;
private Optimizer optimizer;
private BackPropagation propagation;
Expand All @@ -27,13 +53,7 @@ public FeedForwardModel(Layer... layers) {
this.layers = new ArrayList<>(Arrays.asList(layers));
}

@Override
public void compile(InitializationType type, LossFunctions function, Optimizer optimizer) {
this.function = function;
this.optimizer = optimizer;
this.propagation = new BackPropagation(this, optimizer);

// Ignore the output layer
private void connect(InitializationType type) {
for (int i = 0; i < layers.size() - 1; i++) {
Layer layer = layers.get(i);

Expand All @@ -54,6 +74,15 @@ public void compile(InitializationType type, LossFunctions function, Optimizer o
}
}

@Override
public void compile(InitializationType type, LossFunctions function, Optimizer optimizer) {
this.function = function;
this.optimizer = optimizer;
this.propagation = new BackPropagation(this, optimizer);

connect(type);
}

@Override
public double fit(DataSet set) {
return propagation.iterate(set, optimizer.getLearningRate());
Expand Down Expand Up @@ -84,8 +113,7 @@ public double[] predict(double ... input) {

Layer nextLayer = layers.get(l + 1);

if (nextLayer instanceof DropoutLayer dropoutLayer) {
// dropoutLayer.process(layer.getNeurons());
if (nextLayer instanceof DropoutLayer) {
nextLayer = layers.get(l + 2);
}

Expand Down Expand Up @@ -158,4 +186,79 @@ public String getStats() {
stats.append("================================================\n");
return stats.toString();
}

@Override
public void load(String path) {
File file = new File(path);

if (!file.exists()) {
throw new IllegalArgumentException("File does not exist: " + path);
}

try {
JsonObject parent = JsonParser.parseReader(new FileReader(file)).getAsJsonObject();

this.optimizer = GSON.fromJson(parent.get("optimizer"), Optimizer.class);
this.function = LossFunctions.valueOf(parent.get("lossFunction").getAsString());

Type listType = new TypeToken<ArrayList<Layer>>(){}.getType();

this.layers = GSON.fromJson(parent.get("layers"), listType);

connect(InitializationType.NORMAL);

double[][] weights = GSON.fromJson(parent.get("weights"), double[][].class);

for (int i = 0; i < weights.length; i++) {
double[] layerWeights = weights[i];
Layer layer = layers.get(i);

for (int j = 0; j < layerWeights.length; j++) {
layer.getSynapses().get(j).setWeight(layerWeights[j]);
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}

@Override
public void save(String path) {
File file = new File(path);

JsonObject parent = new JsonObject();
JsonObject optimizerObject = GSON.toJsonTree(optimizer).getAsJsonObject();

parent.addProperty("lossFunction", function.name());
parent.add("optimizer", optimizerObject);

List<JsonObject> layerObjects = new ArrayList<>();

for (Layer layer : layers) {
layerObjects.add(GSON.toJsonTree(layer).getAsJsonObject());
}

parent.add("layers", GSON.toJsonTree(layerObjects).getAsJsonArray());

double[][] weights = new double[layers.size()][];

for (int i = 0; i < layers.size(); i++) {
Layer layer = layers.get(i);
weights[i] = new double[layer.getSynapses().size()];

for (int j = 0; j < layer.getSynapses().size(); j++) {
Synapse synapse = layer.getSynapses().get(j);

weights[i][j] = synapse.getWeight();
}
}

parent.add("weights", GSON.toJsonTree(weights));

try (BufferedWriter writer = new BufferedWriter(new FileWriter(file))) {
GSON.toJson(parent, writer);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
3 changes: 3 additions & 0 deletions src/main/java/net/echo/brain4j/structure/Neuron.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package net.echo.brain4j.structure;

import com.google.gson.annotations.Expose;

public class Neuron {

private Synapse synapse;
private double delta;
private double value;
@Expose
private double bias = 2 * Math.random() - 1;

public Synapse getSynapse() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package net.echo.brain4j.training.optimizers;

import com.google.gson.annotations.JsonAdapter;
import net.echo.brain4j.adapters.OptimizerAdapter;
import net.echo.brain4j.structure.Synapse;

@JsonAdapter(OptimizerAdapter.class)
public abstract class Optimizer {

protected double learningRate;
Expand All @@ -14,5 +17,9 @@ public double getLearningRate() {
return learningRate;
}

public void setLearningRate(double learningRate) {
this.learningRate = learningRate;
}

public abstract void update(Synapse synapse, int timestep);
}
Loading

0 comments on commit 1755b9c

Please sign in to comment.