Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implementation of test preprocessing into modules and pipeline #197

Merged
merged 8 commits into from
Oct 22, 2024

Conversation

Sllambias
Copy link
Owner

No description provided.

Copy link
Collaborator

@asbjrnmunk asbjrnmunk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tak!. My only comment is that i think it would be cleaner to handle volumes smaller than the patch size in the predict step or dataloader (where we have a patch size available), rather than during preprocessing.

@asbjrnmunk
Copy link
Collaborator

asbjrnmunk commented Oct 21, 2024

I propose adding the following or similar code to the on_before_batch_transfer to get rid of the patch size argument:

def on_before_batch_transfer(self, batch, dataloader_idx):
        if self.trainer.predicting is True:
            if self.disable_inference_preprocessing is False:
                batch["data"], batch["data_properties"] = self.preprocessor.preprocess_case_for_inference(
                    images=batch["data_paths"],
                    patch_size=self.patch_size,
                    ext=batch["extension"],
                    sliding_window_prediction=self.sliding_window_prediction,
                )
            else:
                batch["data"], batch["data_properties"] = ensure_batch_fits_patch_size(batch, patch_size=self.patch_size)
...
def ensure_batch_fits_patch_size(batch, patch_size):
    """
    Pads the spatial dimensions of the input tensor so that they are at least the size of the patch dimensions.
    If all spatial dimensions are already larger than or equal to the patch size, the input tensor is returned unchanged.

    Parameters:
    - batch: dict
        a dict with keys {"data": data, "data_properties": data_properties, "case_id": case_id},
        where data is a Tensor of shape (B, C, *spatial_dims)

    - patch_size: tuple of ints
        The minimum desired size for each spatial dimension.

    Returns:
    - padded_input: torch.Tensor
        The input tensor padded to the desired spatial dimensions.
    """
    image = batch["data"]
    image_properties = batch["data_properties"]

    spatial_dims = image.dim() - 2  # Subtract batch and channel dimensions

    if spatial_dims != len(patch_size):
        raise ValueError("Input spatial dimensions and patch size dimensions do not match.")

    current_sizes = image.shape[2:]  # Spatial dimensions

    current_sizes_tensor = torch.tensor(current_sizes)
    patch_size_tensor = torch.tensor(patch_size)

    if torch.any(current_sizes_tensor < patch_size_tensor).item():
        return image, image_properties

    pad_sizes = torch.clamp(patch_size_tensor - current_sizes_tensor, min=0)
    pad_left = pad_sizes // 2
    pad_right = pad_sizes - pad_left

    # Construct padding tuple in reverse order for F.pad
    padding_reversed = []
    for left, right in zip(reversed(pad_left.tolist()), reversed(pad_right.tolist())):
        padding_reversed.extend([left.item(), right.item()])

    padded_input = F.pad(image, padding_reversed)

    image_properties["padded_shape"] = np.array(image.shape)
    image_properties["padding"] = list(reversed(padding_reversed))

    return padded_input, image_properties

@Sllambias Sllambias merged commit 799b272 into main Oct 22, 2024
3 checks passed
@Sllambias Sllambias deleted the preproc_test branch October 22, 2024 09:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants