-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implementation of test preprocessing into modules and pipeline #197
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tak!. My only comment is that i think it would be cleaner to handle volumes smaller than the patch size in the predict step or dataloader (where we have a patch size available), rather than during preprocessing.
I propose adding the following or similar code to the on_before_batch_transfer to get rid of the patch size argument: def on_before_batch_transfer(self, batch, dataloader_idx):
if self.trainer.predicting is True:
if self.disable_inference_preprocessing is False:
batch["data"], batch["data_properties"] = self.preprocessor.preprocess_case_for_inference(
images=batch["data_paths"],
patch_size=self.patch_size,
ext=batch["extension"],
sliding_window_prediction=self.sliding_window_prediction,
)
else:
batch["data"], batch["data_properties"] = ensure_batch_fits_patch_size(batch, patch_size=self.patch_size)
... def ensure_batch_fits_patch_size(batch, patch_size):
"""
Pads the spatial dimensions of the input tensor so that they are at least the size of the patch dimensions.
If all spatial dimensions are already larger than or equal to the patch size, the input tensor is returned unchanged.
Parameters:
- batch: dict
a dict with keys {"data": data, "data_properties": data_properties, "case_id": case_id},
where data is a Tensor of shape (B, C, *spatial_dims)
- patch_size: tuple of ints
The minimum desired size for each spatial dimension.
Returns:
- padded_input: torch.Tensor
The input tensor padded to the desired spatial dimensions.
"""
image = batch["data"]
image_properties = batch["data_properties"]
spatial_dims = image.dim() - 2 # Subtract batch and channel dimensions
if spatial_dims != len(patch_size):
raise ValueError("Input spatial dimensions and patch size dimensions do not match.")
current_sizes = image.shape[2:] # Spatial dimensions
current_sizes_tensor = torch.tensor(current_sizes)
patch_size_tensor = torch.tensor(patch_size)
if torch.any(current_sizes_tensor < patch_size_tensor).item():
return image, image_properties
pad_sizes = torch.clamp(patch_size_tensor - current_sizes_tensor, min=0)
pad_left = pad_sizes // 2
pad_right = pad_sizes - pad_left
# Construct padding tuple in reverse order for F.pad
padding_reversed = []
for left, right in zip(reversed(pad_left.tolist()), reversed(pad_right.tolist())):
padding_reversed.extend([left.item(), right.item()])
padded_input = F.pad(image, padding_reversed)
image_properties["padded_shape"] = np.array(image.shape)
image_properties["padding"] = list(reversed(padding_reversed))
return padded_input, image_properties |
No description provided.