diff --git a/workshop/places365_train.ipynb b/workshop/places365_train.ipynb index b447046..79e74a0 100644 --- a/workshop/places365_train.ipynb +++ b/workshop/places365_train.ipynb @@ -17,7 +17,9 @@ }, "outputs": [], "source": [ - "# !gdown 1Hkk2HNvnh2cZqIcOGuxpxUkDSDh-QW86\n", + "# 1200:\n", + "#!gdown 1Hkk2HNvnh2cZqIcOGuxpxUkDSDh-QW86\n", + "# 300:\n", "!gdown 1y-LdQ_4dbOip6sBgZ-Ub1FI6Hh5kl3h1" ] }, @@ -189,33 +191,24 @@ }, "outputs": [], "source": [ - "class Places365Model(tf.keras.Model):\n", - " def __init__(self):\n", - " super().__init__()\n", - " self.base_model = tf.keras.applications.MobileNetV3Small(\n", + "inputs = tf.keras.Input(shape=(224, 224, 3))\n", + "x = tf.keras.applications.MobileNetV3Small(\n", " input_shape=(224, 224, 3),\n", " include_top=False,\n", " weights=None,\n", " classes=3,\n", " pooling=\"avg\",\n", " minimalistic=True\n", - " )\n", - " self.fc = tf.keras.layers.Dense(3, activation='softmax')\n", - "\n", - " def call(self, x):\n", - " x = self.base_model(x)\n", - " return self.fc(x)\n", - "\n", - "\n", - "model = Places365Model()\n", + " )(inputs)\n", + "outputs = tf.keras.layers.Dense(3, activation=\"softmax\")(x)\n", + "model = tf.keras.Model(inputs=inputs, outputs=outputs)\n", "\n", "# updating momentum of the BatchNorm layers\n", - "for layer in model.base_model.layers:\n", + "for layer in model.layers:\n", " if isinstance(layer, tf.keras.layers.BatchNormalization):\n", " layer.momentum = 0.5\n", "\n", - "optimizer = tf.keras.optimizers.Adam(learning_rate=0.005)\n", - "\n", + "optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)\n", "model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics='accuracy')" ] }, @@ -238,7 +231,7 @@ "source": [ "history = model.fit(\n", " train_ds,\n", - " epochs=32)" + " epochs=20)" ] }, { @@ -399,17 +392,26 @@ " img, label = gt\n", " predicted_class = np.argmax(result)\n", " if predicted_class != label:\n", - " images_to_plot.append((img, label, predicted_class))\n", - "\n", + " images_to_plot.append((img, label, predicted_class))" + ] + }, + { + "cell_type": "code", + "source": [ "fig = plt.figure(figsize=(128., 128.))\n", "grid = ImageGrid(fig, 111, nrows_ncols=(math.ceil(len(images_to_plot) / 4), 4), axes_pad=0.6,)\n", "\n", "for ax, im in zip(grid, images_to_plot):\n", - " ax.set_title(f\"True: {class_names[im[1]]}, predicted: {class_names[im[2]]}\", fontdict=None, loc='center', color = \"k\", fontsize=15)\n", + " ax.set_title(f\"True: {class_names[im[1]]}, predicted: {class_names[im[2]]}\", fontdict=None, loc='center', color = \"k\", fontsize=70)\n", " ax.imshow(im[0] / 255)\n", "\n", "plt.show()" - ] + ], + "metadata": { + "id": "J8plRcodbaXT" + }, + "execution_count": null, + "outputs": [] } ], "metadata": {