diff --git a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java index e8136d95ad..e3d9ef57f5 100644 --- a/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/kerberos/HTTPSpnegoAuthenticator.java @@ -48,6 +48,7 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.env.Environment; import org.ietf.jgss.GSSContext; import org.ietf.jgss.GSSCredential; @@ -281,10 +282,12 @@ public GSSCredential run() throws GSSException { public Optional reRequestAuthentication(final SecurityRequest request, AuthCredentials creds) { final Map headers = new HashMap<>(); String responseBody = ""; + String contentType = null; + SecurityResponse response; final String negotiateResponseBody = getNegotiateResponseBody(); if (negotiateResponseBody != null) { responseBody = negotiateResponseBody; - headers.putAll(SecurityResponse.CONTENT_TYPE_APP_JSON); + contentType = XContentType.JSON.mediaType(); } if (creds == null || creds.getNativeCredentials() == null) { @@ -293,7 +296,12 @@ public Optional reRequestAuthentication(final SecurityRequest headers.put("WWW-Authenticate", "Negotiate " + Base64.getEncoder().encodeToString((byte[]) creds.getNativeCredentials())); } - return Optional.of(new SecurityResponse(SC_UNAUTHORIZED, headers, responseBody)); + if (contentType != null) { + response = new SecurityResponse(SC_UNAUTHORIZED, headers, responseBody, contentType); + } else { + response = new SecurityResponse(SC_UNAUTHORIZED, headers, responseBody); + } + return Optional.of(response); } @Override diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java index 6157853324..3fd858c9f0 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/AuthTokenProcessorHandler.java @@ -237,7 +237,7 @@ private Optional handleLowLevel(RestRequest restRequest) throw String responseBodyString = DefaultObjectMapper.objectMapper.writeValueAsString(responseBody); - return Optional.of(new SecurityResponse(HttpStatus.SC_OK, SecurityResponse.CONTENT_TYPE_APP_JSON, responseBodyString)); + return Optional.of(new SecurityResponse(HttpStatus.SC_OK, null, responseBodyString, XContentType.JSON.mediaType())); } catch (JsonProcessingException e) { log.warn("Error while parsing JSON for /_opendistro/_security/api/authtoken", e); return Optional.of(new SecurityResponse(HttpStatus.SC_BAD_REQUEST, null, "JSON could not be parsed")); diff --git a/src/main/java/org/opensearch/security/auth/BackendRegistry.java b/src/main/java/org/opensearch/security/auth/BackendRegistry.java index 1492e533ec..80c47fae2a 100644 --- a/src/main/java/org/opensearch/security/auth/BackendRegistry.java +++ b/src/main/java/org/opensearch/security/auth/BackendRegistry.java @@ -509,11 +509,7 @@ public boolean authenticate(final SecurityRequestChannel request) { log.error("Cannot authenticate rest user because admin user is not permitted to login via HTTP"); auditLog.logFailedLogin(authenticatedUser.getName(), true, null, request); request.queueForSending( - new SecurityResponse( - SC_FORBIDDEN, - null, - "Cannot authenticate user because admin user is not permitted to login via HTTP" - ) + new SecurityResponse(SC_FORBIDDEN, "Cannot authenticate user because admin user is not permitted to login via HTTP") ); return false; } @@ -581,7 +577,7 @@ public boolean authenticate(final SecurityRequestChannel request) { notifyIpAuthFailureListeners(request, authCredenetials); request.queueForSending( - challengeResponse.orElseGet(() -> new SecurityResponse(SC_UNAUTHORIZED, null, "Authentication finally failed")) + challengeResponse.orElseGet(() -> new SecurityResponse(SC_UNAUTHORIZED, "Authentication finally failed")) ); return false; } diff --git a/src/main/java/org/opensearch/security/filter/SecurityResponse.java b/src/main/java/org/opensearch/security/filter/SecurityResponse.java index 009a1c3769..61f1b87c4f 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityResponse.java +++ b/src/main/java/org/opensearch/security/filter/SecurityResponse.java @@ -12,13 +12,17 @@ package org.opensearch.security.filter; import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import org.apache.http.HttpHeaders; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.rest.RestStatus; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestResponse; +import org.opensearch.rest.RestStatus; import com.google.common.collect.ImmutableMap; @@ -27,26 +31,63 @@ public class SecurityResponse { public static final Map CONTENT_TYPE_APP_JSON = ImmutableMap.of(HttpHeaders.CONTENT_TYPE, "application/json"); private final int status; - private final Map headers; + private Map> headers; private final String body; + private final String contentType; public SecurityResponse(final int status, final Exception e) { this.status = status; - this.headers = CONTENT_TYPE_APP_JSON; + populateHeaders(CONTENT_TYPE_APP_JSON); this.body = generateFailureMessage(e); + this.contentType = XContentType.JSON.mediaType(); + } + + public SecurityResponse(final int status, String body) { + this.status = status; + this.body = body; + this.contentType = null; } public SecurityResponse(final int status, final Map headers, final String body) { this.status = status; - this.headers = headers; + populateHeaders(headers); + this.body = body; + this.contentType = null; + } + + public SecurityResponse(final int status, final Map headers, final String body, String contentType) { + this.status = status; this.body = body; + this.contentType = contentType; + populateHeaders(headers); + } + + private void populateHeaders(Map headers) { + if (headers != null) { + headers.entrySet().forEach(entry -> addHeader(entry.getKey(), entry.getValue())); + } + } + + /** + * Add a custom header. + */ + public void addHeader(String name, String value) { + if (headers == null) { + headers = new HashMap<>(2); + } + List header = headers.get(name); + if (header == null) { + header = new ArrayList<>(); + headers.put(name, header); + } + header.add(value); } public int getStatus() { return status; } - public Map getHeaders() { + public Map> getHeaders() { return headers; } @@ -55,9 +96,14 @@ public String getBody() { } public RestResponse asRestResponse() { - final RestResponse restResponse = new BytesRestResponse(RestStatus.fromCode(getStatus()), getBody()); + final RestResponse restResponse; + if (this.contentType != null) { + restResponse = new BytesRestResponse(RestStatus.fromCode(getStatus()), this.contentType, getBody()); + } else { + restResponse = new BytesRestResponse(RestStatus.fromCode(getStatus()), getBody()); + } if (getHeaders() != null) { - getHeaders().forEach(restResponse::addHeader); + getHeaders().entrySet().forEach(entry -> { entry.getValue().forEach(value -> restResponse.addHeader(entry.getKey(), value)); }); } return restResponse; } diff --git a/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java b/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java index 14ae972685..0093ddc686 100644 --- a/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java +++ b/src/main/java/org/opensearch/security/securityconf/impl/WhitelistingSettings.java @@ -23,6 +23,7 @@ import org.apache.http.HttpStatus; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.rest.RestStatus; import org.opensearch.security.filter.SecurityRequest; import org.opensearch.security.filter.SecurityResponse; @@ -117,7 +118,7 @@ public Optional checkRequestIsAllowed(final SecurityRequest re // if whitelisting is enabled but the request is not whitelisted, then return false, otherwise true. if (this.enabled && !requestIsWhitelisted(request)) { return Optional.of( - new SecurityResponse(HttpStatus.SC_FORBIDDEN, SecurityResponse.CONTENT_TYPE_APP_JSON, generateFailureMessage(request)) + new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, generateFailureMessage(request), XContentType.JSON.mediaType()) ); } return Optional.empty(); diff --git a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java index afa8328140..75a303665d 100644 --- a/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java +++ b/src/test/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticatorTest.java @@ -717,7 +717,7 @@ private AuthenticateHeaders getAutenticateHeaders(HTTPSamlAuthenticator samlAuth RestRequest restRequest = new FakeRestRequest(ImmutableMap.of(), new HashMap()); SecurityResponse response = sendToAuthenticator(samlAuthenticator, restRequest).orElseThrow(NoSuchElementException::new); - String wwwAuthenticateHeader = response.getHeaders().get("WWW-Authenticate"); + String wwwAuthenticateHeader = response.getHeaders().get("WWW-Authenticate").get(0); Assert.assertNotNull(wwwAuthenticateHeader); diff --git a/src/test/java/org/opensearch/security/filter/SecurityResponseTests.java b/src/test/java/org/opensearch/security/filter/SecurityResponseTests.java new file mode 100644 index 0000000000..483a4f77c5 --- /dev/null +++ b/src/test/java/org/opensearch/security/filter/SecurityResponseTests.java @@ -0,0 +1,155 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.security.filter; + +import java.util.List; +import java.util.Map; + +import org.apache.http.HttpHeaders; +import org.apache.http.HttpStatus; +import org.junit.Test; + +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestResponse; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +public class SecurityResponseTests { + + /** + * This test should check whether a basic constructor with the JSON content type is successfully converted to RestResponse + */ + @Test + public void testSecurityResponseHasSingleContentType() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, null, "foo bar", XContentType.JSON.mediaType()); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + assertThat(restResponse.contentType(), equalTo(XContentType.JSON.mediaType())); + } + + /** + * This test should check whether adding a new HTTP Header for the content type takes the argument or the added header (should take arg.) + */ + @Test + public void testSecurityResponseMultipleContentTypesUsesPassed() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, null, "foo bar", XContentType.JSON.mediaType()); + response.addHeader(HttpHeaders.CONTENT_TYPE, BytesRestResponse.TEXT_CONTENT_TYPE); + assertThat(response.getHeaders().get("Content-Type"), equalTo(List.of(BytesRestResponse.TEXT_CONTENT_TYPE))); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(XContentType.JSON.mediaType())); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test should check whether specifying no content type correctly uses plain text + */ + @Test + public void testSecurityResponseDefaultContentTypeIsText() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, null, "foo bar"); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(BytesRestResponse.TEXT_CONTENT_TYPE)); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test checks whether adding a new ContentType header actually changes the converted content type header (it should not) + */ + @Test + public void testSecurityResponseSetHeaderContentTypeDoesNothing() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, null, "foo bar"); + response.addHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(BytesRestResponse.TEXT_CONTENT_TYPE)); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test should check whether adding a multiple new HTTP Headers for the content type takes the argument or the added header (should take arg.) + */ + @Test + public void testSecurityResponseAddMultipleContentTypeHeaders() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, null, "foo bar", XContentType.JSON.mediaType()); + response.addHeader(HttpHeaders.CONTENT_TYPE, BytesRestResponse.TEXT_CONTENT_TYPE); + assertThat(response.getHeaders().get("Content-Type"), equalTo(List.of(BytesRestResponse.TEXT_CONTENT_TYPE))); + response.addHeader(HttpHeaders.CONTENT_TYPE, "newContentType"); + assertThat(response.getHeaders().get("Content-Type"), equalTo(List.of(BytesRestResponse.TEXT_CONTENT_TYPE, "newContentType"))); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test confirms that fake content types work for conversion + */ + @Test + public void testSecurityResponseFakeContentTypeArgumentPasses() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, null, "foo bar", "testType"); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo("testType")); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test checks that types passed as part of the Headers parameter in the argument do not overwrite actual Content Type + */ + @Test + public void testSecurityResponseContentTypeInConstructorHeader() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_OK, Map.of("Content-Type", "testType"), "foo bar"); + assertThat(response.getHeaders().get("Content-Type"), equalTo(List.of("testType"))); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(BytesRestResponse.TEXT_CONTENT_TYPE)); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test confirms the same as above but with a conflicting content type arg + */ + @Test + public void testSecurityResponseContentTypeInConstructorHeaderConflicts() { + final SecurityResponse response = new SecurityResponse( + HttpStatus.SC_OK, + Map.of("Content-Type", "testType"), + "foo bar", + XContentType.JSON.mediaType() + ); + assertThat(response.getHeaders().get("Content-Type"), equalTo(List.of("testType"))); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(XContentType.JSON.mediaType())); + assertThat(restResponse.status(), equalTo(RestStatus.OK)); + } + + /** + * This test should check whether unauthorized requests are converted properly + */ + @Test + public void testSecurityResponseUnauthorizedRequestWithPlainTextContentType() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, null, "foo bar"); + response.addHeader(HttpHeaders.CONTENT_TYPE, "application/json"); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(BytesRestResponse.TEXT_CONTENT_TYPE)); + assertThat(restResponse.status(), equalTo(RestStatus.UNAUTHORIZED)); + } + + /** + * This test should check whether forbidden requests are converted properly + */ + @Test + public void testSecurityResponseForbiddenRequestWithPlainTextContentType() { + final SecurityResponse response = new SecurityResponse(HttpStatus.SC_FORBIDDEN, null, "foo bar"); + response.addHeader(HttpHeaders.CONTENT_TYPE, "application/json"); + final RestResponse restResponse = response.asRestResponse(); + assertThat(restResponse.contentType(), equalTo(BytesRestResponse.TEXT_CONTENT_TYPE)); + assertThat(restResponse.status(), equalTo(RestStatus.FORBIDDEN)); + } +}