diff --git a/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java b/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java index 936663095492..13f96a6b2179 100644 --- a/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java +++ b/core/src/main/java/org/apache/iceberg/rest/HTTPClient.java @@ -26,10 +26,12 @@ import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.stream.Collectors; import org.apache.hc.client5.http.classic.methods.HttpUriRequest; import org.apache.hc.client5.http.classic.methods.HttpUriRequestBase; +import org.apache.hc.client5.http.config.ConnectionConfig; import org.apache.hc.client5.http.impl.classic.CloseableHttpClient; import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; import org.apache.hc.client5.http.impl.classic.HttpClientBuilder; @@ -79,6 +81,11 @@ public class HTTPClient implements RESTClient { private static final String REST_MAX_CONNECTIONS_PER_ROUTE = "rest.client.connections-per-route"; private static final int REST_MAX_CONNECTIONS_PER_ROUTE_DEFAULT = 100; + @VisibleForTesting + static final String REST_CONNECTION_TIMEOUT_MS = "rest.client.connection-timeout-ms"; + + @VisibleForTesting static final String REST_SOCKET_TIMEOUT_MS = "rest.client.socket-timeout-ms"; + private final String uri; private final CloseableHttpClient httpClient; private final ObjectMapper mapper; @@ -88,22 +95,13 @@ private HTTPClient( Map baseHeaders, ObjectMapper objectMapper, HttpRequestInterceptor requestInterceptor, - Map properties) { + Map properties, + HttpClientConnectionManager connectionManager) { this.uri = uri; this.mapper = objectMapper; HttpClientBuilder clientBuilder = HttpClients.custom(); - HttpClientConnectionManager connectionManager = - PoolingHttpClientConnectionManagerBuilder.create() - .useSystemProperties() - .setMaxConnTotal(Integer.getInteger(REST_MAX_CONNECTIONS, REST_MAX_CONNECTIONS_DEFAULT)) - .setMaxConnPerRoute( - PropertyUtil.propertyAsInt( - properties, - REST_MAX_CONNECTIONS_PER_ROUTE, - REST_MAX_CONNECTIONS_PER_ROUTE_DEFAULT)) - .build(); clientBuilder.setConnectionManager(connectionManager); if (baseHeaders != null) { @@ -448,6 +446,47 @@ static HttpRequestInterceptor loadInterceptorDynamically( return instance; } + static HttpClientConnectionManager configureConnectionManager(Map properties) { + PoolingHttpClientConnectionManagerBuilder connectionManagerBuilder = + PoolingHttpClientConnectionManagerBuilder.create(); + ConnectionConfig connectionConfig = configureConnectionConfig(properties); + if (connectionConfig != null) { + connectionManagerBuilder.setDefaultConnectionConfig(connectionConfig); + } + + return connectionManagerBuilder + .useSystemProperties() + .setMaxConnTotal(Integer.getInteger(REST_MAX_CONNECTIONS, REST_MAX_CONNECTIONS_DEFAULT)) + .setMaxConnPerRoute( + PropertyUtil.propertyAsInt( + properties, REST_MAX_CONNECTIONS_PER_ROUTE, REST_MAX_CONNECTIONS_PER_ROUTE_DEFAULT)) + .build(); + } + + @VisibleForTesting + static ConnectionConfig configureConnectionConfig(Map properties) { + Long connectionTimeoutMillis = + PropertyUtil.propertyAsNullableLong(properties, REST_CONNECTION_TIMEOUT_MS); + Integer socketTimeoutMillis = + PropertyUtil.propertyAsNullableInt(properties, REST_SOCKET_TIMEOUT_MS); + + if (connectionTimeoutMillis == null && socketTimeoutMillis == null) { + return null; + } + + ConnectionConfig.Builder connConfigBuilder = ConnectionConfig.custom(); + + if (connectionTimeoutMillis != null) { + connConfigBuilder.setConnectTimeout(connectionTimeoutMillis, TimeUnit.MILLISECONDS); + } + + if (socketTimeoutMillis != null) { + connConfigBuilder.setSocketTimeout(socketTimeoutMillis, TimeUnit.MILLISECONDS); + } + + return connConfigBuilder.build(); + } + public static Builder builder(Map properties) { return new Builder(properties); } @@ -493,7 +532,13 @@ public HTTPClient build() { interceptor = loadInterceptorDynamically(SIGV4_REQUEST_INTERCEPTOR_IMPL, properties); } - return new HTTPClient(uri, baseHeaders, mapper, interceptor, properties); + return new HTTPClient( + uri, + baseHeaders, + mapper, + interceptor, + properties, + configureConnectionManager(properties)); } } diff --git a/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java b/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java index e596df43e6f5..93585cdbb52e 100644 --- a/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java +++ b/core/src/test/java/org/apache/iceberg/rest/TestHTTPClient.java @@ -31,10 +31,13 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; +import java.net.SocketTimeoutException; import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.concurrent.TimeUnit; import java.util.function.Consumer; +import org.apache.hc.client5.http.config.ConnectionConfig; import org.apache.hc.core5.http.EntityDetails; import org.apache.hc.core5.http.HttpException; import org.apache.hc.core5.http.HttpRequestInterceptor; @@ -47,6 +50,8 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockserver.integration.ClientAndServer; import org.mockserver.model.HttpRequest; import org.mockserver.model.HttpResponse; @@ -133,6 +138,71 @@ public void testDynamicHttpRequestInterceptorLoading() { assertThat(((TestHttpRequestInterceptor) interceptor).properties).isEqualTo(properties); } + @Test + public void testSocketAndConnectionTimeoutSet() { + long connectionTimeoutMs = 10L; + int socketTimeoutMs = 10; + Map properties = + ImmutableMap.of( + HTTPClient.REST_CONNECTION_TIMEOUT_MS, String.valueOf(connectionTimeoutMs), + HTTPClient.REST_SOCKET_TIMEOUT_MS, String.valueOf(socketTimeoutMs)); + + ConnectionConfig connectionConfig = HTTPClient.configureConnectionConfig(properties); + assertThat(connectionConfig).isNotNull(); + assertThat(connectionConfig.getConnectTimeout().getDuration()).isEqualTo(connectionTimeoutMs); + assertThat(connectionConfig.getSocketTimeout().getDuration()).isEqualTo(socketTimeoutMs); + } + + @Test + public void testSocketTimeout() throws IOException { + long socketTimeoutMs = 2000L; + Map properties = + ImmutableMap.of(HTTPClient.REST_SOCKET_TIMEOUT_MS, String.valueOf(socketTimeoutMs)); + String path = "socket/timeout/path"; + + try (HTTPClient client = HTTPClient.builder(properties).uri(URI).build()) { + HttpRequest mockRequest = + request() + .withPath("/" + path) + .withMethod(HttpMethod.HEAD.name().toUpperCase(Locale.ROOT)); + // Setting a response delay of 5 seconds to simulate hitting the configured socket timeout of + // 2 seconds + HttpResponse mockResponse = + response() + .withStatusCode(200) + .withBody("Delayed response") + .withDelay(TimeUnit.MILLISECONDS, 5000); + mockServer.when(mockRequest).respond(mockResponse); + + Assertions.assertThatThrownBy(() -> client.head(path, ImmutableMap.of(), (unused) -> {})) + .cause() + .isInstanceOf(SocketTimeoutException.class) + .hasMessage("Read timed out"); + } + } + + @ParameterizedTest + @ValueSource(strings = {HTTPClient.REST_CONNECTION_TIMEOUT_MS, HTTPClient.REST_SOCKET_TIMEOUT_MS}) + public void testInvalidTimeout(String timeoutMsType) { + String invalidTimeoutMs = "invalidMs"; + Assertions.assertThatThrownBy( + () -> + HTTPClient.builder(ImmutableMap.of(timeoutMsType, invalidTimeoutMs)) + .uri(URI) + .build()) + .isInstanceOf(NumberFormatException.class) + .hasMessage(String.format("For input string: \"%s\"", invalidTimeoutMs)); + + String invalidNegativeTimeoutMs = "-1"; + Assertions.assertThatThrownBy( + () -> + HTTPClient.builder(ImmutableMap.of(timeoutMsType, invalidNegativeTimeoutMs)) + .uri(URI) + .build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(String.format("duration must not be negative: %s", invalidNegativeTimeoutMs)); + } + public static void testHttpMethodOnSuccess(HttpMethod method) throws JsonProcessingException { Item body = new Item(0L, "hank"); int statusCode = 200;