diff --git a/src/main/java/net/echo/brain4j/model/Model.java b/src/main/java/net/echo/brain4j/model/Model.java index 614aede..a713d5f 100644 --- a/src/main/java/net/echo/brain4j/model/Model.java +++ b/src/main/java/net/echo/brain4j/model/Model.java @@ -80,4 +80,11 @@ public interface Model { * @param path path to save model */ void save(String path); + + /** + * Adds a layer to the network. + * + * @param layer the layer to add + */ + void add(Layer layer); } \ No newline at end of file diff --git a/src/main/java/net/echo/brain4j/model/impl/FeedForwardModel.java b/src/main/java/net/echo/brain4j/model/impl/FeedForwardModel.java index 3e786fe..9776657 100644 --- a/src/main/java/net/echo/brain4j/model/impl/FeedForwardModel.java +++ b/src/main/java/net/echo/brain4j/model/impl/FeedForwardModel.java @@ -30,6 +30,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.atomic.AtomicReference; public class FeedForwardModel implements Model { @@ -110,7 +111,8 @@ public double[] predict(double ... input) { Layer inputLayer = layers.get(0); if (input.length != inputLayer.getNeurons().size()) { - throw new IllegalArgumentException("Input size does not match model's input dimension!"); + throw new IllegalArgumentException("Input size does not match model's input dimension! " + + input.length + " != " + inputLayer.getNeurons().size()); } for (Layer layer : layers) { @@ -278,4 +280,9 @@ public void save(String path) { throw new RuntimeException(e); } } + + @Override + public void add(Layer layer) { + layers.add(layer); + } }