diff --git a/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java b/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java index e14c567590..f3d1b57829 100644 --- a/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java +++ b/src/main/java/org/opensearch/security/auditlog/impl/AuditMessage.java @@ -11,6 +11,7 @@ package org.opensearch.security.auditlog.impl; +import com.google.common.collect.Sets; import java.io.IOException; import java.nio.file.Files; import java.nio.file.LinkOption; @@ -22,6 +23,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Objects; +import java.util.Set; import java.util.regex.Pattern; import com.google.common.annotations.VisibleForTesting; @@ -62,6 +64,7 @@ public final class AuditMessage { // clustername and cluster uuid private static final WildcardMatcher AUTHORIZATION_HEADER = WildcardMatcher.from("Authorization", false); private static final String SENSITIVE_KEY = "password"; + private static final Set SAFE_HEADERS = Set.of("Accept", "Connection", "User-Agent", "Host", "Content-Type", "Accept-Encoding"); private static final String SENSITIVE_REPLACEMENT_VALUE = "__SENSITIVE__"; private static final Pattern SENSITIVE_PATHS = Pattern.compile( @@ -128,12 +131,10 @@ public final class AuditMessage { private static final DateTimeFormatter DEFAULT_FORMAT = DateTimeFormat.forPattern("yyyy-MM-dd'T'HH:mm:ss.SSSZZ"); private final Map auditInfo = new HashMap(50); private final AuditCategory msgCategory; - private final String customHeader; public AuditMessage(final AuditCategory msgCategory, final ClusterService clusterService, final Origin origin, final Origin layer) { this.msgCategory = Objects.requireNonNull(msgCategory); final String currentTime = currentTime(); - this.customHeader = clusterService.getSettings().get("jwt_header", HttpHeaders.AUTHORIZATION); auditInfo.put(FORMAT_VERSION, 4); auditInfo.put(CATEGORY, Objects.requireNonNull(msgCategory)); auditInfo.put(UTC_TIMESTAMP, currentTime); @@ -363,11 +364,7 @@ public void addRestHeaders(Map> headers, boolean excludeSen if (headers != null && !headers.isEmpty()) { final Map> headersClone = new HashMap<>(headers); if (excludeSensitiveHeaders) { - if (headersClone.containsKey(AUTHORIZATION_HEADER)) { //Look for default "Authorization header - headersClone.keySet().removeIf(AUTHORIZATION_HEADER); - } else { // This means it was replaced by a custom header - headersClone.keySet().remove(this.customHeader); - } + headersClone.keySet().retainAll(SAFE_HEADERS); } auditInfo.put(REST_REQUEST_HEADERS, headersClone); } @@ -424,11 +421,8 @@ public void addTransportHeaders(Map headers, boolean excludeSens if (headers != null && !headers.isEmpty()) { final Map headersClone = new HashMap<>(headers); if (excludeSensitiveHeaders) { - if (headersClone.containsKey(AUTHORIZATION_HEADER)) { //Look for default "Authorization header + if (headersClone.containsKey(AUTHORIZATION_HEADER)) { //JWT will never have transport headers so can just look for default Authorization header headersClone.keySet().removeIf(AUTHORIZATION_HEADER); - } else { // This means it was replaced by a custom header - headersClone.keySet().remove(customHeader); - } } auditInfo.put(TRANSPORT_REQUEST_HEADERS, headersClone); diff --git a/src/test/java/org/opensearch/security/auditlog/impl/AuditMessageTest.java b/src/test/java/org/opensearch/security/auditlog/impl/AuditMessageTest.java index f53872bb3a..9079def100 100644 --- a/src/test/java/org/opensearch/security/auditlog/impl/AuditMessageTest.java +++ b/src/test/java/org/opensearch/security/auditlog/impl/AuditMessageTest.java @@ -45,7 +45,9 @@ public class AuditMessageTest { "AuThOrIzAtIoN", ImmutableList.of("test-3"), "test-header", - ImmutableList.of("test-4") + ImmutableList.of("test-4"), + "Accept", // One of the safe rest headers + ImmutableList.of("test-5") ); private static final Map TEST_TRANSPORT_HEADERS = ImmutableMap.of( @@ -72,7 +74,7 @@ public void setUp() { @Test public void testRestHeadersAreFiltered() { message.addRestHeaders(TEST_REST_HEADERS, true); - assertEquals(message.getAsMap().get(AuditMessage.REST_REQUEST_HEADERS), ImmutableMap.of("test-header", ImmutableList.of("test-4"))); + assertEquals(message.getAsMap().get(AuditMessage.REST_REQUEST_HEADERS), ImmutableMap.of("Accept", ImmutableList.of("test-5"))); } @Test