Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'Finished' #19

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 155 additions & 15 deletions notebooks/shape_classifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "1",
"metadata": {},
"outputs": [],
Expand All @@ -25,10 +25,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "2",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Classes: ['circle', 'diamond', 'triangle']\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
Expand Down Expand Up @@ -64,7 +72,7 @@
" def __init__(self):\n",
" super(SimpleCNN, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)\n",
" self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)\n",
" self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)\n",
" self.fc1 = nn.Linear(32 * 16 * 16, 128)\n",
" self.fc2 = nn.Linear(128, 3) # 3 classes: circle, triangle, rectangle\n",
" \n",
Expand Down Expand Up @@ -108,7 +116,7 @@
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
" \n",
" accuracy = 0 * correct / total\n",
" accuracy = 100 * correct / total\n",
" print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.2f}%')\n"
]
},
Expand All @@ -122,10 +130,32 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "4",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/15], Loss: 1.1637, Accuracy: 32.71%\n",
"Epoch [2/15], Loss: 1.0782, Accuracy: 36.00%\n",
"Epoch [3/15], Loss: 1.0146, Accuracy: 48.71%\n",
"Epoch [4/15], Loss: 0.9366, Accuracy: 54.86%\n",
"Epoch [5/15], Loss: 0.8362, Accuracy: 61.29%\n",
"Epoch [6/15], Loss: 0.7207, Accuracy: 67.14%\n",
"Epoch [7/15], Loss: 0.5859, Accuracy: 76.43%\n",
"Epoch [8/15], Loss: 0.4584, Accuracy: 81.29%\n",
"Epoch [9/15], Loss: 0.3422, Accuracy: 88.86%\n",
"Epoch [10/15], Loss: 0.2567, Accuracy: 92.57%\n",
"Epoch [11/15], Loss: 0.2028, Accuracy: 93.57%\n",
"Epoch [12/15], Loss: 0.1408, Accuracy: 96.57%\n",
"Epoch [13/15], Loss: 0.1064, Accuracy: 97.43%\n",
"Epoch [14/15], Loss: 0.0809, Accuracy: 97.57%\n",
"Epoch [15/15], Loss: 0.0598, Accuracy: 98.86%\n"
]
}
],
"source": [
"train_model(model, train_loader, criterion, optimizer, epochs=15)"
]
Expand All @@ -140,15 +170,38 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "6",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import precision_score, recall_score, f1_score\n",
"\n",
"def test(model, test_loader):\n",
" \"\"\"Print the Precision, Recall and F1-score for the trained model\n",
" \"\"\"\n",
" pass"
" model.eval() # Set the model to evaluation mode\n",
" all_labels = [] # To store true labels\n",
" all_predictions = [] # To store predictions\n",
"\n",
" with torch.no_grad(): # Disable gradient calculation\n",
" for images, labels in test_loader:\n",
" images, labels = images.to(device), labels.to(device) # Move to device\n",
" outputs = model(images) # Get model predictions\n",
" _, predicted = torch.max(outputs, 1) # Get the predicted class\n",
" \n",
" all_labels.extend(labels.cpu().numpy()) # Store true labels\n",
" all_predictions.extend(predicted.cpu().numpy()) # Store predictions\n",
"\n",
" # Calculate precision, recall, and F1 score\n",
" precision = precision_score(all_labels, all_predictions, average='weighted')\n",
" recall = recall_score(all_labels, all_predictions, average='weighted')\n",
" f1 = f1_score(all_labels, all_predictions, average='weighted')\n",
"\n",
" # Print results\n",
" print(f'Precision: {precision:.4f}')\n",
" print(f'Recall: {recall:.4f}')\n",
" print(f'F1-score: {f1:.4f}')\n"
]
},
{
Expand All @@ -161,24 +214,111 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "8",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import torchvision.transforms as transforms\n",
"\n",
"def show_prediction(model, image):\n",
" \"\"\"Pass the image to the model and overlay the predicted shape and confidence on the input\n",
" image and display it\n",
" image and display it.\n",
" \"\"\"\n",
" pass"
" # Define transformations for the input image\n",
" transform = transforms.Compose([\n",
" transforms.Grayscale(num_output_channels=1), # Ensure the image is in grayscale\n",
" transforms.Resize((64, 64)), # Resize to the model input size\n",
" transforms.ToTensor(), # Convert to tensor\n",
" transforms.Normalize((0.5,), (0.5,)) # Normalize the image\n",
" ])\n",
"\n",
" # Preprocess the image\n",
" image_tensor = transform(image).unsqueeze(0) # Add a batch dimension\n",
"\n",
" # Set the model to evaluation mode\n",
" model.eval()\n",
" \n",
" with torch.no_grad(): # Disable gradient calculation\n",
" # Pass the image through the model\n",
" output = model(image_tensor.to(device))\n",
" probabilities = F.softmax(output, dim=1) # Convert logits to probabilities\n",
" confidence, predicted_class = torch.max(probabilities, 1) # Get the predicted class and confidence\n",
"\n",
" # Convert the predicted class to label (assuming class indices are 0, 1, 2)\n",
" class_labels = ['circle', 'triangle', 'rectangle'] # Modify according to your class labels\n",
" predicted_label = class_labels[predicted_class.item()]\n",
" confidence_value = confidence.item()\n",
"\n",
" # Display the image with the prediction overlay\n",
" plt.imshow(image, cmap='gray')\n",
" plt.title(f'Predicted: {predicted_label} (Confidence: {confidence_value:.2f})')\n",
" plt.axis('off') # Hide axes\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "35f6dfbe-e8bc-4cf3-8be3-d639475ad1bc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Precision: 0.7265\n",
"Recall: 0.7267\n",
"F1-score: 0.7264\n"
]
}
],
"source": [
"test(model, test_loader)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6595116c-87ab-441b-8e4d-d84a9d9e9229",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGbCAYAAAAr/4yjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAhDUlEQVR4nO3deViVdf7/8dcBERBR0DBNFBfcTRosrdwizQytplyyaaFSs8bSNk3tqqmsnLFmRtM2WtQsnRmjLLNsdNLUycpUMstGQtQsBc3cMYTz+f7Rj/fPI6AcBQ/g83FdXV3cZ3uf+yBP7vvc3MfjnHMCAEBSUKAHAABUHEQBAGCIAgDAEAUAgCEKAABDFAAAhigAAAxRAAAYogAAMEShjDRp0kS33HKLfb106VJ5PB4tXbo0YDMd69gZy9rmzZvl8Xg0Y8aMMrvP8pjZ6/Wqffv2evLJJ8v0fk9VRkaGevfurdq1a8vj8WjevHmaMWOGPB6PNm/efMLbl/frC18LFy5UzZo1tXPnzkCPUqaqRBQK/+EU/hcWFqaWLVvqrrvuUnZ2dqDH88sHH3ygRx99NNBjVGlz5szRDz/8oLvuuqvIZZmZmRo+fLiaNWumsLAw1apVS126dNGUKVOUm5tbrnOlpKTo66+/1pNPPqlZs2bp/PPPL9fHq2xeeOEFDRw4UI0bN5bH4/E7gF6vV5MmTVLTpk0VFhamDh06aM6cOcVed8OGDerTp49q1qypOnXq6Kabbiryw79Pnz6Kj4/XxIkTT/YpVUyuCpg+fbqT5B5//HE3a9Ys9/LLL7uUlBQXFBTkmjZt6g4ePFjuM8TFxbmUlBT7uqCgwOXm5rqCggK/7mfEiBGuvF6WY2csa16v1+Xm5rr8/Pwyu8/ymDkhIcHdfvvtRZa///77Ljw83EVFRbmRI0e61NRUN23aNDd48GAXEhLihg0bVqZzHO3QoUNOknvooYd8lufn57vc3Fzn9XpPeB/l/foGWlxcnKtTp47r06ePq1atmt/PdezYsU6SGzZsmEtNTXV9+/Z1ktycOXN8rvfDDz+4s846yzVv3txNmTLFPfnkky46OtolJCS4X3/91ee6zz//vKtRo4bbt2/fqT69CqNKRWHVqlU+y++77z4nyc2ePbvE2x44cKBMZiirf5CVOQql5c86L+uZ16xZ4yS5xYsX+yzftGmTq1mzpmvdurX76aefitwuIyPDTZ48uczmONaWLVucJPf000+f9H1UlNe3vGzevNniGBER4ddz3bZtmwsJCXEjRoywZV6v13Xr1s3Fxsb6/CJz5513uvDwcLdlyxZbtmjRIifJvfTSSz73m52d7YKDg92rr756ks+q4qkSu49Kcumll0qSsrKyJEm33HKLatasqczMTCUnJysyMlI33HCDpN82LSdPnqx27dopLCxMZ599toYPH65ffvnF5z6dc3riiScUGxurGjVqKCkpSd98802Rxy7pPYXPP/9cycnJio6OVkREhDp06KApU6bYfM8995wk+ewOK1TWM0q/7S7JzMws1frcs2eP7r33XjVp0kShoaGKjY3VzTffrF27dkkq/j2FE63zKVOm6Nxzz1VYWJhiYmLUp08fffnllyec45577lGjRo0UGhqq+Ph4/eUvf5HX6z3hc5g3b56qV6+u7t27+yyfNGmSDhw4oFdffVUNGjQocrv4+HiNGjXKvs7Pz9eECRPUvHlzhYaGqkmTJho/frx+/fVXn9s1adJE/fr104oVK9SpUyeFhYWpWbNmev311+06jz76qOLi4iRJo0ePlsfjUZMmTSSp2PcU/Hl9S7OuCl+3Z555RqmpqfacLrjgAq1atarIfX733XcaNGiQYmJiFB4erlatWumhhx7yuc6PP/6o2267TWeffbZCQ0PVrl07vfbaa0Xua+vWrfruu++Knf1YcXFxPv8e/PHuu+/qyJEj+uMf/2jLPB6P7rzzTm3btk0rV6605WlpaerXr58aN25sy3r16qWWLVvqX//6l8/91qtXTx06dNC77757UnNVRNUCPUB5KvxhV7duXVuWn5+vyy+/XF27dtUzzzyjGjVqSJKGDx+uGTNm6NZbb9XIkSOVlZWladOmae3atfrvf/+rkJAQSdIjjzyiJ554QsnJyUpOTtaaNWvUu3dv5eXlnXCeRYsWqV+/fmrQoIFGjRql+vXra8OGDXr//fc1atQoDR8+XD/99JMWLVqkWbNmFbl9eczYs2dPSTrhG5kHDhxQt27dtGHDBt12221KTEzUrl279N5772nbtm0666yzSrxtSet8yJAhmjFjhq644goNHTpU+fn5Wr58uT777LMS96cfOnRIPXr00I8//qjhw4ercePG+vTTTzVu3Dht375dkydPPu7z+PTTT9W+fXtbV4Xmz5+vZs2a6eKLLz7u7QsNHTpUM2fO1IABA3T//ffr888/18SJE7Vhwwa98847Ptf9/vvvNWDAAA0ZMkQpKSl67bXXdMstt6hjx45q166drr32WkVFRenee+/V9ddfr+TkZNWsWbPExy7t6+vvupo9e7b279+v4cOHy+PxaNKkSbr22mu1adMmW1/r1q1Tt27dFBISottvv11NmjRRZmam5s+fb2/cZ2dn68ILL5TH49Fdd92lmJgYffjhhxoyZIj27dune+65xx7z5ptv1ieffCJXzmfwX7t2rSIiItSmTRuf5Z06dbLLu3btqh9//FE5OTnFfv916tRJH3zwQZHlHTt21Lx588pl7oAI7IZK2SjcfbR48WK3c+dO98MPP7h//OMfrm7dui48PNxt27bNOedcSkqKk+TGjh3rc/vly5c7Se7NN9/0Wb5w4UKf5Tk5Oa569equb9++Pvt4x48f7yT5bM4uWbLESXJLlixxzv22b7hp06YuLi7O/fLLLz6Pc/R9lbT7qDxmdO63XQ5xcXFFHu9YjzzyiJPk3n777SKXFT5OVlaWk+SmT59ul5W0zj/++GMnyY0cObLE+yuc7+iZJ0yY4CIiItzGjRt9bjN27FgXHBzstm7detznERsb6/r37++zbO/evU6Su/rqq49720Lp6elOkhs6dKjP8gceeMBJch9//LHP/JLcsmXLbFlOTo4LDQ11999/vy0rXHfH7j4q/N7Oysqy25b29S3tuip87Lp167rdu3fb9d59910nyc2fP9+Wde/e3UVGRvrsWnHO9zUbMmSIa9Cggdu1a5fPdQYPHuxq167tDh06ZMt69OhxUrtL/d191LdvX9esWbMiyw8ePOjz/blq1Sonyb3++utFrjt69GgnyR0+fNhn+VNPPeUkuezsbP+eRAVVpXYf9erVSzExMWrUqJEGDx6smjVr6p133lHDhg19rnfnnXf6fD137lzVrl1bl112mXbt2mX/dezYUTVr1tSSJUskSYsXL1ZeXp7uvvtun83Yo3/zKcnatWuVlZWle+65R1FRUT6XlWaTuLxm3Lx5c6kOd0xLS1NCQoKuueaaIpeVZv5j13laWpo8Ho/+9Kc/+XV/c+fOVbdu3RQdHe2zHnr16qWCggItW7bsuHP8/PPPio6O9lm2b98+SVJkZOQJn4ck+23xvvvu81l+//33S5IWLFjgs7xt27bq1q2bfR0TE6NWrVpp06ZNpXq8o/nz+vq7rq677jqfdVM4c+GcO3fu1LJly3Tbbbf57FqR/v9r5pxTWlqarrzySjnnfB738ssv1969e7VmzRq73dKlS8t9K0GScnNzFRoaWmR5WFiYXX70/0tz3UKF66xwN2plV6V2Hz333HNq2bKlqlWrprPPPlutWrVSUJBv96pVq6bY2FifZRkZGdq7d6/q1atX7P3m5ORIkrZs2SJJatGihc/lMTExRX7QHKtwV1b79u1L/4RO84zHk5mZqf79+5/UbYtb55mZmTrnnHNUp04dv+4rIyND69atU0xMTLGXF66H4zn2h1CtWrUkSfv37y/VDFu2bFFQUJDi4+N9ltevX19RUVH2GhQ69geo9NsPkmPfCyrtY0ule339XVfHzll4f4VzFsbheN/DO3fu1J49e5SamqrU1NRSPe7pEB4eXuT9Hkk6fPiwXX70/0tz3UKF308n+35HRVOlotCpU6cTHtsdGhpaJBRer1f16tXTm2++WextSvpHdTpVhhlLUtw6P1ler1eXXXaZxowZU+zlLVu2PO7t69atW+SHca1atXTOOedo/fr1fs1S2h8CwcHBxS4v79+Q/V1XZTFn4RvYN954o1JSUoq9TocOHUp9f2WlQYMGWrJkiZxzPq/b9u3bJUnnnHOOXe/o5Ufbvn276tSpU2QrovD76Xjvq1UmVSoKJ6t58+ZavHixunTpUuS3gKMVHiGSkZGhZs2a2fKdO3ee8Le+5s2bS5LWr1+vXr16lXi9kn7QnI4Zj6d58+Z+/9A80f199NFH2r17t19bC82bN9eBAweOuw6Pp3Xr1nY02tH69eun1NRUrVy5UhdddNFx7yMuLk5er1cZGRk+b1xmZ2drz5499hqUB39e31NdV8cqfLzjfR/ExMQoMjJSBQUFZfa4ZeG8887TK6+8og0bNqht27a2/PPPP7fLJalhw4aKiYkp9gi4L774wq53tKysLJ111lkV+hczf1Sp9xRO1qBBg1RQUKAJEyYUuSw/P1979uyR9Nt7FiEhIZo6darPb08nOuJFkhITE9W0aVNNnjzZ7q/Q0fcVEREhSUWuU14zlvaQ1P79++urr74qcmTNsfOXVv/+/eWc02OPPebX/Q0aNEgrV67URx99VOSyPXv2KD8//7iPe9FFF2n9+vVFdg+MGTNGERERGjp0aLF/BZ+ZmWmHDicnJ0squk7/9re/SZL69u173BlOhT+v76muq2PFxMSoe/fueu2117R161afywpnCQ4OVv/+/ZWWllZsPI79q2B/Dkktrb179+q7777T3r17bdnVV1+tkJAQPf/88z4zv/jii2rYsKHPUWf9+/fX+++/rx9++MGW/ec//9HGjRs1cODAIo+3evXqE/4iUZmwpSCpR48eGj58uCZOnKj09HT17t1bISEhysjI0Ny5czVlyhQNGDBAMTExeuCBBzRx4kT169dPycnJWrt2rT788MMTbjoGBQXphRde0JVXXqnzzjtPt956qxo0aKDvvvtO33zzjf3D7dixoyRp5MiRuvzyyxUcHKzBgweX24ylPSR19OjReuuttzRw4EDddttt6tixo3bv3q333ntPL774ohISEvxa50lJSbrpppv07LPPKiMjQ3369JHX69Xy5cuVlJRU7CkoCud477331K9fPzus8+DBg/r666/11ltvafPmzcd9La6++mpNmDBBn3zyiXr37m3LmzdvrtmzZ+u6665TmzZtdPPNN6t9+/bKy8vTp59+qrlz59ppFRISEpSSkqLU1FTt2bNHPXr00BdffKGZM2fq97//vZKSkvxaF/7w5/U91XVVnGeffVZdu3ZVYmKibr/9djVt2lSbN2/WggULlJ6eLkn685//rCVLlqhz584aNmyY2rZtq927d2vNmjVavHixdu/ebffnzyGp8+fP11dffSVJOnLkiNatW6cnnnhCknTVVVfZbql33nlHt956q6ZPn26vWWxsrO655x49/fTTOnLkiC644ALNmzdPy5cv15tvvumz62z8+PGaO3eukpKSNGrUKB04cEBPP/20zj33XN16660+M+Xk5GjdunUaMWKEX+uxQjvdhzuVh5L+ovlYKSkpLiIiosTLU1NTXceOHV14eLiLjIx05557rhszZozPX7gWFBS4xx57zDVo0MCFh4e7Sy65xK1fv77IoZPHHpJaaMWKFe6yyy5zkZGRLiIiwnXo0MFNnTrVLs/Pz3d33323i4mJcR6Pp8jhemU5o3OlPyTVOed+/vlnd9ddd7mGDRu66tWru9jYWJeSkmKHHpZ0SGpJ6zw/P989/fTTrnXr1q569eouJibGXXHFFW716tU+8x078/79+924ceNcfHy8q169ujvrrLPcxRdf7J555hmXl5d3wufRoUMHN2TIkGIv27hxoxs2bJhr0qSJq169uouMjHRdunRxU6dO9TkU8ciRI+6xxx5zTZs2dSEhIa5Ro0Zu3LhxRQ5XjIuLc3379i3yOD169HA9evSwr0t7SKpz/r2+pVlXJT22c85Jcn/60598lq1fv95dc801LioqyoWFhblWrVq5hx9+2Oc62dnZbsSIEa5Ro0YuJCTE1a9f3/Xs2dOlpqYWWQ+l/TFUeHhzcf8d/T1XuM6OXla43p566ikXFxfnqlev7tq1a+feeOONYh9r/fr1rnfv3q5GjRouKirK3XDDDW7Hjh1FrvfCCy9UudNceJw7DceDARXIrFmzNGLECG3durXI4cGAP373u9/pkksu0d///vdAj1JmeE8BZ5wbbrhBjRs3tlOKACdj4cKFysjI0Lhx4wI9SpliSwEAYNhSAAAYogAAMEQBAGCIAgDAEAUAgOEvmlFlpaWlnfBU2v4o/Et0oCojCqjUfvrpp2JPcyz9dlqEmTNnltlj9enTR5dcckmxl1WrVk2NGjUqs8cCAoW/U0Cl1rlzZ61evbrYy7xeb5mentrj8ZR4CvC4uDht3LixxNNPA5UFWwqoFF5++WV9+OGHRZZv3LhRBQUFp2UG51yJj7V9+3YNGDCgyKnPO3furAcffPB0jAeUCaKACmv79u12iuaPPvqo2NN2VxS5ubnFfnj77t27bZdT/fr1y/WzFoCywO4jVFiTJk2qUr9ljxw50j6TAaioiAIqlEWLFunhhx+W9NubyEd/0Elld/SWwsiRI/WHP/whwBMBRbH7CAH3888/a+nSpZKkFStW2EckVjU7duzQjh07JP22O6zws347d+6s2NjYQI4GGLYUEDCFb9quWLGixEM9zwRz5syxj3nk6CUEGlFAQGzevFk9e/aU1+vV4cOH7TfoM1FMTIx9NndaWpoSExMDPBHOZOw+wmlz+PBhvfLKK8rPz1dOTo6ysrLK9O8IKqudO3faB9q//vrrWrZsmTwej2655RbVrl07wNPhTMOWAk6LvLw8/fTTT2rXrp0OHToU6HEqvKCgIK1Zs0atWrVSWFhYoMfBGYQT4uG0mDZtmjp06EAQSsnr9apbt252JBZwuhAFlKuCggI99thjevvtt7V///5Aj1Op7N+/X4sWLdK4ceOUm5sb6HFwhmD3EcrF9u3blZeXpyNHjqhr167Kzs4O9EiVVo0aNbRy5UrVrl1bwcHBHL6KckUUUC7OP/98paenS9JpOzdRVVZ4qGqjRo2UkZGhatU4RgTlg91HKFPffvutBgwYoO+//14FBQUEoYwUrsvs7GwNGjRIK1asCPRIqKL4dQNlZuPGjVq2bJnS0tICPUqVlZubq3feeUeJiYmKiopS+/btAz0Sqhh2H6HMJCcnF3t6a5SPxMTEEj9LAjhZ7D7CKdu6dau6dOmilStXBnqUM8r//vc/XXTRRfrqq68CPQqqEHYf4ZSkp6fr008/1cqVK/nr5NPs4MGD+uyzz7RgwQLl5ubqwgsvDPRIqALYfYRTcscdd+ill14K9BhnvH79+mn+/PmBHgNVALuPAACGKOCkHD58WM8//7y+/vrrQI8CSVlZWZo6dar27t0b6FFQybH7CH4rPLld27ZtOf1CBRIUFKTVq1erTZs29gE+gL/YUoDfpk6dqoSEBIJQwXi9XnXv3l0PPfRQoEdBJUYU4Lfc3Fzt27cv0GOgGPv37+dMtDglRAF+2bZtG0Go4A4cOKBt27bJ6/UGehRUQryngFLLy8tTfHy8fvzxR37gVGAej0e1a9fWpk2bFB0dHehxUMmwpQC/5OfnE4QKzjmnI0eO8MeEOClEAaWye/durV69WkeOHAn0KCgFr9er9PR07dixI9CjoJIhCiiVBQsW6OKLL9auXbsCPQpKITc3Vz179tSMGTMCPQoqGaIAADBEASe0dOlSrVmzJtBj4CR8++23+ve//837Cyg1jj7CCSUkJGjdunWBHgMnqWHDhsrKylJISEigR0ElwJYCAMAQBQCAIQoAAEMUAACGKKBEmZmZGj16tLZv3x7oUXAK9u7dqwcffFDp6emBHgWVAEcfoURLly5VUlJSoMdAGZk9e7auv/76QI+BCo4tBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKKBEzZo108SJE1W/fv1Aj4JTUKtWLU2YMEHnnXdeoEdBJcAJ8XBCfBxn5cbHccIfbCkAAAxRAAAYogAAMEQBAGCIAk5o+vTpGjduXKDHwEm4/fbblZaWpmrVqgV6FFQSRAEnlJiYqDZt2gR6DJyEpk2bqnPnzvJ4PIEeBZUEUUCp8YOlcvF4PLxm8Bt/p4BS2bdvn7Zu3apLL71UO3fuDPQ4OIHw8HAtXbpU8fHxqlOnTqDHQSXClgJKpVatWmrZsiX7piuJoKAggoCTQhTglxo1avCXsRVctWrVFBERwa4jnBSigFKrXr26Vq1apQceeCDQo+A4brrpJm3YsEFRUVGBHgWVEPsC4Jfo6GjVqFEj0GPgOMLCwththJPGlgL8VqtWLdWvX5/dExVQvXr12ELAKeHoI/itoKBAOTk5io+P16FDhwI9Dv6foKAgrV+/Xi1btlRwcHCgx0ElxZYC/BYcHKw6dero1VdfVVJSUqDHgaTzzjtPM2fOVGxsLEHAKSEKOCmhoaEaPHiwWrZsGehRICk2NlY33nijIiMjAz0KKjmiAAAwvKeAU5KRkaHPP/9cN998s/hWCoxnn31WSUlJat++faBHQRXAIak4JS1atFB4eLj69OmjVatWadeuXYEe6YxRq1YtdenShSCgTLGlgDKTnJyshQsXssVwGng8HiUmJurLL78M9CioYnhPAWUmNTVVM2fODPQYZ4S//vWvSktLC/QYqIKIAspMbGysOnfurGHDhqlu3bqBHqdKioiI0JAhQ9S1a1fFxcUFehxUQew+Qrm48MILtW7dOknS4cOH2aV0isLDwyVJDRs21IYNGzhbLcoNUUC52LNnj/Lz85WXl6eOHTtqx44dgR6p0oqIiFB6erqioqIUFBTEeY1Qrvh1A+Wi8Pw7Xq9XY8eO1fvvv6/FixcHdqhKqHPnzrr22msVGxursLCwQI+DMwBRQLkKCgrSqFGjFBQUpK+//lo5OTnsSiqlmJgYXXrppRozZkygR8EZhN1HOC04iZ5/Ck9u16JFC94/wGnFdxtOi+DgYNWtW1evvPKK8vPztWPHDj344INsNRzj4YcfVosWLeTxeNSoUSOCgNOOLQUExNatW3XVVVepoKBAhw4d0qZNmwI9UsA0btxYtWrVkiS98cYbSkhICPBEOJMRBQTc8uXL1b1790CPETD//Oc/NWjQoECPAUgiCqgA9u7dq9WrV0uSPv74Yz355JMBnqj83XHHHRo4cKAkqX379qpXr16AJwJ+ww5LBFzt2rV16aWXSvrtDdZVq1ZJkjZt2qTvv/8+kKOVqcaNG6t169aSpF69etlzBioSthRQYU2aNEljx46VpEr9hnThZ1mPHDlSkydPDuwwwAkQBVRYu3fvVk5OjiTp8ccf15w5cwI8kf969uypadOmSZKio6N19tlnB3gi4PjYfYQKq06dOnZKh+TkZNWoUaPIdd59990K8RkONWvW1MCBAxUU5HuOycTERNtlBFQGbCmgUuvatavWrFlT7GV5eXkqKCgos8cKCgpSaGhosZc1btxY33zzjYKDg8vs8YBAIAqo1H755Rfl5+cXe9n48eP1yiuvlNlj9evXT6+99lqxlwUHB3OiOlQJ7D5CpRYdHV3iZf3791ejRo3K7LHatGmjmJiYMrs/oCJiSwEAYPjkNQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAwRAEAYIgCAMAQBQCAIQoAAEMUAACGKAAADFEAABiiAAAw/wdsClualIWvKQAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from PIL import Image # Import Image from PIL\n",
"\n",
"image_path = \"../datasets/test/circle/circle_285.png\" # Replace with the actual image path\n",
"image = Image.open(image_path).convert(\"L\") \n",
"# Run the show_prediction function\n",
"show_prediction(model, image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ab81524-8730-472b-ba07-8bc6df915e01",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "intern-skills",
"language": "python",
"name": "python3"
"name": "intern-skills"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -190,7 +330,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.9.20"
}
},
"nbformat": 4,
Expand Down