-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #21 from BU-Spark/poc
Final POC Delivery
- Loading branch information
Showing
16 changed files
with
7,849 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Image Cropping\n", | ||
"\n", | ||
"This notebook defines helper functions that we used to extract teabags from the Spare-it dataset to use for fine-tuning Stable Diffusion. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "wyGUnS-qkKWB" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import json\n", | ||
"from PIL import Image, ImageDraw, ImageOps\n", | ||
"\n", | ||
"# Function to check if a directory exists\n", | ||
"def check_directory(path):\n", | ||
" if not os.path.exists(path):\n", | ||
" raise FileNotFoundError(f\"Directory does not exist: {path}\")\n", | ||
"\n", | ||
"# Function to process JSON files and extract segmentation for a given category_id\n", | ||
"def process_json_files(json_source_path, category_id_to_crop):\n", | ||
" crop_info = {}\n", | ||
"\n", | ||
" # Check if the source path exists\n", | ||
" check_directory(json_source_path)\n", | ||
"\n", | ||
" # Iterate through all JSON files in the source directory\n", | ||
" for filename in os.listdir(json_source_path):\n", | ||
" if filename.endswith('.json'):\n", | ||
" json_file_path = os.path.join(json_source_path, filename)\n", | ||
" with open(json_file_path, 'r') as f:\n", | ||
" data = json.load(f)\n", | ||
"\n", | ||
" # Iterate through annotations and find matching category_id\n", | ||
" for annotation in data['annotations']:\n", | ||
" if annotation['category_id'] == category_id_to_crop:\n", | ||
" crop_info[filename] = annotation['segmentation'] # Use segmentation data\n", | ||
" break # Only need one match per file\n", | ||
"\n", | ||
" return crop_info\n", | ||
"\n", | ||
"# Function to create a mask and crop the image based on segmentation\n", | ||
"def crop_images(crop_info, image_source_path, image_dest_path):\n", | ||
" # Ensure destination path exists\n", | ||
" os.makedirs(image_dest_path, exist_ok=True)\n", | ||
"\n", | ||
" for json_filename, segmentations in crop_info.items():\n", | ||
" image_filename = json_filename.replace('.json', '.jpeg') # Assuming .jpeg images\n", | ||
" image_path = os.path.join(image_source_path, image_filename)\n", | ||
"\n", | ||
" if os.path.exists(image_path):\n", | ||
" with Image.open(image_path) as img:\n", | ||
" img = ImageOps.exif_transpose(img) # Correct orientation if needed\n", | ||
"\n", | ||
" # Create a transparent background image of same size as original\n", | ||
" object_img = Image.new('RGBA', img.size)\n", | ||
"\n", | ||
" # Create a mask from the segmentation points (same size as image)\n", | ||
" mask = Image.new('L', img.size, 0) # Create a blank grayscale mask (L mode)\n", | ||
" mask_draw = ImageDraw.Draw(mask)\n", | ||
"\n", | ||
" # Draw all polygons (segmentations) on the mask\n", | ||
" for segmentation in segmentations:\n", | ||
" polygon = [(segmentation[i], segmentation[i+1]) for i in range(0, len(segmentation), 2)]\n", | ||
" mask_draw.polygon(polygon, outline=255, fill=255) # Fill polygon with white (255)\n", | ||
"\n", | ||
" # Apply mask to keep only object and make background transparent\n", | ||
" object_img.paste(img.convert('RGBA'), (0, 0), mask=mask)\n", | ||
"\n", | ||
" # Crop to bounding box of non-zero pixels in mask (object area)\n", | ||
" bbox = mask.getbbox() # Get bounding box of non-zero region in mask\n", | ||
" if bbox:\n", | ||
" object_img_cropped = object_img.crop(bbox)\n", | ||
"\n", | ||
" # Convert RGBA to RGB before saving as JPEG\n", | ||
" object_img_cropped_rgb = object_img_cropped.convert(\"RGB\")\n", | ||
"\n", | ||
" # Save cropped image without transparency\n", | ||
" object_img_cropped_rgb.save(os.path.join(image_dest_path, image_filename))\n", | ||
" else:\n", | ||
" print(f\"Image file not found: {image_path}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "2oa8H58ikKqn" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"#Usage\n", | ||
"def main():\n", | ||
" #Parameters\n", | ||
" json_source_path = 'original_json'#Replace with you own path for images\n", | ||
" image_source_path = 'original_img'#Replace with you own path for images\n", | ||
" image_dest_path = './img_cropped'#Replace with you own path for images\n", | ||
" category_id_to_crop = 55\n", | ||
"\n", | ||
" try:\n", | ||
" # Process JSON files to get cropping information (segmentation points)\n", | ||
" crop_info = process_json_files(json_source_path, category_id_to_crop)\n", | ||
"\n", | ||
" if not crop_info:\n", | ||
" print(f\"No annotations found with category_id {category_id_to_crop}\")\n", | ||
" return\n", | ||
"\n", | ||
" # Crop images based on segmentation and save them without transparency\n", | ||
" crop_images(crop_info, image_source_path, image_dest_path)\n", | ||
"\n", | ||
" print(\"Cropping completed successfully.\")\n", | ||
"\n", | ||
" except FileNotFoundError as e:\n", | ||
" print(e)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 1 | ||
} |
Oops, something went wrong.