Skip to content

Commit

Permalink
sampling_segmentation.py: remove unused mask_image() function (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
remtav authored Jun 7, 2022
1 parent 682bdc7 commit 406771c
Showing 1 changed file with 0 additions and 34 deletions.
34 changes: 0 additions & 34 deletions sampling_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,35 +27,6 @@
np.random.seed(1234)


def mask_image(arrayA, arrayB):
"""Function to mask values of arrayB, based on 0 values from arrayA.
>>> x1 = np.array([0, 2, 4, 6, 0, 3, 9, 8], dtype=np.uint8).reshape(2,2,2)
>>> x2 = np.array([1.5, 1.2, 1.6, 1.2, 11., 1.1, 25.9, 0.1], dtype=np.float32).reshape(2,2,2)
>>> mask_image(x1, x2)
array([[[ 0. , 0. ],
[ 1.6, 1.2]],
<BLANKLINE>
[[ 0. , 0. ],
[25.9, 0.1]]], dtype=float32)
"""

# Handle arrayA of shapes (h,w,c) and (h,w)
if len(arrayA.shape) == 3:
mask = arrayA[:, :, 0] != 0
else:
mask = arrayA != 0

ma_array = np.zeros(arrayB.shape, dtype=arrayB.dtype)
# Handle arrayB of shapes (h,w,c) and (h,w)
if len(arrayB.shape) == 3:
for i in range(0, arrayB.shape[2]):
ma_array[:, :, i] = mask * arrayB[:, :, i]
else:
ma_array = arrayB * mask
return ma_array


def validate_class_prop_dict(actual_classes_dict, config_dict):
"""
Populate dictionary containing class values found in vector data with values (thresholds) from sample/class_prop
Expand Down Expand Up @@ -400,7 +371,6 @@ def main(cfg: DictConfig) -> None:
# OTHER PARAMETERS
# TODO class_prop get_key_def('class_proportion', params['sample']['sampling_method'], None, expected_type=dict)
class_prop = None
mask_reference = False # TODO get_key_def('mask_reference', params['sample'], default=False, expected_type=bool)
# set dontcare (aka ignore_index) value
dontcare = cfg.dataset.ignore_index if cfg.dataset.ignore_index is not None else -1
if dontcare == 0:
Expand Down Expand Up @@ -520,10 +490,6 @@ def main(cfg: DictConfig) -> None:
with rasterio.open(out_tif, "w", **out_meta) as dest:
dest.write(np_label_debug)

# Mask the zeros from input image into label raster.
if mask_reference:
np_label_raster = mask_image(np_input_image, np_label_raster)

if aoi.split == 'trn':
out_file = trn_hdf5
elif aoi.split == 'tst':
Expand Down

0 comments on commit 406771c

Please sign in to comment.