Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --noremove option to tiling jobs. Pull out alignment code check into a separate function #63

Merged
merged 9 commits into from
Jun 27, 2024
11 changes: 10 additions & 1 deletion parallel_examples/awsbatch/do_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def getCmdargs():
help="Maximum spectral difference for segmentation (default=%(default)s)")
p.add_argument("--spectDistPcntile", type=int, default=50, required=False,
help="Spectral Distance Percentile for segmentation (default=%(default)s)")
p.add_argument("--noremove", action="store_true", default=False,
help="don't remove files from S3 (for debugging)")

cmdargs = p.parse_args()
if cmdargs.bands is not None:
Expand Down Expand Up @@ -109,6 +111,11 @@ def main():

# now submit an array job with all the tiles
# (can't do this before now because we don't know how many tiles)
arrayProperties = {}
if len(colRowList) > 1:
# throws error if this is 1...
arrayProperties = {'size': len(colRowList)}

containerOverrides = {
"command": ['/usr/bin/python3', '/ubarscsw/bin/do_tile.py',
'--bucket', cmdargs.bucket, '--pickle', cmdargs.pickle,
Expand All @@ -119,7 +126,7 @@ def main():
response = batch.submit_job(jobName="pyshepseg_tiles",
jobQueue=cmdargs.jobqueue,
jobDefinition=cmdargs.jobdefntile,
arrayProperties={'size': len(colRowList)},
arrayProperties=arrayProperties,
containerOverrides=containerOverrides)
tilesJobId = response['jobId']
print('Tiles Job Id', tilesJobId)
Expand All @@ -137,6 +144,8 @@ def main():
cmd.extend(['--spatialstats', cmdargs.spatialstats])
if cmdargs.nogdalstats:
cmd.append('--nogdalstats')
if cmdargs.noremove:
cmd.append('--noremove')

response = batch.submit_job(jobName="pyshepseg_stitch",
jobQueue=cmdargs.jobqueue,
Expand Down
36 changes: 20 additions & 16 deletions parallel_examples/awsbatch/do_stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def getCmdargs():
p.add_argument("--nogdalstats", action="store_true", default=False,
help="don't calculate GDAL's statistics or write a colour table. " +
"Can't be used with --stats.")
p.add_argument("--noremove", action="store_true", default=False,
help="don't remove files from S3 (for debugging)")

cmdargs = p.parse_args()

Expand Down Expand Up @@ -93,15 +95,16 @@ def main():
cmdargs.overlapsize, tempDir)

# clean up files to release space
objs = []
for col, row in tileFilenames:
filename = '{}_{}_{}.{}'.format(cmdargs.tileprefix, col, row, 'tif')
objs.append({'Key': filename})

# workaround 1000 at a time limit
while len(objs) > 0:
s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs[0:1000]})
del objs[0:1000]
if not cmdargs.noremove:
objs = []
for col, row in tileFilenames:
filename = '{}_{}_{}.{}'.format(cmdargs.tileprefix, col, row, 'tif')
objs.append({'Key': filename})

# workaround 1000 at a time limit
while len(objs) > 0:
s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs[0:1000]})
del objs[0:1000]

# open for the creation of stats
localDs = gdal.Open(localOutfile, gdal.GA_Update)
Expand Down Expand Up @@ -153,13 +156,14 @@ def main():
s3.upload_file(localOutfile, cmdargs.bucket, cmdargs.outfile)

# cleanup temp files from S3
objs = [{'Key': cmdargs.pickle}]
if cmdargs.stats is not None:
objs.append({'Key': statsKey})
if cmdargs.spatialstats is not None:
objs.append({'Key': spatialstatsKey})

s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs})
if not cmdargs.noremove:
objs = [{'Key': cmdargs.pickle}]
if cmdargs.stats is not None:
objs.append({'Key': statsKey})
if cmdargs.spatialstats is not None:
objs.append({'Key': spatialstatsKey})

s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs})

# cleanup
shutil.rmtree(tempDir)
Expand Down
4 changes: 4 additions & 0 deletions parallel_examples/awsbatch/submit-pyshepseg-job.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def getCmdargs():
help="Maximum spectral difference for segmentation (default=%(default)s)")
p.add_argument("--spectDistPcntile", type=int, default=50, required=False,
help="Spectral Distance Percentile for segmentation (default=%(default)s)")
p.add_argument("--noremove", action="store_true", default=False,
help="don't remove files from S3 (for debugging)")

cmdargs = p.parse_args()

Expand Down Expand Up @@ -98,6 +100,8 @@ def main():
cmd.append('--nogdalstats')
if cmdargs.tileprefix is not None:
cmd.extend(['--tileprefix', cmdargs.tileprefix])
if cmdargs.noremove:
cmd.append('--noremove')

# submit the prepare job
response = batch.submit_job(jobName="pyshepseg_prepare",
Expand Down
98 changes: 56 additions & 42 deletions pyshepseg/tilingstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,27 +111,8 @@ def calcPerSegmentStatsTiled(imgfile, imgbandnum, segfile,
valid pixels (not nodata) that were used to calculate the statistics.

"""
segds = segfile
if not isinstance(segds, gdal.Dataset):
segds = gdal.Open(segfile, gdal.GA_Update)
segband = segds.GetRasterBand(1)

imgds = imgfile
if not isinstance(imgds, gdal.Dataset):
imgds = gdal.Open(imgfile, gdal.GA_ReadOnly)
imgband = imgds.GetRasterBand(imgbandnum)
if (imgband.DataType == gdal.GDT_Float32 or
imgband.DataType == gdal.GDT_Float64):
raise PyShepSegStatsError("Float image types not supported")

if segband.XSize != imgband.XSize or segband.YSize != imgband.YSize:
raise PyShepSegStatsError("Images must be same size")

if segds.GetGeoTransform() != imgds.GetGeoTransform():
raise PyShepSegStatsError("Images must have same spatial extent and pixel size")

if not equalProjection(segds.GetProjection(), imgds.GetProjection()):
raise PyShepSegStatsError("Images must be in the same projection")
segds, segband, imgds, imgband = doImageAlignmentChecks(segfile,
imgfile, imgbandnum)

attrTbl = segband.GetDefaultRAT()
existingColNames = [attrTbl.GetNameOfCol(i)
Expand Down Expand Up @@ -184,6 +165,58 @@ def calcPerSegmentStatsTiled(imgfile, imgbandnum, segfile,
raise PyShepSegStatsError('Not all pixels found during processing')


def doImageAlignmentChecks(segfile, imgfile, imgbandnum):
"""
Do the checks that the segment file and image file that is being used to
collect the stats actually align. We refuse to process the files if they
don't as it is not clear how they should be made to line up - this is up
to the user to get right. Also checks that imgfile is not a float image.

Parameters
----------
segfile : str or gdal.Dataset
Path to segmented file or an open GDAL dataset.
imgfile : string
Path to input file for collecting statistics from
imgbandnum : int
1-based index of the band number in imgfile to use for collecting stats

Returns
-------
segds: gdal.Dataset
Opened GDAL datset for the segments file
segband: gdal.Band
First Band of the segds
imgds: gdal.Dataset
Opened GDAL dataset for the image data file
imgband: gdal.Band
Requested band for the imgds
"""
segds = segfile
if not isinstance(segds, gdal.Dataset):
segds = gdal.Open(segfile, gdal.GA_Update)
segband = segds.GetRasterBand(1)

imgds = imgfile
if not isinstance(imgds, gdal.Dataset):
imgds = gdal.Open(imgfile, gdal.GA_ReadOnly)
imgband = imgds.GetRasterBand(imgbandnum)
if (imgband.DataType == gdal.GDT_Float32 or
imgband.DataType == gdal.GDT_Float64):
raise PyShepSegStatsError("Float image types not supported")

if segband.XSize != imgband.XSize or segband.YSize != imgband.YSize:
raise PyShepSegStatsError("Images must be same size")

if segds.GetGeoTransform() != imgds.GetGeoTransform():
raise PyShepSegStatsError("Images must have same spatial extent and pixel size")

if not equalProjection(segds.GetProjection(), imgds.GetProjection()):
raise PyShepSegStatsError("Images must be in the same projection")

return segds, segband, imgds, imgband


@njit
def accumulateSegDict(segDict, noDataDict, imgNullVal, tileSegments, tileImageData):
"""
Expand Down Expand Up @@ -1028,28 +1061,9 @@ def calcPerSegmentSpatialStatsTiled(imgfile, imgbandnum, segfile,
The value to fill in for segments that have no data.

"""
segds = segfile
if not isinstance(segds, gdal.Dataset):
segds = gdal.Open(segfile, gdal.GA_Update)
segband = segds.GetRasterBand(1)
segds, segband, imgds, imgband = doImageAlignmentChecks(segfile,
imgfile, imgbandnum)

imgds = imgfile
if not isinstance(imgds, gdal.Dataset):
imgds = gdal.Open(imgfile, gdal.GA_ReadOnly)
imgband = imgds.GetRasterBand(imgbandnum)
if (imgband.DataType == gdal.GDT_Float32 or
imgband.DataType == gdal.GDT_Float64):
raise PyShepSegStatsError("Float image types not supported")

if segband.XSize != imgband.XSize or segband.YSize != imgband.YSize:
raise PyShepSegStatsError("Images must be same size")

if segds.GetGeoTransform() != imgds.GetGeoTransform():
raise PyShepSegStatsError("Images must have same spatial extent and pixel size")

if not equalProjection(segds.GetProjection(), imgds.GetProjection()):
raise PyShepSegStatsError("Images must be in the same projection")

attrTbl = segband.GetDefaultRAT()
existingColNames = [attrTbl.GetNameOfCol(i)
for i in range(attrTbl.GetColumnCount())]
Expand Down
Loading