diff --git a/parallel_examples/awsbatch/do_prepare.py b/parallel_examples/awsbatch/do_prepare.py index cd2d7a2..df35f33 100755 --- a/parallel_examples/awsbatch/do_prepare.py +++ b/parallel_examples/awsbatch/do_prepare.py @@ -64,6 +64,8 @@ def getCmdargs(): help="Maximum spectral difference for segmentation (default=%(default)s)") p.add_argument("--spectDistPcntile", type=int, default=50, required=False, help="Spectral Distance Percentile for segmentation (default=%(default)s)") + p.add_argument("--noremove", action="store_true", default=False, + help="don't remove files from S3 (for debugging)") cmdargs = p.parse_args() if cmdargs.bands is not None: @@ -116,10 +118,20 @@ def main(): '--minSegmentSize', str(cmdargs.minSegmentSize), '--maxSpectDiff', cmdargs.maxSpectDiff, '--spectDistPcntile', str(cmdargs.spectDistPcntile)]} + + arrayProperties = {} + if len(colRowList) > 1: + # throws error if this is 1... + arrayProperties['size'] = len(colRowList) + else: + # must fake AWS_BATCH_JOB_ARRAY_INDEX + # can't set this as and env var as Batch overrides + containerOverrides['command'].extend(['--arrayindex', '0']) + response = batch.submit_job(jobName="pyshepseg_tiles", jobQueue=cmdargs.jobqueue, jobDefinition=cmdargs.jobdefntile, - arrayProperties={'size': len(colRowList)}, + arrayProperties=arrayProperties, containerOverrides=containerOverrides) tilesJobId = response['jobId'] print('Tiles Job Id', tilesJobId) @@ -137,6 +149,8 @@ def main(): cmd.extend(['--spatialstats', cmdargs.spatialstats]) if cmdargs.nogdalstats: cmd.append('--nogdalstats') + if cmdargs.noremove: + cmd.append('--noremove') response = batch.submit_job(jobName="pyshepseg_stitch", jobQueue=cmdargs.jobqueue, diff --git a/parallel_examples/awsbatch/do_stitch.py b/parallel_examples/awsbatch/do_stitch.py index 29a3e3c..a76a909 100755 --- a/parallel_examples/awsbatch/do_stitch.py +++ b/parallel_examples/awsbatch/do_stitch.py @@ -49,6 +49,8 @@ def getCmdargs(): p.add_argument("--nogdalstats", action="store_true", default=False, help="don't calculate GDAL's statistics or write a colour table. " + "Can't be used with --stats.") + p.add_argument("--noremove", action="store_true", default=False, + help="don't remove files from S3 (for debugging)") cmdargs = p.parse_args() @@ -93,15 +95,16 @@ def main(): cmdargs.overlapsize, tempDir) # clean up files to release space - objs = [] - for col, row in tileFilenames: - filename = '{}_{}_{}.{}'.format(cmdargs.tileprefix, col, row, 'tif') - objs.append({'Key': filename}) - - # workaround 1000 at a time limit - while len(objs) > 0: - s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs[0:1000]}) - del objs[0:1000] + if not cmdargs.noremove: + objs = [] + for col, row in tileFilenames: + filename = '{}_{}_{}.{}'.format(cmdargs.tileprefix, col, row, 'tif') + objs.append({'Key': filename}) + + # workaround 1000 at a time limit + while len(objs) > 0: + s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs[0:1000]}) + del objs[0:1000] # open for the creation of stats localDs = gdal.Open(localOutfile, gdal.GA_Update) @@ -153,13 +156,14 @@ def main(): s3.upload_file(localOutfile, cmdargs.bucket, cmdargs.outfile) # cleanup temp files from S3 - objs = [{'Key': cmdargs.pickle}] - if cmdargs.stats is not None: - objs.append({'Key': statsKey}) - if cmdargs.spatialstats is not None: - objs.append({'Key': spatialstatsKey}) - - s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs}) + if not cmdargs.noremove: + objs = [{'Key': cmdargs.pickle}] + if cmdargs.stats is not None: + objs.append({'Key': statsKey}) + if cmdargs.spatialstats is not None: + objs.append({'Key': spatialstatsKey}) + + s3.delete_objects(Bucket=cmdargs.bucket, Delete={'Objects': objs}) # cleanup shutil.rmtree(tempDir) diff --git a/parallel_examples/awsbatch/do_tile.py b/parallel_examples/awsbatch/do_tile.py index 9834623..affee72 100755 --- a/parallel_examples/awsbatch/do_tile.py +++ b/parallel_examples/awsbatch/do_tile.py @@ -21,13 +21,6 @@ gdal.UseExceptions() -# set by AWS Batch -ARRAY_INDEX = os.getenv('AWS_BATCH_JOB_ARRAY_INDEX') -if ARRAY_INDEX is None: - raise SystemExit('Must set AWS_BATCH_JOB_ARRAY_INDEX env var') - -ARRAY_INDEX = int(ARRAY_INDEX) - def getCmdargs(): """ @@ -48,9 +41,17 @@ def getCmdargs(): help="Maximum spectral difference for segmentation (default=%(default)s)") p.add_argument("--spectDistPcntile", type=int, default=50, required=False, help="Spectral Distance Percentile for segmentation (default=%(default)s)") + p.add_argument("--arrayindex", type=int, + help="Override AWS_BATCH_JOB_ARRAY_INDEX env var") cmdargs = p.parse_args() + if cmdargs.arrayindex is None: + cmdargs.arrayindex = os.getenv('AWS_BATCH_JOB_ARRAY_INDEX') + if cmdargs.arrayindex is None: + raise SystemExit('Must set AWS_BATCH_JOB_ARRAY_INDEX env var or ' + + 'specify --arrayindex') + return cmdargs @@ -75,7 +76,7 @@ def main(): tempDir = tempfile.mkdtemp() # work out which tile we are processing - col, row = dataFromPickle['colRowList'][ARRAY_INDEX] + col, row = dataFromPickle['colRowList'][cmdargs.arrayindex] # work out a filename to save with the output of this tile # Note: this filename format is repeated in do_stitch.py diff --git a/parallel_examples/awsbatch/submit-pyshepseg-job.py b/parallel_examples/awsbatch/submit-pyshepseg-job.py index 1e44579..d3ee754 100755 --- a/parallel_examples/awsbatch/submit-pyshepseg-job.py +++ b/parallel_examples/awsbatch/submit-pyshepseg-job.py @@ -59,6 +59,8 @@ def getCmdargs(): help="Maximum spectral difference for segmentation (default=%(default)s)") p.add_argument("--spectDistPcntile", type=int, default=50, required=False, help="Spectral Distance Percentile for segmentation (default=%(default)s)") + p.add_argument("--noremove", action="store_true", default=False, + help="don't remove files from S3 (for debugging)") cmdargs = p.parse_args() @@ -98,6 +100,8 @@ def main(): cmd.append('--nogdalstats') if cmdargs.tileprefix is not None: cmd.extend(['--tileprefix', cmdargs.tileprefix]) + if cmdargs.noremove: + cmd.append('--noremove') # submit the prepare job response = batch.submit_job(jobName="pyshepseg_prepare", diff --git a/pyshepseg/tilingstats.py b/pyshepseg/tilingstats.py index 7461c81..eb51f14 100644 --- a/pyshepseg/tilingstats.py +++ b/pyshepseg/tilingstats.py @@ -111,27 +111,8 @@ def calcPerSegmentStatsTiled(imgfile, imgbandnum, segfile, valid pixels (not nodata) that were used to calculate the statistics. """ - segds = segfile - if not isinstance(segds, gdal.Dataset): - segds = gdal.Open(segfile, gdal.GA_Update) - segband = segds.GetRasterBand(1) - - imgds = imgfile - if not isinstance(imgds, gdal.Dataset): - imgds = gdal.Open(imgfile, gdal.GA_ReadOnly) - imgband = imgds.GetRasterBand(imgbandnum) - if (imgband.DataType == gdal.GDT_Float32 or - imgband.DataType == gdal.GDT_Float64): - raise PyShepSegStatsError("Float image types not supported") - - if segband.XSize != imgband.XSize or segband.YSize != imgband.YSize: - raise PyShepSegStatsError("Images must be same size") - - if segds.GetGeoTransform() != imgds.GetGeoTransform(): - raise PyShepSegStatsError("Images must have same spatial extent and pixel size") - - if not equalProjection(segds.GetProjection(), imgds.GetProjection()): - raise PyShepSegStatsError("Images must be in the same projection") + segds, segband, imgds, imgband = doImageAlignmentChecks(segfile, + imgfile, imgbandnum) attrTbl = segband.GetDefaultRAT() existingColNames = [attrTbl.GetNameOfCol(i) @@ -184,6 +165,58 @@ def calcPerSegmentStatsTiled(imgfile, imgbandnum, segfile, raise PyShepSegStatsError('Not all pixels found during processing') +def doImageAlignmentChecks(segfile, imgfile, imgbandnum): + """ + Do the checks that the segment file and image file that is being used to + collect the stats actually align. We refuse to process the files if they + don't as it is not clear how they should be made to line up - this is up + to the user to get right. Also checks that imgfile is not a float image. + + Parameters + ---------- + segfile : str or gdal.Dataset + Path to segmented file or an open GDAL dataset. + imgfile : string + Path to input file for collecting statistics from + imgbandnum : int + 1-based index of the band number in imgfile to use for collecting stats + + Returns + ------- + segds: gdal.Dataset + Opened GDAL datset for the segments file + segband: gdal.Band + First Band of the segds + imgds: gdal.Dataset + Opened GDAL dataset for the image data file + imgband: gdal.Band + Requested band for the imgds + """ + segds = segfile + if not isinstance(segds, gdal.Dataset): + segds = gdal.Open(segfile, gdal.GA_Update) + segband = segds.GetRasterBand(1) + + imgds = imgfile + if not isinstance(imgds, gdal.Dataset): + imgds = gdal.Open(imgfile, gdal.GA_ReadOnly) + imgband = imgds.GetRasterBand(imgbandnum) + if (imgband.DataType == gdal.GDT_Float32 or + imgband.DataType == gdal.GDT_Float64): + raise PyShepSegStatsError("Float image types not supported") + + if segband.XSize != imgband.XSize or segband.YSize != imgband.YSize: + raise PyShepSegStatsError("Images must be same size") + + if segds.GetGeoTransform() != imgds.GetGeoTransform(): + raise PyShepSegStatsError("Images must have same spatial extent and pixel size") + + if not equalProjection(segds.GetProjection(), imgds.GetProjection()): + raise PyShepSegStatsError("Images must be in the same projection") + + return segds, segband, imgds, imgband + + @njit def accumulateSegDict(segDict, noDataDict, imgNullVal, tileSegments, tileImageData): """ @@ -1028,28 +1061,9 @@ def calcPerSegmentSpatialStatsTiled(imgfile, imgbandnum, segfile, The value to fill in for segments that have no data. """ - segds = segfile - if not isinstance(segds, gdal.Dataset): - segds = gdal.Open(segfile, gdal.GA_Update) - segband = segds.GetRasterBand(1) + segds, segband, imgds, imgband = doImageAlignmentChecks(segfile, + imgfile, imgbandnum) - imgds = imgfile - if not isinstance(imgds, gdal.Dataset): - imgds = gdal.Open(imgfile, gdal.GA_ReadOnly) - imgband = imgds.GetRasterBand(imgbandnum) - if (imgband.DataType == gdal.GDT_Float32 or - imgband.DataType == gdal.GDT_Float64): - raise PyShepSegStatsError("Float image types not supported") - - if segband.XSize != imgband.XSize or segband.YSize != imgband.YSize: - raise PyShepSegStatsError("Images must be same size") - - if segds.GetGeoTransform() != imgds.GetGeoTransform(): - raise PyShepSegStatsError("Images must have same spatial extent and pixel size") - - if not equalProjection(segds.GetProjection(), imgds.GetProjection()): - raise PyShepSegStatsError("Images must be in the same projection") - attrTbl = segband.GetDefaultRAT() existingColNames = [attrTbl.GetNameOfCol(i) for i in range(attrTbl.GetColumnCount())]