Skip to content

Commit

Permalink
fix: request body can only be read once
Browse files Browse the repository at this point in the history
  • Loading branch information
pboos committed Nov 22, 2023
1 parent 1887e83 commit 5e77bdc
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
package com.getyourguide.openapi.validation.factory;

import com.getyourguide.openapi.validation.filter.MultiReadHttpServletRequestWrapper;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import javax.annotation.Nullable;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.ContentCachingResponseWrapper;
import org.springframework.web.util.WebUtils;

public class ContentCachingWrapperFactory {
public ContentCachingRequestWrapper buildContentCachingRequestWrapper(HttpServletRequest request) {
if (request instanceof ContentCachingRequestWrapper) {
return (ContentCachingRequestWrapper) request;
public MultiReadHttpServletRequestWrapper buildContentCachingRequestWrapper(HttpServletRequest request) {
if (request instanceof MultiReadHttpServletRequestWrapper) {
return (MultiReadHttpServletRequestWrapper) request;
}

return new ContentCachingRequestWrapper(request);
return new MultiReadHttpServletRequestWrapper(request);
}

public ContentCachingResponseWrapper buildContentCachingResponseWrapper(HttpServletResponse response) {
Expand All @@ -26,12 +26,12 @@ public ContentCachingResponseWrapper buildContentCachingResponseWrapper(HttpServ
}

@Nullable
public ContentCachingResponseWrapper getCachingResponse(final HttpServletResponse response) {
return WebUtils.getNativeResponse(response, ContentCachingResponseWrapper.class);
public MultiReadHttpServletRequestWrapper getCachingRequest(HttpServletRequest request) {
return request instanceof MultiReadHttpServletRequestWrapper ? (MultiReadHttpServletRequestWrapper) request : null;
}

@Nullable
public ContentCachingRequestWrapper getCachingRequest(HttpServletRequest request) {
return request instanceof ContentCachingRequestWrapper ? (ContentCachingRequestWrapper) request : null;
public ContentCachingResponseWrapper getCachingResponse(final HttpServletResponse response) {
return WebUtils.getNativeResponse(response, ContentCachingResponseWrapper.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
import com.getyourguide.openapi.validation.factory.ServletMetaDataFactory;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import javax.annotation.Nullable;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatusCode;
import org.springframework.util.StreamUtils;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.servlet.AsyncHandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.ContentCachingResponseWrapper;

@Slf4j
Expand Down Expand Up @@ -126,7 +127,7 @@ private static RequestMetaData getRequestMetaData(HttpServletRequest request) {
}

private ValidationResult validateRequest(
ContentCachingRequestWrapper request,
MultiReadHttpServletRequestWrapper request,
RequestMetaData requestMetaData,
@Nullable ResponseMetaData responseMetaData,
RunType runType
Expand All @@ -137,9 +138,7 @@ private ValidationResult validateRequest(
return ValidationResult.NOT_APPLICABLE;
}

var requestBody = request.getContentType() != null
? new String(request.getContentAsByteArray(), StandardCharsets.UTF_8)
: null;
var requestBody = request.getContentType() != null ? readBodyCatchingException(request) : null;

if (runType == RunType.ASYNC) {
validator.validateRequestObjectAsync(requestMetaData, responseMetaData, requestBody);
Expand All @@ -149,6 +148,14 @@ private ValidationResult validateRequest(
}
}

private static String readBodyCatchingException(MultiReadHttpServletRequestWrapper request) {
try {
return StreamUtils.copyToString(request.getInputStream(), StandardCharsets.UTF_8);
} catch (IOException e) {
return null;
}
}

private ValidationResult validateResponse(
HttpServletRequest request,
ContentCachingResponseWrapper response,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import lombok.Builder;
import org.mockito.Mockito;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.mock.web.DelegatingServletInputStream;
import org.springframework.web.util.ContentCachingResponseWrapper;

public class BaseFilterTest {
Expand Down Expand Up @@ -48,15 +50,12 @@ private static void mockRequestAttributes(ServletRequest request, HashMap<String
}

protected MockSetupData mockSetup(MockConfiguration configuration) {
var request = mock(ContentCachingRequestWrapper.class);
var request = mock(MultiReadHttpServletRequestWrapper.class);
var response = mock(ContentCachingResponseWrapper.class);
var cachingRequest = mockContentCachingRequest(request, configuration);
var cachingResponse = mockContentCachingResponse(response, configuration);
mockRequestAttributes(request, cachingRequest);

when(request.getContentType()).thenReturn("application/json");
when(request.getContentAsByteArray()).thenReturn(configuration.requestBody.getBytes(StandardCharsets.UTF_8));

when(response.getContentType()).thenReturn("application/json");
when(response.getContentAsByteArray()).thenReturn(configuration.responseBody.getBytes(StandardCharsets.UTF_8));

Expand Down Expand Up @@ -102,16 +101,24 @@ private ContentCachingResponseWrapper mockContentCachingResponse(
return cachingResponse;
}

private ContentCachingRequestWrapper mockContentCachingRequest(
private MultiReadHttpServletRequestWrapper mockContentCachingRequest(
HttpServletRequest request,
MockConfiguration configuration
) {
var cachingRequest = mock(ContentCachingRequestWrapper.class);
var cachingRequest = mock(MultiReadHttpServletRequestWrapper.class);
when(contentCachingWrapperFactory.buildContentCachingRequestWrapper(request)).thenReturn(cachingRequest);
if (configuration.responseBody != null) {
when(cachingRequest.getContentType()).thenReturn("application/json");
when(cachingRequest.getContentAsByteArray())
.thenReturn(configuration.requestBody.getBytes(StandardCharsets.UTF_8));
if (configuration.requestBody != null) {
try {
var sourceStream = new ByteArrayInputStream(configuration.requestBody.getBytes(StandardCharsets.UTF_8));
when(request.getContentType()).thenReturn("application/json");
when(request.getInputStream()).thenReturn(new DelegatingServletInputStream(sourceStream));

sourceStream = new ByteArrayInputStream(configuration.requestBody.getBytes(StandardCharsets.UTF_8));
when(cachingRequest.getContentType()).thenReturn("application/json");
when(cachingRequest.getInputStream()).thenReturn(new DelegatingServletInputStream(sourceStream));
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
return cachingRequest;
}
Expand Down

0 comments on commit 5e77bdc

Please sign in to comment.