From 3c25c7c4c780952f3d777298fcbe9a047ea89666 Mon Sep 17 00:00:00 2001 From: Deep Patel Date: Tue, 22 Oct 2024 05:41:00 +1100 Subject: [PATCH] 'Finished' --- notebooks/shape_classifier.ipynb | 170 ++++++++++++++++++++++++++++--- 1 file changed, 155 insertions(+), 15 deletions(-) diff --git a/notebooks/shape_classifier.ipynb b/notebooks/shape_classifier.ipynb index e7c32b2..0068054 100644 --- a/notebooks/shape_classifier.ipynb +++ b/notebooks/shape_classifier.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "1", "metadata": {}, "outputs": [], @@ -25,10 +25,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Classes: ['circle', 'diamond', 'triangle']\n" + ] + } + ], "source": [ "import torch\n", "import torch.nn as nn\n", @@ -64,7 +72,7 @@ " def __init__(self):\n", " super(SimpleCNN, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)\n", - " self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)\n", + " self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)\n", " self.fc1 = nn.Linear(32 * 16 * 16, 128)\n", " self.fc2 = nn.Linear(128, 3) # 3 classes: circle, triangle, rectangle\n", " \n", @@ -108,7 +116,7 @@ " total += labels.size(0)\n", " correct += (predicted == labels).sum().item()\n", " \n", - " accuracy = 0 * correct / total\n", + " accuracy = 100 * correct / total\n", " print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.2f}%')\n" ] }, @@ -122,10 +130,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "4", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch [1/15], Loss: 1.1637, Accuracy: 32.71%\n", + "Epoch [2/15], Loss: 1.0782, Accuracy: 36.00%\n", + "Epoch [3/15], Loss: 1.0146, Accuracy: 48.71%\n", + "Epoch [4/15], Loss: 0.9366, Accuracy: 54.86%\n", + "Epoch [5/15], Loss: 0.8362, Accuracy: 61.29%\n", + "Epoch [6/15], Loss: 0.7207, Accuracy: 67.14%\n", + "Epoch [7/15], Loss: 0.5859, Accuracy: 76.43%\n", + "Epoch [8/15], Loss: 0.4584, Accuracy: 81.29%\n", + "Epoch [9/15], Loss: 0.3422, Accuracy: 88.86%\n", + "Epoch [10/15], Loss: 0.2567, Accuracy: 92.57%\n", + "Epoch [11/15], Loss: 0.2028, Accuracy: 93.57%\n", + "Epoch [12/15], Loss: 0.1408, Accuracy: 96.57%\n", + "Epoch [13/15], Loss: 0.1064, Accuracy: 97.43%\n", + "Epoch [14/15], Loss: 0.0809, Accuracy: 97.57%\n", + "Epoch [15/15], Loss: 0.0598, Accuracy: 98.86%\n" + ] + } + ], "source": [ "train_model(model, train_loader, criterion, optimizer, epochs=15)" ] @@ -140,15 +170,38 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "6", "metadata": {}, "outputs": [], "source": [ + "from sklearn.metrics import precision_score, recall_score, f1_score\n", + "\n", "def test(model, test_loader):\n", " \"\"\"Print the Precision, Recall and F1-score for the trained model\n", " \"\"\"\n", - " pass" + " model.eval() # Set the model to evaluation mode\n", + " all_labels = [] # To store true labels\n", + " all_predictions = [] # To store predictions\n", + "\n", + " with torch.no_grad(): # Disable gradient calculation\n", + " for images, labels in test_loader:\n", + " images, labels = images.to(device), labels.to(device) # Move to device\n", + " outputs = model(images) # Get model predictions\n", + " _, predicted = torch.max(outputs, 1) # Get the predicted class\n", + " \n", + " all_labels.extend(labels.cpu().numpy()) # Store true labels\n", + " all_predictions.extend(predicted.cpu().numpy()) # Store predictions\n", + "\n", + " # Calculate precision, recall, and F1 score\n", + " precision = precision_score(all_labels, all_predictions, average='weighted')\n", + " recall = recall_score(all_labels, all_predictions, average='weighted')\n", + " f1 = f1_score(all_labels, all_predictions, average='weighted')\n", + "\n", + " # Print results\n", + " print(f'Precision: {precision:.4f}')\n", + " print(f'Recall: {recall:.4f}')\n", + " print(f'F1-score: {f1:.4f}')\n" ] }, { @@ -161,24 +214,111 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "8", "metadata": {}, "outputs": [], "source": [ + "import matplotlib.pyplot as plt\n", + "import torchvision.transforms as transforms\n", + "\n", "def show_prediction(model, image):\n", " \"\"\"Pass the image to the model and overlay the predicted shape and confidence on the input\n", - " image and display it\n", + " image and display it.\n", " \"\"\"\n", - " pass" + " # Define transformations for the input image\n", + " transform = transforms.Compose([\n", + " transforms.Grayscale(num_output_channels=1), # Ensure the image is in grayscale\n", + " transforms.Resize((64, 64)), # Resize to the model input size\n", + " transforms.ToTensor(), # Convert to tensor\n", + " transforms.Normalize((0.5,), (0.5,)) # Normalize the image\n", + " ])\n", + "\n", + " # Preprocess the image\n", + " image_tensor = transform(image).unsqueeze(0) # Add a batch dimension\n", + "\n", + " # Set the model to evaluation mode\n", + " model.eval()\n", + " \n", + " with torch.no_grad(): # Disable gradient calculation\n", + " # Pass the image through the model\n", + " output = model(image_tensor.to(device))\n", + " probabilities = F.softmax(output, dim=1) # Convert logits to probabilities\n", + " confidence, predicted_class = torch.max(probabilities, 1) # Get the predicted class and confidence\n", + "\n", + " # Convert the predicted class to label (assuming class indices are 0, 1, 2)\n", + " class_labels = ['circle', 'triangle', 'rectangle'] # Modify according to your class labels\n", + " predicted_label = class_labels[predicted_class.item()]\n", + " confidence_value = confidence.item()\n", + "\n", + " # Display the image with the prediction overlay\n", + " plt.imshow(image, cmap='gray')\n", + " plt.title(f'Predicted: {predicted_label} (Confidence: {confidence_value:.2f})')\n", + " plt.axis('off') # Hide axes\n", + " plt.show()\n" ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "35f6dfbe-e8bc-4cf3-8be3-d639475ad1bc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Precision: 0.7265\n", + "Recall: 0.7267\n", + "F1-score: 0.7264\n" + ] + } + ], + "source": [ + "test(model, test_loader)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6595116c-87ab-441b-8e4d-d84a9d9e9229", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAhDUlEQVR4nO3deViVdf7/8dcBERBR0DBNFBfcTRosrdwizQytplyyaaFSs8bSNk3tqqmsnLFmRtM2WtQsnRmjLLNsdNLUycpUMstGQtQsBc3cMYTz+f7Rj/fPI6AcBQ/g83FdXV3cZ3uf+yBP7vvc3MfjnHMCAEBSUKAHAABUHEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEShjDRp0kS33HKLfb106VJ5PB4tXbo0YDMd69gZy9rmzZvl8Xg0Y8aMMrvP8pjZ6/Wqffv2evLJJ8v0fk9VRkaGevfurdq1a8vj8WjevHmaMWOGPB6PNm/efMLbl/frC18LFy5UzZo1tXPnzkCPUqaqRBQK/+EU/hcWFqaWLVvqrrvuUnZ2dqDH88sHH3ygRx99NNBjVGlz5szRDz/8oLvuuqvIZZmZmRo+fLiaNWumsLAw1apVS126dNGUKVOUm5tbrnOlpKTo66+/1pNPPqlZs2bp/PPPL9fHq2xeeOEFDRw4UI0bN5bH4/E7gF6vV5MmTVLTpk0VFhamDh06aM6cOcVed8OGDerTp49q1qypOnXq6Kabbiryw79Pnz6Kj4/XxIkTT/YpVUyuCpg+fbqT5B5//HE3a9Ys9/LLL7uUlBQXFBTkmjZt6g4ePFjuM8TFxbmUlBT7uqCgwOXm5rqCggK/7mfEiBGuvF6WY2csa16v1+Xm5rr8/Pwyu8/ymDkhIcHdfvvtRZa///77Ljw83EVFRbmRI0e61NRUN23aNDd48GAXEhLihg0bVqZzHO3QoUNOknvooYd8lufn57vc3Fzn9XpPeB/l/foGWlxcnKtTp47r06ePq1atmt/PdezYsU6SGzZsmEtNTXV9+/Z1ktycOXN8rvfDDz+4s846yzVv3txNmTLFPfnkky46OtolJCS4X3/91ee6zz//vKtRo4bbt2/fqT69CqNKRWHVqlU+y++77z4nyc2ePbvE2x44cKBMZiirf5CVOQql5c86L+uZ16xZ4yS5xYsX+yzftGmTq1mzpmvdurX76aefitwuIyPDTZ48uczmONaWLVucJPf000+f9H1UlNe3vGzevNniGBER4ddz3bZtmwsJCXEjRoywZV6v13Xr1s3Fxsb6/CJz5513uvDwcLdlyxZbtmjRIifJvfTSSz73m52d7YKDg92rr756ks+q4qkSu49Kcumll0qSsrKyJEm33HKLatasqczMTCUnJysyMlI33HCDpN82LSdPnqx27dopLCxMZ599toYPH65ffvnF5z6dc3riiScUGxurGjVqKCkpSd98802Rxy7pPYXPP/9cycnJio6OVkREhDp06KApU6bYfM8995wk+ewOK1TWM0q/7S7JzMws1frcs2eP7r33XjVp0kShoaGKjY3VzTffrF27dkkq/j2FE63zKVOm6Nxzz1VYWJhiYmLUp08fffnllyec45577lGjRo0UGhqq+Ph4/eUvf5HX6z3hc5g3b56qV6+u7t27+yyfNGmSDhw4oFdffVUNGjQocrv4+HiNGjXKvs7Pz9eECRPUvHlzhYaGqkmTJho/frx+/fVXn9s1adJE/fr104oVK9SpUyeFhYWpWbNmev311+06jz76qOLi4iRJo0ePlsfjUZMmTSSp2PcU/Hl9S7OuCl+3Z555RqmpqfacLrjgAq1atarIfX733XcaNGiQYmJiFB4erlatWumhhx7yuc6PP/6o2267TWeffbZCQ0PVrl07vfbaa0Xua+vWrfruu++Knf1YcXFxPv8e/PHuu+/qyJEj+uMf/2jLPB6P7rzzTm3btk0rV6605WlpaerXr58aN25sy3r16qWWLVvqX//6l8/91qtXTx06dNC77757UnNVRNUCPUB5KvxhV7duXVuWn5+vyy+/XF27dtUzzzyjGjVqSJKGDx+uGTNm6NZbb9XIkSOVlZWladOmae3atfrvf/+rkJAQSdIjjzyiJ554QsnJyUpOTtaaNWvUu3dv5eXlnXCeRYsWqV+/fmrQoIFGjRql+vXra8OGDXr//fc1atQoDR8+XD/99JMWLVqkWbNmFbl9eczYs2dPSTrhG5kHDhxQt27dtGHDBt12221KTEzUrl279N5772nbtm0666yzSrxtSet8yJAhmjFjhq644goNHTpU+fn5Wr58uT777LMS96cfOnRIPXr00I8//qjhw4ercePG+vTTTzVu3Dht375dkydPPu7z+PTTT9W+fXtbV4Xmz5+vZs2a6eKLLz7u7QsNHTpUM2fO1IABA3T//ffr888/18SJE7Vhwwa98847Ptf9/vvvNWDAAA0ZMkQpKSl67bXXdMstt6hjx45q166drr32WkVFRenee+/V9ddfr+TkZNWsWbPExy7t6+vvupo9e7b279+v4cOHy+PxaNKkSbr22mu1adMmW1/r1q1Tt27dFBISottvv11NmjRRZmam5s+fb2/cZ2dn68ILL5TH49Fdd92lmJgYffjhhxoyZIj27dune+65xx7z5ptv1ieffCJXzmfwX7t2rSIiItSmTRuf5Z06dbLLu3btqh9//FE5OTnFfv916tRJH3zwQZHlHTt21Lx588pl7oAI7IZK2SjcfbR48WK3c+dO98MPP7h//OMfrm7dui48PNxt27bNOedcSkqKk+TGjh3rc/vly5c7Se7NN9/0Wb5w4UKf5Tk5Oa569equb9++Pvt4x48f7yT5bM4uWbLESXJLlixxzv22b7hp06YuLi7O/fLLLz6Pc/R9lbT7qDxmdO63XQ5xcXFFHu9YjzzyiJPk3n777SKXFT5OVlaWk+SmT59ul5W0zj/++GMnyY0cObLE+yuc7+iZJ0yY4CIiItzGjRt9bjN27FgXHBzstm7detznERsb6/r37++zbO/evU6Su/rqq49720Lp6elOkhs6dKjP8gceeMBJch9//LHP/JLcsmXLbFlOTo4LDQ11999/vy0rXHfH7j4q/N7Oysqy25b29S3tuip87Lp167rdu3fb9d59910nyc2fP9+Wde/e3UVGRvrsWnHO9zUbMmSIa9Cggdu1a5fPdQYPHuxq167tDh06ZMt69OhxUrtL/d191LdvX9esWbMiyw8ePOjz/blq1Sonyb3++utFrjt69GgnyR0+fNhn+VNPPeUkuezsbP+eRAVVpXYf9erVSzExMWrUqJEGDx6smjVr6p133lHDhg19rnfnnXf6fD137lzVrl1bl112mXbt2mX/dezYUTVr1tSSJUskSYsXL1ZeXp7uvvtun83Yo3/zKcnatWuVlZWle+65R1FRUT6XlWaTuLxm3Lx5c6kOd0xLS1NCQoKuueaaIpeVZv5j13laWpo8Ho/+9Kc/+XV/c+fOVbdu3RQdHe2zHnr16qWCggItW7bsuHP8/PPPio6O9lm2b98+SVJkZOQJn4ck+23xvvvu81l+//33S5IWLFjgs7xt27bq1q2bfR0TE6NWrVpp06ZNpXq8o/nz+vq7rq677jqfdVM4c+GcO3fu1LJly3Tbbbf57FqR/v9r5pxTWlqarrzySjnnfB738ssv1969e7VmzRq73dKlS8t9K0GScnNzFRoaWmR5WFiYXX70/0tz3UKF66xwN2plV6V2Hz333HNq2bKlqlWrprPPPlutWrVSUJBv96pVq6bY2FifZRkZGdq7d6/q1atX7P3m5ORIkrZs2SJJatGihc/lMTExRX7QHKtwV1b79u1L/4RO84zHk5mZqf79+5/UbYtb55mZmTrnnHNUp04dv+4rIyND69atU0xMTLGXF66H4zn2h1CtWrUkSfv37y/VDFu2bFFQUJDi4+N9ltevX19RUVH2GhQ69geo9NsPkmPfCyrtY0ule339XVfHzll4f4VzFsbheN/DO3fu1J49e5SamqrU1NRSPe7pEB4eXuT9Hkk6fPiwXX70/0tz3UKF308n+35HRVOlotCpU6cTHtsdGhpaJBRer1f16tXTm2++WextSvpHdTpVhhlLUtw6P1ler1eXXXaZxowZU+zlLVu2PO7t69atW+SHca1atXTOOedo/fr1fs1S2h8CwcHBxS4v79+Q/V1XZTFn4RvYN954o1JSUoq9TocOHUp9f2WlQYMGWrJkiZxzPq/b9u3bJUnnnHOOXe/o5Ufbvn276tSpU2QrovD76Xjvq1UmVSoKJ6t58+ZavHixunTpUuS3gKMVHiGSkZGhZs2a2fKdO3ee8Le+5s2bS5LWr1+vXr16lXi9kn7QnI4Zj6d58+Z+/9A80f199NFH2r17t19bC82bN9eBAweOuw6Pp3Xr1nY02tH69eun1NRUrVy5UhdddNFx7yMuLk5er1cZGRk+b1xmZ2drz5499hqUB39e31NdV8cqfLzjfR/ExMQoMjJSBQUFZfa4ZeG8887TK6+8og0bNqht27a2/PPPP7fLJalhw4aKiYkp9gi4L774wq53tKysLJ111lkV+hczf1Sp9xRO1qBBg1RQUKAJEyYUuSw/P1979uyR9Nt7FiEhIZo6darPb08nOuJFkhITE9W0aVNNnjzZ7q/Q0fcVEREhSUWuU14zlvaQ1P79++urr74qcmTNsfOXVv/+/eWc02OPPebX/Q0aNEgrV67URx99VOSyPXv2KD8//7iPe9FFF2n9+vVFdg+MGTNGERERGjp0aLF/BZ+ZmWmHDicnJ0squk7/9re/SZL69u173BlOhT+v76muq2PFxMSoe/fueu2117R161afywpnCQ4OVv/+/ZWWllZsPI79q2B/Dkktrb179+q7777T3r17bdnVV1+tkJAQPf/88z4zv/jii2rYsKHPUWf9+/fX+++/rx9++MGW/ec//9HGjRs1cODAIo+3evXqE/4iUZmwpSCpR48eGj58uCZOnKj09HT17t1bISEhysjI0Ny5czVlyhQNGDBAMTExeuCBBzRx4kT169dPycnJWrt2rT788MMTbjoGBQXphRde0JVXXqnzzjtPt956qxo0aKDvvvtO33zzjf3D7dixoyRp5MiRuvzyyxUcHKzBgweX24ylPSR19OjReuuttzRw4EDddttt6tixo3bv3q333ntPL774ohISEvxa50lJSbrpppv07LPPKiMjQ3369JHX69Xy5cuVlJRU7CkoCud477331K9fPzus8+DBg/r666/11ltvafPmzcd9La6++mpNmDBBn3zyiXr37m3LmzdvrtmzZ+u6665TmzZtdPPNN6t9+/bKy8vTp59+qrlz59ppFRISEpSSkqLU1FTt2bNHPXr00BdffKGZM2fq97//vZKSkvxaF/7w5/U91XVVnGeffVZdu3ZVYmKibr/9djVt2lSbN2/WggULlJ6eLkn685//rCVLlqhz584aNmyY2rZtq927d2vNmjVavHixdu/ebffnzyGp8+fP11dffSVJOnLkiNatW6cnnnhCknTVVVfZbql33nlHt956q6ZPn26vWWxsrO655x49/fTTOnLkiC644ALNmzdPy5cv15tvvumz62z8+PGaO3eukpKSNGrUKB04cEBPP/20zj33XN16660+M+Xk5GjdunUaMWKEX+uxQjvdhzuVh5L+ovlYKSkpLiIiosTLU1NTXceOHV14eLiLjIx05557rhszZozPX7gWFBS4xx57zDVo0MCFh4e7Sy65xK1fv77IoZPHHpJaaMWKFe6yyy5zkZGRLiIiwnXo0MFNnTrVLs/Pz3d33323i4mJcR6Pp8jhemU5o3OlPyTVOed+/vlnd9ddd7mGDRu66tWru9jYWJeSkmKHHpZ0SGpJ6zw/P989/fTTrnXr1q569eouJibGXXHFFW716tU+8x078/79+924ceNcfHy8q169ujvrrLPcxRdf7J555hmXl5d3wufRoUMHN2TIkGIv27hxoxs2bJhr0qSJq169uouMjHRdunRxU6dO9TkU8ciRI+6xxx5zTZs2dSEhIa5Ro0Zu3LhxRQ5XjIuLc3379i3yOD169HA9evSwr0t7SKpz/r2+pVlXJT22c85Jcn/60598lq1fv95dc801LioqyoWFhblWrVq5hx9+2Oc62dnZbsSIEa5Ro0YuJCTE1a9f3/Xs2dOlpqYWWQ+l/TFUeHhzcf8d/T1XuM6OXla43p566ikXFxfnqlev7tq1a+feeOONYh9r/fr1rnfv3q5GjRouKirK3XDDDW7Hjh1FrvfCCy9UudNceJw7DceDARXIrFmzNGLECG3durXI4cGAP373u9/pkksu0d///vdAj1JmeE8BZ5wbbrhBjRs3tlOKACdj4cKFysjI0Lhx4wI9SpliSwEAYNhSAAAYogAAMEQBAGCIAgDAEAUAgOEvmlFlpaWlnfBU2v4o/Et0oCojCqjUfvrpp2JPcyz9dlqEmTNnltlj9enTR5dcckmxl1WrVk2NGjUqs8cCAoW/U0Cl1rlzZ61evbrYy7xeb5mentrj8ZR4CvC4uDht3LixxNNPA5UFWwqoFF5++WV9+OGHRZZv3LhRBQUFp2UG51yJj7V9+3YNGDCgyKnPO3furAcffPB0jAeUCaKACmv79u12iuaPPvqo2NN2VxS5ubnFfnj77t27bZdT/fr1y/WzFoCywO4jVFiTJk2qUr9ljxw50j6TAaioiAIqlEWLFunhhx+W9NubyEd/0Elld/SWwsiRI/WHP/whwBMBRbH7CAH3888/a+nSpZKkFStW2EckVjU7duzQjh07JP22O6zws347d+6s2NjYQI4GGLYUEDCFb9quWLGixEM9zwRz5syxj3nk6CUEGlFAQGzevFk9e/aU1+vV4cOH7TfoM1FMTIx9NndaWpoSExMDPBHOZOw+wmlz+PBhvfLKK8rPz1dOTo6ysrLK9O8IKqudO3faB9q//vrrWrZsmTwej2655RbVrl07wNPhTMOWAk6LvLw8/fTTT2rXrp0OHToU6HEqvKCgIK1Zs0atWrVSWFhYoMfBGYQT4uG0mDZtmjp06EAQSsnr9apbt252JBZwuhAFlKuCggI99thjevvtt7V///5Aj1Op7N+/X4sWLdK4ceOUm5sb6HFwhmD3EcrF9u3blZeXpyNHjqhr167Kzs4O9EiVVo0aNbRy5UrVrl1bwcHBHL6KckUUUC7OP/98paenS9JpOzdRVVZ4qGqjRo2UkZGhatU4RgTlg91HKFPffvutBgwYoO+//14FBQUEoYwUrsvs7GwNGjRIK1asCPRIqKL4dQNlZuPGjVq2bJnS0tICPUqVlZubq3feeUeJiYmKiopS+/btAz0Sqhh2H6HMJCcnF3t6a5SPxMTEEj9LAjhZ7D7CKdu6dau6dOmilStXBnqUM8r//vc/XXTRRfrqq68CPQqqEHYf4ZSkp6fr008/1cqVK/nr5NPs4MGD+uyzz7RgwQLl5ubqwgsvDPRIqALYfYRTcscdd+ill14K9BhnvH79+mn+/PmBHgNVALuPAACGKOCkHD58WM8//7y+/vrrQI8CSVlZWZo6dar27t0b6FFQybH7CH4rPLld27ZtOf1CBRIUFKTVq1erTZs29gE+gL/YUoDfpk6dqoSEBIJQwXi9XnXv3l0PPfRQoEdBJUYU4Lfc3Fzt27cv0GOgGPv37+dMtDglRAF+2bZtG0Go4A4cOKBt27bJ6/UGehRUQryngFLLy8tTfHy8fvzxR37gVGAej0e1a9fWpk2bFB0dHehxUMmwpQC/5OfnE4QKzjmnI0eO8MeEOClEAaWye/durV69WkeOHAn0KCgFr9er9PR07dixI9CjoJIhCiiVBQsW6OKLL9auXbsCPQpKITc3Vz179tSMGTMCPQoqGaIAADBEASe0dOlSrVmzJtBj4CR8++23+ve//837Cyg1jj7CCSUkJGjdunWBHgMnqWHDhsrKylJISEigR0ElwJYCAMAQBQCAIQoAAEMUAACGKKBEmZmZGj16tLZv3x7oUXAK9u7dqwcffFDp6emBHgWVAEcfoURLly5VUlJSoMdAGZk9e7auv/76QI+BCo4tBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKKBEzZo108SJE1W/fv1Aj4JTUKtWLU2YMEHnnXdeoEdBJcAJ8XBCfBxn5cbHccIfbCkAAAxRAAAYogAAMEQBAGCIAk5o+vTpGjduXKDHwEm4/fbblZaWpmrVqgV6FFQSRAEnlJiYqDZt2gR6DJyEpk2bqnPnzvJ4PIEeBZUEUUCp8YOlcvF4PLxm8Bt/p4BS2bdvn7Zu3apLL71UO3fuDPQ4OIHw8HAtXbpU8fHxqlOnTqDHQSXClgJKpVatWmrZsiX7piuJoKAggoCTQhTglxo1avCXsRVctWrVFBERwa4jnBSigFKrXr26Vq1apQceeCDQo+A4brrpJm3YsEFRUVGBHgWVEPsC4Jfo6GjVqFEj0GPgOMLCwththJPGlgL8VqtWLdWvX5/dExVQvXr12ELAKeHoI/itoKBAOTk5io+P16FDhwI9Dv6foKAgrV+/Xi1btlRwcHCgx0ElxZYC/BYcHKw6dero1VdfVVJSUqDHgaTzzjtPM2fOVGxsLEHAKSEKOCmhoaEaPHiwWrZsGehRICk2NlY33nijIiMjAz0KKjmiAAAwvKeAU5KRkaHPP/9cN998s/hWCoxnn31WSUlJat++faBHQRXAIak4JS1atFB4eLj69OmjVatWadeuXYEe6YxRq1YtdenShSCgTLGlgDKTnJyshQsXssVwGng8HiUmJurLL78M9CioYnhPAWUmNTVVM2fODPQYZ4S//vWvSktLC/QYqIKIAspMbGysOnfurGHDhqlu3bqBHqdKioiI0JAhQ9S1a1fFxcUFehxUQew+Qrm48MILtW7dOknS4cOH2aV0isLDwyVJDRs21IYNGzhbLcoNUUC52LNnj/Lz85WXl6eOHTtqx44dgR6p0oqIiFB6erqioqIUFBTEeY1Qrvh1A+Wi8Pw7Xq9XY8eO1fvvv6/FixcHdqhKqHPnzrr22msVGxursLCwQI+DMwBRQLkKCgrSqFGjFBQUpK+//lo5OTnsSiqlmJgYXXrppRozZkygR8EZhN1HOC04iZ5/Ck9u16JFC94/wGnFdxtOi+DgYNWtW1evvPKK8vPztWPHDj344INsNRzj4YcfVosWLeTxeNSoUSOCgNOOLQUExNatW3XVVVepoKBAhw4d0qZNmwI9UsA0btxYtWrVkiS98cYbSkhICPBEOJMRBQTc8uXL1b1790CPETD//Oc/NWjQoECPAUgiCqgA9u7dq9WrV0uSPv74Yz355JMBnqj83XHHHRo4cKAkqX379qpXr16AJwJ+ww5LBFzt2rV16aWXSvrtDdZVq1ZJkjZt2qTvv/8+kKOVqcaNG6t169aSpF69etlzBioSthRQYU2aNEljx46VpEr9hnThZ1mPHDlSkydPDuwwwAkQBVRYu3fvVk5OjiTp8ccf15w5cwI8kf969uypadOmSZKio6N19tlnB3gi4PjYfYQKq06dOnZKh+TkZNWoUaPIdd59990K8RkONWvW1MCBAxUU5HuOycTERNtlBFQGbCmgUuvatavWrFlT7GV5eXkqKCgos8cKCgpSaGhosZc1btxY33zzjYKDg8vs8YBAIAqo1H755Rfl5+cXe9n48eP1yiuvlNlj9evXT6+99lqxlwUHB3OiOlQJ7D5CpRYdHV3iZf3791ejRo3K7LHatGmjmJiYMrs/oCJiSwEAYPjkNQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAw/wdsClualIWvKQAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from PIL import Image # Import Image from PIL\n", + "\n", + "image_path = \"../datasets/test/circle/circle_285.png\" # Replace with the actual image path\n", + "image = Image.open(image_path).convert(\"L\") \n", + "# Run the show_prediction function\n", + "show_prediction(model, image)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ab81524-8730-472b-ba07-8bc6df915e01", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "intern-skills", "language": "python", - "name": "python3" + "name": "intern-skills" }, "language_info": { "codemirror_mode": { @@ -190,7 +330,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.9.20" } }, "nbformat": 4,