Skip to content

Commit

Permalink
bugfix: do not report violation on blocked request (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
pboos authored Nov 24, 2023
1 parent 0055dd4 commit 70f07bd
Show file tree
Hide file tree
Showing 14 changed files with 364 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
public class OpenApiRequestValidationConfiguration {
private double sampleRate;
private int validationReportThrottleWaitSeconds;
private boolean shouldFailOnRequestViolation;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public class OpenApiRequestValidator {
private final ThreadPoolExecutor threadPoolExecutor;
private final OpenApiInteractionValidatorWrapper validator;
private final ValidationReportHandler validationReportHandler;
private final OpenApiRequestValidationConfiguration configuration;

public OpenApiRequestValidator(
ThreadPoolExecutor threadPoolExecutor,
Expand All @@ -34,6 +35,7 @@ public OpenApiRequestValidator(
this.threadPoolExecutor = threadPoolExecutor;
this.validator = validator;
this.validationReportHandler = validationReportHandler;
this.configuration = configuration;

metricsReporter.reportStartup(
validator != null,
Expand Down Expand Up @@ -74,7 +76,12 @@ public ValidationResult validateRequestObject(
try {
var simpleRequest = buildSimpleRequest(request, requestBody);
var result = validator.validateRequest(simpleRequest);
validationReportHandler.handleValidationReport(request, response, Direction.REQUEST, requestBody, result);
// TODO this should not be done here, but currently the only way to do it -> Refactor this so that logging
// is actually done in the interceptor/filter where logging can easily be skipped then.
if (!configuration.isShouldFailOnRequestViolation()) {
validationReportHandler
.handleValidationReport(request, response, Direction.REQUEST, requestBody, result);
}
return buildValidationResult(result);
} catch (Exception e) {
log.error("Could not validate request", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public OpenApiRequestValidationConfiguration toOpenApiRequestValidationConfigura
return OpenApiRequestValidationConfiguration.builder()
.sampleRate(getSampleRate())
.validationReportThrottleWaitSeconds(getValidationReportThrottleWaitSeconds())
.shouldFailOnRequestViolation(getShouldFailOnRequestViolation() != null && getShouldFailOnRequestViolation())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
package com.getyourguide.openapi.validation.factory;

import com.getyourguide.openapi.validation.filter.MultiReadContentCachingRequestWrapper;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import javax.annotation.Nullable;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.ContentCachingResponseWrapper;
import org.springframework.web.util.WebUtils;

public class ContentCachingWrapperFactory {
public ContentCachingRequestWrapper buildContentCachingRequestWrapper(HttpServletRequest request) {
if (request instanceof ContentCachingRequestWrapper) {
return (ContentCachingRequestWrapper) request;
public MultiReadContentCachingRequestWrapper buildContentCachingRequestWrapper(HttpServletRequest request) {
if (request instanceof MultiReadContentCachingRequestWrapper) {
return (MultiReadContentCachingRequestWrapper) request;
}

return new ContentCachingRequestWrapper(request);
return new MultiReadContentCachingRequestWrapper(request);
}

public ContentCachingResponseWrapper buildContentCachingResponseWrapper(HttpServletResponse response) {
Expand All @@ -26,12 +26,12 @@ public ContentCachingResponseWrapper buildContentCachingResponseWrapper(HttpServ
}

@Nullable
public ContentCachingResponseWrapper getCachingResponse(final HttpServletResponse response) {
return WebUtils.getNativeResponse(response, ContentCachingResponseWrapper.class);
public MultiReadContentCachingRequestWrapper getCachingRequest(HttpServletRequest request) {
return request instanceof MultiReadContentCachingRequestWrapper ? (MultiReadContentCachingRequestWrapper) request : null;
}

@Nullable
public ContentCachingRequestWrapper getCachingRequest(HttpServletRequest request) {
return request instanceof ContentCachingRequestWrapper ? (ContentCachingRequestWrapper) request : null;
public ContentCachingResponseWrapper getCachingResponse(final HttpServletResponse response) {
return WebUtils.getNativeResponse(response, ContentCachingResponseWrapper.class);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package com.getyourguide.openapi.validation.filter;

import jakarta.servlet.ReadListener;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.http.HttpServletRequest;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import org.springframework.web.util.ContentCachingRequestWrapper;

public class MultiReadContentCachingRequestWrapper extends ContentCachingRequestWrapper {

public MultiReadContentCachingRequestWrapper(HttpServletRequest request) {
super(request);
}

public MultiReadContentCachingRequestWrapper(HttpServletRequest request, int contentCacheLimit) {
super(request, contentCacheLimit);
}

@Override
public ServletInputStream getInputStream() throws IOException {
var inputStream = super.getInputStream();
if (inputStream.isFinished()) {
return new CachedServletInputStream(getContentAsByteArray());
}

return inputStream;
}

@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}

private static class CachedServletInputStream extends ServletInputStream {
private final ByteArrayInputStream buffer;

public CachedServletInputStream(byte[] contents) {
this.buffer = new ByteArrayInputStream(contents);
}

@Override
public int read() throws IOException {
return buffer.read();
}

@Override
public boolean isFinished() {
return buffer.available() == 0;
}

@Override
public boolean isReady() {
return true;
}

@Override
public void setReadListener(ReadListener listener) {
throw new UnsupportedOperationException("Not implemented");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
import com.getyourguide.openapi.validation.factory.ServletMetaDataFactory;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import javax.annotation.Nullable;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatusCode;
import org.springframework.util.StreamUtils;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.servlet.AsyncHandlerInterceptor;
import org.springframework.web.servlet.ModelAndView;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.web.util.ContentCachingResponseWrapper;

@Slf4j
Expand Down Expand Up @@ -114,6 +115,8 @@ private void validateResponse(
);
// Note: validateResponseResult will always be null on ASYNC
if (validateResponseResult == ValidationResult.INVALID) {
response.reset();
response.setStatus(500);
throw new ResponseStatusException(HttpStatusCode.valueOf(500), "Response validation failed");
}
}
Expand All @@ -126,7 +129,7 @@ private static RequestMetaData getRequestMetaData(HttpServletRequest request) {
}

private ValidationResult validateRequest(
ContentCachingRequestWrapper request,
MultiReadContentCachingRequestWrapper request,
RequestMetaData requestMetaData,
@Nullable ResponseMetaData responseMetaData,
RunType runType
Expand All @@ -137,9 +140,7 @@ private ValidationResult validateRequest(
return ValidationResult.NOT_APPLICABLE;
}

var requestBody = request.getContentType() != null
? new String(request.getContentAsByteArray(), StandardCharsets.UTF_8)
: null;
var requestBody = request.getContentType() != null ? readBodyCatchingException(request) : null;

if (runType == RunType.ASYNC) {
validator.validateRequestObjectAsync(requestMetaData, responseMetaData, requestBody);
Expand All @@ -149,6 +150,14 @@ private ValidationResult validateRequest(
}
}

private static String readBodyCatchingException(MultiReadContentCachingRequestWrapper request) {
try {
return StreamUtils.copyToString(request.getInputStream(), StandardCharsets.UTF_8);
} catch (IOException e) {
return null;
}
}

private ValidationResult validateResponse(
HttpServletRequest request,
ContentCachingResponseWrapper response,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import lombok.Builder;
import org.mockito.Mockito;
import org.springframework.web.util.ContentCachingRequestWrapper;
import org.springframework.mock.web.DelegatingServletInputStream;
import org.springframework.web.util.ContentCachingResponseWrapper;

public class BaseFilterTest {
Expand Down Expand Up @@ -48,15 +50,12 @@ private static void mockRequestAttributes(ServletRequest request, HashMap<String
}

protected MockSetupData mockSetup(MockConfiguration configuration) {
var request = mock(ContentCachingRequestWrapper.class);
var request = mock(MultiReadContentCachingRequestWrapper.class);
var response = mock(ContentCachingResponseWrapper.class);
var cachingRequest = mockContentCachingRequest(request, configuration);
var cachingResponse = mockContentCachingResponse(response, configuration);
mockRequestAttributes(request, cachingRequest);

when(request.getContentType()).thenReturn("application/json");
when(request.getContentAsByteArray()).thenReturn(configuration.requestBody.getBytes(StandardCharsets.UTF_8));

when(response.getContentType()).thenReturn("application/json");
when(response.getContentAsByteArray()).thenReturn(configuration.responseBody.getBytes(StandardCharsets.UTF_8));

Expand Down Expand Up @@ -102,16 +101,24 @@ private ContentCachingResponseWrapper mockContentCachingResponse(
return cachingResponse;
}

private ContentCachingRequestWrapper mockContentCachingRequest(
private MultiReadContentCachingRequestWrapper mockContentCachingRequest(
HttpServletRequest request,
MockConfiguration configuration
) {
var cachingRequest = mock(ContentCachingRequestWrapper.class);
var cachingRequest = mock(MultiReadContentCachingRequestWrapper.class);
when(contentCachingWrapperFactory.buildContentCachingRequestWrapper(request)).thenReturn(cachingRequest);
if (configuration.responseBody != null) {
when(cachingRequest.getContentType()).thenReturn("application/json");
when(cachingRequest.getContentAsByteArray())
.thenReturn(configuration.requestBody.getBytes(StandardCharsets.UTF_8));
if (configuration.requestBody != null) {
try {
var sourceStream = new ByteArrayInputStream(configuration.requestBody.getBytes(StandardCharsets.UTF_8));
when(request.getContentType()).thenReturn("application/json");
when(request.getInputStream()).thenReturn(new DelegatingServletInputStream(sourceStream));

sourceStream = new ByteArrayInputStream(configuration.requestBody.getBytes(StandardCharsets.UTF_8));
when(cachingRequest.getContentType()).thenReturn("application/json");
when(cachingRequest.getInputStream()).thenReturn(new DelegatingServletInputStream(sourceStream));
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
return cachingRequest;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package com.getyourguide.openapi.validation.integration;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

import com.getyourguide.openapi.validation.integration.controller.DefaultRestController;
import com.getyourguide.openapi.validation.test.TestViolationLogger;
import java.util.Optional;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.SpyBean;
import org.springframework.http.MediaType;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.springframework.test.web.servlet.MockMvc;

@SpringBootTest(properties = {
"openapi.validation.should-fail-on-request-violation=true",
"openapi.validation.should-fail-on-response-violation=true",
})
@AutoConfigureMockMvc
@ExtendWith(SpringExtension.class)
public class FailOnViolationIntegrationTest {

@Autowired
private MockMvc mockMvc;

@Autowired
private TestViolationLogger openApiViolationLogger;

@SpyBean
private DefaultRestController defaultRestController;

@BeforeEach
public void setup() {
openApiViolationLogger.clearViolations();
}

@Test
public void whenValidRequestThenReturnSuccessfully() throws Exception {
mockMvc.perform(post("/test")
.content("{ \"value\": \"testing\", \"responseStatusCode\": 200 }").contentType(MediaType.APPLICATION_JSON))
.andExpectAll(
status().isOk(),
jsonPath("$.value").value("testing")
);
Thread.sleep(100);

assertEquals(0, openApiViolationLogger.getViolations().size());
verify(defaultRestController).postTest(any());
}

@Test
public void whenInvalidRequestThenReturn400AndNoViolationLogged() throws Exception {
mockMvc.perform(post("/test").content("{ \"value\": 1 }").contentType(MediaType.APPLICATION_JSON))
.andExpectAll(
status().is4xxClientError(),
content().string(Matchers.blankOrNullString())
);
Thread.sleep(100);

assertEquals(0, openApiViolationLogger.getViolations().size());
verify(defaultRestController, never()).postTest(any());
// TODO check that something else gets logged?
}

@Test
public void whenInvalidResponseThenReturn500AndViolationLogged() throws Exception {
mockMvc.perform(get("/test").queryParam("value", "invalid-response-value!"))
.andExpectAll(
status().is5xxServerError(),
content().string(Matchers.blankOrNullString())
);
Thread.sleep(100);

assertEquals(1, openApiViolationLogger.getViolations().size());
var violation = openApiViolationLogger.getViolations().get(0);
assertEquals("validation.response.body.schema.pattern", violation.getRule());
assertEquals(Optional.of(200), violation.getResponseStatus());
assertEquals(Optional.of("/value"), violation.getInstance());
verify(defaultRestController).getTest(any(), any(), any());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ public void whenTestOptionsCallThenShouldNotValidate() throws Exception {
assertEquals(0, openApiViolationLogger.getViolations().size());
}

// TODO Add test that fails on request violation immediately (maybe needs separate test class & setup) should not log violation

@Nullable
private OpenApiViolation getViolationByRule(List<OpenApiViolation> violations, String rule) {
return violations.stream()
Expand Down
Loading

0 comments on commit 70f07bd

Please sign in to comment.