Skip to content

Commit

Permalink
Update SegmentationDataset class constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Feb 22, 2024
1 parent 1403896 commit 0ef8ed1
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions dataset/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,30 @@


class SegmentationDataset(Dataset):
"""Semantic segmentation dataset based on input csvs listing pairs of imagery and ground truth patches as .tif."""
"""Semantic segmentation dataset based on input csvs listing pairs of imagery and ground truth patches as .tif.
Args:
dataset_list_path (str): The path to the dataset list file.
num_bands (int): The number of bands in the imagery.
dontcare (Optional[int]): The value to be ignored in the label.
max_sample_count (Optional[int]): The maximum number of samples to load from the dataset.
radiom_transform (Optional[Callable]): The radiometric transform function to be applied to the samples.
geom_transform (Optional[Callable]): The geometric transform function to be applied to the samples.
totensor_transform (Optional[Callable]): The transform function to convert samples to tensors.
debug (bool): Whether to enable debug mode.
Attributes:
max_sample_count (int): The maximum number of samples to load from the dataset.
num_bands (int): The number of bands in the imagery.
radiom_transform (Optional[Callable]): The radiometric transform function to be applied to the samples.
geom_transform (Optional[Callable]): The geometric transform function to be applied to the samples.
totensor_transform (Optional[Callable]): The transform function to convert samples to tensors.
debug (bool): Whether debug mode is enabled.
dontcare (Optional[int]): The value to be ignored in the label.
list_path (str): The path to the dataset list file.
assets (List[Dict[str, str]]): The list of filepaths to images and labels.
"""

def __init__(self,
dataset_list_path,
Expand All @@ -37,7 +60,6 @@ def __init__(self,
geom_transform=None,
totensor_transform=None,
debug=False):
# note: if 'max_sample_count' is None, then it will be read from the dataset at runtime
self.max_sample_count = max_sample_count
self.num_bands = num_bands
self.radiom_transform = radiom_transform
Expand Down

0 comments on commit 0ef8ed1

Please sign in to comment.