-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from BU-Spark/project-setup
Project setup
- Loading branch information
Showing
9 changed files
with
2,091 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "fc8d75bc-5e9e-4fe6-a683-ae07437eac38", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import json\n", | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"from torchvision import transforms\n", | ||
"from torch.utils.data import DataLoader\n", | ||
"from PIL import Image\n", | ||
"from yolov5 import YOLOv5" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "aa49ffdc-15ed-45ad-8e3b-770b727e404b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class CustomDataset(Dataset):\n", | ||
" def __init__(self, directory, transform=None):\n", | ||
" self.directory = directory\n", | ||
" self.transform = transform\n", | ||
" self.images = []\n", | ||
" self.annotations = []\n", | ||
"\n", | ||
" for filename in os.listdir(directory):\n", | ||
" if filename.endswith('.jpeg'):\n", | ||
" img_path = os.path.join(directory, filename)\n", | ||
" ann_path = os.path.join(directory, filename.replace('.jpeg', '.json'))\n", | ||
" \n", | ||
" if os.path.exists(ann_path):\n", | ||
" self.images.append(img_path)\n", | ||
" self.annotations.append(ann_path)\n", | ||
"\n", | ||
" def __len__(self):\n", | ||
" return len(self.images)\n", | ||
"\n", | ||
" def __getitem__(self, idx):\n", | ||
" img_path = self.images[idx]\n", | ||
" ann_path = self.annotations[idx]\n", | ||
"\n", | ||
" image = Image.open(img_path).convert('RGB')\n", | ||
" with open(ann_path, 'r') as f:\n", | ||
" annotation = json.load(f)\n", | ||
"\n", | ||
" categories = [ann['category_id'] for ann in annotation['annotations']]\n", | ||
" bboxes = [ann['bbox'] for ann in annotation['annotations']]\n", | ||
"\n", | ||
" image_width, image_height = image.size\n", | ||
" bboxes = [[(x + w / 2) / image_width, (y + h / 2) / image_height, w / image_width, h / image_height] \n", | ||
" for x, y, w, h in bboxes]\n", | ||
"\n", | ||
" target = {'boxes': torch.Tensor(bboxes), 'labels': torch.Tensor(categories).long()}\n", | ||
" \n", | ||
" if self.transform:\n", | ||
" image = self.transform(image)\n", | ||
"\n", | ||
" return image, target\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "acb0de3a-d9fa-444e-ae5d-090f0b3a673e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"transform = transforms.Compose([\n", | ||
" transforms.Resize((640, 640)),\n", | ||
" transforms.ToTensor(),\n", | ||
"])\n", | ||
"\n", | ||
"dataset = CustomDataset(directory=\"train\", transform=transform)\n", | ||
"data_loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)\n", | ||
"\n", | ||
"model = YOLOv5('yolov5s.pt', autoshape=True)\n", | ||
"model.train()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "cec6bbdf-b14e-405c-87d5-4bcb0edab308", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Loss function\n", | ||
"class YOLOv5Loss(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super(YOLOv5Loss, self).__init__()\n", | ||
" self.mse_loss = nn.MSELoss() # GIoU another option\n", | ||
" self.bce_loss = nn.BCEWithLogitsLoss()\n", | ||
" self.ce_loss = nn.CrossEntropyLoss()\n", | ||
"\n", | ||
" def forward(self, predictions, targets):\n", | ||
" # Implement parsing of predictions and targets\n", | ||
" obj_preds, no_obj_preds, class_preds, box_preds = predictions\n", | ||
" obj_targets, no_obj_targets, class_targets, box_targets = targets\n", | ||
" objectness_loss = self.bce_loss(obj_preds, obj_targets)\n", | ||
"\n", | ||
" no_objectness_loss = self.bce_loss(no_obj_preds, no_obj_targets)\n", | ||
"\n", | ||
" classification_loss = self.ce_loss(class_preds, class_targets)\n", | ||
"\n", | ||
" box_loss = self.mse_loss(box_preds, box_targets)\n", | ||
"\n", | ||
" # Combine losses\n", | ||
" total_loss = (\n", | ||
" objectness_loss + \n", | ||
" no_objectness_loss + \n", | ||
" classification_loss + \n", | ||
" box_loss\n", | ||
" )\n", | ||
"\n", | ||
" return total_loss" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2e6c45fb-87bc-430f-be57-fe68aa915242", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Optimizer\n", | ||
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ef53d453-b3fa-430a-826e-39d21e0882c2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Training Loop\n", | ||
"num_epochs = 10 \n", | ||
"for epoch in range(num_epochs):\n", | ||
" for images, targets in data_loader:\n", | ||
" optimizer.zero_grad()\n", | ||
" \n", | ||
" outputs = model(images)\n", | ||
" \n", | ||
" loss = loss_function(outputs, targets)\n", | ||
" \n", | ||
" loss.backward()\n", | ||
" optimizer.step()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2476c91f-a0ea-44b7-9555-2126a1a14666", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
train: ../train/images | ||
val: ../test/images | ||
|
||
nc: 127 | ||
names: | ||
[ | ||
"Paper Cup", | ||
"Snack or Candy Bag or Wrapper ", | ||
"Wipe", | ||
"Wax Paper", | ||
"Latex Gloves", | ||
"Juice or Other Pouch", | ||
"Diaper", | ||
"Padded Envelope (mixed materials)", | ||
"Blister Pack", | ||
"Pens and Pencils", | ||
"Miscellaneous Office Supplies", | ||
"Facemask and Other PPE", | ||
"Shelf Stable Carton", | ||
"Soiled Plastic", | ||
"Soiled Metal", | ||
"Soiled Glass", | ||
"Ceramics", | ||
"Unclassifiable", | ||
"Filled Bag", | ||
"Coffee Pod", | ||
"Other Trash", | ||
"Flexible container lid / seal", | ||
"Snack Food Canister", | ||
"Hard Cover Books", | ||
"Compostable Fiber Ware", | ||
"Compostable Cutlery", | ||
"Compostable Plastic Cups", | ||
"Compostable Paper Cups", | ||
"Paper Towel/Napkins/Tissue", | ||
"Wooden Coffee Stirrer or Chopstick", | ||
"Soiled Cardboard Box", | ||
"Compostable Plastic Lid", | ||
"Food Soiled Paper", | ||
"Other compostable material", | ||
"Sandwich paper wrapper", | ||
"Plastic strapping", | ||
"Batteries", | ||
"Cables", | ||
"Computers", | ||
"Monitors", | ||
"Toner and Ink Cartridges", | ||
"Miscellaneous Electronics", | ||
"LED Lightbulb", | ||
"Meat and Fish", | ||
"Bones and Shells", | ||
"Cheese and Other Fats", | ||
"Fruits And Veggies", | ||
"Other Food or Mixed Food", | ||
"Breads", | ||
"Grains", | ||
"Tea Bags", | ||
"Coffee Grounds", | ||
"Egg Shell", | ||
"Glass Bottles", | ||
"Glass Jars", | ||
"Broken Glass", | ||
"Other Clean Glass", | ||
"Drinking glass or glass ovenware", | ||
"Metal Can", | ||
"Aluminum Foil", | ||
"Aluminum Catering Tray", | ||
"Other Clean Metal", | ||
"Aerosol Can", | ||
"Metallic Bottle Cap or Lid", | ||
"Metal Strapping", | ||
"Liquids", | ||
"Leaves, Flowers, Grass Clippings", | ||
"Office Paper", | ||
"Shredded Paper", | ||
"Clean Cardboard", | ||
"Refrigerated Beverage Carton", | ||
"Magazines Newspaper", | ||
"Receipts and Thermal Paper", | ||
"Empty Paper Bag", | ||
"Cardboard Coffee Cup Sleeve", | ||
"Clean Paper Plate", | ||
"Colored Memo Note", | ||
"Office Folder", | ||
"Paper Roll", | ||
"Wrapping Paper", | ||
"Other Clean Paper", | ||
"Plastic Drink Bottle", | ||
"Plastic Milk Jug or Personal Care Bottle", | ||
"Empty Plastic Bag", | ||
"Yogurt Tub or Container", | ||
"Expanded Polystyrene (styrofoam)", | ||
"Other Clean Plastics (rigid)", | ||
"Straws", | ||
"Clear Clamshell Container", | ||
"Plastic Cutlery", | ||
"Plastic Lid except black", | ||
"Plastic Coffee Stirrer", | ||
"Clear Plastic Cup", | ||
"Colored Plastic Cup", | ||
"Black Plastic", | ||
"Plastic Wrap", | ||
"Bubble Wrap", | ||
"Incandescent Lightbulbs", | ||
"CFL Lightbulbs", | ||
"Textiles and Clothes", | ||
"Food canister", | ||
"Plastic Lid", | ||
"Other Trash", | ||
"Paper Wrapper", | ||
"Black Plastic Container", | ||
"Padded Envelope", | ||
] |
Oops, something went wrong.