Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does SlidingWindowInferer support test time augmentation? #8218

Closed
lizhuoq opened this issue Nov 19, 2024 · 1 comment
Closed

Does SlidingWindowInferer support test time augmentation? #8218

lizhuoq opened this issue Nov 19, 2024 · 1 comment

Comments

@lizhuoq
Copy link

lizhuoq commented Nov 19, 2024

No description provided.

@lizhuoq
Copy link
Author

lizhuoq commented Nov 21, 2024

The following code solves my problem

def predict(data, model, model_name, batch_size=32, n_fold=0, device='cuda', tta=False):
    model     = model.to(device) # load the model into the GPU
    model.load_state_dict(torch.load(os.path.join(model_name + str(n_fold), 'checkpoint.pth')))
    
    model.eval()
    with torch.no_grad():
        inferer = SlidingWindowInferer(roi_size=(128, 128), sw_batch_size=batch_size, overlap=0.5, mode="gaussian", progress=True, sw_device=device, device=torch.device('cpu'))
        outputs = inferer(data, model)
        outputs = torch.softmax(outputs, dim=1)
    if tta:
        tta_list = [Flip(spatial_axis=0), Flip(spatial_axis=1), Compose([Flip(spatial_axis=0), Flip(spatial_axis=1)])]
        tta_res = [outputs]
        for aug in tta_list:
            with torch.no_grad():
                inferer = SlidingWindowInferer(roi_size=(128, 128), sw_batch_size=batch_size, overlap=0.5, mode="gaussian", progress=True, sw_device=device, device=torch.device('cpu'))
                transformed_data = aug(data[0]).unsqueeze(0)
                outputs = inferer(transformed_data, model)
                outputs = aug.inverse(outputs[0]).unsqueeze(0)
                outputs = torch.softmax(outputs, dim=1)
                tta_res.append(outputs)
            gc.collect()
        outputs = torch.stack(tta_res, dim=0).mean(dim=0)

    return outputs

@Project-MONAI Project-MONAI locked and limited conversation to collaborators Nov 21, 2024
@KumoLiu KumoLiu converted this issue into discussion #8227 Nov 21, 2024

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant