Skip to content

Commit

Permalink
Improving comments in the code
Browse files Browse the repository at this point in the history
  • Loading branch information
apacha committed Aug 30, 2022
1 parent 4112c6f commit 1baa70e
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions sbb_binarize/sbb_binarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def binarize_image(self, image_path: Path, save_path: Path):
padded_image = np.zeros((padded_image_height, padded_image_width, image_channels))
padded_image[0:original_image_height, 0:original_image_width, :] = img[:, :, :]

image_batch = np.expand_dims(padded_image, 0) # To create the batch information
image_batch = np.expand_dims(padded_image, 0) # Create the batch dimension
patches = tf.image.extract_patches(
images=image_batch,
sizes=[1, self.model_height, self.model_width, 1],
Expand Down Expand Up @@ -117,6 +117,7 @@ def split_list_into_worker_batches(files: List[Any], number_of_workers: int) ->
def batch_predict(input_data):
model_dir, input_images, output_images, worker_number = input_data
print(f"Setting visible cuda devices to {str(worker_number)}")
# Each worker thread will be assigned only one of the available GPUs to allow multiprocessing across GPUs
os.environ["CUDA_VISIBLE_DEVICES"] = str(worker_number)

binarizer = SbbBinarizer()
Expand Down Expand Up @@ -146,13 +147,14 @@ def batch_predict(input_data):
output_images = [output_path / (i.relative_to(input_path)) for i in input_images]
input_images = [i for i in input_images]

print(f"Starting binarization of {len(input_images)} images")
print(f"Starting batch-binarization of {len(input_images)} images")

number_of_gpus = len(tf.config.list_physical_devices('GPU'))
number_of_workers = max(1, number_of_gpus)
image_batches = split_list_into_worker_batches(input_images, number_of_workers)
output_batches = split_list_into_worker_batches(output_images, number_of_workers)

# Must use spawn to create completely new process that has its own resources to properly multiprocess across GPUs
with WorkerPool(n_jobs=number_of_workers, start_method='spawn') as pool:
model_dirs = itertools.repeat(model_directory, len(image_batches))
input_data = zip(model_dirs, image_batches, output_batches, range(number_of_workers))
Expand Down

0 comments on commit 1baa70e

Please sign in to comment.