diff --git a/NALU.ipynb b/NALU.ipynb new file mode 100644 index 0000000..5313ccb --- /dev/null +++ b/NALU.ipynb @@ -0,0 +1,434 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Neural Arithmatic Logic Units\n", + "Google DeepMind's research paper: https://arxiv.org/abs/1808.00508" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![title](images/naluandnac.png)\n", + "**The Neural Accumulator (NAC)** is a linear transformation of its inputs.\n", + "And the tranformation matrix is elementwise product of **tanh(W)** and **sigmoid(M)**\n", + "\n", + "**The Neural Arithmetic Logic Unit (NALU)** uses two NACs with tied weights." + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [], + "source": [ + "# The Neural Arithmetic Logic Unit\n", + "def NALU(in_dim, out_dim):\n", + "\n", + " shape = (int(in_dim.shape[-1]), out_dim)\n", + " epsilon = 1e-7 \n", + " \n", + " # NAC\n", + " W_hat = tf.Variable(tf.truncated_normal(shape, stddev=0.02))\n", + " M_hat = tf.Variable(tf.truncated_normal(shape, stddev=0.02))\n", + " G = tf.Variable(tf.truncated_normal(shape, stddev=0.02))\n", + " \n", + " W = tf.tanh(W_hat) * tf.sigmoid(M_hat)\n", + " # Forward propogation\n", + " a = tf.matmul(in_dim, W)\n", + " \n", + " # NALU \n", + " m = tf.exp(tf.matmul(tf.log(tf.abs(in_dim) + epsilon), W))\n", + " g = tf.sigmoid(tf.matmul(in_dim, G))\n", + " y = g * a + (1 - g) * m\n", + " \n", + " return y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Helper Function" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_dataset(size=10000):\n", + " # input data\n", + " X = np.random.randint(9, size=(size,2))\n", + " # output data (labels) \n", + " Y = np.prod(X, axis=1, keepdims=True)\n", + "\n", + " \n", + " return X, Y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train NALU on generated data" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [], + "source": [ + "# Hyperparameters\n", + "EPOCHS = 200\n", + "LEARNING_RATE = 1e-3\n", + "BATCH_SIZE = 10" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [], + "source": [ + "# create dataset\n", + "X_data, Y_data = generate_dataset()" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [], + "source": [ + "# define placeholders and network\n", + "X = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 2])\n", + "\n", + "Y_true = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 1])\n", + "\n", + "Y_pred = NALU(X, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [], + "source": [ + "loss = tf.nn.l2_loss(Y_pred - Y_true) \n", + " \n", + "optimizer = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 0, loss: 2732.2294921875, accuracy: 0.0062\n", + "epoch 1, loss: 530.1744384765625, accuracy: 0.0104\n", + "epoch 2, loss: 195.937744140625, accuracy: 0.0101\n", + "epoch 3, loss: 92.95536041259766, accuracy: 0.0101\n", + "epoch 4, loss: 49.38507080078125, accuracy: 0.0101\n", + "epoch 5, loss: 27.861385345458984, accuracy: 0.0101\n", + "epoch 6, loss: 16.27603530883789, accuracy: 0.0101\n", + "epoch 7, loss: 9.71354866027832, accuracy: 0.014\n", + "epoch 8, loss: 5.8773369789123535, accuracy: 0.0101\n", + "epoch 9, loss: 3.588534355163574, accuracy: 0.0101\n", + "epoch 10, loss: 2.204646587371826, accuracy: 0.0101\n", + "epoch 11, loss: 1.360314965248108, accuracy: 0.0101\n", + "epoch 12, loss: 0.8419667482376099, accuracy: 0.0101\n", + "epoch 13, loss: 0.5223231315612793, accuracy: 0.0101\n", + "epoch 14, loss: 0.3245672583580017, accuracy: 0.017\n", + "epoch 15, loss: 0.20196343958377838, accuracy: 0.0226\n", + "epoch 16, loss: 0.1257719099521637, accuracy: 0.0163\n", + "epoch 17, loss: 0.07839271426200867, accuracy: 0.0336\n", + "epoch 18, loss: 0.04889478161931038, accuracy: 0.0538\n", + "epoch 19, loss: 0.030501829460263252, accuracy: 0.0662\n", + "epoch 20, loss: 0.01904493384063244, accuracy: 0.0848\n", + "epoch 21, loss: 0.011892171576619148, accuracy: 0.0927\n", + "epoch 22, loss: 0.007434078957885504, accuracy: 0.1135\n", + "epoch 23, loss: 0.004642682150006294, accuracy: 0.1332\n", + "epoch 24, loss: 0.0029009534046053886, accuracy: 0.1494\n", + "epoch 25, loss: 0.0018110544187948108, accuracy: 0.1586\n", + "epoch 26, loss: 0.0011344996746629477, accuracy: 0.1586\n", + "epoch 27, loss: 0.0007090618018992245, accuracy: 0.1715\n", + "epoch 28, loss: 0.0004433385329321027, accuracy: 0.2058\n", + "epoch 29, loss: 0.0002778613124974072, accuracy: 0.2598\n", + "epoch 30, loss: 0.00017349905101582408, accuracy: 0.3286\n", + "epoch 31, loss: 0.00010843736527021974, accuracy: 0.425\n", + "epoch 32, loss: 6.789401959395036e-05, accuracy: 0.5576\n", + "epoch 33, loss: 4.224254371365532e-05, accuracy: 0.7862\n", + "epoch 34, loss: 2.6566793167148717e-05, accuracy: 0.9441\n", + "epoch 35, loss: 1.6610187230980955e-05, accuracy: 0.9516\n", + "epoch 36, loss: 1.0384383131167851e-05, accuracy: 0.9565\n", + "epoch 37, loss: 6.4911127992672846e-06, accuracy: 0.9745\n", + "epoch 38, loss: 4.0671784518053755e-06, accuracy: 0.9758\n", + "epoch 39, loss: 2.543429218349047e-06, accuracy: 0.9758\n", + "epoch 40, loss: 1.5800842447788455e-06, accuracy: 0.9758\n", + "epoch 41, loss: 1.0215846941719064e-06, accuracy: 0.9758\n", + "epoch 42, loss: 5.984767312838812e-07, accuracy: 0.9758\n", + "epoch 43, loss: 4.0423867631034227e-07, accuracy: 0.9758\n", + "epoch 44, loss: 2.3378470359602943e-07, accuracy: 0.9758\n", + "epoch 45, loss: 1.3873339810288599e-07, accuracy: 0.9758\n", + "epoch 46, loss: 9.241715304142417e-08, accuracy: 0.9758\n", + "epoch 47, loss: 6.142135333675469e-08, accuracy: 0.9758\n", + "epoch 48, loss: 3.927708647211148e-08, accuracy: 0.9758\n", + "epoch 49, loss: 2.327270820501326e-08, accuracy: 0.9758\n", + "epoch 50, loss: 1.744568933759183e-08, accuracy: 0.9758\n", + "epoch 51, loss: 1.0950079065707996e-08, accuracy: 0.9758\n", + "epoch 52, loss: 1.284910400300987e-08, accuracy: 0.9758\n", + "epoch 53, loss: 9.246936549800466e-09, accuracy: 0.9758\n", + "epoch 54, loss: 5.2956368179479796e-09, accuracy: 0.9758\n", + "epoch 55, loss: 4.111474716239627e-09, accuracy: 0.9758\n", + "epoch 56, loss: 1.9377823790023285e-09, accuracy: 0.9758\n", + "epoch 57, loss: 7.369083121488984e-10, accuracy: 0.9758\n", + "epoch 58, loss: 7.369083121488984e-10, accuracy: 0.9758\n", + "epoch 59, loss: 7.369083121488984e-10, accuracy: 0.9758\n", + "epoch 60, loss: 7.369083121488984e-10, accuracy: 0.9758\n", + "epoch 61, loss: 7.369083121488984e-10, accuracy: 0.9758\n", + "epoch 62, loss: 7.369083121488984e-10, accuracy: 0.9789\n", + "epoch 63, loss: 7.369083121488984e-10, accuracy: 0.9886\n", + "epoch 64, loss: 7.369083121488984e-10, accuracy: 0.9956\n", + "epoch 65, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 66, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 67, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 68, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 69, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 70, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 71, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 72, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 73, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 74, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 75, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 76, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 77, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 78, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 79, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 80, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 81, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 82, loss: 7.369083121488984e-10, accuracy: 1.0\n", + "epoch 83, loss: 1.3198067638775512e-10, accuracy: 1.0\n", + "epoch 84, loss: 1.3198067638775512e-10, accuracy: 1.0\n", + "epoch 85, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 86, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 87, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 88, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 89, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 90, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 91, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 92, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 93, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 94, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 95, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 96, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 97, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 98, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 99, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 100, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 101, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 102, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 103, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 104, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 105, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 106, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 107, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 108, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 109, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 110, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 111, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 112, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 113, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 114, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 115, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 116, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 117, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 118, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 119, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 120, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 121, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 122, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 123, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 124, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 125, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 126, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 127, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 128, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 129, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 130, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 131, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 132, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 133, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 134, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 135, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 136, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 137, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 138, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 139, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 140, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 141, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 142, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 143, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 144, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 145, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 146, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 147, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 148, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 149, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 150, loss: 3.7393220464476684e-11, accuracy: 1.0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 151, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 152, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 153, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 154, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 155, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 156, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 157, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 158, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 159, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 160, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 161, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 162, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 163, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 164, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 165, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 166, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 167, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 168, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 169, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 170, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 171, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 172, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 173, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 174, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 175, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 176, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 177, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 178, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 179, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 180, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 181, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 182, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 183, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 184, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 185, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 186, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 187, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 188, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 189, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 190, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 191, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 192, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 193, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 194, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 195, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 196, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 197, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 198, loss: 3.7393220464476684e-11, accuracy: 1.0\n", + "epoch 199, loss: 3.7393220464476684e-11, accuracy: 1.0\n" + ] + } + ], + "source": [ + "# create session\n", + "sess = tf.Session()\n", + "# create writer to store tensorboard graph \n", + "writer = tf.summary.FileWriter('/tmp', sess.graph)\n", + " \n", + "init = tf.global_variables_initializer()\n", + " \n", + "sess.run(init)\n", + "\n", + "# Run training loop\n", + "for i in range(EPOCHS):\n", + " j = 0\n", + " g = 0\n", + " \n", + " while j < len(X_data):\n", + " xs, ys = X_data[j:j + BATCH_SIZE], Y_data[j:j + BATCH_SIZE]\n", + "\n", + " _, ys_pred, l = sess.run([optimizer, Y_pred, loss], \n", + " feed_dict={X: xs, Y_true: ys})\n", + " \n", + " # calculate number of correct predictions from batch\n", + " g += np.sum(np.isclose(ys, ys_pred, atol=1e-4, rtol=1e-4)) \n", + "\n", + " j += BATCH_SIZE\n", + "\n", + " acc = g / len(Y_data)\n", + " \n", + " print(f'epoch {i}, loss: {l}, accuracy: {acc}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Uncomment to run TensorBoard" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorBoard 1.9.0 at http://Akils-Air-2.home:6006 (Press CTRL+C to quit)\n", + "^C\n" + ] + } + ], + "source": [ + "# !tensorboard --logdir /tmp" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/NALU.py b/NALU.py new file mode 100644 index 0000000..f1fba61 --- /dev/null +++ b/NALU.py @@ -0,0 +1,106 @@ + +# coding: utf-8 + +# ## Neural Arithmatic Logic Units +# Google DeepMind's research paper: https://arxiv.org/abs/1808.00508 + +# In[73]: + + +import tensorflow as tf +import numpy as np + + + +# The Neural Arithmetic Logic Unit +def NALU(in_dim, out_dim): + + shape = (int(in_dim.shape[-1]), out_dim) + epsilon = 1e-7 + + # NAC + W_hat = tf.Variable(tf.truncated_normal(shape, stddev=0.02)) + M_hat = tf.Variable(tf.truncated_normal(shape, stddev=0.02)) + G = tf.Variable(tf.truncated_normal(shape, stddev=0.02)) + + W = tf.tanh(W_hat) * tf.sigmoid(M_hat) + # Forward propogation + a = tf.matmul(in_dim, W) + + # NALU + m = tf.exp(tf.matmul(tf.log(tf.abs(in_dim) + epsilon), W)) + g = tf.sigmoid(tf.matmul(in_dim, G)) + y = g * a + (1 - g) * m + + return y + + +### Helper Function + +def generate_dataset(size=10000): + # input data + X = np.random.randint(9, size=(size,2)) + # output data (labels) + Y = np.prod(X, axis=1, keepdims=True) + + + return X, Y + + +### Train NALU on generated data + +# Hyperparameters +EPOCHS = 200 +LEARNING_RATE = 1e-3 +BATCH_SIZE = 10 + + + +# create dataset +X_data, Y_data = generate_dataset() + + +# define placeholders and network +X = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 2]) + +Y_true = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 1]) + +Y_pred = NALU(X, 1) + +loss = tf.nn.l2_loss(Y_pred - Y_true) + +optimizer = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss) + + +# create session +sess = tf.Session() +# create writer to store tensorboard graph +writer = tf.summary.FileWriter('/tmp', sess.graph) + +init = tf.global_variables_initializer() + +sess.run(init) + +# Run training loop +for i in range(EPOCHS): + j = 0 + g = 0 + + while j < len(X_data): + xs, ys = X_data[j:j + BATCH_SIZE], Y_data[j:j + BATCH_SIZE] + + _, ys_pred, l = sess.run([optimizer, Y_pred, loss], + feed_dict={X: xs, Y_true: ys}) + + # calculate number of correct predictions from batch + g += np.sum(np.isclose(ys, ys_pred, atol=1e-4, rtol=1e-4)) + + j += BATCH_SIZE + + acc = g / len(Y_data) + + print(f'epoch {i}, loss: {l}, accuracy: {acc}') + + +# !tensorboard --logdir /tmp + diff --git a/images/naluandnac.png b/images/naluandnac.png new file mode 100644 index 0000000..80795be Binary files /dev/null and b/images/naluandnac.png differ diff --git a/images/tensorboard.png b/images/tensorboard.png new file mode 100644 index 0000000..6379526 Binary files /dev/null and b/images/tensorboard.png differ