Skip to content

Commit

Permalink
Merge pull request #923 from adamjstewart/notebooks
Browse files Browse the repository at this point in the history
Ruff: format Jupyter notebooks too
  • Loading branch information
adamjstewart authored Sep 12, 2024
2 parents ccccadd + 360ca5b commit 97e0ae8
Show file tree
Hide file tree
Showing 5 changed files with 5,511 additions and 31 deletions.
4,212 changes: 4,211 additions & 1 deletion examples/binary_segmentation_intro.ipynb

Large diffs are not rendered by default.

14 changes: 6 additions & 8 deletions examples/camvid_segmentation_multiclass.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@
" augmentation=get_validation_augmentation(),\n",
")\n",
"\n",
"#Change to > 0 if not on Windows machine\n",
"# Change to > 0 if not on Windows machine\n",
"train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)\n",
"valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=0)\n",
"test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)"
Expand Down Expand Up @@ -545,12 +545,10 @@
"import pytorch_lightning as pl\n",
"import segmentation_models_pytorch as smp\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from torch.optim import lr_scheduler\n",
"\n",
"\n",
"class CamVidModel(pl.LightningModule):\n",
"\n",
" def __init__(self, arch, encoder_name, in_channels, out_classes, **kwargs):\n",
" super().__init__()\n",
" self.model = smp.create_model(\n",
Expand Down Expand Up @@ -591,13 +589,14 @@
" mask = mask.long()\n",
"\n",
" # Mask shape\n",
" assert mask.ndim == 3 # [batch_size, H, W]\n",
" assert mask.ndim == 3 # [batch_size, H, W]\n",
"\n",
" # Predict mask logits\n",
" logits_mask = self.forward(image)\n",
" \n",
" assert logits_mask.shape[1] == self.number_of_classes # [batch_size, number_of_classes, H, W]\n",
" \n",
"\n",
" assert (\n",
" logits_mask.shape[1] == self.number_of_classes\n",
" ) # [batch_size, number_of_classes, H, W]\n",
"\n",
" # Ensure the logits mask is contiguous\n",
" logits_mask = logits_mask.contiguous()\n",
Expand Down Expand Up @@ -1678,7 +1677,6 @@
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Fetch a batch from the test loader\n",
Expand Down
1,271 changes: 1,270 additions & 1 deletion examples/cars segmentation (camvid).ipynb

Large diffs are not rendered by default.

41 changes: 20 additions & 21 deletions examples/save_load_model_and_share_with_hf_hub.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,13 @@
"# save the model\n",
"model.save_pretrained(\n",
" \"saved-model-dir/unet-with-metadata/\",\n",
"\n",
" # additional information to be saved with the model\n",
" # only \"dataset\" and \"metrics\" are supported\n",
" dataset=\"PASCAL VOC\", # only string name is supported\n",
" metrics={ # should be a dictionary with metric name as key and metric value as value\n",
" metrics={ # should be a dictionary with metric name as key and metric value as value\n",
" \"mIoU\": 0.95,\n",
" \"accuracy\": 0.96\n",
" }\n",
" \"accuracy\": 0.96,\n",
" },\n",
")"
]
},
Expand Down Expand Up @@ -222,13 +221,10 @@
"# save the model and share it on the HF Hub (https://huggingface.co/models)\n",
"model.save_pretrained(\n",
" \"qubvel-hf/unet-with-metadata/\",\n",
" push_to_hub=True, # <---------- push the model to the hub\n",
" private=False, # <---------- make the model private or or public\n",
" push_to_hub=True, # <---------- push the model to the hub\n",
" private=False, # <---------- make the model private or or public\n",
" dataset=\"PASCAL VOC\",\n",
" metrics={\n",
" \"mIoU\": 0.95,\n",
" \"accuracy\": 0.96\n",
" }\n",
" metrics={\"mIoU\": 0.95, \"accuracy\": 0.96},\n",
")\n",
"\n",
"# see result here https://huggingface.co/qubvel-hf/unet-with-metadata"
Expand Down Expand Up @@ -267,10 +263,7 @@
"outputs": [],
"source": [
"# define a preprocessing transform for image that would be used during inference\n",
"preprocessing_transform = A.Compose([\n",
" A.Resize(256, 256),\n",
" A.Normalize()\n",
"])\n",
"preprocessing_transform = A.Compose([A.Resize(256, 256), A.Normalize()])\n",
"\n",
"model = smp.Unet()"
]
Expand Down Expand Up @@ -367,15 +360,21 @@
"# You can also save training augmentations to the Hub too (and load it back)!\n",
"#! Just make sure to provide key=\"train\" when saving and loading the augmentations.\n",
"\n",
"train_augmentations = A.Compose([\n",
" A.HorizontalFlip(p=0.5),\n",
" A.RandomBrightnessContrast(p=0.2),\n",
" A.ShiftScaleRotate(p=0.5),\n",
"])\n",
"train_augmentations = A.Compose(\n",
" [\n",
" A.HorizontalFlip(p=0.5),\n",
" A.RandomBrightnessContrast(p=0.2),\n",
" A.ShiftScaleRotate(p=0.5),\n",
" ]\n",
")\n",
"\n",
"train_augmentations.save_pretrained(directory_or_repo_on_the_hub, key=\"train\", push_to_hub=True)\n",
"train_augmentations.save_pretrained(\n",
" directory_or_repo_on_the_hub, key=\"train\", push_to_hub=True\n",
")\n",
"\n",
"restored_train_augmentations = A.Compose.from_pretrained(directory_or_repo_on_the_hub, key=\"train\")\n",
"restored_train_augmentations = A.Compose.from_pretrained(\n",
" directory_or_repo_on_the_hub, key=\"train\"\n",
")\n",
"print(restored_train_augmentations)"
]
},
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ test = [
[project.urls]
Homepage = 'https://github.com/qubvel-org/segmentation_models.pytorch'

[tool.ruff]
extend-include = ['*.ipynb']
fix = true

[tool.setuptools.dynamic]
version = {attr = 'segmentation_models_pytorch.__version__.__version__'}

Expand Down

0 comments on commit 97e0ae8

Please sign in to comment.