Skip to content

Commit

Permalink
Update notebook too
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Sep 11, 2024
1 parent 321f571 commit 468d2f2
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions examples/camvid_segmentation_multiclass.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@
" augmentation=get_validation_augmentation(),\n",
")\n",
"\n",
"#Change to > 0 if not on Windows machine\n",
"# Change to > 0 if not on Windows machine\n",
"train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)\n",
"valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=0)\n",
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)"
Expand Down Expand Up @@ -545,12 +545,10 @@
"import pytorch_lightning as pl\n",
"import segmentation_models_pytorch as smp\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from torch.optim import lr_scheduler\n",
"\n",
"\n",
"class CamVidModel(pl.LightningModule):\n",
"\n",
" def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):\n",
" super().__init__()\n",
" self.model = smp.create_model(\n",
Expand Down Expand Up @@ -591,13 +589,14 @@
" mask = mask.long()\n",
"\n",
" # Mask shape\n",
" assert mask.ndim == 3 # [batch_size, H, W]\n",
" assert mask.ndim == 3 # [batch_size, H, W]\n",
"\n",
" # Predict mask logits\n",
" logits_mask = self.forward(image)\n",
" \n",
" assert logits_mask.shape[1] == self.number_of_classes # [batch_size, number_of_classes, H, W]\n",
" \n",
"\n",
" assert (\n",
" logits_mask.shape[1] == self.number_of_classes\n",
" ) # [batch_size, number_of_classes, H, W]\n",
"\n",
" # Ensure the logits mask is contiguous\n",
" logits_mask = logits_mask.contiguous()\n",
Expand Down Expand Up @@ -1678,7 +1677,6 @@
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Fetch a batch from the test loader\n",
Expand Down

0 comments on commit 468d2f2

Please sign in to comment.