diff --git a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java index f91d2226685..3804c10c922 100644 --- a/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java +++ b/benchmarks/jmh/src/jmh/java/com/linecorp/armeria/server/RoutersBenchmark.java @@ -61,7 +61,7 @@ public class RoutersBenchmark { FALLBACK_SERVICE = newServiceConfig(Route.ofCatchAll()); HOST = new VirtualHost( "localhost", "localhost", 0, null, - null, SERVICES, FALLBACK_SERVICE, RejectedRouteHandler.DISABLED, + null, null, SERVICES, FALLBACK_SERVICE, RejectedRouteHandler.DISABLED, unused -> NOPLogger.NOP_LOGGER, FALLBACK_SERVICE.defaultServiceNaming(), FALLBACK_SERVICE.defaultLogName(), 0, 0, false, AccessLogWriter.disabled(), CommonPools.blockingTaskExecutor(), 0, SuccessFunction.ofDefault(), diff --git a/build.gradle b/build.gradle index 2a28be232be..ccb39c888fc 100644 --- a/build.gradle +++ b/build.gradle @@ -117,6 +117,10 @@ allprojects { doFirst { addTestOutputListener({ descriptor, event -> if (event.message.contains('LEAK: ')) { + if (isCi) { + logger.warn("Leak is detected in ${descriptor.className}.${descriptor.displayName}\n" + + "${event.message}") + } hasLeak.set(true) } }) diff --git a/core/src/main/java/com/linecorp/armeria/client/AbstractClientOptionsBuilder.java b/core/src/main/java/com/linecorp/armeria/client/AbstractClientOptionsBuilder.java index ac8d7cd7709..50cd5b966f1 100644 --- a/core/src/main/java/com/linecorp/armeria/client/AbstractClientOptionsBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/AbstractClientOptionsBuilder.java @@ -29,6 +29,8 @@ import java.util.function.Function; import java.util.function.Supplier; +import com.google.common.collect.ImmutableList; + import com.linecorp.armeria.client.endpoint.EndpointGroup; import com.linecorp.armeria.client.redirect.RedirectConfig; import com.linecorp.armeria.common.HttpHeaderNames; @@ -532,20 +534,20 @@ protected final ClientOptions buildOptions() { */ protected final ClientOptions buildOptions(@Nullable ClientOptions baseOptions) { final Collection> optVals = options.values(); - final int numOpts = optVals.size(); - final int extra = contextCustomizer == null ? 3 : 4; - final ClientOptionValue[] optValArray = optVals.toArray(new ClientOptionValue[numOpts + extra]); - optValArray[numOpts] = ClientOptions.DECORATION.newValue(decoration.build()); - optValArray[numOpts + 1] = ClientOptions.HEADERS.newValue(headers.build()); - optValArray[numOpts + 2] = ClientOptions.CONTEXT_HOOK.newValue(contextHook); + final ImmutableList.Builder> additionalValues = + ImmutableList.builder(); + additionalValues.addAll(optVals); + additionalValues.add(ClientOptions.DECORATION.newValue(decoration.build())); + additionalValues.add(ClientOptions.HEADERS.newValue(headers.build())); + additionalValues.add(ClientOptions.CONTEXT_HOOK.newValue(contextHook)); if (contextCustomizer != null) { - optValArray[numOpts + 3] = ClientOptions.CONTEXT_CUSTOMIZER.newValue(contextCustomizer); + additionalValues.add(ClientOptions.CONTEXT_CUSTOMIZER.newValue(contextCustomizer)); } if (baseOptions != null) { - return ClientOptions.of(baseOptions, optValArray); + return ClientOptions.of(baseOptions, additionalValues.build()); } else { - return ClientOptions.of(optValArray); + return ClientOptions.of(additionalValues.build()); } } } diff --git a/core/src/main/java/com/linecorp/armeria/client/Bootstraps.java b/core/src/main/java/com/linecorp/armeria/client/Bootstraps.java index 6849c11f216..ac4616ef6e4 100644 --- a/core/src/main/java/com/linecorp/armeria/client/Bootstraps.java +++ b/core/src/main/java/com/linecorp/armeria/client/Bootstraps.java @@ -26,6 +26,8 @@ import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.internal.common.SslContextFactory; +import com.linecorp.armeria.internal.common.SslContextFactory.SslContextMode; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; @@ -36,56 +38,42 @@ final class Bootstraps { - private final Bootstrap[][] inetBootstraps; - private final Bootstrap @Nullable [][] unixBootstraps; private final EventLoop eventLoop; private final SslContext sslCtxHttp1Only; private final SslContext sslCtxHttp1Or2; + @Nullable + private final SslContextFactory sslContextFactory; + + private final HttpClientFactory clientFactory; + private final Bootstrap inetBaseBootstrap; + @Nullable + private final Bootstrap unixBaseBootstrap; + private final Bootstrap[][] inetBootstraps; + private final Bootstrap @Nullable [][] unixBootstraps; - Bootstraps(HttpClientFactory clientFactory, EventLoop eventLoop, SslContext sslCtxHttp1Or2, - SslContext sslCtxHttp1Only) { + Bootstraps(HttpClientFactory clientFactory, EventLoop eventLoop, + SslContext sslCtxHttp1Or2, SslContext sslCtxHttp1Only, + @Nullable SslContextFactory sslContextFactory) { this.eventLoop = eventLoop; this.sslCtxHttp1Or2 = sslCtxHttp1Or2; this.sslCtxHttp1Only = sslCtxHttp1Only; + this.sslContextFactory = sslContextFactory; + this.clientFactory = clientFactory; + + inetBaseBootstrap = clientFactory.newInetBootstrap(); + inetBaseBootstrap.group(eventLoop); + inetBootstraps = staticBootstrapMap(inetBaseBootstrap); - final Bootstrap inetBaseBootstrap = clientFactory.newInetBootstrap(); - final Bootstrap unixBaseBootstrap = clientFactory.newUnixBootstrap(); - inetBootstraps = newBootstrapMap(inetBaseBootstrap, clientFactory, eventLoop); + unixBaseBootstrap = clientFactory.newUnixBootstrap(); if (unixBaseBootstrap != null) { - unixBootstraps = newBootstrapMap(unixBaseBootstrap, clientFactory, eventLoop); + unixBaseBootstrap.group(eventLoop); + unixBootstraps = staticBootstrapMap(unixBaseBootstrap); } else { unixBootstraps = null; } } - /** - * Returns a {@link Bootstrap} corresponding to the specified {@link SocketAddress} - * {@link SessionProtocol} and {@link SerializationFormat}. - */ - Bootstrap get(SocketAddress remoteAddress, SessionProtocol desiredProtocol, - SerializationFormat serializationFormat) { - if (!httpAndHttpsValues().contains(desiredProtocol)) { - throw new IllegalArgumentException("Unsupported session protocol: " + desiredProtocol); - } - - if (remoteAddress instanceof InetSocketAddress) { - return select(inetBootstraps, desiredProtocol, serializationFormat); - } - - assert remoteAddress instanceof DomainSocketAddress : remoteAddress; - - if (unixBootstraps == null) { - throw new IllegalArgumentException("Domain sockets are not supported by " + - eventLoop.getClass().getName()); - } - - return select(unixBootstraps, desiredProtocol, serializationFormat); - } - - private Bootstrap[][] newBootstrapMap(Bootstrap baseBootstrap, - HttpClientFactory clientFactory, - EventLoop eventLoop) { - baseBootstrap.group(eventLoop); + private Bootstrap[][] staticBootstrapMap(Bootstrap baseBootstrap) { final Set sessionProtocols = httpAndHttpsValues(); final Bootstrap[][] maps = (Bootstrap[][]) Array.newInstance( Bootstrap.class, SessionProtocol.values().length, 2); @@ -93,8 +81,8 @@ private Bootstrap[][] newBootstrapMap(Bootstrap baseBootstrap, // which will help us find a bug. for (SessionProtocol p : sessionProtocols) { final SslContext sslCtx = determineSslContext(p); - setBootstrap(baseBootstrap.clone(), clientFactory, maps, p, sslCtx, true); - setBootstrap(baseBootstrap.clone(), clientFactory, maps, p, sslCtx, false); + createAndSetBootstrap(baseBootstrap, maps, p, sslCtx, true); + createAndSetBootstrap(baseBootstrap, maps, p, sslCtx, false); } return maps; } @@ -106,22 +94,18 @@ SslContext determineSslContext(SessionProtocol desiredProtocol) { return desiredProtocol.isExplicitHttp1() ? sslCtxHttp1Only : sslCtxHttp1Or2; } - private static Bootstrap select(Bootstrap[][] bootstraps, SessionProtocol desiredProtocol, - SerializationFormat serializationFormat) { + private Bootstrap select(boolean isDomainSocket, SessionProtocol desiredProtocol, + SerializationFormat serializationFormat) { + final Bootstrap[][] bootstraps = isDomainSocket ? unixBootstraps : inetBootstraps; + assert bootstraps != null; return bootstraps[desiredProtocol.ordinal()][toIndex(serializationFormat)]; } - private static void setBootstrap(Bootstrap bootstrap, HttpClientFactory clientFactory, Bootstrap[][] maps, - SessionProtocol p, SslContext sslCtx, boolean webSocket) { - bootstrap.handler(new ChannelInitializer() { - @Override - protected void initChannel(Channel ch) throws Exception { - ch.pipeline().addLast(new HttpClientPipelineConfigurator( - clientFactory, webSocket, p, sslCtx)); - } - } - ); - maps[p.ordinal()][toIndex(webSocket)] = bootstrap; + private void createAndSetBootstrap(Bootstrap baseBootstrap, Bootstrap[][] maps, + SessionProtocol desiredProtocol, SslContext sslContext, + boolean webSocket) { + maps[desiredProtocol.ordinal()][toIndex(webSocket)] = newBootstrap(baseBootstrap, desiredProtocol, + sslContext, webSocket, false); } private static int toIndex(boolean webSocket) { @@ -131,4 +115,92 @@ private static int toIndex(boolean webSocket) { private static int toIndex(SerializationFormat serializationFormat) { return toIndex(serializationFormat == SerializationFormat.WS); } + + /** + * Returns a {@link Bootstrap} corresponding to the specified {@link SocketAddress} + * {@link SessionProtocol} and {@link SerializationFormat}. + */ + Bootstrap getOrCreate(SocketAddress remoteAddress, SessionProtocol desiredProtocol, + SerializationFormat serializationFormat) { + if (!httpAndHttpsValues().contains(desiredProtocol)) { + throw new IllegalArgumentException("Unsupported session protocol: " + desiredProtocol); + } + + final boolean isDomainSocket = remoteAddress instanceof DomainSocketAddress; + if (isDomainSocket && unixBaseBootstrap == null) { + throw new IllegalArgumentException("Domain sockets are not supported by " + + eventLoop.getClass().getName()); + } + + if (sslContextFactory == null || !desiredProtocol.isTls()) { + return select(isDomainSocket, desiredProtocol, serializationFormat); + } + + final Bootstrap baseBootstrap = isDomainSocket ? unixBaseBootstrap : inetBaseBootstrap; + assert baseBootstrap != null; + return newBootstrap(baseBootstrap, remoteAddress, desiredProtocol, serializationFormat); + } + + private Bootstrap newBootstrap(Bootstrap baseBootstrap, SocketAddress remoteAddress, + SessionProtocol desiredProtocol, + SerializationFormat serializationFormat) { + final boolean webSocket = serializationFormat == SerializationFormat.WS; + final SslContext sslContext = newSslContext(remoteAddress, desiredProtocol); + return newBootstrap(baseBootstrap, desiredProtocol, sslContext, webSocket, true); + } + + private Bootstrap newBootstrap(Bootstrap baseBootstrap, SessionProtocol desiredProtocol, + SslContext sslContext, boolean webSocket, boolean closeSslContext) { + final Bootstrap bootstrap = baseBootstrap.clone(); + bootstrap.handler(clientChannelInitializer(desiredProtocol, sslContext, webSocket, closeSslContext)); + return bootstrap; + } + + SslContext getOrCreateSslContext(SocketAddress remoteAddress, SessionProtocol desiredProtocol) { + if (sslContextFactory == null) { + return determineSslContext(desiredProtocol); + } else { + return newSslContext(remoteAddress, desiredProtocol); + } + } + + private SslContext newSslContext(SocketAddress remoteAddress, SessionProtocol desiredProtocol) { + final String hostname; + if (remoteAddress instanceof InetSocketAddress) { + hostname = ((InetSocketAddress) remoteAddress).getHostString(); + } else { + assert remoteAddress instanceof DomainSocketAddress; + hostname = "unix:" + ((DomainSocketAddress) remoteAddress).path(); + } + + final SslContextMode sslContextMode = + desiredProtocol.isExplicitHttp1() ? SslContextFactory.SslContextMode.CLIENT_HTTP1_ONLY + : SslContextFactory.SslContextMode.CLIENT; + assert sslContextFactory != null; + return sslContextFactory.getOrCreate(sslContextMode, hostname); + } + + boolean shouldReleaseSslContext(SslContext sslContext) { + return sslContext != sslCtxHttp1Only && sslContext != sslCtxHttp1Or2; + } + + void releaseSslContext(SslContext sslContext) { + if (sslContextFactory != null) { + sslContextFactory.release(sslContext); + } + } + + private ChannelInitializer clientChannelInitializer(SessionProtocol p, SslContext sslCtx, + boolean webSocket, boolean closeSslContext) { + return new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + if (closeSslContext) { + ch.closeFuture().addListener(unused -> releaseSslContext(sslCtx)); + } + ch.pipeline().addLast(new HttpClientPipelineConfigurator( + clientFactory, webSocket, p, sslCtx)); + } + }; + } } diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientFactoryBuilder.java b/core/src/main/java/com/linecorp/armeria/client/ClientFactoryBuilder.java index 92a28d5c825..029ffab983e 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientFactoryBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientFactoryBuilder.java @@ -24,10 +24,7 @@ import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_INITIAL_WINDOW_SIZE; import static java.util.Objects.requireNonNull; -import java.io.ByteArrayInputStream; import java.io.File; -import java.io.IOError; -import java.io.IOException; import java.io.InputStream; import java.net.InetSocketAddress; import java.net.ProxySelector; @@ -53,7 +50,6 @@ import com.google.common.base.MoreObjects.ToStringHelper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.io.ByteStreams; import com.google.common.primitives.Ints; import com.linecorp.armeria.client.proxy.ProxyConfig; @@ -63,12 +59,15 @@ import com.linecorp.armeria.common.Http1HeaderNaming; import com.linecorp.armeria.common.Request; import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.TlsKeyPair; +import com.linecorp.armeria.common.TlsProvider; import com.linecorp.armeria.common.TlsSetters; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.outlier.OutlierDetection; import com.linecorp.armeria.common.util.EventLoopGroups; import com.linecorp.armeria.common.util.TlsEngineType; +import com.linecorp.armeria.internal.common.IgnoreHostsTrustManager; import com.linecorp.armeria.internal.common.RequestContextUtil; import com.linecorp.armeria.internal.common.util.ChannelUtil; @@ -127,6 +126,11 @@ public final class ClientFactoryBuilder implements TlsSetters { private final List> maxNumEventLoopsFunctions = new ArrayList<>(); private boolean tlsNoVerifySet; private final Set insecureHosts = new HashSet<>(); + @Nullable + private TlsProvider tlsProvider; + @Nullable + private ClientTlsConfig tlsConfig; + private boolean staticTlsSettingsSet; ClientFactoryBuilder() { connectTimeoutMillis(Flags.defaultConnectTimeoutMillis()); @@ -286,6 +290,7 @@ private void channelOptions(Map, Object> newChannelOptions) { */ public ClientFactoryBuilder tlsNoVerify() { checkState(insecureHosts.isEmpty(), "tlsNoVerify() and tlsNoVerifyHosts() are mutually exclusive."); + ensureNoTlsProvider(); tlsNoVerifySet = true; return this; } @@ -299,6 +304,7 @@ public ClientFactoryBuilder tlsNoVerify() { */ public ClientFactoryBuilder tlsNoVerifyHosts(String... insecureHosts) { checkState(!tlsNoVerifySet, "tlsNoVerify() and tlsNoVerifyHosts() are mutually exclusive."); + ensureNoTlsProvider(); this.insecureHosts.addAll(Arrays.asList(insecureHosts)); return this; } @@ -306,7 +312,10 @@ public ClientFactoryBuilder tlsNoVerifyHosts(String... insecureHosts) { /** * Configures SSL or TLS for client certificate authentication with the specified {@code keyCertChainFile} * and cleartext {@code keyFile}. + * + * @deprecated Use {@link #tls(TlsKeyPair)} or {@link #tlsProvider(TlsProvider)} instead. */ + @Deprecated @Override public ClientFactoryBuilder tls(File keyCertChainFile, File keyFile) { return (ClientFactoryBuilder) TlsSetters.super.tls(keyCertChainFile, keyFile); @@ -315,18 +324,22 @@ public ClientFactoryBuilder tls(File keyCertChainFile, File keyFile) { /** * Configures SSL or TLS for client certificate authentication with the specified {@code keyCertChainFile}, * {@code keyFile} and {@code keyPassword}. + * + * @deprecated Use {@link #tls(TlsKeyPair)} or {@link #tlsProvider(TlsProvider)} instead. */ + @Deprecated @Override public ClientFactoryBuilder tls(File keyCertChainFile, File keyFile, @Nullable String keyPassword) { - requireNonNull(keyCertChainFile, "keyCertChainFile"); - requireNonNull(keyFile, "keyFile"); - return tlsCustomizer(customizer -> customizer.keyManager(keyCertChainFile, keyFile, keyPassword)); + return (ClientFactoryBuilder) TlsSetters.super.tls(keyCertChainFile, keyFile, keyPassword); } /** * Configures SSL or TLS for client certificate authentication with the specified * {@code keyCertChainInputStream} and cleartext {@code keyInputStream}. + * + * @deprecated Use {@link #tls(TlsKeyPair)} or {@link #tlsProvider(TlsProvider)} instead. */ + @Deprecated @Override public ClientFactoryBuilder tls(InputStream keyCertChainInputStream, InputStream keyInputStream) { return (ClientFactoryBuilder) TlsSetters.super.tls(keyCertChainInputStream, keyInputStream); @@ -335,32 +348,26 @@ public ClientFactoryBuilder tls(InputStream keyCertChainInputStream, InputStream /** * Configures SSL or TLS for client certificate authentication with the specified * {@code keyCertChainInputStream} and {@code keyInputStream} and {@code keyPassword}. + * + * @deprecated Use {@link #tls(TlsKeyPair)} or {@link #tlsProvider(TlsProvider)} instead. */ + @Deprecated @Override public ClientFactoryBuilder tls(InputStream keyCertChainInputStream, InputStream keyInputStream, @Nullable String keyPassword) { requireNonNull(keyCertChainInputStream, "keyCertChainInputStream"); requireNonNull(keyInputStream, "keyInputStream"); - - // Retrieve the content of the given streams so that they can be consumed more than once. - final byte[] keyCertChain; - final byte[] key; - try { - keyCertChain = ByteStreams.toByteArray(keyCertChainInputStream); - key = ByteStreams.toByteArray(keyInputStream); - } catch (IOException e) { - throw new IOError(e); - } - - return tlsCustomizer(customizer -> customizer.keyManager(new ByteArrayInputStream(keyCertChain), - new ByteArrayInputStream(key), - keyPassword)); + return (ClientFactoryBuilder) TlsSetters.super.tls(keyCertChainInputStream, keyInputStream, + keyPassword); } /** * Configures SSL or TLS for client certificate authentication with the specified cleartext * {@link PrivateKey} and {@link X509Certificate} chain. + * + * @deprecated Use {@link #tls(TlsKeyPair)} or {@link #tlsProvider(TlsProvider)} instead. */ + @Deprecated @Override public ClientFactoryBuilder tls(PrivateKey key, X509Certificate... keyCertChain) { return (ClientFactoryBuilder) TlsSetters.super.tls(key, keyCertChain); @@ -369,7 +376,10 @@ public ClientFactoryBuilder tls(PrivateKey key, X509Certificate... keyCertChain) /** * Configures SSL or TLS for client certificate authentication with the specified cleartext * {@link PrivateKey} and {@link X509Certificate} chain. + * + * @deprecated Use {@link #tls(TlsKeyPair)} with {@link TlsKeyPair#of(PrivateKey, Iterable)} instead. */ + @Deprecated @Override public ClientFactoryBuilder tls(PrivateKey key, Iterable keyCertChain) { return (ClientFactoryBuilder) TlsSetters.super.tls(key, keyCertChain); @@ -378,7 +388,11 @@ public ClientFactoryBuilder tls(PrivateKey key, Iterable keyCertChain) { - requireNonNull(key, "key"); - requireNonNull(keyCertChain, "keyCertChain"); - - for (X509Certificate keyCert : keyCertChain) { - requireNonNull(keyCert, "keyCertChain contains null."); - } + return (ClientFactoryBuilder) TlsSetters.super.tls(key, keyPassword, keyCertChain); + } - return tlsCustomizer(customizer -> customizer.keyManager(key, keyPassword, keyCertChain)); + /** + * Configures SSL or TLS for client certificate authentication with the specified {@link TlsKeyPair}. + */ + @Override + public ClientFactoryBuilder tls(TlsKeyPair tlsKeyPair) { + requireNonNull(tlsKeyPair, "tlsKeyPair"); + return tlsCustomizer(customizer -> customizer.keyManager(tlsKeyPair.privateKey(), + tlsKeyPair.certificateChain())); } /** @@ -420,6 +440,8 @@ public ClientFactoryBuilder tls(KeyManagerFactory keyManagerFactory) { @Override public ClientFactoryBuilder tlsCustomizer(Consumer tlsCustomizer) { requireNonNull(tlsCustomizer, "tlsCustomizer"); + ensureNoTlsProvider(); + staticTlsSettingsSet = true; @SuppressWarnings("unchecked") final ClientFactoryOptionValue> oldTlsCustomizerValue = (ClientFactoryOptionValue>) @@ -439,6 +461,44 @@ public ClientFactoryBuilder tlsCustomizer(Consumer tl return this; } + /** + * Sets the {@link TlsProvider} that provides {@link TlsKeyPair}s for client certificate authentication. + *
+     * ClientFactory
+     *   .builder()
+     *   .tlsProvider(
+     *     TlsProvider.builder()
+     *                // Set the default key pair.
+     *                .keyPair(TlsKeyPair.of(...))
+     *                // Set the key pair for "example.com".
+     *                .keyPair("example.com", TlsKeyPair.of(...))
+     *                .build())
+     * 
+ */ + @UnstableApi + public ClientFactoryBuilder tlsProvider(TlsProvider tlsProvider) { + requireNonNull(tlsProvider, "tlsProvider"); + checkState(!staticTlsSettingsSet, + "Cannot configure the TlsProvider because static TLS settings have been set already."); + this.tlsProvider = tlsProvider; + tlsConfig = null; + return this; + } + + /** + * Sets the {@link TlsProvider} that provides {@link TlsKeyPair}s for client certificate authentication. + */ + @UnstableApi + public ClientFactoryBuilder tlsProvider(TlsProvider tlsProvider, ClientTlsConfig tlsConfig) { + tlsProvider(tlsProvider); + this.tlsConfig = requireNonNull(tlsConfig, "tlsConfig"); + return this; + } + + private void ensureNoTlsProvider() { + checkState(tlsProvider == null, "Cannot configure TLS settings because a TlsProvider has been set."); + } + /** * Allows the bad cipher suites listed in * RFC7540 for TLS handshake. @@ -959,10 +1019,17 @@ private ClientFactoryOptions buildOptions() { return ClientFactoryOptions.ADDRESS_RESOLVER_GROUP_FACTORY.newValue(addressResolverGroupFactory); }); - if (tlsNoVerifySet) { - tlsCustomizer(b -> b.trustManager(InsecureTrustManagerFactory.INSTANCE)); - } else if (!insecureHosts.isEmpty()) { - tlsCustomizer(b -> b.trustManager(IgnoreHostsTrustManager.of(insecureHosts))); + if (tlsProvider != null) { + option(ClientFactoryOptions.TLS_PROVIDER, tlsProvider); + if (tlsConfig != null) { + option(ClientFactoryOptions.TLS_CONFIG, tlsConfig); + } + } else { + if (tlsNoVerifySet) { + tlsCustomizer(b -> b.trustManager(InsecureTrustManagerFactory.INSTANCE)); + } else if (!insecureHosts.isEmpty()) { + tlsCustomizer(b -> b.trustManager(IgnoreHostsTrustManager.of(insecureHosts))); + } } final ClientFactoryOptions newOptions = ClientFactoryOptions.of(options.values()); diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientFactoryOptions.java b/core/src/main/java/com/linecorp/armeria/client/ClientFactoryOptions.java index 5042a49bdd5..ae12b569dd6 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientFactoryOptions.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientFactoryOptions.java @@ -35,6 +35,8 @@ import com.linecorp.armeria.common.Flags; import com.linecorp.armeria.common.Http1HeaderNaming; import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.TlsKeyPair; +import com.linecorp.armeria.common.TlsProvider; import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.outlier.OutlierDetection; import com.linecorp.armeria.common.util.AbstractOptions; @@ -46,6 +48,7 @@ import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoop; import io.netty.channel.EventLoopGroup; +import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.resolver.AddressResolverGroup; @@ -107,6 +110,21 @@ public final class ClientFactoryOptions public static final ClientFactoryOption TLS_ENGINE_TYPE = ClientFactoryOption.define("tlsEngineType", Flags.tlsEngineType()); + /** + * The {@link TlsProvider} which provides the {@link TlsKeyPair} to create the + * {@link SslContext} for TLS handshake. + */ + @UnstableApi + public static final ClientFactoryOption TLS_PROVIDER = + ClientFactoryOption.define("TLS_PROVIDER", NullTlsProvider.INSTANCE); + + /** + * Ths {@link ClientTlsConfig} which is used to configure the client-side TLS. + */ + @UnstableApi + public static final ClientFactoryOption TLS_CONFIG = + ClientFactoryOption.define("TLS_CONFIG", ClientTlsConfig.NOOP); + /** * The factory that creates an {@link AddressResolverGroup} which resolves remote addresses into * {@link InetSocketAddress}es. @@ -654,6 +672,23 @@ public TlsEngineType tlsEngineType() { return get(TLS_ENGINE_TYPE); } + /** + * Returns the {@link TlsProvider} which provides the {@link TlsKeyPair} that is used to create the + * {@link SslContext} for TLS handshake. + */ + @UnstableApi + public TlsProvider tlsProvider() { + return get(TLS_PROVIDER); + } + + /** + * Returns the {@link ClientTlsConfig} which is used to configure the client-side {@link SslContext}. + */ + @UnstableApi + public ClientTlsConfig tlsConfig() { + return get(TLS_CONFIG); + } + /** * The {@link Consumer} that customizes the Netty {@link ChannelPipeline}. * This customizer is run right before {@link ChannelPipeline#connect(SocketAddress)} diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientTlsConfig.java b/core/src/main/java/com/linecorp/armeria/client/ClientTlsConfig.java new file mode 100644 index 00000000000..dea42bafd1c --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/ClientTlsConfig.java @@ -0,0 +1,104 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import java.util.Objects; +import java.util.Set; +import java.util.function.Consumer; + +import com.google.common.base.MoreObjects; + +import com.linecorp.armeria.common.AbstractTlsConfig; +import com.linecorp.armeria.common.TlsProvider; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.common.metric.MeterIdPrefix; + +import io.netty.handler.ssl.SslContextBuilder; + +/** + * Provides client-side TLS configuration for {@link TlsProvider}. + */ +@UnstableApi +public final class ClientTlsConfig extends AbstractTlsConfig { + + static final ClientTlsConfig NOOP = builder().build(); + + /** + * Returns a new {@link ClientTlsConfigBuilder}. + */ + public static ClientTlsConfigBuilder builder() { + return new ClientTlsConfigBuilder(); + } + + private final boolean tlsNoVerifySet; + private final Set insecureHosts; + + ClientTlsConfig(boolean allowsUnsafeCiphers, @Nullable MeterIdPrefix meterIdPrefix, + Consumer tlsCustomizer, boolean tlsNoVerifySet, + Set insecureHosts) { + super(allowsUnsafeCiphers, meterIdPrefix, tlsCustomizer); + this.tlsNoVerifySet = tlsNoVerifySet; + this.insecureHosts = insecureHosts; + } + + /** + * Returns whether the verification of server's TLS certificate chain is disabled. + */ + public boolean tlsNoVerifySet() { + return tlsNoVerifySet; + } + + /** + * Returns the hosts for which the verification of server's TLS certificate chain is disabled. + */ + public Set insecureHosts() { + return insecureHosts; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ClientTlsConfig)) { + return false; + } + if (!super.equals(o)) { + return false; + } + + final ClientTlsConfig that = (ClientTlsConfig) o; + return tlsNoVerifySet == that.tlsNoVerifySet && insecureHosts.equals(that.insecureHosts); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), tlsNoVerifySet, insecureHosts); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("allowsUnsafeCiphers", allowsUnsafeCiphers()) + .add("meterIdPrefix", meterIdPrefix()) + .add("tlsCustomizer", tlsCustomizer()) + .add("tlsNoVerifySet", tlsNoVerifySet) + .add("insecureHosts", insecureHosts) + .toString(); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientTlsConfigBuilder.java b/core/src/main/java/com/linecorp/armeria/client/ClientTlsConfigBuilder.java new file mode 100644 index 00000000000..06ec1fd8a88 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/ClientTlsConfigBuilder.java @@ -0,0 +1,95 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +import java.util.HashSet; +import java.util.Set; +import java.util.function.Consumer; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import com.linecorp.armeria.common.AbstractTlsConfigBuilder; +import com.linecorp.armeria.common.annotation.UnstableApi; + +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; + +/** + * A builder class for creating a {@link ClientTlsConfig}. + */ +@UnstableApi +public final class ClientTlsConfigBuilder extends AbstractTlsConfigBuilder { + + private boolean tlsNoVerifySet; + private final Set insecureHosts = new HashSet<>(); + + ClientTlsConfigBuilder() {} + + /** + * Disables the verification of server's TLS certificate chain. If you want to disable verification for + * only specific hosts, use {@link #tlsNoVerifyHosts(String...)}. + * + *

Note: You should never use this in production but only for a testing purpose. + * + * @see InsecureTrustManagerFactory + * @see #tlsCustomizer(Consumer) + */ + public ClientTlsConfigBuilder tlsNoVerify() { + tlsNoVerifySet = true; + checkState(insecureHosts.isEmpty(), "tlsNoVerify() and tlsNoVerifyHosts() are mutually exclusive."); + return this; + } + + /** + * Disables the verification of server's TLS certificate chain for specific hosts. If you want to disable + * all verification, use {@link #tlsNoVerify()} . + * + *

Note: You should never use this in production but only for a testing purpose. + * + * @see #tlsCustomizer(Consumer) + */ + public ClientTlsConfigBuilder tlsNoVerifyHosts(String... insecureHosts) { + requireNonNull(insecureHosts, "insecureHosts"); + return tlsNoVerifyHosts(ImmutableList.copyOf(insecureHosts)); + } + + /** + * Disables the verification of server's TLS certificate chain for specific hosts. If you want to disable + * all verification, use {@link #tlsNoVerify()} . + * + *

Note: You should never use this in production but only for a testing purpose. + * + * @see #tlsCustomizer(Consumer) + */ + public ClientTlsConfigBuilder tlsNoVerifyHosts(Iterable insecureHosts) { + requireNonNull(insecureHosts, "insecureHosts"); + checkState(!tlsNoVerifySet, "tlsNoVerify() and tlsNoVerifyHosts() are mutually exclusive."); + insecureHosts.forEach(this.insecureHosts::add); + return this; + } + + /** + * Returns a newly-created {@link ClientTlsConfig} based on the properties of this builder. + */ + public ClientTlsConfig build() { + return new ClientTlsConfig(allowsUnsafeCiphers(), meterIdPrefix(), tlsCustomizer(), + tlsNoVerifySet, ImmutableSet.copyOf(insecureHosts)); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpChannelPool.java b/core/src/main/java/com/linecorp/armeria/client/HttpChannelPool.java index 57e9fa13261..2bd867e93f5 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpChannelPool.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpChannelPool.java @@ -55,6 +55,7 @@ import com.linecorp.armeria.common.util.AsyncCloseableSupport; import com.linecorp.armeria.internal.client.HttpSession; import com.linecorp.armeria.internal.client.PooledChannel; +import com.linecorp.armeria.internal.common.SslContextFactory; import com.linecorp.armeria.internal.common.util.ChannelUtil; import com.linecorp.armeria.internal.common.util.TemporaryThreadLocals; @@ -101,6 +102,7 @@ final class HttpChannelPool implements AsyncCloseable { HttpChannelPool(HttpClientFactory clientFactory, EventLoop eventLoop, SslContext sslCtxHttp1Or2, SslContext sslCtxHttp1Only, + @Nullable SslContextFactory sslContextFactory, ConnectionPoolListener listener) { this.clientFactory = clientFactory; this.eventLoop = eventLoop; @@ -116,7 +118,8 @@ final class HttpChannelPool implements AsyncCloseable { .get(ChannelOption.CONNECT_TIMEOUT_MILLIS); assert connectTimeoutMillisBoxed != null; connectTimeoutMillis = connectTimeoutMillisBoxed; - bootstraps = new Bootstraps(clientFactory, eventLoop, sslCtxHttp1Or2, sslCtxHttp1Only); + bootstraps = new Bootstraps(clientFactory, eventLoop, sslCtxHttp1Or2, sslCtxHttp1Only, + sslContextFactory); } private void configureProxy(Channel ch, ProxyConfig proxyConfig, SessionProtocol desiredProtocol) { @@ -157,8 +160,11 @@ private void configureProxy(Channel ch, ProxyConfig proxyConfig, SessionProtocol ch.pipeline().addFirst(proxyHandler); if (proxyConfig instanceof ConnectProxyConfig && ((ConnectProxyConfig) proxyConfig).useTls()) { - final SslContext sslCtx = bootstraps.determineSslContext(desiredProtocol); + final SslContext sslCtx = bootstraps.getOrCreateSslContext(proxyAddress, desiredProtocol); ch.pipeline().addFirst(sslCtx.newHandler(ch.alloc())); + if (bootstraps.shouldReleaseSslContext(sslCtx)) { + ch.closeFuture().addListener(unused -> bootstraps.releaseSslContext(sslCtx)); + } } } @@ -382,7 +388,7 @@ void connect(SocketAddress remoteAddress, SessionProtocol desiredProtocol, @Nullable ClientConnectionTimingsBuilder timingsBuilder) { final Bootstrap bootstrap; try { - bootstrap = bootstraps.get(remoteAddress, desiredProtocol, serializationFormat); + bootstrap = bootstraps.getOrCreate(remoteAddress, desiredProtocol, serializationFormat); } catch (Exception e) { sessionPromise.tryFailure(e); return; diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java b/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java index ee65a69f054..d4d7aacb279 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpClientFactory.java @@ -36,7 +36,6 @@ import org.slf4j.LoggerFactory; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; import com.google.common.collect.MapMaker; import com.linecorp.armeria.client.endpoint.EndpointGroup; @@ -47,6 +46,7 @@ import com.linecorp.armeria.common.Scheme; import com.linecorp.armeria.common.SerializationFormat; import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.TlsProvider; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.metric.MeterIdPrefix; import com.linecorp.armeria.common.metric.MoreMeterBinders; @@ -56,6 +56,7 @@ import com.linecorp.armeria.common.util.TlsEngineType; import com.linecorp.armeria.common.util.TransportType; import com.linecorp.armeria.internal.common.RequestTargetCache; +import com.linecorp.armeria.internal.common.SslContextFactory; import com.linecorp.armeria.internal.common.util.ChannelUtil; import com.linecorp.armeria.internal.common.util.SslContextUtil; @@ -89,12 +90,12 @@ final class HttpClientFactory implements ClientFactory { private static void setupTlsMetrics(List certificates, MeterRegistry registry) { final MeterIdPrefix meterIdPrefix = new MeterIdPrefix("armeria.client"); - try { - MoreMeterBinders.certificateMetrics(certificates, meterIdPrefix) - .bindTo(registry); - } catch (Exception ex) { - logger.warn("Failed to set up TLS certificate metrics: {}", certificates, ex); - } + try { + MoreMeterBinders.certificateMetrics(certificates, meterIdPrefix) + .bindTo(registry); + } catch (Exception ex) { + logger.warn("Failed to set up TLS certificate metrics: {}", certificates, ex); + } } private final EventLoopGroup workerGroup; @@ -104,6 +105,8 @@ private static void setupTlsMetrics(List certificates, MeterReg private final Bootstrap unixBaseBootstrap; private final SslContext sslCtxHttp1Or2; private final SslContext sslCtxHttp1Only; + @Nullable + private final SslContextFactory sslContextFactory; private final AddressResolverGroup addressResolverGroup; private final int http2InitialConnectionWindowSize; private final int http2InitialStreamWindowSize; @@ -176,19 +179,31 @@ private static void setupTlsMetrics(List certificates, MeterReg unixBaseBootstrap = null; } - final ImmutableList> tlsCustomizers = - ImmutableList.of(options.tlsCustomizer()); + final Consumer tlsCustomizer = + options.tlsCustomizer(); final boolean tlsAllowUnsafeCiphers = options.tlsAllowUnsafeCiphers(); final List keyCertChainCaptor = new ArrayList<>(); final TlsEngineType tlsEngineType = options.tlsEngineType(); sslCtxHttp1Or2 = SslContextUtil .createSslContext(SslContextBuilder::forClient, false, tlsEngineType, - tlsAllowUnsafeCiphers, tlsCustomizers, keyCertChainCaptor); + tlsAllowUnsafeCiphers, tlsCustomizer, keyCertChainCaptor); sslCtxHttp1Only = SslContextUtil .createSslContext(SslContextBuilder::forClient, true, tlsEngineType, - tlsAllowUnsafeCiphers, tlsCustomizers, keyCertChainCaptor); + tlsAllowUnsafeCiphers, tlsCustomizer, keyCertChainCaptor); setupTlsMetrics(keyCertChainCaptor, options.meterRegistry()); + final TlsProvider tlsProvider = options.tlsProvider(); + if (tlsProvider != NullTlsProvider.INSTANCE) { + ClientTlsConfig clientTlsConfig = options.tlsConfig(); + if (clientTlsConfig == ClientTlsConfig.NOOP) { + clientTlsConfig = null; + } + sslContextFactory = new SslContextFactory(tlsProvider, options.tlsEngineType(), clientTlsConfig, + options.meterRegistry()); + } else { + sslContextFactory = null; + } + http2InitialConnectionWindowSize = options.http2InitialConnectionWindowSize(); http2InitialStreamWindowSize = options.http2InitialStreamWindowSize(); http2MaxFrameSize = options.http2MaxFrameSize(); @@ -495,6 +510,13 @@ HttpChannelPool pool(EventLoop eventLoop) { return pools.computeIfAbsent(eventLoop, e -> new HttpChannelPool(this, eventLoop, sslCtxHttp1Or2, sslCtxHttp1Only, + sslContextFactory, connectionPoolListener())); } + + @VisibleForTesting + @Nullable + SslContextFactory sslContextFactory() { + return sslContextFactory; + } } diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java b/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java index 4d742e06a24..7f119e28da4 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpClientPipelineConfigurator.java @@ -159,7 +159,7 @@ private enum HttpPreference { HttpClientPipelineConfigurator(HttpClientFactory clientFactory, boolean webSocket, SessionProtocol sessionProtocol, - @Nullable SslContext sslCtx) { + SslContext sslCtx) { this.clientFactory = clientFactory; this.webSocket = webSocket; diff --git a/core/src/main/java/com/linecorp/armeria/client/NullTlsProvider.java b/core/src/main/java/com/linecorp/armeria/client/NullTlsProvider.java new file mode 100644 index 00000000000..90519034de4 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/client/NullTlsProvider.java @@ -0,0 +1,30 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import com.linecorp.armeria.common.TlsKeyPair; +import com.linecorp.armeria.common.TlsProvider; +import com.linecorp.armeria.common.annotation.Nullable; + +enum NullTlsProvider implements TlsProvider { + INSTANCE; + + @Override + public @Nullable TlsKeyPair keyPair(String hostname) { + return null; + } +} diff --git a/core/src/main/java/com/linecorp/armeria/common/AbstractTlsConfig.java b/core/src/main/java/com/linecorp/armeria/common/AbstractTlsConfig.java new file mode 100644 index 00000000000..ff5dd754fd3 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/common/AbstractTlsConfig.java @@ -0,0 +1,92 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common; + +import java.util.Objects; +import java.util.function.Consumer; + +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.common.metric.MeterIdPrefix; + +import io.netty.handler.ssl.SslContextBuilder; + +/** + * Provides common configuration for TLS. + */ +@UnstableApi +public abstract class AbstractTlsConfig { + + private final boolean allowsUnsafeCiphers; + + @Nullable + private final MeterIdPrefix meterIdPrefix; + + private final Consumer tlsCustomizer; + + /** + * Creates a new instance. + */ + protected AbstractTlsConfig(boolean allowsUnsafeCiphers, @Nullable MeterIdPrefix meterIdPrefix, + Consumer tlsCustomizer) { + this.allowsUnsafeCiphers = allowsUnsafeCiphers; + this.meterIdPrefix = meterIdPrefix; + this.tlsCustomizer = tlsCustomizer; + } + + /** + * Returns whether to allow the bad cipher suites listed in + * RFC7540 for TLS handshake. + */ + public final boolean allowsUnsafeCiphers() { + return allowsUnsafeCiphers; + } + + /** + * Sets the {@link MeterIdPrefix} for the TLS metrics. + */ + @Nullable + public final MeterIdPrefix meterIdPrefix() { + return meterIdPrefix; + } + + /** + * Returns the {@link Consumer} which can arbitrarily configure the {@link SslContextBuilder}. + */ + public final Consumer tlsCustomizer() { + return tlsCustomizer; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof AbstractTlsConfig)) { + return false; + } + final AbstractTlsConfig that = (AbstractTlsConfig) o; + return allowsUnsafeCiphers == that.allowsUnsafeCiphers && + Objects.equals(meterIdPrefix, that.meterIdPrefix) && + tlsCustomizer.equals(that.tlsCustomizer); + } + + @Override + public int hashCode() { + return Objects.hash(allowsUnsafeCiphers, meterIdPrefix, tlsCustomizer); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/common/AbstractTlsConfigBuilder.java b/core/src/main/java/com/linecorp/armeria/common/AbstractTlsConfigBuilder.java new file mode 100644 index 00000000000..77e51eb2df0 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/common/AbstractTlsConfigBuilder.java @@ -0,0 +1,123 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common; + +import static java.util.Objects.requireNonNull; + +import java.util.function.Consumer; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.TrustManagerFactory; + +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.common.metric.MeterIdPrefix; + +import io.netty.handler.ssl.SslContextBuilder; + +/** + * A skeletal builder implementation for {@link TlsProvider}. + */ +@UnstableApi +public abstract class AbstractTlsConfigBuilder> { + + private static final Consumer NOOP = b -> {}; + + private boolean allowsUnsafeCiphers; + private Consumer tlsCustomizer = NOOP; + @Nullable + private MeterIdPrefix meterIdPrefix; + + /** + * Creates a new instance. + */ + protected AbstractTlsConfigBuilder() {} + + /** + * Allows the bad cipher suites listed in + * RFC7540 for TLS handshake. + * + *

Note that enabling this option increases the security risk of your connection. + * Use it only when you must communicate with a legacy system that does not support + * secure cipher suites. + * See Section 9.2.2, RFC7540 for + * more information. This option is disabled by default. + * + * @param allowsUnsafeCiphers Whether to allow the unsafe ciphers + * + * @deprecated It's not recommended to enable this option. Use it only when you have no other way to + * communicate with an insecure peer than this. + */ + @Deprecated + public SELF allowsUnsafeCiphers(boolean allowsUnsafeCiphers) { + this.allowsUnsafeCiphers = allowsUnsafeCiphers; + return self(); + } + + /** + * Returns whether to allow the bad cipher suites listed in + * RFC7540 for TLS handshake. + */ + protected final boolean allowsUnsafeCiphers() { + return allowsUnsafeCiphers; + } + + /** + * Adds the {@link Consumer} which can arbitrarily configure the {@link SslContextBuilder} that will be + * applied to the SSL session. For example, use {@link SslContextBuilder#trustManager(TrustManagerFactory)} + * to configure a custom server CA or {@link SslContextBuilder#keyManager(KeyManagerFactory)} to configure + * a client certificate for SSL authorization. + */ + public SELF tlsCustomizer(Consumer tlsCustomizer) { + requireNonNull(tlsCustomizer, "tlsCustomizer"); + if (this.tlsCustomizer == NOOP) { + //noinspection unchecked + this.tlsCustomizer = (Consumer) tlsCustomizer; + } else { + this.tlsCustomizer = this.tlsCustomizer.andThen(tlsCustomizer); + } + return self(); + } + + /** + * Returns the {@link Consumer} which can arbitrarily configure the {@link SslContextBuilder}. + */ + protected final Consumer tlsCustomizer() { + return tlsCustomizer; + } + + /** + * Sets the {@link MeterIdPrefix} for the TLS metrics. + */ + public SELF meterIdPrefix(MeterIdPrefix meterIdPrefix) { + this.meterIdPrefix = requireNonNull(meterIdPrefix, "meterIdPrefix"); + return self(); + } + + /** + * Returns the {@link MeterIdPrefix} for TLS metrics. + */ + @Nullable + protected final MeterIdPrefix meterIdPrefix() { + return meterIdPrefix; + } + + @SuppressWarnings("unchecked") + private SELF self() { + return (SELF) this; + } +} diff --git a/core/src/main/java/com/linecorp/armeria/common/Flags.java b/core/src/main/java/com/linecorp/armeria/common/Flags.java index b3327c7d06d..3fc9c0acbcd 100644 --- a/core/src/main/java/com/linecorp/armeria/common/Flags.java +++ b/core/src/main/java/com/linecorp/armeria/common/Flags.java @@ -641,7 +641,7 @@ private static void detectTlsEngineAndDumpOpenSslInfo() { /* forceHttp1 */ false, tlsEngineType, /* tlsAllowUnsafeCiphers */ false, - ImmutableList.of(), null).newEngine(ByteBufAllocator.DEFAULT); + null, null).newEngine(ByteBufAllocator.DEFAULT); logger.info("All available SSL protocols: {}", ImmutableList.copyOf(engine.getSupportedProtocols())); logger.info("Default enabled SSL protocols: {}", SslContextUtil.DEFAULT_PROTOCOLS); diff --git a/core/src/main/java/com/linecorp/armeria/common/MappedTlsProvider.java b/core/src/main/java/com/linecorp/armeria/common/MappedTlsProvider.java new file mode 100644 index 00000000000..c350d228f19 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/common/MappedTlsProvider.java @@ -0,0 +1,106 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common; + +import static com.google.common.base.MoreObjects.firstNonNull; +import static com.linecorp.armeria.internal.common.TlsProviderUtil.normalizeHostname; +import static java.util.Objects.requireNonNull; + +import java.security.cert.X509Certificate; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; + +import com.linecorp.armeria.common.annotation.Nullable; + +final class MappedTlsProvider implements TlsProvider { + + private final Map tlsKeyPairs; + private final Map> trustedCertificates; + + MappedTlsProvider(Map tlsKeyPairs, + Map> trustedCertificates) { + this.tlsKeyPairs = tlsKeyPairs; + this.trustedCertificates = trustedCertificates; + } + + @Nullable + @Override + public TlsKeyPair keyPair(String hostname) { + requireNonNull(hostname, "hostname"); + return find(hostname, tlsKeyPairs); + } + + @Override + public List trustedCertificates(String hostname) { + final List certs = find(hostname, trustedCertificates); + return firstNonNull(certs, ImmutableList.of()); + } + + @Nullable + private static T find(String hostname, Map map) { + if ("*".equals(hostname)) { + return map.get("*"); + } + hostname = normalizeHostname(hostname); + + T value = map.get(hostname); + if (value != null) { + return value; + } + + // No exact match, let's try a wildcard match. + final int idx = hostname.indexOf('.'); + if (idx != -1) { + value = map.get(hostname.substring(idx)); + if (value != null) { + return value; + } + } + // Try to find the default one. + return map.get("*"); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof MappedTlsProvider)) { + return false; + } + final MappedTlsProvider that = (MappedTlsProvider) o; + return tlsKeyPairs.equals(that.tlsKeyPairs) && + trustedCertificates.equals(that.trustedCertificates); + } + + @Override + public int hashCode() { + return Objects.hash(tlsKeyPairs, trustedCertificates); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("tlsKeyPairs", tlsKeyPairs) + .add("trustedCertificates", trustedCertificates) + .toString(); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/common/StaticTlsProvider.java b/core/src/main/java/com/linecorp/armeria/common/StaticTlsProvider.java new file mode 100644 index 00000000000..b254a85e1f9 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/common/StaticTlsProvider.java @@ -0,0 +1,61 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common; + +import static java.util.Objects.requireNonNull; + +import com.google.common.base.MoreObjects; + +final class StaticTlsProvider implements TlsProvider { + + private final TlsKeyPair tlsKeyPair; + + StaticTlsProvider(TlsKeyPair tlsKeyPair) { + requireNonNull(tlsKeyPair, "tlsKeyPair"); + this.tlsKeyPair = tlsKeyPair; + } + + @Override + public TlsKeyPair keyPair(String hostname) { + return tlsKeyPair; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof StaticTlsProvider)) { + return false; + } + final StaticTlsProvider that = (StaticTlsProvider) o; + return tlsKeyPair.equals(that.tlsKeyPair); + } + + @Override + public int hashCode() { + return tlsKeyPair.hashCode(); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .omitNullValues() + .add("tlsKeyPair", tlsKeyPair) + .toString(); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/common/TlsKeyPair.java b/core/src/main/java/com/linecorp/armeria/common/TlsKeyPair.java new file mode 100644 index 00000000000..51f5f2a8092 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/common/TlsKeyPair.java @@ -0,0 +1,178 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common; + +import static com.linecorp.armeria.internal.common.util.CertificateUtil.toPrivateKey; +import static com.linecorp.armeria.internal.common.util.CertificateUtil.toX509Certificates; +import static java.util.Objects.requireNonNull; + +import java.io.File; +import java.io.InputStream; +import java.security.KeyException; +import java.security.PrivateKey; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.List; + +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; + +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.common.util.SystemInfo; +import com.linecorp.armeria.internal.common.util.SelfSignedCertificate; + +/** + * A pair of a {@link PrivateKey} and a {@link X509Certificate} chain. + */ +@UnstableApi +public final class TlsKeyPair { + + /** + * Creates a new {@link TlsKeyPair} from the specified key {@link InputStream}, and certificate chain + * {@link InputStream}. + */ + public static TlsKeyPair of(InputStream keyInputStream, InputStream certificateChainInputStream) { + return of(keyInputStream, null, certificateChainInputStream); + } + + /** + * Creates a new {@link TlsKeyPair} from the specified key {@link InputStream}, key password + * {@link InputStream} and certificate chain {@link InputStream}. + */ + public static TlsKeyPair of(InputStream keyInputStream, @Nullable String keyPassword, + InputStream certificateChainInputStream) { + requireNonNull(keyInputStream, "keyInputStream"); + requireNonNull(certificateChainInputStream, "certificateChainInputStream"); + try { + final List certs = toX509Certificates(certificateChainInputStream); + final PrivateKey key = toPrivateKey(keyInputStream, keyPassword); + return of(key, certs); + } catch (CertificateException | KeyException e) { + throw new IllegalArgumentException(e); + } + } + + /** + * Creates a new {@link TlsKeyPair} from the specified key file and certificate chain file. + */ + public static TlsKeyPair of(File keyFile, File certificateChainFile) { + return of(keyFile, null, certificateChainFile); + } + + /** + * Creates a new {@link TlsKeyPair} from the specified key file, key password and certificate chain + * file. + */ + public static TlsKeyPair of(File keyFile, @Nullable String keyPassword, File certificateChainFile) { + requireNonNull(keyFile, "keyFile"); + requireNonNull(certificateChainFile, "certificateChainFile"); + try { + final List certs = toX509Certificates(certificateChainFile); + final PrivateKey key = toPrivateKey(keyFile, keyPassword); + return of(key, certs); + } catch (CertificateException | KeyException e) { + throw new IllegalArgumentException(e); + } + } + + /** + * Creates a new {@link TlsKeyPair} from the specified {@link PrivateKey} and {@link X509Certificate}s. + */ + public static TlsKeyPair of(PrivateKey key, X509Certificate... certificateChain) { + requireNonNull(certificateChain, "certificateChain"); + return of(key, ImmutableList.copyOf(certificateChain)); + } + + /** + * Creates a new {@link TlsKeyPair} from the specified {@link PrivateKey} and {@link X509Certificate}s. + */ + public static TlsKeyPair of(PrivateKey key, Iterable certificateChain) { + requireNonNull(key, "key"); + requireNonNull(certificateChain, "certificateChain"); + return new TlsKeyPair(key, ImmutableList.copyOf(certificateChain)); + } + + /** + * Generates a self-signed certificate for the specified {@code hostname}. + */ + public static TlsKeyPair ofSelfSigned(String hostname) { + requireNonNull(hostname, "hostname"); + try { + final SelfSignedCertificate ssc = new SelfSignedCertificate(hostname); + return of(ssc.key(), ssc.cert()); + } catch (CertificateException e) { + throw new IllegalStateException("Failed to create a self-signed certificate for " + hostname, e); + } + } + + /** + * Generates a self-signed certificate for the local hostname. + */ + public static TlsKeyPair ofSelfSigned() { + return ofSelfSigned(SystemInfo.hostname()); + } + + private final PrivateKey privateKey; + private final List certificateChain; + + private TlsKeyPair(PrivateKey privateKey, List certificateChain) { + this.privateKey = privateKey; + this.certificateChain = certificateChain; + } + + /** + * Returns the private key. + */ + public PrivateKey privateKey() { + return privateKey; + } + + /** + * Returns the certificate chain. + */ + public List certificateChain() { + return certificateChain; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof TlsKeyPair)) { + return false; + } + + final TlsKeyPair that = (TlsKeyPair) o; + return privateKey.equals(that.privateKey) && certificateChain.equals(that.certificateChain); + } + + @Override + public int hashCode() { + return privateKey.hashCode() * 31 + certificateChain.hashCode(); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("privateKey", "****") + .add("certificateChain", certificateChain) + .toString(); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/common/TlsProvider.java b/core/src/main/java/com/linecorp/armeria/common/TlsProvider.java new file mode 100644 index 00000000000..d7256fcd217 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/common/TlsProvider.java @@ -0,0 +1,87 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common; + +import static java.util.Objects.requireNonNull; + +import java.security.cert.X509Certificate; +import java.util.List; + +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; + +/** + * Provides {@link TlsKeyPair}s for TLS handshakes. + */ +@UnstableApi +@FunctionalInterface +public interface TlsProvider { + + /** + * Returns a {@link TlsProvider} which always returns the specified {@link TlsKeyPair}. + */ + static TlsProvider of(TlsKeyPair tlsKeyPair) { + requireNonNull(tlsKeyPair, "tlsKeyPair"); + return builder().keyPair(tlsKeyPair).build(); + } + + /** + * Returns a newly created {@link TlsProviderBuilder}. + * + *

Example usage: + *

{@code
+     * TlsProvider
+     *   .builder()
+     *   // Set the default key pair.
+     *   .keyPair(TlsKeyPair.of(...))
+     *   // Set the key pair for "api.example.com".
+     *   .keyPair("api.example.com", TlsKeyPair.of(...))
+     *   // Set the key pair for "web.example.com".
+     *   .keyPair("web.example.com", TlsKeyPair.of(...))
+     *   .build();
+     * }
+ */ + static TlsProviderBuilder builder() { + return new TlsProviderBuilder(); + } + + /** + * Finds a {@link TlsKeyPair} for the specified {@code hostname}. + * + *

If no matching {@link TlsKeyPair} is found for a hostname, {@code "*"} will be specified to get the + * default {@link TlsKeyPair}. + * If no default {@link TlsKeyPair} is found, {@code null} will be returned. + * + *

Note that this operation is executed in an event loop thread, so it should not be blocked. + */ + @Nullable + TlsKeyPair keyPair(String hostname); + + /** + * Returns trusted certificates for verifying the remote endpoint's certificate. + * + *

If no matching {@link X509Certificate}s are found for a hostname, {@code "*"} will be specified to get + * the default {@link X509Certificate}s. + * The system default will be used if this method returns null. + * + *

Note that this operation is executed in an event loop thread, so it should not be blocked. + */ + @Nullable + default List trustedCertificates(String hostname) { + return null; + } +} diff --git a/core/src/main/java/com/linecorp/armeria/common/TlsProviderBuilder.java b/core/src/main/java/com/linecorp/armeria/common/TlsProviderBuilder.java new file mode 100644 index 00000000000..39af95aa721 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/common/TlsProviderBuilder.java @@ -0,0 +1,155 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common; + +import static java.util.Objects.requireNonNull; + +import java.security.cert.X509Certificate; +import java.util.List; +import java.util.Map; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import com.linecorp.armeria.client.ClientFactoryBuilder; +import com.linecorp.armeria.internal.common.TlsProviderUtil; +import com.linecorp.armeria.server.ServerBuilder; + +/** + * A builder for {@link TlsProvider}. + * + * @see ClientFactoryBuilder#tlsProvider(TlsProvider) + * @see ServerBuilder#tlsProvider(TlsProvider) + */ +public final class TlsProviderBuilder { + + private final ImmutableMap.Builder tlsKeyPairsBuilder = ImmutableMap.builder(); + private final ImmutableMap.Builder> x509CertificateBuilder = + ImmutableMap.builder(); + + /** + * Creates a new instance. + */ + TlsProviderBuilder() {} + + /** + * Sets the {@link TlsKeyPair} for the specified (optionally wildcard) {@code hostname}. + * + *

DNS wildcard is supported as hostname. + * The wildcard will only match one sub-domain deep and only when wildcard is used as the most-left label. + * For example, *.armeria.dev will match foo.armeria.dev but NOT bar.foo.armeria.dev + * + *

Note that {@code "*"} is a special hostname which matches any hostname which may be used to find the + * {@link TlsKeyPair} for the {@linkplain ServerBuilder#defaultVirtualHost() default virtual host}. + * + *

The {@link TlsKeyPair} will be used for + * client certificate authentication + * when it is used for a client. + */ + public TlsProviderBuilder keyPair(String hostname, TlsKeyPair tlsKeyPair) { + requireNonNull(hostname, "hostname"); + requireNonNull(tlsKeyPair, "tlsKeyPair"); + tlsKeyPairsBuilder.put(normalize(hostname), tlsKeyPair); + return this; + } + + /** + * Sets the default {@link TlsKeyPair} which is used when no {@link TlsKeyPair} is specified for a hostname. + * + *

The {@link TlsKeyPair} will be used for + * client certificate authentication + * when it is used for a client. + */ + public TlsProviderBuilder keyPair(TlsKeyPair tlsKeyPair) { + return keyPair("*", tlsKeyPair); + } + + /** + * Sets the specified {@link X509Certificate}s to the trusted certificates that will be used for verifying + * the remote endpoint's certificate. + * + *

The system default will be used if no specific trusted certificates are set for a hostname and no + * default trusted certificates are set. + */ + public TlsProviderBuilder trustedCertificates(String hostname, X509Certificate... trustedCertificates) { + requireNonNull(trustedCertificates, "trustedCertificates"); + return trustedCertificates(hostname, ImmutableList.copyOf(trustedCertificates)); + } + + /** + * Sets the specified {@link X509Certificate}s to the trusted certificates that will be used for verifying + * the specified {@code hostname}'s certificate. + * + *

The system default will be used if no specific trusted certificates are set for a hostname and no + * default trusted certificates are set. + */ + public TlsProviderBuilder trustedCertificates(String hostname, + Iterable trustedCertificates) { + requireNonNull(hostname, "hostname"); + requireNonNull(trustedCertificates, "trustedCertificates"); + x509CertificateBuilder.put(normalize(hostname), ImmutableList.copyOf(trustedCertificates)); + return this; + } + + /** + * Sets the default {@link X509Certificate}s to the trusted certificates that is used for verifying + * the remote endpoint's certificate if no specific trusted certificates are set for a hostname. + * + *

The system default will be used if no specific trusted certificates are set for a hostname and no + * default trusted certificates are set. + */ + public TlsProviderBuilder trustedCertificates(X509Certificate... trustedCertificates) { + requireNonNull(trustedCertificates, "trustedCertificates"); + return trustedCertificates(ImmutableList.copyOf(trustedCertificates)); + } + + /** + * Sets the default {@link X509Certificate}s to the trusted certificates that is used for verifying + * the remote endpoint's certificate if no specific trusted certificates are set for a hostname. + * + *

The system default will be used if no specific trusted certificates are set for a hostname and no + * default trusted certificates are set. + */ + public TlsProviderBuilder trustedCertificates(Iterable trustedCertificates) { + return trustedCertificates("*", trustedCertificates); + } + + private static String normalize(String hostname) { + if ("*".equals(hostname)) { + return "*"; + } else { + return TlsProviderUtil.normalizeHostname(hostname); + } + } + + /** + * Returns a newly-created {@link TlsProvider} instance. + */ + public TlsProvider build() { + final Map keyPairMappings = tlsKeyPairsBuilder.build(); + if (keyPairMappings.isEmpty()) { + throw new IllegalStateException("No TLS key pair is set."); + } + + final Map> trustedCerts = x509CertificateBuilder.build(); + if (keyPairMappings.size() == 1 && keyPairMappings.containsKey("*") && trustedCerts.isEmpty()) { + return new StaticTlsProvider(keyPairMappings.get("*")); + } + + return new MappedTlsProvider(keyPairMappings, trustedCerts); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/common/TlsSetters.java b/core/src/main/java/com/linecorp/armeria/common/TlsSetters.java index f8691b12e0e..cf42957a4f6 100644 --- a/core/src/main/java/com/linecorp/armeria/common/TlsSetters.java +++ b/core/src/main/java/com/linecorp/armeria/common/TlsSetters.java @@ -16,8 +16,6 @@ package com.linecorp.armeria.common; -import static java.util.Objects.requireNonNull; - import java.io.File; import java.io.InputStream; import java.security.PrivateKey; @@ -26,8 +24,6 @@ import javax.net.ssl.KeyManagerFactory; -import com.google.common.collect.ImmutableList; - import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; @@ -42,62 +38,100 @@ public interface TlsSetters { /** * Configures SSL or TLS with the specified {@code keyCertChainFile} * and cleartext {@code keyFile}. + * + * @deprecated Use {@link #tls(TlsKeyPair)} instead. */ + @Deprecated default TlsSetters tls(File keyCertChainFile, File keyFile) { - return tls(keyCertChainFile, keyFile, null); + return tls(TlsKeyPair.of(keyFile, keyCertChainFile)); } /** * Configures SSL or TLS with the specified {@code keyCertChainFile}, * {@code keyFile} and {@code keyPassword}. + * + * @deprecated Use {@link #tls(TlsKeyPair)} instead. */ - TlsSetters tls(File keyCertChainFile, File keyFile, @Nullable String keyPassword); + @Deprecated + default TlsSetters tls(File keyCertChainFile, File keyFile, @Nullable String keyPassword) { + return tls(TlsKeyPair.of(keyFile, keyPassword, keyCertChainFile)); + } /** * Configures SSL or TLS with the specified {@code keyCertChainInputStream} and * cleartext {@code keyInputStream}. + * + * @deprecated Use {@link #tls(TlsKeyPair)} instead. */ + @Deprecated default TlsSetters tls(InputStream keyCertChainInputStream, InputStream keyInputStream) { - return tls(keyCertChainInputStream, keyInputStream, null); + return tls(TlsKeyPair.of(keyInputStream, null, keyCertChainInputStream)); } /** * Configures SSL or TLS of this with the specified {@code keyCertChainInputStream}, * {@code keyInputStream} and {@code keyPassword}. + * + * @deprecated Use {@link #tls(TlsKeyPair)} instead. */ - TlsSetters tls(InputStream keyCertChainInputStream, InputStream keyInputStream, - @Nullable String keyPassword); + @Deprecated + default TlsSetters tls(InputStream keyCertChainInputStream, InputStream keyInputStream, + @Nullable String keyPassword) { + return tls(TlsKeyPair.of(keyInputStream, keyPassword, keyCertChainInputStream)); + } /** * Configures SSL or TLS with the specified cleartext {@link PrivateKey} and * {@link X509Certificate} chain. + * + * @deprecated Use {@link #tls(TlsKeyPair)} instead. */ + @Deprecated default TlsSetters tls(PrivateKey key, X509Certificate... keyCertChain) { - return tls(key, null, keyCertChain); + return tls(TlsKeyPair.of(key, keyCertChain)); } /** * Configures SSL or TLS with the specified cleartext {@link PrivateKey} and * {@link X509Certificate} chain. + * + * @deprecated Use {@link #tls(TlsKeyPair)} instead. */ + @Deprecated default TlsSetters tls(PrivateKey key, Iterable keyCertChain) { - return tls(key, null, keyCertChain); + return tls(TlsKeyPair.of(key, keyCertChain)); } /** * Configures SSL or TLS with the specified {@link PrivateKey}, {@code keyPassword} and * {@link X509Certificate} chain. + * + * @deprecated Use {@link #tls(TlsKeyPair)} instead. */ + @Deprecated default TlsSetters tls(PrivateKey key, @Nullable String keyPassword, X509Certificate... keyCertChain) { - return tls(key, keyPassword, ImmutableList.copyOf(requireNonNull(keyCertChain, "keyCertChain"))); + // keyPassword is not required for PrivateKey since it is not encrypted. + return tls(TlsKeyPair.of(key, keyCertChain)); } /** * Configures SSL or TLS with the specified {@link PrivateKey}, {@code keyPassword} and * {@link X509Certificate} chain. + * + * @deprecated Use {@link #tls(TlsKeyPair)} instead. + */ + @Deprecated + default TlsSetters tls(PrivateKey key, @Nullable String keyPassword, + Iterable keyCertChain) { + // keyPassword is not required for PrivateKey since it is not encrypted. + return tls(TlsKeyPair.of(key, keyCertChain)); + } + + /** + * Configures SSL or TLS with the specified {@link TlsKeyPair}. */ - TlsSetters tls(PrivateKey key, @Nullable String keyPassword, - Iterable keyCertChain); + @UnstableApi + TlsSetters tls(TlsKeyPair tlsKeyPair); /** * Configures SSL or TLS with the specified {@link KeyManagerFactory}. diff --git a/core/src/main/java/com/linecorp/armeria/common/metric/AbstractCloseableMeterBinder.java b/core/src/main/java/com/linecorp/armeria/common/metric/AbstractCloseableMeterBinder.java new file mode 100644 index 00000000000..5a7659ec70e --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/common/metric/AbstractCloseableMeterBinder.java @@ -0,0 +1,50 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common.metric; + +import java.util.ArrayList; +import java.util.List; + +import com.linecorp.armeria.internal.common.util.ReentrantShortLock; + +abstract class AbstractCloseableMeterBinder implements CloseableMeterBinder { + + private final List closingTasks = new ArrayList<>(); + private final ReentrantShortLock lock = new ReentrantShortLock(); + + protected final void addClosingTask(Runnable closingTask) { + lock.lock(); + try { + closingTasks.add(closingTask); + } finally { + lock.unlock(); + } + } + + @Override + public void close() { + lock.lock(); + try { + for (Runnable task : closingTasks) { + task.run(); + } + closingTasks.clear(); + } finally { + lock.unlock(); + } + } +} diff --git a/core/src/main/java/com/linecorp/armeria/common/metric/CertificateMetrics.java b/core/src/main/java/com/linecorp/armeria/common/metric/CertificateMetrics.java index ed0bc1a6e1c..212526e9eff 100644 --- a/core/src/main/java/com/linecorp/armeria/common/metric/CertificateMetrics.java +++ b/core/src/main/java/com/linecorp/armeria/common/metric/CertificateMetrics.java @@ -23,15 +23,29 @@ import java.security.cert.X509Certificate; import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; import java.util.List; +import com.google.common.base.MoreObjects; + import com.linecorp.armeria.internal.common.util.CertificateUtil; import io.micrometer.core.instrument.Gauge; import io.micrometer.core.instrument.MeterRegistry; import io.micrometer.core.instrument.binder.MeterBinder; -final class CertificateMetrics implements MeterBinder { +/** + * A {@link MeterBinder} that provides metrics for TLS certificates. + * The following stats are currently exported per registered {@link MeterIdPrefix}. + * + *

    + *
  • "tls.certificate.validity" (gauge) - 1 if TLS certificate is in validity period, 0 if certificate + * is not in validity period
  • + *
  • "tls.certificate.validity.days" (gauge) - Duration in days before TLS certificate expires, which + * becomes -1 if certificate is expired
  • + *
+ */ +public final class CertificateMetrics extends AbstractCloseableMeterBinder { private final List certificates; private final MeterIdPrefix meterIdPrefix; @@ -43,33 +57,55 @@ final class CertificateMetrics implements MeterBinder { @Override public void bindTo(MeterRegistry registry) { + final List meters = new ArrayList<>(certificates.size() * 2); for (X509Certificate certificate : certificates) { final String commonName = firstNonNull(CertificateUtil.getCommonName(certificate), ""); - Gauge.builder(meterIdPrefix.name("tls.certificate.validity"), certificate, x509Cert -> { - try { - x509Cert.checkValidity(); - } catch (CertificateExpiredException | CertificateNotYetValidException e) { - return 0; - } - return 1; - }) - .description("1 if TLS certificate is in validity period, 0 if certificate is not in " + - "validity period") - .tags("common.name", commonName) - .tags(meterIdPrefix.tags()) - .register(registry); + final Gauge validityMeter = + Gauge.builder(meterIdPrefix.name("tls.certificate.validity"), certificate, x509Cert -> { + try { + x509Cert.checkValidity(); + } catch (CertificateExpiredException | CertificateNotYetValidException e) { + return 0; + } + return 1; + }) + .description( + "1 if TLS certificate is in validity period, 0 if certificate is not in " + + "validity period") + .tags("common.name", commonName) + .tags(meterIdPrefix.tags()) + .register(registry); + meters.add(validityMeter); - Gauge.builder(meterIdPrefix.name("tls.certificate.validity.days"), certificate, x509Cert -> { - final Duration diff = Duration.between(Instant.now(), - x509Cert.getNotAfter().toInstant()); - return diff.isNegative() ? -1 : diff.toDays(); - }) - .description("Duration in days before TLS certificate expires, which becomes -1 " + - "if certificate is expired") - .tags("common.name", commonName) - .tags(meterIdPrefix.tags()) - .register(registry); + final Gauge validityDaysMeter = + Gauge.builder(meterIdPrefix.name("tls.certificate.validity.days"), certificate, + x509Cert -> { + final Instant notAfter = x509Cert.getNotAfter().toInstant(); + final Duration diff = + Duration.between(Instant.now(), notAfter); + return diff.toDays(); + }) + .description("Duration in days before TLS certificate expires, which becomes -1 " + + "if certificate is expired") + .tags("common.name", commonName) + .tags(meterIdPrefix.tags()) + .register(registry); + meters.add(validityDaysMeter); } + + addClosingTask(() -> { + for (Gauge meter : meters) { + registry.remove(meter); + } + }); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("certificates", certificates) + .add("meterIdPrefix", meterIdPrefix) + .toString(); } } diff --git a/core/src/main/java/com/linecorp/armeria/common/metric/CloseableMeterBinder.java b/core/src/main/java/com/linecorp/armeria/common/metric/CloseableMeterBinder.java new file mode 100644 index 00000000000..ac0f7fb5943 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/common/metric/CloseableMeterBinder.java @@ -0,0 +1,31 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common.metric; + +import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.common.util.SafeCloseable; + +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.binder.MeterBinder; + +/** + * A {@link MeterBinder} that cleans up the registered metrics by + * {@link MeterBinder#bindTo(MeterRegistry)} via {@link SafeCloseable#close()}. + */ +@UnstableApi +public interface CloseableMeterBinder extends MeterBinder, SafeCloseable { +} diff --git a/core/src/main/java/com/linecorp/armeria/common/metric/EventLoopMetrics.java b/core/src/main/java/com/linecorp/armeria/common/metric/EventLoopMetrics.java index 35eac7a8b64..760486b7b81 100644 --- a/core/src/main/java/com/linecorp/armeria/common/metric/EventLoopMetrics.java +++ b/core/src/main/java/com/linecorp/armeria/common/metric/EventLoopMetrics.java @@ -41,7 +41,7 @@ * - the total number of IO tasks waiting to be run on event loops * **/ -final class EventLoopMetrics implements MeterBinder { +public final class EventLoopMetrics extends AbstractCloseableMeterBinder { private final EventLoopGroup eventLoopGroup; private final MeterIdPrefix idPrefix; @@ -58,6 +58,8 @@ final class EventLoopMetrics implements MeterBinder { public void bindTo(MeterRegistry registry) { final Self metrics = MicrometerUtil.register(registry, idPrefix, Self.class, Self::new); metrics.add(eventLoopGroup); + + addClosingTask(() -> metrics.remove(eventLoopGroup)); } /** @@ -79,6 +81,10 @@ void add(EventLoopGroup eventLoopGroup) { registry.add(eventLoopGroup); } + void remove(EventLoopGroup eventLoopGroup) { + registry.remove(eventLoopGroup); + } + double numWorkers() { int result = 0; for (EventLoopGroup group : registry) { @@ -97,7 +103,7 @@ void add(EventLoopGroup eventLoopGroup) { for (EventLoopGroup group : registry) { // Purge event loop groups that were shutdown. if (group.isShutdown()) { - registry.remove(group); + remove(group); continue; } for (EventExecutor eventLoop : group) { diff --git a/core/src/main/java/com/linecorp/armeria/common/metric/MoreMeterBinders.java b/core/src/main/java/com/linecorp/armeria/common/metric/MoreMeterBinders.java index f687b914efc..1e569e408cc 100644 --- a/core/src/main/java/com/linecorp/armeria/common/metric/MoreMeterBinders.java +++ b/core/src/main/java/com/linecorp/armeria/common/metric/MoreMeterBinders.java @@ -32,7 +32,7 @@ import io.netty.channel.EventLoopGroup; /** - * Provides useful {@link MeterBinder}s to monitor various Armeria components. + * Provides useful {@link MeterBinder}s to monitor various Armeria components. */ public final class MoreMeterBinders { @@ -47,7 +47,7 @@ public final class MoreMeterBinders { * */ @UnstableApi - public static MeterBinder eventLoopMetrics(EventLoopGroup eventLoopGroup, String name) { + public static CloseableMeterBinder eventLoopMetrics(EventLoopGroup eventLoopGroup, String name) { requireNonNull(name, "name"); return eventLoopMetrics(eventLoopGroup, new MeterIdPrefix("armeria.netty." + name)); } @@ -63,7 +63,8 @@ public static MeterBinder eventLoopMetrics(EventLoopGroup eventLoopGroup, String * */ @UnstableApi - public static MeterBinder eventLoopMetrics(EventLoopGroup eventLoopGroup, MeterIdPrefix meterIdPrefix) { + public static CloseableMeterBinder eventLoopMetrics(EventLoopGroup eventLoopGroup, + MeterIdPrefix meterIdPrefix) { return new EventLoopMetrics(eventLoopGroup, meterIdPrefix); } @@ -82,7 +83,8 @@ public static MeterBinder eventLoopMetrics(EventLoopGroup eventLoopGroup, MeterI * @param meterIdPrefix the prefix to use for all metrics */ @UnstableApi - public static MeterBinder certificateMetrics(X509Certificate certificate, MeterIdPrefix meterIdPrefix) { + public static CloseableMeterBinder certificateMetrics(X509Certificate certificate, + MeterIdPrefix meterIdPrefix) { requireNonNull(certificate, "certificate"); return certificateMetrics(ImmutableList.of(certificate), meterIdPrefix); } @@ -102,8 +104,8 @@ public static MeterBinder certificateMetrics(X509Certificate certificate, MeterI * @param meterIdPrefix the prefix to use for all metrics */ @UnstableApi - public static MeterBinder certificateMetrics(Iterable certificates, - MeterIdPrefix meterIdPrefix) { + public static CloseableMeterBinder certificateMetrics(Iterable certificates, + MeterIdPrefix meterIdPrefix) { requireNonNull(certificates, "certificates"); requireNonNull(meterIdPrefix, "meterIdPrefix"); return new CertificateMetrics(ImmutableList.copyOf(certificates), meterIdPrefix); @@ -124,7 +126,7 @@ public static MeterBinder certificateMetrics(Iterable * @param meterIdPrefix the prefix to use for all metrics */ @UnstableApi - public static MeterBinder certificateMetrics(File keyCertChainFile, MeterIdPrefix meterIdPrefix) + public static CloseableMeterBinder certificateMetrics(File keyCertChainFile, MeterIdPrefix meterIdPrefix) throws CertificateException { requireNonNull(keyCertChainFile, "keyCertChainFile"); return certificateMetrics(CertificateUtil.toX509Certificates(keyCertChainFile), meterIdPrefix); @@ -145,7 +147,8 @@ public static MeterBinder certificateMetrics(File keyCertChainFile, MeterIdPrefi * @param meterIdPrefix the prefix to use for all metrics */ @UnstableApi - public static MeterBinder certificateMetrics(InputStream keyCertChainFile, MeterIdPrefix meterIdPrefix) + public static CloseableMeterBinder certificateMetrics(InputStream keyCertChainFile, + MeterIdPrefix meterIdPrefix) throws CertificateException { requireNonNull(keyCertChainFile, "keyCertChainFile"); return certificateMetrics(CertificateUtil.toX509Certificates(keyCertChainFile), meterIdPrefix); diff --git a/core/src/main/java/com/linecorp/armeria/client/IgnoreHostsTrustManager.java b/core/src/main/java/com/linecorp/armeria/internal/common/IgnoreHostsTrustManager.java similarity index 70% rename from core/src/main/java/com/linecorp/armeria/client/IgnoreHostsTrustManager.java rename to core/src/main/java/com/linecorp/armeria/internal/common/IgnoreHostsTrustManager.java index 276ac650db7..bb3943a409d 100644 --- a/core/src/main/java/com/linecorp/armeria/client/IgnoreHostsTrustManager.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/IgnoreHostsTrustManager.java @@ -1,35 +1,20 @@ /* - * Copyright 2020 LINE Corporation + * Copyright 2024 LINE Corporation * - * LINE Corporation licenses this file to you under the Apache License, - * version 2.0 (the "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at: + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - */ -/* - * Copyright (C) 2020 Square, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. */ -package com.linecorp.armeria.client; +package com.linecorp.armeria.internal.common; import static java.util.Objects.requireNonNull; @@ -51,7 +36,7 @@ /** * An implementation of {@link X509ExtendedTrustManager} that skips verification on the list of allowed hosts. */ -final class IgnoreHostsTrustManager extends X509ExtendedTrustManager { +public final class IgnoreHostsTrustManager extends X509ExtendedTrustManager { // Forked from okhttp-4.9.0 // https://github.com/square/okhttp/blob/1364ea44ae1f1c4b5a1cc32e4e7b51d23cb78517/okhttp-tls/src/main/kotlin/okhttp3/tls/internal/InsecureExtendedTrustManager.kt @@ -59,7 +44,7 @@ final class IgnoreHostsTrustManager extends X509ExtendedTrustManager { /** * Returns new {@link IgnoreHostsTrustManager} instance. */ - static IgnoreHostsTrustManager of(Set insecureHosts) { + public static IgnoreHostsTrustManager of(Set insecureHosts) { X509ExtendedTrustManager delegate = null; try { final TrustManagerFactory trustManagerFactory = TrustManagerFactory @@ -82,7 +67,7 @@ static IgnoreHostsTrustManager of(Set insecureHosts) { private final X509ExtendedTrustManager delegate; private final Set insecureHosts; - IgnoreHostsTrustManager(X509ExtendedTrustManager delegate, Set insecureHosts) { + public IgnoreHostsTrustManager(X509ExtendedTrustManager delegate, Set insecureHosts) { this.delegate = delegate; this.insecureHosts = insecureHosts; } diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/SslContextFactory.java b/core/src/main/java/com/linecorp/armeria/internal/common/SslContextFactory.java new file mode 100644 index 00000000000..e68ddbfaf07 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/internal/common/SslContextFactory.java @@ -0,0 +1,337 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.common; + +import static com.google.common.base.MoreObjects.firstNonNull; +import static com.linecorp.armeria.internal.common.util.SslContextUtil.createSslContext; + +import java.security.cert.X509Certificate; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; +import com.google.common.collect.ImmutableList; + +import com.linecorp.armeria.client.ClientTlsConfig; +import com.linecorp.armeria.common.AbstractTlsConfig; +import com.linecorp.armeria.common.TlsKeyPair; +import com.linecorp.armeria.common.TlsProvider; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.metric.CloseableMeterBinder; +import com.linecorp.armeria.common.metric.MeterIdPrefix; +import com.linecorp.armeria.common.metric.MoreMeterBinders; +import com.linecorp.armeria.common.util.TlsEngineType; +import com.linecorp.armeria.internal.common.util.ReentrantShortLock; +import com.linecorp.armeria.server.ServerTlsConfig; + +import io.micrometer.core.instrument.MeterRegistry; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProvider; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.util.ReferenceCountUtil; + +public final class SslContextFactory { + + private static final MeterIdPrefix SERVER_METER_ID_PREFIX = + new MeterIdPrefix("armeria.server", "hostname.pattern", "UNKNOWN"); + private static final MeterIdPrefix CLIENT_METER_ID_PREFIX = + new MeterIdPrefix("armeria.client"); + + private final Map cache = new HashMap<>(); + private final Map reverseCache = new HashMap<>(); + + private final TlsProvider tlsProvider; + private final TlsEngineType engineType; + private final MeterRegistry meterRegistry; + @Nullable + private final AbstractTlsConfig tlsConfig; + @Nullable + private final MeterIdPrefix meterIdPrefix; + private final boolean allowsUnsafeCiphers; + + private final ReentrantShortLock lock = new ReentrantShortLock(); + + public SslContextFactory(TlsProvider tlsProvider, TlsEngineType engineType, + @Nullable AbstractTlsConfig tlsConfig, MeterRegistry meterRegistry) { + // TODO(ikhoon): Support OPENSSL_REFCNT engine type. + assert engineType.sslProvider() != SslProvider.OPENSSL_REFCNT; + + this.tlsProvider = tlsProvider; + this.engineType = engineType; + this.meterRegistry = meterRegistry; + if (tlsConfig != null) { + this.tlsConfig = tlsConfig; + meterIdPrefix = tlsConfig.meterIdPrefix(); + allowsUnsafeCiphers = tlsConfig.allowsUnsafeCiphers(); + } else { + this.tlsConfig = null; + meterIdPrefix = null; + allowsUnsafeCiphers = false; + } + } + + /** + * Returns an {@link SslContext} for the specified {@link SslContextMode} and {@link TlsKeyPair}. + * Note that the returned {@link SslContext} should be released via + * {@link ReferenceCountUtil#release(Object)} when it is no longer used. + */ + public SslContext getOrCreate(SslContextMode mode, String hostname) { + lock.lock(); + try { + final TlsKeyPair tlsKeyPair = findTlsKeyPair(mode, hostname); + final List trustedCertificates = findTrustedCertificates(hostname); + final CacheKey cacheKey = new CacheKey(mode, tlsKeyPair, trustedCertificates); + final SslContextHolder contextHolder = cache.computeIfAbsent(cacheKey, this::create); + contextHolder.retain(); + reverseCache.putIfAbsent(contextHolder.sslContext(), cacheKey); + return contextHolder.sslContext(); + } finally { + lock.unlock(); + } + } + + public void release(SslContext sslContext) { + lock.lock(); + try { + final CacheKey cacheKey = reverseCache.get(sslContext); + final SslContextHolder contextHolder = cache.get(cacheKey); + assert contextHolder != null : "sslContext not found in the cache: " + sslContext; + + if (contextHolder.release()) { + final SslContextHolder removed = cache.remove(cacheKey); + assert removed == contextHolder; + reverseCache.remove(sslContext); + contextHolder.destroy(); + } + } finally { + lock.unlock(); + } + } + + @Nullable + private TlsKeyPair findTlsKeyPair(SslContextMode mode, String hostname) { + TlsKeyPair tlsKeyPair = tlsProvider.keyPair(hostname); + if (tlsKeyPair == null) { + // Try to find the default TLS key pair. + tlsKeyPair = tlsProvider.keyPair("*"); + } + if (mode == SslContextMode.SERVER && tlsKeyPair == null) { + // A TlsKeyPair must exist for a server. + throw new IllegalStateException("No TLS key pair found for " + hostname); + } + return tlsKeyPair; + } + + private List findTrustedCertificates(String hostname) { + List certs = tlsProvider.trustedCertificates(hostname); + if (certs == null) { + certs = tlsProvider.trustedCertificates("*"); + } + return firstNonNull(certs, ImmutableList.of()); + } + + private SslContextHolder create(CacheKey key) { + final MeterIdPrefix meterIdPrefix = meterIdPrefix(key.mode); + final SslContext sslContext = newSslContext(key); + final ImmutableList.Builder builder = ImmutableList.builder(); + if (key.tlsKeyPair != null) { + builder.addAll(key.tlsKeyPair.certificateChain()); + } + if (!key.trustedCertificates.isEmpty()) { + builder.addAll(key.trustedCertificates); + } + final List certs = builder.build(); + CloseableMeterBinder meterBinder = null; + if (!certs.isEmpty()) { + meterBinder = MoreMeterBinders.certificateMetrics(certs, meterIdPrefix); + meterBinder.bindTo(meterRegistry); + } + return new SslContextHolder(sslContext, meterBinder); + } + + private SslContext newSslContext(CacheKey key) { + final SslContextMode mode = key.mode(); + final TlsKeyPair tlsKeyPair = key.tlsKeyPair(); + final List trustedCerts = key.trustedCertificates(); + if (mode == SslContextMode.SERVER) { + assert tlsKeyPair != null; + return createSslContext( + () -> { + final SslContextBuilder contextBuilder = SslContextBuilder.forServer( + tlsKeyPair.privateKey(), + tlsKeyPair.certificateChain()); + if (!trustedCerts.isEmpty()) { + contextBuilder.trustManager(trustedCerts); + } + applyTlsConfig(contextBuilder); + return contextBuilder; + }, + false, engineType, allowsUnsafeCiphers, + null, null); + } else { + final boolean forceHttp1 = mode == SslContextMode.CLIENT_HTTP1_ONLY; + return createSslContext( + () -> { + final SslContextBuilder contextBuilder = SslContextBuilder.forClient(); + if (tlsKeyPair != null) { + contextBuilder.keyManager(tlsKeyPair.privateKey(), tlsKeyPair.certificateChain()); + } + if (!trustedCerts.isEmpty()) { + contextBuilder.trustManager(trustedCerts); + } + applyTlsConfig(contextBuilder); + return contextBuilder; + }, + forceHttp1, engineType, allowsUnsafeCiphers, null, null); + } + } + + private void applyTlsConfig(SslContextBuilder contextBuilder) { + if (tlsConfig == null) { + return; + } + + if (tlsConfig instanceof ServerTlsConfig) { + final ServerTlsConfig serverTlsConfig = (ServerTlsConfig) tlsConfig; + contextBuilder.clientAuth(serverTlsConfig.clientAuth()); + } else { + final ClientTlsConfig clientTlsConfig = (ClientTlsConfig) tlsConfig; + if (clientTlsConfig.tlsNoVerifySet()) { + contextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE); + } else if (!clientTlsConfig.insecureHosts().isEmpty()) { + contextBuilder.trustManager(IgnoreHostsTrustManager.of(clientTlsConfig.insecureHosts())); + } + } + tlsConfig.tlsCustomizer().accept(contextBuilder); + } + + private MeterIdPrefix meterIdPrefix(SslContextMode mode) { + MeterIdPrefix meterIdPrefix = this.meterIdPrefix; + if (meterIdPrefix == null) { + if (mode == SslContextMode.SERVER) { + meterIdPrefix = SERVER_METER_ID_PREFIX; + } else { + meterIdPrefix = CLIENT_METER_ID_PREFIX; + } + } + return meterIdPrefix; + } + + @VisibleForTesting + public int numCachedContexts() { + return cache.size(); + } + + public enum SslContextMode { + SERVER, + CLIENT_HTTP1_ONLY, + CLIENT + } + + private static final class CacheKey { + private final SslContextMode mode; + @Nullable + private final TlsKeyPair tlsKeyPair; + + private final List trustedCertificates; + + private CacheKey(SslContextMode mode, @Nullable TlsKeyPair tlsKeyPair, + List trustedCertificates) { + this.mode = mode; + this.tlsKeyPair = tlsKeyPair; + this.trustedCertificates = trustedCertificates; + } + + SslContextMode mode() { + return mode; + } + + @Nullable + TlsKeyPair tlsKeyPair() { + return tlsKeyPair; + } + + public List trustedCertificates() { + return trustedCertificates; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof CacheKey)) { + return false; + } + final CacheKey that = (CacheKey) o; + return mode == that.mode && + Objects.equals(tlsKeyPair, that.tlsKeyPair) && + trustedCertificates.equals(that.trustedCertificates); + } + + @Override + public int hashCode() { + return Objects.hash(mode, tlsKeyPair, trustedCertificates); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .omitNullValues() + .add("mode", mode) + .add("tlsKeyPair", tlsKeyPair) + .add("trustedCertificates", trustedCertificates) + .toString(); + } + } + + private static final class SslContextHolder { + private final SslContext sslContext; + @Nullable + private final CloseableMeterBinder meterBinder; + private long refCnt; + + SslContextHolder(SslContext sslContext, @Nullable CloseableMeterBinder meterBinder) { + this.sslContext = sslContext; + this.meterBinder = meterBinder; + } + + SslContext sslContext() { + return sslContext; + } + + void retain() { + refCnt++; + } + + boolean release() { + refCnt--; + assert refCnt >= 0 : "refCount: " + refCnt; + return refCnt == 0; + } + + void destroy() { + if (meterBinder != null) { + meterBinder.close(); + } + } + } +} diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/TlsProviderUtil.java b/core/src/main/java/com/linecorp/armeria/internal/common/TlsProviderUtil.java new file mode 100644 index 00000000000..0940ce5711c --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/internal/common/TlsProviderUtil.java @@ -0,0 +1,59 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.internal.common; + +import java.net.IDN; +import java.util.Locale; + +public final class TlsProviderUtil { + + // Forked from https://github.com/netty/netty/blob/60430c80e7f8718ecd07ac31e01297b42a176b87/common/src/main/java/io/netty/util/DomainWildcardMappingBuilder.java#L78 + + /** + * IDNA ASCII conversion and case normalization. + */ + public static String normalizeHostname(String hostname) { + if (hostname.isEmpty() || hostname.charAt(0) == '.') { + throw new IllegalArgumentException("Hostname '" + hostname + "' not valid"); + } + if (needsNormalization(hostname)) { + hostname = IDN.toASCII(hostname, IDN.ALLOW_UNASSIGNED); + } + hostname = hostname.toLowerCase(Locale.US); + + if (hostname.charAt(0) == '*') { + if (hostname.length() < 3 || hostname.charAt(1) != '.') { + throw new IllegalArgumentException("Wildcard Hostname '" + hostname + "'not valid"); + } + return hostname.substring(1); + } + return hostname; + } + + private static boolean needsNormalization(String hostname) { + final int length = hostname.length(); + for (int i = 0; i < length; i++) { + final int c = hostname.charAt(i); + if (c > 0x7F) { + return true; + } + } + return false; + } + + private TlsProviderUtil() {} +} diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/util/CertificateUtil.java b/core/src/main/java/com/linecorp/armeria/internal/common/util/CertificateUtil.java index f66d8002835..b90dce5d441 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/util/CertificateUtil.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/util/CertificateUtil.java @@ -19,6 +19,8 @@ import java.io.File; import java.io.InputStream; +import java.security.KeyException; +import java.security.PrivateKey; import java.security.cert.Certificate; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; @@ -41,6 +43,7 @@ import com.google.common.collect.ImmutableList; import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.util.Exceptions; import io.netty.buffer.ByteBufAllocator; import io.netty.handler.ssl.ApplicationProtocolNegotiator; @@ -91,6 +94,29 @@ public static List toX509Certificates(InputStream in) throws Ce return ImmutableList.copyOf(SslContextProtectedAccessHack.toX509CertificateList(in)); } + public static PrivateKey toPrivateKey(File file, @Nullable String keyPassword) throws KeyException { + requireNonNull(file, "file"); + return MinifiedBouncyCastleProvider.call(() -> { + try { + return SslContextProtectedAccessHack.privateKey(file, keyPassword); + } catch (KeyException e) { + return Exceptions.throwUnsafely(e); + } + }); + } + + public static PrivateKey toPrivateKey(InputStream keyInputStream, @Nullable String keyPassword) + throws KeyException { + requireNonNull(keyInputStream, "keyInputStream"); + return MinifiedBouncyCastleProvider.call(() -> { + try { + return SslContextProtectedAccessHack.privateKey(keyInputStream, keyPassword); + } catch (KeyException e) { + return Exceptions.throwUnsafely(e); + } + }); + } + private static final class SslContextProtectedAccessHack extends SslContext { static X509Certificate[] toX509CertificateList(File file) throws CertificateException { @@ -101,6 +127,29 @@ static X509Certificate[] toX509CertificateList(InputStream in) throws Certificat return SslContext.toX509Certificates(in); } + static PrivateKey privateKey(File file, @Nullable String keyPassword) throws KeyException { + try { + return SslContext.toPrivateKey(file, keyPassword); + } catch (Exception e) { + if (e instanceof KeyException) { + throw (KeyException) e; + } + throw new KeyException("Fail to read a private key file: " + file.getName(), e); + } + } + + static PrivateKey privateKey(InputStream keyInputStream, @Nullable String keyPassword) + throws KeyException { + try { + return SslContext.toPrivateKey(keyInputStream, keyPassword); + } catch (Exception e) { + if (e instanceof KeyException) { + throw (KeyException) e; + } + throw new KeyException("Fail to parse a private key", e); + } + } + @Override public boolean isClient() { throw new UnsupportedOperationException(); diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/util/KeyStoreUtil.java b/core/src/main/java/com/linecorp/armeria/internal/common/util/KeyStoreUtil.java index 5826eb9bc25..9e50999dc78 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/util/KeyStoreUtil.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/util/KeyStoreUtil.java @@ -33,32 +33,31 @@ import com.google.common.collect.ImmutableList; import com.google.common.io.ByteStreams; +import com.linecorp.armeria.common.TlsKeyPair; import com.linecorp.armeria.common.annotation.Nullable; -import io.netty.util.internal.EmptyArrays; - public final class KeyStoreUtil { - public static KeyPair load(File keyStoreFile, - @Nullable String keyStorePassword, - @Nullable String keyPassword, - @Nullable String alias) throws IOException, GeneralSecurityException { + public static TlsKeyPair load(File keyStoreFile, + @Nullable String keyStorePassword, + @Nullable String keyPassword, + @Nullable String alias) throws IOException, GeneralSecurityException { try (InputStream in = new FileInputStream(keyStoreFile)) { return load(in, keyStorePassword, keyPassword, alias, keyStoreFile); } } - public static KeyPair load(InputStream keyStoreStream, - @Nullable String keyStorePassword, - @Nullable String keyPassword, - @Nullable String alias) throws IOException, GeneralSecurityException { + public static TlsKeyPair load(InputStream keyStoreStream, + @Nullable String keyStorePassword, + @Nullable String keyPassword, + @Nullable String alias) throws IOException, GeneralSecurityException { return load(keyStoreStream, keyStorePassword, keyPassword, alias, null); } - private static KeyPair load(InputStream keyStoreStream, - @Nullable String keyStorePassword, - @Nullable String keyPassword, - @Nullable String alias, - @Nullable File keyStoreFile) + private static TlsKeyPair load(InputStream keyStoreStream, + @Nullable String keyStorePassword, + @Nullable String keyPassword, + @Nullable String alias, + @Nullable File keyStoreFile) throws IOException, GeneralSecurityException { try (InputStream in = new BufferedInputStream(keyStoreStream, 8192)) { @@ -117,7 +116,7 @@ private static KeyPair load(InputStream keyStoreStream, assert certificateChain != null; - return new KeyPair(privateKey, certificateChain); + return TlsKeyPair.of(privateKey, certificateChain); } } @@ -165,22 +164,4 @@ private static IllegalArgumentException newException(String message, @Nullable F } private KeyStoreUtil() {} - - public static final class KeyPair { - private final PrivateKey privateKey; - private final List certificateChain; - - private KeyPair(PrivateKey privateKey, Iterable certificateChain) { - this.privateKey = privateKey; - this.certificateChain = ImmutableList.copyOf(certificateChain); - } - - public PrivateKey privateKey() { - return privateKey; - } - - public X509Certificate[] certificateChain() { - return certificateChain.toArray(EmptyArrays.EMPTY_X509_CERTIFICATES); - } - } } diff --git a/core/src/main/java/com/linecorp/armeria/internal/common/util/SslContextUtil.java b/core/src/main/java/com/linecorp/armeria/internal/common/util/SslContextUtil.java index 2a67d63f5e4..3e41a171115 100644 --- a/core/src/main/java/com/linecorp/armeria/internal/common/util/SslContextUtil.java +++ b/core/src/main/java/com/linecorp/armeria/internal/common/util/SslContextUtil.java @@ -98,7 +98,7 @@ public final class SslContextUtil { public static SslContext createSslContext( Supplier builderSupplier, boolean forceHttp1, TlsEngineType tlsEngineType, boolean tlsAllowUnsafeCiphers, - Iterable> userCustomizers, + @Nullable Consumer userCustomizer, @Nullable List keyCertChainCaptor) { return MinifiedBouncyCastleProvider.call(() -> { @@ -127,7 +127,9 @@ public static SslContext createSslContext( builder.protocols(protocols.toArray(EmptyArrays.EMPTY_STRINGS)) .ciphers(DEFAULT_CIPHERS, SupportedCipherSuiteFilter.INSTANCE); - userCustomizers.forEach(customizer -> customizer.accept(builder)); + if (userCustomizer != null) { + userCustomizer.accept(builder); + } // We called user customization logic before setting ALPN to make sure they don't break // compatibility with HTTP/2. diff --git a/core/src/main/java/com/linecorp/armeria/server/HttpServerPipelineConfigurator.java b/core/src/main/java/com/linecorp/armeria/server/HttpServerPipelineConfigurator.java index 2c23dc5f357..284eff8ca1f 100644 --- a/core/src/main/java/com/linecorp/armeria/server/HttpServerPipelineConfigurator.java +++ b/core/src/main/java/com/linecorp/armeria/server/HttpServerPipelineConfigurator.java @@ -230,13 +230,27 @@ private Timer newKeepAliveTimer(SessionProtocol protocol) { } private void configureHttps(ChannelPipeline p, @Nullable ProxiedAddresses proxiedAddresses) { - final Mapping sslContexts = - requireNonNull(config.sslContextMapping(), "config.sslContextMapping() returned null"); - p.addLast(new SniHandler(sslContexts, Flags.defaultMaxClientHelloLength(), config.idleTimeoutMillis())); + p.addLast(newSniHandler(p)); p.addLast(TrafficLoggingHandler.SERVER); p.addLast(new Http2OrHttpHandler(proxiedAddresses)); } + private SniHandler newSniHandler(ChannelPipeline p) { + final Mapping sslContexts = + requireNonNull(config.sslContextMapping(), "config.sslContextMapping() returned null"); + final SniHandler sniHandler = new SniHandler(sslContexts, Flags.defaultMaxClientHelloLength(), + config.idleTimeoutMillis()); + if (sslContexts instanceof TlsProviderMapping) { + p.channel().closeFuture().addListener(future -> { + final SslContext sslContext = sniHandler.sslContext(); + if (sslContext != null) { + ((TlsProviderMapping) sslContexts).release(sslContext); + } + }); + } + return sniHandler; + } + private Http2ConnectionHandler newHttp2ConnectionHandler(ChannelPipeline pipeline, AsciiString scheme) { final Timer keepAliveTimer = newKeepAliveTimer(scheme == SCHEME_HTTP ? H2C : H2); diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java index 5674d209f67..9988ddb6976 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServerBuilder.java @@ -83,6 +83,8 @@ import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.SessionProtocol; import com.linecorp.armeria.common.SuccessFunction; +import com.linecorp.armeria.common.TlsKeyPair; +import com.linecorp.armeria.common.TlsProvider; import com.linecorp.armeria.common.TlsSetters; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; @@ -237,6 +239,10 @@ public final class ServerBuilder implements TlsSetters, ServiceConfigsBuilder shutdownSupports = new ArrayList<>(); private int http2MaxResetFramesPerWindow = Flags.defaultServerHttp2MaxResetFramesPerMinute(); private int http2MaxResetFramesWindowSeconds = 60; + @Nullable + private TlsProvider tlsProvider; + @Nullable + private ServerTlsConfig tlsConfig; ServerBuilder() { // Set the default host-level properties. @@ -1074,49 +1080,65 @@ public ServerBuilder proxyProtocolMaxTlvSize(int proxyProtocolMaxTlvSize) { return this; } + @Deprecated @Override public ServerBuilder tls(File keyCertChainFile, File keyFile) { return (ServerBuilder) TlsSetters.super.tls(keyCertChainFile, keyFile); } + @Deprecated @Override public ServerBuilder tls( File keyCertChainFile, File keyFile, @Nullable String keyPassword) { - virtualHostTemplate.tls(keyCertChainFile, keyFile, keyPassword); - return this; + return (ServerBuilder) TlsSetters.super.tls(keyCertChainFile, keyFile, keyPassword); } + @Deprecated @Override public ServerBuilder tls(InputStream keyCertChainInputStream, InputStream keyInputStream) { return (ServerBuilder) TlsSetters.super.tls(keyCertChainInputStream, keyInputStream); } + @Deprecated @Override public ServerBuilder tls(InputStream keyCertChainInputStream, InputStream keyInputStream, @Nullable String keyPassword) { - virtualHostTemplate.tls(keyCertChainInputStream, keyInputStream, keyPassword); - return this; + return (ServerBuilder) TlsSetters.super.tls(keyCertChainInputStream, keyInputStream, keyPassword); } + @Deprecated @Override public ServerBuilder tls(PrivateKey key, X509Certificate... keyCertChain) { return (ServerBuilder) TlsSetters.super.tls(key, keyCertChain); } + @Deprecated @Override public ServerBuilder tls(PrivateKey key, Iterable keyCertChain) { return (ServerBuilder) TlsSetters.super.tls(key, keyCertChain); } + @Deprecated @Override public ServerBuilder tls(PrivateKey key, @Nullable String keyPassword, X509Certificate... keyCertChain) { return (ServerBuilder) TlsSetters.super.tls(key, keyPassword, keyCertChain); } + @Deprecated @Override public ServerBuilder tls(PrivateKey key, @Nullable String keyPassword, Iterable keyCertChain) { - virtualHostTemplate.tls(key, keyPassword, keyCertChain); + return (ServerBuilder) TlsSetters.super.tls(key, keyPassword, keyCertChain); + } + + /** + * Configures SSL or TLS with the specified {@link TlsKeyPair}. + * + *

Note that this method mutually exclusive with {@link #tlsProvider(TlsProvider)}. + */ + @Override + public ServerBuilder tls(TlsKeyPair tlsKeyPair) { + virtualHostTemplate.tls(tlsKeyPair); return this; } @@ -1126,9 +1148,69 @@ public ServerBuilder tls(KeyManagerFactory keyManagerFactory) { return this; } + /** + * Sets the specified {@link TlsProvider} which will be used for building an {@link SslContext} of + * a hostname. + * + *

{@code
+     * Server
+     *   .builder()
+     *   .tlsProvider(
+     *     TlsProvider.builder()
+     *                // Set the default key pair.
+     *                .keyPair(TlsKeyPair.of(...))
+     *                // Set the key pair for "example.com".
+     *                .keyPair("example.com", TlsKeyPair.of(...))
+     *                .build())
+     * }
+ * + *

Note that this method mutually exclusive with {@link #tls(TlsKeyPair)} and other static TLS settings. + */ + @UnstableApi + public ServerBuilder tlsProvider(TlsProvider tlsProvider) { + requireNonNull(tlsProvider, "tlsProvider"); + this.tlsProvider = tlsProvider; + tlsConfig = null; + return this; + } + + /** + * Sets the specified {@link TlsProvider} and {@link ServerTlsConfig} which will be used for building an + * {@link SslContext} of a hostname. + * + *

{@code
+     * TlsProvider tlsProvider =
+     *   TlsProvider
+     *     .builder()
+     *     // Set the default key pair.
+     *     .keyPair(TlsKeyPair.of(...))
+     *     // Set the key pair for "example.com".
+     *     .keyPair("example.com", TlsKeyPair.of(...))
+     *     .build();
+     *
+     * ServerTlsConfig tlsConfig =
+     *   ServerTlsConfig
+     *     .builder()
+     *     .clientAuth(ClientAuth.REQUIRED)
+     *     .meterIdPrefix(...)
+     *     .build();
+     *
+     * Server
+     *   .builder()
+     *   .tlsProvider(tlsProvider, tlsConfig)
+     * }
+ */ + @UnstableApi + public ServerBuilder tlsProvider(TlsProvider tlsProvider, ServerTlsConfig tlsConfig) { + tlsProvider(tlsProvider); + this.tlsConfig = requireNonNull(tlsConfig, "tlsConfig"); + return this; + } + /** * Configures SSL or TLS of the {@link Server} with an auto-generated self-signed certificate. - * Note: You should never use this in production but only for a testing purpose. + * + *

Note: You should never use this in production but only for a testing purpose. * * @see #tlsCustomizer(Consumer) */ @@ -1139,7 +1221,8 @@ public ServerBuilder tlsSelfSigned() { /** * Configures SSL or TLS of the {@link Server} with an auto-generated self-signed certificate. - * Note: You should never use this in production but only for a testing purpose. + * + *

Note: You should never use this in production but only for a testing purpose. * * @see #tlsCustomizer(Consumer) */ @@ -2222,11 +2305,11 @@ private DefaultServerConfig buildServerConfig(List serverPorts) { : this.errorHandler.orElse(ServerErrorHandler.ofDefault())); final VirtualHost defaultVirtualHost = defaultVirtualHostBuilder.build(virtualHostTemplate, dependencyInjector, - unloggedExceptionsReporter, errorHandler); + unloggedExceptionsReporter, errorHandler, tlsProvider); final List virtualHosts = virtualHostBuilders.stream() .map(vhb -> vhb.build(virtualHostTemplate, dependencyInjector, - unloggedExceptionsReporter, errorHandler)) + unloggedExceptionsReporter, errorHandler, tlsProvider)) .collect(toImmutableList()); // Pre-populate the domain name mapping for later matching. final Mapping sslContexts; @@ -2254,7 +2337,9 @@ private DefaultServerConfig buildServerConfig(List serverPorts) { virtualHostPort, portNumbers); } - if (defaultSslContext == null) { + checkState(defaultSslContext == null || tlsProvider == null, + "Can't set %s with a static TLS setting", TlsProvider.class.getSimpleName()); + if (defaultSslContext == null && tlsProvider == null) { sslContexts = null; if (!serverPorts.isEmpty()) { ports = resolveDistinctPorts(serverPorts); @@ -2282,21 +2367,28 @@ private DefaultServerConfig buildServerConfig(List serverPorts) { ports = ImmutableList.of(new ServerPort(0, HTTPS)); } - final DomainMappingBuilder - mappingBuilder = new DomainMappingBuilder<>(defaultSslContext); - for (VirtualHost h : virtualHosts) { - final SslContext sslCtx = h.sslContext(); - if (sslCtx != null) { - final String originalHostnamePattern = h.originalHostnamePattern(); - // The SslContext for the default virtual host was added when creating DomainMappingBuilder. - if (!"*".equals(originalHostnamePattern)) { - mappingBuilder.add(originalHostnamePattern, sslCtx); + if (defaultSslContext != null) { + final DomainMappingBuilder + mappingBuilder = new DomainMappingBuilder<>(defaultSslContext); + for (VirtualHost h : virtualHosts) { + final SslContext sslCtx = h.sslContext(); + if (sslCtx != null) { + final String originalHostnamePattern = h.originalHostnamePattern(); + // The SslContext for the default virtual host was added when creating + // DomainMappingBuilder. + if (!"*".equals(originalHostnamePattern)) { + mappingBuilder.add(originalHostnamePattern, sslCtx); + } } } + sslContexts = mappingBuilder.build(); + } else { + final TlsEngineType tlsEngineType = defaultVirtualHost.tlsEngineType(); + assert tlsEngineType != null; + assert tlsProvider != null; + sslContexts = new TlsProviderMapping(tlsProvider, tlsEngineType, tlsConfig, meterRegistry); } - sslContexts = mappingBuilder.build(); } - if (pingIntervalMillis > 0) { pingIntervalMillis = Math.max(pingIntervalMillis, MIN_PING_INTERVAL_MILLIS); if (idleTimeoutMillis > 0 && pingIntervalMillis >= idleTimeoutMillis) { diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerSslContextUtil.java b/core/src/main/java/com/linecorp/armeria/server/ServerSslContextUtil.java index 0d3c080b5dc..d73dc0e62fc 100644 --- a/core/src/main/java/com/linecorp/armeria/server/ServerSslContextUtil.java +++ b/core/src/main/java/com/linecorp/armeria/server/ServerSslContextUtil.java @@ -26,8 +26,7 @@ import javax.net.ssl.SSLException; import javax.net.ssl.SSLSession; -import com.google.common.collect.ImmutableList; - +import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.TlsEngineType; import com.linecorp.armeria.internal.common.util.SslContextUtil; @@ -65,7 +64,7 @@ static SSLSession validateSslContext(SslContext sslContext, TlsEngineType tlsEng final SslContext sslContextClient = buildSslContext(() -> SslContextBuilder.forClient() .trustManager(InsecureTrustManagerFactory.INSTANCE), - tlsEngineType, true, ImmutableList.of()); + tlsEngineType, true, null); clientEngine = sslContextClient.newEngine(ByteBufAllocator.DEFAULT); clientEngine.setUseClientMode(true); clientEngine.setEnabledProtocols(clientEngine.getSupportedProtocols()); @@ -99,10 +98,10 @@ static SslContext buildSslContext( Supplier sslContextBuilderSupplier, TlsEngineType tlsEngineType, boolean tlsAllowUnsafeCiphers, - Iterable> tlsCustomizers) { + @Nullable Consumer tlsCustomizer) { return SslContextUtil .createSslContext(sslContextBuilderSupplier,/* forceHttp1 */ false, tlsEngineType, - tlsAllowUnsafeCiphers, tlsCustomizers, null); + tlsAllowUnsafeCiphers, tlsCustomizer, null); } private static void unwrap(SSLEngine engine, ByteBuffer packetBuf) throws SSLException { diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerTlsConfig.java b/core/src/main/java/com/linecorp/armeria/server/ServerTlsConfig.java new file mode 100644 index 00000000000..d1c63db70e8 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/server/ServerTlsConfig.java @@ -0,0 +1,70 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server; + +import java.util.function.Consumer; + +import com.google.common.base.MoreObjects; + +import com.linecorp.armeria.common.AbstractTlsConfig; +import com.linecorp.armeria.common.TlsProvider; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.common.metric.MeterIdPrefix; + +import io.netty.handler.ssl.ClientAuth; +import io.netty.handler.ssl.SslContextBuilder; + +/** + * Provides server-side TLS configuration for {@link TlsProvider}. + */ +@UnstableApi +public final class ServerTlsConfig extends AbstractTlsConfig { + + /** + * Returns a new {@link ServerTlsConfigBuilder}. + */ + public static ServerTlsConfigBuilder builder() { + return new ServerTlsConfigBuilder(); + } + + private final ClientAuth clientAuth; + + ServerTlsConfig(boolean allowsUnsafeCiphers, @Nullable MeterIdPrefix meterIdPrefix, + ClientAuth clientAuth, Consumer tlsCustomizer) { + super(allowsUnsafeCiphers, meterIdPrefix, tlsCustomizer); + this.clientAuth = clientAuth; + } + + /** + * Returns the client authentication mode. + */ + public ClientAuth clientAuth() { + return clientAuth; + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .omitNullValues() + .add("allowsUnsafeCiphers", allowsUnsafeCiphers()) + .add("meterIdPrefix", meterIdPrefix()) + .add("clientAuth", clientAuth) + .add("tlsCustomizer", tlsCustomizer()) + .toString(); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/server/ServerTlsConfigBuilder.java b/core/src/main/java/com/linecorp/armeria/server/ServerTlsConfigBuilder.java new file mode 100644 index 00000000000..97c53e828aa --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/server/ServerTlsConfigBuilder.java @@ -0,0 +1,51 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server; + +import static java.util.Objects.requireNonNull; + +import com.linecorp.armeria.common.AbstractTlsConfigBuilder; +import com.linecorp.armeria.common.TlsProvider; +import com.linecorp.armeria.common.annotation.UnstableApi; + +import io.netty.handler.ssl.ClientAuth; + +/** + * A builder class for creating a {@link TlsProvider} that provides server-side TLS. + */ +@UnstableApi +public final class ServerTlsConfigBuilder extends AbstractTlsConfigBuilder { + + private ClientAuth clientAuth = ClientAuth.NONE; + + ServerTlsConfigBuilder() {} + + /** + * Sets the client authentication mode. + */ + public ServerTlsConfigBuilder clientAuth(ClientAuth clientAuth) { + this.clientAuth = requireNonNull(clientAuth, "clientAuth"); + return this; + } + + /** + * Returns a newly-created {@link ServerTlsConfig} based on the properties of this builder. + */ + public ServerTlsConfig build() { + return new ServerTlsConfig(allowsUnsafeCiphers(), meterIdPrefix(), clientAuth, tlsCustomizer()); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/server/TlsProviderMapping.java b/core/src/main/java/com/linecorp/armeria/server/TlsProviderMapping.java new file mode 100644 index 00000000000..a36b9065be2 --- /dev/null +++ b/core/src/main/java/com/linecorp/armeria/server/TlsProviderMapping.java @@ -0,0 +1,51 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server; + +import com.linecorp.armeria.common.TlsProvider; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.util.TlsEngineType; +import com.linecorp.armeria.internal.common.SslContextFactory; +import com.linecorp.armeria.internal.common.TlsProviderUtil; + +import io.micrometer.core.instrument.MeterRegistry; +import io.netty.handler.ssl.SslContext; +import io.netty.util.Mapping; + +final class TlsProviderMapping implements Mapping { + + private final SslContextFactory sslContextFactory; + + TlsProviderMapping(TlsProvider tlsProvider, TlsEngineType tlsEngineType, + @Nullable ServerTlsConfig tlsConfig, MeterRegistry meterRegistry) { + sslContextFactory = new SslContextFactory(tlsProvider, tlsEngineType, tlsConfig, meterRegistry); + } + + @Override + public SslContext map(@Nullable String hostname) { + if (hostname == null) { + hostname = "*"; + } else { + hostname = TlsProviderUtil.normalizeHostname(hostname); + } + return sslContextFactory.getOrCreate(SslContextFactory.SslContextMode.SERVER, hostname); + } + + void release(SslContext sslContext) { + sslContextFactory.release(sslContext); + } +} diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java index 89ddb678a29..297f1a55507 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHost.java @@ -39,6 +39,7 @@ import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.RequestId; import com.linecorp.armeria.common.SuccessFunction; +import com.linecorp.armeria.common.TlsProvider; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.logging.RequestLog; @@ -52,6 +53,7 @@ import io.netty.channel.EventLoopGroup; import io.netty.handler.ssl.SslContext; import io.netty.util.Mapping; +import io.netty.util.ReferenceCountUtil; /** * A name-based virtual host. @@ -84,6 +86,8 @@ public final class VirtualHost { @Nullable private final SslContext sslContext; @Nullable + private final TlsProvider tlsProvider; + @Nullable private final TlsEngineType tlsEngineType; private final Router router; private final List serviceConfigs; @@ -109,6 +113,7 @@ public final class VirtualHost { VirtualHost(String defaultHostname, String hostnamePattern, int port, @Nullable SslContext sslContext, + @Nullable TlsProvider tlsProvider, @Nullable TlsEngineType tlsEngineType, Iterable serviceConfigs, ServiceConfig fallbackServiceConfig, @@ -138,6 +143,7 @@ public final class VirtualHost { } this.port = port; this.sslContext = sslContext; + this.tlsProvider = tlsProvider; this.tlsEngineType = tlsEngineType; this.defaultServiceNaming = defaultServiceNaming; this.defaultLogName = defaultLogName; @@ -172,7 +178,11 @@ public final class VirtualHost { } VirtualHost withNewSslContext(SslContext sslContext) { - return new VirtualHost(originalDefaultHostname, originalHostnamePattern, port, sslContext, + if (tlsProvider != null) { + ReferenceCountUtil.release(sslContext); + throw new IllegalStateException("Cannot set a new SslContext when TlsProvider is set."); + } + return new VirtualHost(originalDefaultHostname, originalHostnamePattern, port, sslContext, null, tlsEngineType, serviceConfigs, fallbackServiceConfig, RejectedRouteHandler.DISABLED, host -> accessLogger, defaultServiceNaming, defaultLogName, requestTimeoutMillis, maxRequestLength, verboseResponses, @@ -590,7 +600,7 @@ VirtualHost decorate(@Nullable Function accessLogger, defaultServiceNaming, defaultLogName, requestTimeoutMillis, maxRequestLength, verboseResponses, diff --git a/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java b/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java index 323f7e04300..0868d04e9df 100644 --- a/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java +++ b/core/src/main/java/com/linecorp/armeria/server/VirtualHostBuilder.java @@ -32,10 +32,7 @@ import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.isPseudoHeader; import static java.util.Objects.requireNonNull; -import java.io.ByteArrayInputStream; import java.io.File; -import java.io.IOError; -import java.io.IOException; import java.io.InputStream; import java.nio.file.Path; import java.security.PrivateKey; @@ -60,7 +57,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import com.google.common.io.ByteStreams; import com.google.common.net.HostAndPort; import com.linecorp.armeria.common.CommonPools; @@ -76,6 +72,8 @@ import com.linecorp.armeria.common.RequestId; import com.linecorp.armeria.common.ResponseHeaders; import com.linecorp.armeria.common.SuccessFunction; +import com.linecorp.armeria.common.TlsKeyPair; +import com.linecorp.armeria.common.TlsProvider; import com.linecorp.armeria.common.TlsSetters; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; @@ -134,7 +132,8 @@ public final class VirtualHostBuilder implements TlsSetters, ServiceConfigsBuild private Boolean tlsSelfSigned; @Nullable private SelfSignedCertificate selfSignedCertificate; - private final List> tlsCustomizers = new ArrayList<>(); + @Nullable + private Consumer tlsCustomizer; @Nullable private Boolean tlsAllowUnsafeCiphers; @Nullable @@ -276,70 +275,61 @@ VirtualHostBuilder hostnamePattern(String hostnamePattern, int port) { return this; } + @Deprecated @Override public VirtualHostBuilder tls(File keyCertChainFile, File keyFile) { return (VirtualHostBuilder) TlsSetters.super.tls(keyCertChainFile, keyFile); } + @Deprecated @Override public VirtualHostBuilder tls(File keyCertChainFile, File keyFile, @Nullable String keyPassword) { - requireNonNull(keyCertChainFile, "keyCertChainFile"); - requireNonNull(keyFile, "keyFile"); - return tls(() -> SslContextBuilder.forServer(keyCertChainFile, keyFile, keyPassword)); + return (VirtualHostBuilder) TlsSetters.super.tls(keyCertChainFile, keyFile, keyPassword); } + @Deprecated @Override public VirtualHostBuilder tls(InputStream keyCertChainInputStream, InputStream keyInputStream) { return (VirtualHostBuilder) TlsSetters.super.tls(keyCertChainInputStream, keyInputStream); } + @Deprecated @Override public VirtualHostBuilder tls(InputStream keyCertChainInputStream, InputStream keyInputStream, @Nullable String keyPassword) { - requireNonNull(keyCertChainInputStream, "keyCertChainInputStream"); - requireNonNull(keyInputStream, "keyInputStream"); - - // Retrieve the content of the given streams so that they can be consumed more than once. - final byte[] keyCertChain; - final byte[] key; - try { - keyCertChain = ByteStreams.toByteArray(keyCertChainInputStream); - key = ByteStreams.toByteArray(keyInputStream); - } catch (IOException e) { - throw new IOError(e); - } - - return tls(() -> SslContextBuilder.forServer(new ByteArrayInputStream(keyCertChain), - new ByteArrayInputStream(key), - keyPassword)); + return (VirtualHostBuilder) TlsSetters.super.tls(keyCertChainInputStream, keyInputStream, keyPassword); } + @Deprecated @Override public VirtualHostBuilder tls(PrivateKey key, X509Certificate... keyCertChain) { return (VirtualHostBuilder) TlsSetters.super.tls(key, keyCertChain); } + @Deprecated @Override public VirtualHostBuilder tls(PrivateKey key, Iterable keyCertChain) { return (VirtualHostBuilder) TlsSetters.super.tls(key, keyCertChain); } + @Deprecated @Override public VirtualHostBuilder tls(PrivateKey key, @Nullable String keyPassword, X509Certificate... keyCertChain) { return (VirtualHostBuilder) TlsSetters.super.tls(key, keyPassword, keyCertChain); } + @Deprecated @Override public VirtualHostBuilder tls(PrivateKey key, @Nullable String keyPassword, Iterable keyCertChain) { - requireNonNull(key, "key"); - requireNonNull(keyCertChain, "keyCertChain"); - for (X509Certificate keyCert : keyCertChain) { - requireNonNull(keyCert, "keyCertChain contains null."); - } + return (VirtualHostBuilder) TlsSetters.super.tls(key, keyPassword, keyCertChain); + } - return tls(() -> SslContextBuilder.forServer(key, keyPassword, keyCertChain)); + @Override + public VirtualHostBuilder tls(TlsKeyPair tlsKeyPair) { + requireNonNull(tlsKeyPair, "tlsKeyPair"); + return tls(() -> SslContextBuilder.forServer(tlsKeyPair.privateKey(), tlsKeyPair.certificateChain())); } @Override @@ -363,7 +353,9 @@ private VirtualHostBuilder tls(Supplier sslContextBuilderSupp * Note: You should never use this in production but only for a testing purpose. * * @see #tlsCustomizer(Consumer) + * @deprecated Use {@link #tls(TlsKeyPair)} with {@link TlsKeyPair#ofSelfSigned()}. */ + @Deprecated public VirtualHostBuilder tlsSelfSigned() { return tlsSelfSigned(true); } @@ -373,7 +365,9 @@ public VirtualHostBuilder tlsSelfSigned() { * Note: You should never use this in production but only for a testing purpose. * * @see #tlsCustomizer(Consumer) + * @deprecated Use {@link #tls(TlsKeyPair)} with {@link TlsKeyPair#ofSelfSigned()}. */ + @Deprecated public VirtualHostBuilder tlsSelfSigned(boolean tlsSelfSigned) { checkState(!portBased, "Cannot configure self-signed to a port-based virtual host." + " Please configure to %s.tlsSelfSigned()", ServerBuilder.class.getSimpleName()); @@ -387,7 +381,12 @@ public VirtualHostBuilder tlsCustomizer(Consumer tlsC checkState(!portBased, "Cannot configure TLS to a port-based virtual host. Please configure to %s.tlsCustomizer()", ServerBuilder.class.getSimpleName()); - tlsCustomizers.add(tlsCustomizer); + if (this.tlsCustomizer == null) { + //noinspection unchecked + this.tlsCustomizer = (Consumer) tlsCustomizer; + } else { + this.tlsCustomizer = this.tlsCustomizer.andThen(tlsCustomizer); + } return this; } @@ -1306,7 +1305,7 @@ public VirtualHostBuilder contextHook(Supplier contextH */ VirtualHost build(VirtualHostBuilder template, DependencyInjector dependencyInjector, @Nullable UnloggedExceptionsReporter unloggedExceptionsReporter, - ServerErrorHandler serverErrorHandler) { + ServerErrorHandler serverErrorHandler, @Nullable TlsProvider tlsProvider) { requireNonNull(template, "template"); if (defaultHostname == null) { @@ -1464,9 +1463,17 @@ VirtualHost build(VirtualHostBuilder template, DependencyInjector dependencyInje final TlsEngineType tlsEngineType = this.tlsEngineType != null ? this.tlsEngineType : template.tlsEngineType; assert tlsEngineType != null; + + final SslContext sslContext = sslContext(template, tlsEngineType); + if (sslContext != null && tlsProvider != null) { + ReferenceCountUtil.release(sslContext); + throw new IllegalStateException("Cannot configure TLS settings with a TlsProvider"); + } + final VirtualHost virtualHost = - new VirtualHost(defaultHostname, hostnamePattern, port, sslContext(template, tlsEngineType), - tlsEngineType, serviceConfigs, fallbackServiceConfig, rejectedRouteHandler, + new VirtualHost(defaultHostname, hostnamePattern, port, + sslContext, tlsProvider, tlsEngineType, + serviceConfigs, fallbackServiceConfig, rejectedRouteHandler, accessLoggerMapper, defaultServiceNaming, defaultLogName, requestTimeoutMillis, maxRequestLength, verboseResponses, accessLogWriter, blockingTaskExecutor, requestAutoAbortDelayMillis, successFunction, multipartUploadsLocation, @@ -1516,27 +1523,27 @@ private SslContext sslContext(VirtualHostBuilder template, TlsEngineType tlsEngi // Build a new SslContext or use a user-specified one for backward compatibility. if (sslContextBuilderSupplier != null) { sslContext = buildSslContext(sslContextBuilderSupplier, tlsEngineType, tlsAllowUnsafeCiphers, - tlsCustomizers); + tlsCustomizer); sslContextFromThis = true; releaseSslContextOnFailure = true; } else if (template.sslContextBuilderSupplier != null) { sslContext = buildSslContext(template.sslContextBuilderSupplier, tlsEngineType, - tlsAllowUnsafeCiphers, template.tlsCustomizers); + tlsAllowUnsafeCiphers, template.tlsCustomizer); releaseSslContextOnFailure = true; } // Generate a self-signed certificate if necessary. if (sslContext == null) { final boolean tlsSelfSigned; - final List> tlsCustomizers; + final Consumer tlsCustomizer; if (this.tlsSelfSigned != null) { tlsSelfSigned = this.tlsSelfSigned; - tlsCustomizers = this.tlsCustomizers; + tlsCustomizer = this.tlsCustomizer; sslContextFromThis = true; } else { assert template.tlsSelfSigned != null; tlsSelfSigned = template.tlsSelfSigned; - tlsCustomizers = template.tlsCustomizers; + tlsCustomizer = template.tlsCustomizer; } if (tlsSelfSigned) { @@ -1551,13 +1558,13 @@ private SslContext sslContext(VirtualHostBuilder template, TlsEngineType tlsEngi ssc.privateKey()), tlsEngineType, tlsAllowUnsafeCiphers, - tlsCustomizers); + tlsCustomizer); releaseSslContextOnFailure = true; } } // Reject if a user called `tlsCustomizer()` without `tls()` or `tlsSelfSigned()`. - checkState(sslContextFromThis || tlsCustomizers.isEmpty(), + checkState(sslContextFromThis || tlsCustomizer == null, "Cannot call tlsCustomizer() without tls() or tlsSelfSigned()"); // Validate the built `SslContext`. diff --git a/core/src/test/java/com/linecorp/armeria/client/ClientTlsProviderBuilderTest.java b/core/src/test/java/com/linecorp/armeria/client/ClientTlsProviderBuilderTest.java new file mode 100644 index 00000000000..3763819d982 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/ClientTlsProviderBuilderTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.junit.jupiter.api.Test; + +import com.linecorp.armeria.common.TlsKeyPair; +import com.linecorp.armeria.common.TlsProvider; + +class ClientTlsProviderBuilderTest { + + @Test + void testBuild() { + assertThatThrownBy(() -> { + TlsProvider.builder() + .build(); + }).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("No TLS key pair is set."); + } + + @Test + void testMapping() { + final TlsKeyPair exactKeyPair = TlsKeyPair.ofSelfSigned(); + final TlsKeyPair wildcardKeyPair = TlsKeyPair.ofSelfSigned(); + final TlsKeyPair defaultKeyPair = TlsKeyPair.ofSelfSigned(); + final TlsKeyPair barKeyPair = TlsKeyPair.ofSelfSigned(); + final TlsKeyPair barWildKeyPair = TlsKeyPair.ofSelfSigned(); + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair(defaultKeyPair) + .keyPair("example.com", exactKeyPair) + .keyPair("*.foo.com", wildcardKeyPair) + .keyPair("*.bar.com", barWildKeyPair) + .keyPair("bar.com", barKeyPair) + .build(); + assertThat(tlsProvider.keyPair("any.com")).isEqualTo(defaultKeyPair); + // Exact match + assertThat(tlsProvider.keyPair("example.com")).isEqualTo(exactKeyPair); + // Wildcard match + assertThat(tlsProvider.keyPair("bar.foo.com")).isEqualTo(wildcardKeyPair); + + // Not a wildcard match + assertThat(tlsProvider.keyPair("foo.com")).isEqualTo(defaultKeyPair); + // No nested wildcard support + assertThat(tlsProvider.keyPair("baz.bar.foo.com")).isEqualTo(defaultKeyPair); + + assertThat(tlsProvider.keyPair("bar.com")).isEqualTo(barKeyPair); + assertThat(tlsProvider.keyPair("foo.bar.com")).isEqualTo(barWildKeyPair); + assertThat(tlsProvider.keyPair("foo.foo.bar.com")).isEqualTo(defaultKeyPair); + } +} diff --git a/core/src/test/java/com/linecorp/armeria/client/ClientTlsProviderTest.java b/core/src/test/java/com/linecorp/armeria/client/ClientTlsProviderTest.java new file mode 100644 index 00000000000..922e764488b --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/ClientTlsProviderTest.java @@ -0,0 +1,311 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.awaitility.Awaitility.await; + +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.TlsKeyPair; +import com.linecorp.armeria.common.TlsProvider; +import com.linecorp.armeria.common.metric.MoreMeters; +import com.linecorp.armeria.internal.common.util.CertificateUtil; +import com.linecorp.armeria.internal.testing.MockAddressResolverGroup; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServerTlsConfig; +import com.linecorp.armeria.testing.junit5.server.SelfSignedCertificateExtension; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.netty.handler.ssl.ClientAuth; + +class ClientTlsProviderTest { + + @RegisterExtension + static final SelfSignedCertificateExtension server0DefaultCert = new SelfSignedCertificateExtension(); + @RegisterExtension + static final SelfSignedCertificateExtension server0FooCert = new SelfSignedCertificateExtension( + "foo.com"); + @RegisterExtension + static final SelfSignedCertificateExtension server0SubFooCert = new SelfSignedCertificateExtension( + "sub.foo.com"); + @RegisterExtension + static final SelfSignedCertificateExtension server1DefaultCert = new SelfSignedCertificateExtension(); + @RegisterExtension + static final SelfSignedCertificateExtension server1BarCert = new SelfSignedCertificateExtension("bar.com"); + @RegisterExtension + static final SelfSignedCertificateExtension clientFooCert = new SelfSignedCertificateExtension("foo.com"); + @RegisterExtension + static final SelfSignedCertificateExtension clientSubFooCert = + new SelfSignedCertificateExtension("sub.foo.com"); + @RegisterExtension + static final SelfSignedCertificateExtension clientBarCert = new SelfSignedCertificateExtension("bar.com"); + + @RegisterExtension + static final ServerExtension server0 = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair(server0DefaultCert.tlsKeyPair()) + .keyPair("foo.com", server0FooCert.tlsKeyPair()) + .keyPair("*.foo.com", server0SubFooCert.tlsKeyPair()) + .trustedCertificates(clientFooCert.certificate(), clientSubFooCert.certificate()) + .build(); + + final ServerTlsConfig tlsConfig = ServerTlsConfig.builder() + .clientAuth(ClientAuth.REQUIRE) + .build(); + sb.https(0) + .tlsProvider(tlsProvider, tlsConfig) + .service("/", (ctx, req) -> { + final String commonName = CertificateUtil.getCommonName(ctx.sslSession()); + return HttpResponse.of("default:" + commonName); + }) + .virtualHost("foo.com") + .service("/", (ctx, req) -> { + final String commonName = CertificateUtil.getCommonName(ctx.sslSession()); + return HttpResponse.of("foo:" + commonName); + }) + .and() + .virtualHost("sub.foo.com") + .service("/", (ctx, req) -> { + final String commonName = CertificateUtil.getCommonName(ctx.sslSession()); + return HttpResponse.of("sub.foo:" + commonName); + }); + } + }; + + @RegisterExtension + static final ServerExtension server1 = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair(server1DefaultCert.tlsKeyPair()) + .keyPair("bar.com", server1BarCert.tlsKeyPair()) + .trustedCertificates(clientFooCert.certificate()) + .build(); + final ServerTlsConfig tlsConfig = ServerTlsConfig.builder() + .clientAuth(ClientAuth.REQUIRE) + .build(); + + sb.https(0) + .tlsProvider(tlsProvider, tlsConfig) + .service("/", (ctx, req) -> { + final String commonName = CertificateUtil.getCommonName(ctx.sslSession()); + return HttpResponse.of("default:" + commonName); + }) + .virtualHost("bar.com") + .service("/", (ctx, req) -> { + final String commonName = CertificateUtil.getCommonName(ctx.sslSession()); + return HttpResponse.of("virtual:" + commonName); + }); + } + }; + + @RegisterExtension + static final ServerExtension serverNoMtls = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair(server0DefaultCert.tlsKeyPair()) + .keyPair("bar.com", server1BarCert.tlsKeyPair()) + .build(); + + sb.https(0) + .tlsProvider(tlsProvider) + .service("/", (ctx, req) -> { + final String commonName = CertificateUtil.getCommonName(ctx.sslSession()); + return HttpResponse.of("default:" + commonName); + }) + .virtualHost("bar.com") + .service("/", (ctx, req) -> { + final String commonName = CertificateUtil.getCommonName(ctx.sslSession()); + return HttpResponse.of("virtual:" + commonName); + }); + } + }; + + @Test + void testExactMatch() { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair("*.foo.com", clientFooCert.tlsKeyPair()) + .keyPair("bar.com", clientBarCert.tlsKeyPair()) + .keyPair(TlsKeyPair.of(clientFooCert.privateKey(), + clientFooCert.certificate())) + .trustedCertificates("foo.com", server0FooCert.certificate()) + .trustedCertificates("bar.com", server1BarCert.certificate()) + .trustedCertificates("sub.foo.com", server0SubFooCert.certificate()) + .trustedCertificates(server0DefaultCert.certificate()) + .build(); + + final MeterRegistry meterRegistry = new SimpleMeterRegistry(); + try (ClientFactory factory = ClientFactory.builder() + .tlsProvider(tlsProvider) + .meterRegistry(meterRegistry) + .addressResolverGroupFactory( + unused -> MockAddressResolverGroup.localhost()) + .build()) { + // clientFooCert should be chosen by TlsProvider. + BlockingWebClient client = WebClient.builder("https://foo.com:" + server0.httpsPort()) + .factory(factory) + .build() + .blocking(); + assertThat(client.get("/").contentUtf8()).isEqualTo("foo:foo.com"); + client = WebClient.builder("https://sub.foo.com:" + server0.httpsPort()) + .factory(factory) + .build() + .blocking(); + assertThat(client.get("/").contentUtf8()).isEqualTo("sub.foo:sub.foo.com"); + client = WebClient.builder("https://127.0.0.1:" + server0.httpsPort()) + .factory(factory) + .build() + .blocking(); + assertThat(client.get("/").contentUtf8()).isEqualTo("default:localhost"); + + await().untilAsserted(() -> { + final Map metrics = MoreMeters.measureAll(meterRegistry); + // Make sure that the metrics for the certificates generated from TlsProvider are exported. + assertThat(metrics.get("armeria.client.tls.certificate.validity#value{common.name=foo.com}")) + .isEqualTo(1.0); + assertThat( + metrics.get("armeria.client.tls.certificate.validity#value{common.name=sub.foo.com}")) + .isEqualTo(1.0); + }); + } + + await().untilAsserted(() -> { + final Map metrics = MoreMeters.measureAll(meterRegistry); + // The metrics for the certificates should be closed when the associated connections are closed. + assertThat(metrics.get("armeria.client.tls.certificate.validity#value{common.name=foo.com}")) + .isNull(); + assertThat(metrics.get("armeria.client.tls.certificate.validity#value{common.name=sub.foo.com}")) + .isNull(); + }); + } + + @Test + void testWildcardMatch() { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair("foo.com", clientFooCert.tlsKeyPair()) + .keyPair("*.foo.com", clientFooCert.tlsKeyPair()) + .trustedCertificates(server0FooCert.certificate(), + server0SubFooCert.certificate()) + .build(); + + try ( + ClientFactory factory = ClientFactory.builder() + .tlsProvider(tlsProvider) + .addressResolverGroupFactory( + unused -> MockAddressResolverGroup.localhost()) + .build()) { + // clientFooCert should be chosen by TlsProvider. + BlockingWebClient client = WebClient.builder("https://foo.com:" + server0.httpsPort()) + .factory(factory) + .build() + .blocking(); + assertThat(client.get("/").contentUtf8()).isEqualTo("foo:foo.com"); + client = WebClient.builder("https://sub.foo.com:" + server0.httpsPort()) + .factory(factory) + .build() + .blocking(); + assertThat(client.get("/").contentUtf8()).isEqualTo("sub.foo:sub.foo.com"); + } + } + + @Test + void testNoMtls() { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair("foo.com", clientFooCert.tlsKeyPair()) + .trustedCertificates(server0DefaultCert.certificate(), + server1BarCert.certificate()) + .build(); + + try (ClientFactory factory = ClientFactory.builder() + .tlsProvider(tlsProvider) + .addressResolverGroupFactory( + unused -> MockAddressResolverGroup.localhost()) + .build()) { + // clientFooCert should be chosen by TlsProvider. + BlockingWebClient client = WebClient.builder("https://bar.com:" + serverNoMtls.httpsPort()) + .factory(factory) + .build() + .blocking(); + assertThat(client.get("/").contentUtf8()).isEqualTo("virtual:bar.com"); + + client = WebClient.builder("https://127.0.0.1:" + serverNoMtls.httpsPort()) + .factory(factory) + .build() + .blocking(); + assertThat(client.get("/").contentUtf8()).isEqualTo("default:localhost"); + } + } + + @Test + void disallowTlsProviderWhenTlsSettingsIsSet() { + final TlsProvider tlsProvider = + TlsProvider.of(TlsKeyPair.ofSelfSigned()); + + assertThatThrownBy(() -> { + ClientFactory.builder() + .tlsProvider(tlsProvider) + .tls(TlsKeyPair.ofSelfSigned()); + }).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Cannot configure TLS settings because a TlsProvider has been set."); + + assertThatThrownBy(() -> { + ClientFactory.builder() + .tlsProvider(tlsProvider) + .tlsCustomizer(b -> {}); + }).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Cannot configure TLS settings because a TlsProvider has been set."); + + assertThatThrownBy(() -> { + ClientFactory.builder() + .tlsProvider(tlsProvider) + .tlsNoVerify(); + }).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Cannot configure TLS settings because a TlsProvider has been set."); + + assertThatThrownBy(() -> { + ClientFactory.builder() + .tlsProvider(tlsProvider) + .tlsNoVerifyHosts("example.com"); + }).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Cannot configure TLS settings because a TlsProvider has been set."); + + assertThatThrownBy(() -> { + ClientFactory.builder() + .tls(TlsKeyPair.ofSelfSigned()) + .tlsProvider(tlsProvider); + }).isInstanceOf(IllegalStateException.class) + .hasMessageContaining( + "Cannot configure the TlsProvider because static TLS settings have been set already."); + } +} diff --git a/core/src/test/java/com/linecorp/armeria/client/IgnoreHostsTrustManagerTest.java b/core/src/test/java/com/linecorp/armeria/client/IgnoreHostsTrustManagerTest.java index 6966d8a580a..15f97dba438 100644 --- a/core/src/test/java/com/linecorp/armeria/client/IgnoreHostsTrustManagerTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/IgnoreHostsTrustManagerTest.java @@ -37,6 +37,7 @@ import com.google.common.collect.ImmutableSet; import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.internal.common.IgnoreHostsTrustManager; import com.linecorp.armeria.server.ServerBuilder; import com.linecorp.armeria.testing.junit5.server.ServerExtension; diff --git a/core/src/test/java/com/linecorp/armeria/client/TlsProviderCacheTest.java b/core/src/test/java/com/linecorp/armeria/client/TlsProviderCacheTest.java new file mode 100644 index 00000000000..5ee0e319ce5 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/TlsProviderCacheTest.java @@ -0,0 +1,179 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.google.common.collect.ImmutableList; +import com.spotify.futures.CompletableFutures; + +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.HttpHeaderNames; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.TlsProvider; +import com.linecorp.armeria.common.logging.RequestLogProperty; +import com.linecorp.armeria.internal.common.SslContextFactory; +import com.linecorp.armeria.internal.testing.MockAddressResolverGroup; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServerTlsConfig; +import com.linecorp.armeria.testing.junit5.server.SelfSignedCertificateExtension; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import io.netty.channel.Channel; +import io.netty.handler.ssl.ClientAuth; + +class TlsProviderCacheTest { + + @Order(0) + @RegisterExtension + static final SelfSignedCertificateExtension clientFooCert = new SelfSignedCertificateExtension(); + + @Order(0) + @RegisterExtension + static final SelfSignedCertificateExtension clientBarCert = new SelfSignedCertificateExtension(); + + @Order(0) + @RegisterExtension + static final SelfSignedCertificateExtension serverFooCert = new SelfSignedCertificateExtension("foo.com"); + + @Order(0) + @RegisterExtension + static final SelfSignedCertificateExtension serverBarCert = new SelfSignedCertificateExtension("bar.com"); + + static CompletableFuture startFuture; + + @Order(1) + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair(serverFooCert.tlsKeyPair()) + .keyPair("bar.com", serverBarCert.tlsKeyPair()) + .trustedCertificates(clientFooCert.certificate(), + clientBarCert.certificate()) + .build(); + final ServerTlsConfig tlsConfig = ServerTlsConfig.builder() + .clientAuth(ClientAuth.REQUIRE) + .build(); + sb.tlsProvider(tlsProvider, tlsConfig); + + sb.virtualHost("bar.com") + .service("/", (ctx, req) -> { + final CompletableFuture future = + startFuture.thenApply(unused -> HttpResponse.of("Hello, Bar!")); + return HttpResponse.of(future); + }); + + sb.service("/", (ctx, req) -> { + final CompletableFuture future = + startFuture.thenApply(unused -> HttpResponse.of("Hello!")); + return HttpResponse.of(future); + }); + } + }; + + @BeforeEach + void setUp() { + startFuture = new CompletableFuture<>(); + } + + @Test + void shouldCacheSslContext() { + // This test could be broken if multiple tests are running in parallel. + final CountingConnectionPoolListener poolListener = new CountingConnectionPoolListener(); + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair("foo.com", clientFooCert.tlsKeyPair()) + .keyPair("bar.com", clientBarCert.tlsKeyPair()) + .trustedCertificates(serverFooCert.certificate(), serverBarCert.certificate()) + .build(); + + final List channels = new ArrayList<>(); + final List> responses = new ArrayList<>(); + try ( + ClientFactory factory = ClientFactory + .builder() + .addressResolverGroupFactory(eventLoopGroup -> MockAddressResolverGroup.localhost()) + .tlsProvider(tlsProvider) + .connectionPoolListener(poolListener) + .build()) { + for (String host : ImmutableList.of("foo.com", "bar.com")) { + final WebClient client = + // Use HTTP/1 to create multiple connections. + WebClient.builder("h1://" + host + ':' + server.httpsPort()) + .factory(factory) + .build(); + + for (int i = 0; i < 3; i++) { + try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) { + final CompletableFuture future = + client.prepare() + .get("/") + .header(HttpHeaderNames.CONNECTION, "close") + .execute() + .aggregate(); + responses.add(future); + channels.add(captor.get().log() + .whenAvailable(RequestLogProperty.REQUEST_HEADERS).join() + .channel()); + } + } + } + + await().untilAsserted(() -> { + assertThat(poolListener.opened()).isEqualTo(6); + }); + + final HttpClientFactory clientFactory = (HttpClientFactory) factory.unwrap(); + final SslContextFactory sslContextFactory = clientFactory.sslContextFactory(); + assertThat(sslContextFactory).isNotNull(); + // Make sure the SslContext is reused + assertThat(sslContextFactory.numCachedContexts()).isEqualTo(2); + + startFuture.complete(null); + final List responses0 = CompletableFutures.allAsList(responses).join(); + for (int i = 0; i < responses0.size(); i++) { + final AggregatedHttpResponse response = responses0.get(i); + assertThat(response.status()).isEqualTo(HttpStatus.OK); + if (i < 3) { + assertThat(response.contentUtf8()).isEqualTo("Hello!"); + } else { + assertThat(response.contentUtf8()).isEqualTo("Hello, Bar!"); + } + } + + await().untilAsserted(() -> { + assertThat(poolListener.closed()).isEqualTo(6); + }); + // Make sure a cached SslContext is released when all referenced channels are closed. + assertThat(sslContextFactory.numCachedContexts()).isEqualTo(0); + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/client/TlsProviderMTlsTest.java b/core/src/test/java/com/linecorp/armeria/client/TlsProviderMTlsTest.java new file mode 100644 index 00000000000..b3fda22fc8d --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/TlsProviderMTlsTest.java @@ -0,0 +1,84 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.linecorp.armeria.common.AggregatedHttpResponse; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.TlsProvider; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServerTlsConfig; +import com.linecorp.armeria.testing.junit5.server.SelfSignedCertificateExtension; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import io.netty.handler.ssl.ClientAuth; + +class TlsProviderMTlsTest { + @Order(0) + @RegisterExtension + static SelfSignedCertificateExtension sscServer = new SelfSignedCertificateExtension(); + @Order(0) + @RegisterExtension + static SelfSignedCertificateExtension sscClient = new SelfSignedCertificateExtension(); + + @Order(1) + @RegisterExtension + static ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + final TlsProvider tlsProvider = TlsProvider.builder() + .keyPair(sscServer.tlsKeyPair()) + .trustedCertificates(sscClient.certificate()) + .build(); + final ServerTlsConfig tlsConfig = ServerTlsConfig.builder() + .clientAuth(ClientAuth.REQUIRE) + .build(); + sb.tlsProvider(tlsProvider, tlsConfig); + + sb.service("/", (ctx, req) -> { + return HttpResponse.of(HttpStatus.OK); + }); + } + }; + + @Test + void testMTls() { + final TlsProvider tlsProvider = TlsProvider + .builder() + .keyPair(sscClient.tlsKeyPair()) + .trustedCertificates(sscServer.certificate()) + .build(); + try (ClientFactory factory = ClientFactory + .builder() + .tlsProvider(tlsProvider) + .connectTimeoutMillis(Long.MAX_VALUE) + .build()) { + final BlockingWebClient client = WebClient.builder(server.httpsUri()) + .factory(factory) + .build() + .blocking(); + final AggregatedHttpResponse res = client.get("/"); + assertThat(res.status()).isEqualTo(HttpStatus.OK); + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/client/TlsProviderTrustedCertificatesTest.java b/core/src/test/java/com/linecorp/armeria/client/TlsProviderTrustedCertificatesTest.java new file mode 100644 index 00000000000..50e999cd1d0 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/client/TlsProviderTrustedCertificatesTest.java @@ -0,0 +1,233 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.client; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.security.cert.X509Certificate; +import java.util.stream.Stream; + +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import com.google.common.collect.ImmutableList; + +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.HttpStatus; +import com.linecorp.armeria.common.TlsKeyPair; +import com.linecorp.armeria.common.TlsProvider; +import com.linecorp.armeria.internal.testing.MockAddressResolverGroup; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServerTlsConfig; +import com.linecorp.armeria.testing.junit5.server.SelfSignedCertificateExtension; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import io.netty.handler.ssl.ClientAuth; + +class TlsProviderTrustedCertificatesTest { + + @Order(0) + @RegisterExtension + static SelfSignedCertificateExtension serverCertFoo = new SelfSignedCertificateExtension("foo.com"); + + @Order(0) + @RegisterExtension + static SelfSignedCertificateExtension serverCertBar = new SelfSignedCertificateExtension("bar.com"); + + @Order(0) + + @RegisterExtension + static SelfSignedCertificateExtension serverCertDefault = new SelfSignedCertificateExtension(); + + @Order(0) + @RegisterExtension + static SelfSignedCertificateExtension clientCertFoo = new SelfSignedCertificateExtension(); + + @Order(0) + @RegisterExtension + static SelfSignedCertificateExtension clientCertBar = new SelfSignedCertificateExtension(); + + @Order(0) + @RegisterExtension + static SelfSignedCertificateExtension clientCertDefault = new SelfSignedCertificateExtension(); + + @Order(1) + @RegisterExtension + static ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair("foo.com", serverCertFoo.tlsKeyPair()) + .keyPair("bar.com", serverCertBar.tlsKeyPair()) + .keyPair(serverCertDefault.tlsKeyPair()) + .trustedCertificates("foo.com", clientCertFoo.certificate()) + .trustedCertificates("bar.com", clientCertBar.certificate()) + .trustedCertificates(clientCertDefault.certificate()) + .build(); + final ServerTlsConfig tlsConfig = ServerTlsConfig.builder() + .clientAuth(ClientAuth.REQUIRE) + .build(); + sb.https(0); + sb.tlsProvider(tlsProvider, tlsConfig); + sb.service("/", (ctx, req) -> HttpResponse.of(HttpStatus.OK)); + } + }; + + @RegisterExtension + static ServerExtension fooServer = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + final TlsProvider tlsProvider = TlsProvider.builder() + .keyPair("foo.com", serverCertFoo.tlsKeyPair()) + .trustedCertificates("foo.com", + clientCertFoo.certificate()) + .build(); + final ServerTlsConfig tlsConfig = ServerTlsConfig.builder() + .clientAuth(ClientAuth.REQUIRE) + .build(); + sb.https(0); + sb.tlsProvider(tlsProvider, tlsConfig); + sb.service("/", (ctx, req) -> HttpResponse.of(HttpStatus.OK)); + } + }; + + @RegisterExtension + static ServerExtension barServer = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair("bar.com", serverCertBar.tlsKeyPair()) + .trustedCertificates("bar.com", clientCertBar.certificate()) + .build(); + final ServerTlsConfig tlsConfig = ServerTlsConfig.builder() + .clientAuth(ClientAuth.REQUIRE) + .build(); + sb.https(0); + sb.tlsProvider(tlsProvider, tlsConfig); + sb.service("/", (ctx, req) -> HttpResponse.of(HttpStatus.OK)); + } + }; + + @Test + void complexUsage() { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair("foo.com", clientCertFoo.tlsKeyPair()) + .keyPair("bar.com", clientCertBar.tlsKeyPair()) + .keyPair(clientCertDefault.tlsKeyPair()) + .trustedCertificates(serverCertDefault.certificate()) + .trustedCertificates("foo.com", serverCertFoo.certificate()) + .trustedCertificates("bar.com", serverCertBar.certificate()) + .build(); + try (ClientFactory factory = + ClientFactory.builder() + .addressResolverGroupFactory( + unused -> MockAddressResolverGroup.localhost()) + .tlsProvider(tlsProvider) + .build()) { + for (String hostname : ImmutableList.of("foo.com", "bar.com", "127.0.0.1")) { + final BlockingWebClient client = + WebClient.builder("https://" + hostname + ':' + server.httpsPort()) + .factory(factory) + .build() + .blocking(); + assertThat(client.get("/").status()).isEqualTo(HttpStatus.OK); + } + } + } + + @MethodSource("simpleParameters") + @ParameterizedTest + void simpleUsage(String hostname, int port, TlsKeyPair keyPair, X509Certificate trustedCertificate) { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair(hostname, keyPair) + .trustedCertificates(hostname, trustedCertificate) + .build(); + try (ClientFactory factory = + ClientFactory.builder() + .addressResolverGroupFactory( + unused -> MockAddressResolverGroup.localhost()) + .tlsProvider(tlsProvider) + .build()) { + final BlockingWebClient client = + WebClient.builder("https://" + hostname + ':' + port) + .factory(factory) + .build() + .blocking(); + assertThat(client.get("/").status()).isEqualTo(HttpStatus.OK); + } + } + + @Test + void defaultTrustedCertificates() { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair("foo.com", clientCertFoo.tlsKeyPair()) + .trustedCertificates(serverCertFoo.certificate()) + .build(); + try (ClientFactory factory = + ClientFactory.builder() + .addressResolverGroupFactory( + unused -> MockAddressResolverGroup.localhost()) + .tlsProvider(tlsProvider) + .build()) { + final BlockingWebClient client = + WebClient.builder("https://foo.com:" + fooServer.httpsPort()) + .factory(factory) + .build() + .blocking(); + assertThat(client.get("/").status()).isEqualTo(HttpStatus.OK); + } + } + + static Stream simpleParameters() { + return Stream.of( + Arguments.of("foo.com", fooServer.httpsPort(), + clientCertFoo.tlsKeyPair(), serverCertFoo.certificate()), + Arguments.of("bar.com", barServer.httpsPort(), + clientCertBar.tlsKeyPair(), serverCertBar.certificate())); + } + + @Test + void simpleUsage_bar() { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair("bar.com", clientCertBar.tlsKeyPair()) + .trustedCertificates("bar.com", serverCertBar.certificate()) + .build(); + try (ClientFactory factory = + ClientFactory.builder() + .addressResolverGroupFactory( + unused -> MockAddressResolverGroup.localhost()) + .tlsProvider(tlsProvider) + .build()) { + final BlockingWebClient client = + WebClient.builder("https://bar.com:" + barServer.httpsPort()) + .factory(factory) + .build() + .blocking(); + assertThat(client.get("/").status()).isEqualTo(HttpStatus.OK); + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/client/proxy/ProxyClientIntegrationTest.java b/core/src/test/java/com/linecorp/armeria/client/proxy/ProxyClientIntegrationTest.java index 932c15197a1..f1fee902870 100644 --- a/core/src/test/java/com/linecorp/armeria/client/proxy/ProxyClientIntegrationTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/proxy/ProxyClientIntegrationTest.java @@ -63,6 +63,7 @@ import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.HttpStatus; import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.TlsProvider; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.internal.testing.BlockingUtils; import com.linecorp.armeria.internal.testing.NettyServerExtension; @@ -92,6 +93,7 @@ import io.netty.handler.codec.socksx.v4.Socks4CommandStatus; import io.netty.handler.logging.LoggingHandler; import io.netty.handler.proxy.ProxyConnectException; +import io.netty.handler.ssl.ClientAuth; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.traffic.ChannelTrafficShapingHandler; @@ -111,7 +113,15 @@ class ProxyClientIntegrationTest { @RegisterExtension @Order(0) - static final SelfSignedCertificateExtension ssc = new SelfSignedCertificateExtension(); + static final SelfSignedCertificateExtension proxySsc = new SelfSignedCertificateExtension(); + + @RegisterExtension + @Order(0) + static final SelfSignedCertificateExtension backendSsc = new SelfSignedCertificateExtension(); + + @RegisterExtension + @Order(0) + static final SelfSignedCertificateExtension clientSsc = new SelfSignedCertificateExtension(); @RegisterExtension @Order(1) @@ -120,7 +130,7 @@ class ProxyClientIntegrationTest { protected void configure(ServerBuilder sb) throws Exception { sb.port(0, SessionProtocol.HTTP); sb.port(0, SessionProtocol.HTTPS); - sb.tlsSelfSigned(); + sb.tls(backendSsc.tlsKeyPair()); sb.service(PROXY_PATH, (ctx, req) -> HttpResponse.of(SUCCESS_RESPONSE)); } }; @@ -172,7 +182,27 @@ protected void configure(Channel ch) throws Exception { protected void configure(Channel ch) throws Exception { assertThat(sslContext).isNotNull(); final SslContext sslContext = SslContextBuilder - .forServer(ssc.privateKey(), ssc.certificate()).build(); + .forServer(proxySsc.privateKey(), proxySsc.certificate()).build(); + ch.pipeline().addLast(sslContext.newHandler(ch.alloc())); + ch.pipeline().addLast(new HttpServerCodec()); + ch.pipeline().addLast(new HttpObjectAggregator(1024)); + ch.pipeline().addLast(new HttpProxyServerHandler()); + ch.pipeline().addLast(new SleepHandler()); + ch.pipeline().addLast(new IntermediaryProxyServerHandler("http", PROXY_CALLBACK)); + } + }; + + @RegisterExtension + @Order(4) + static NettyServerExtension mTlsHttpsProxyServer = new NettyServerExtension() { + @Override + protected void configure(Channel ch) throws Exception { + assertThat(sslContext).isNotNull(); + final SslContext sslContext = SslContextBuilder + .forServer(proxySsc.privateKey(), proxySsc.certificate()) + .clientAuth(ClientAuth.REQUIRE) + .trustManager(clientSsc.certificate()) + .build(); ch.pipeline().addLast(sslContext.newHandler(ch.alloc())); ch.pipeline().addLast(new HttpServerCodec()); ch.pipeline().addLast(new HttpObjectAggregator(1024)); @@ -205,7 +235,7 @@ protected void configure(Channel ch) throws Exception { @BeforeAll static void beforeAll() throws Exception { sslContext = SslContextBuilder - .forServer(ssc.privateKey(), ssc.certificate()).build(); + .forServer(proxySsc.privateKey(), proxySsc.certificate()).build(); } @BeforeEach @@ -507,6 +537,33 @@ void testHttpsProxy(SessionProtocol protocol, Endpoint endpoint) throws Exceptio clientFactory.closeAsync(); } + @ParameterizedTest + @MethodSource("sessionAndEndpointProvider") + void testMTlsHttpsProxyWithTlsProvider(SessionProtocol protocol, Endpoint endpoint) throws Exception { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair(clientSsc.tlsKeyPair()) + .trustedCertificates(proxySsc.certificate(), backendSsc.certificate()) + .build(); + + final ClientFactory clientFactory = + ClientFactory.builder() + .tlsProvider(tlsProvider) + .proxyConfig( + ProxyConfig.connect(mTlsHttpsProxyServer.address(), true)).build(); + final WebClient webClient = WebClient.builder(protocol, endpoint) + .factory(clientFactory) + .decorator(LoggingClient.newDecorator()) + .build(); + final CompletableFuture responseFuture = + webClient.get(PROXY_PATH).aggregate(); + final AggregatedHttpResponse response = responseFuture.join(); + assertThat(response.status()).isEqualTo(HttpStatus.OK); + assertThat(response.contentUtf8()).isEqualTo(SUCCESS_RESPONSE); + assertThat(numSuccessfulProxyRequests).isEqualTo(1); + clientFactory.closeAsync(); + } + @Test void testProxyWithH2C() throws Exception { final int numRequests = 5; diff --git a/core/src/test/java/com/linecorp/armeria/internal/common/util/KeyStoreUtilTest.java b/core/src/test/java/com/linecorp/armeria/internal/common/util/KeyStoreUtilTest.java index 3202a4f649a..0f4d44f9d81 100644 --- a/core/src/test/java/com/linecorp/armeria/internal/common/util/KeyStoreUtilTest.java +++ b/core/src/test/java/com/linecorp/armeria/internal/common/util/KeyStoreUtilTest.java @@ -25,8 +25,8 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import com.linecorp.armeria.common.TlsKeyPair; import com.linecorp.armeria.common.annotation.Nullable; -import com.linecorp.armeria.internal.common.util.KeyStoreUtil.KeyPair; class KeyStoreUtilTest { // The key store files used in this test case were generated with the following commands: @@ -73,10 +73,10 @@ class KeyStoreUtilTest { void shouldLoadKeyStoreWithOneKeyPair(String filename, @Nullable String keyStorePassword, @Nullable String keyPassword) throws Exception { - final KeyPair keyPair = KeyStoreUtil.load(getFile(filename), - underscoreToNull(keyStorePassword), - underscoreToNull(keyPassword), - null /* no alias */); + final TlsKeyPair keyPair = KeyStoreUtil.load(getFile(filename), + underscoreToNull(keyStorePassword), + underscoreToNull(keyPassword), + null /* no alias */); assertThat(keyPair.certificateChain()).hasSize(1).allSatisfy(cert -> { assertThat(cert.getSubjectX500Principal().getName()).isEqualTo("CN=foo.com"); }); @@ -85,7 +85,7 @@ void shouldLoadKeyStoreWithOneKeyPair(String filename, @ParameterizedTest @CsvSource({"first, foo.com", "second, bar.com"}) void shouldLoadKeyStoreWithTwoKeyPairsIfAliasIsGiven(String alias, String expectedCN) throws Exception { - final KeyPair keyPair = KeyStoreUtil.load(getFile("keystore-two-keys.p12"), + final TlsKeyPair keyPair = KeyStoreUtil.load(getFile("keystore-two-keys.p12"), "my-second-password", null, alias); diff --git a/core/src/test/java/com/linecorp/armeria/server/ServerTlsCertificateMetricsTest.java b/core/src/test/java/com/linecorp/armeria/server/ServerTlsCertificateMetricsTest.java index d78ab0e4002..7a2371976af 100644 --- a/core/src/test/java/com/linecorp/armeria/server/ServerTlsCertificateMetricsTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/ServerTlsCertificateMetricsTest.java @@ -181,7 +181,7 @@ void tlsMetricGivenCertificateChainExpired() { .build(); assertThatGauge(meterRegistry, CERT_VALIDITY_GAUGE_NAME, "localhost").isZero(); - assertThatGauge(meterRegistry, CERT_VALIDITY_DAYS_GAUGE_NAME, "localhost").isEqualTo(-1); + assertThatGauge(meterRegistry, CERT_VALIDITY_DAYS_GAUGE_NAME, "localhost").isLessThanOrEqualTo(-1); assertThatGauge(meterRegistry, CERT_VALIDITY_GAUGE_NAME, "test.root.armeria").isOne(); assertThatGauge(meterRegistry, CERT_VALIDITY_DAYS_GAUGE_NAME, "test.root.armeria").isPositive(); } diff --git a/core/src/test/java/com/linecorp/armeria/server/ServerTlsProviderTest.java b/core/src/test/java/com/linecorp/armeria/server/ServerTlsProviderTest.java new file mode 100644 index 00000000000..e8a5ec1d308 --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/server/ServerTlsProviderTest.java @@ -0,0 +1,191 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import com.google.common.collect.ImmutableList; + +import com.linecorp.armeria.client.BlockingWebClient; +import com.linecorp.armeria.client.ClientFactory; +import com.linecorp.armeria.client.WebClient; +import com.linecorp.armeria.common.HttpResponse; +import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.TlsKeyPair; +import com.linecorp.armeria.common.TlsProvider; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.internal.common.util.CertificateUtil; +import com.linecorp.armeria.internal.testing.MockAddressResolverGroup; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +class ServerTlsProviderTest { + + @RegisterExtension + static final ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + final TlsProvider tlsProvider = + TlsProvider.builder() + .keyPair("*", TlsKeyPair.ofSelfSigned("default")) + .keyPair("example.com", TlsKeyPair.ofSelfSigned("example.com")) + .keyPair("api.example.com", TlsKeyPair.ofSelfSigned("api.example.com")) + .keyPair("*.example.com", TlsKeyPair.ofSelfSigned("*.example.com")) + .build(); + + sb.https(0) + .tlsProvider(tlsProvider) + .service("/", (ctx, req) -> { + final String commonName = CertificateUtil.getCommonName(ctx.sslSession()); + return HttpResponse.of("default:" + commonName); + }) + .virtualHost("api.example.com") + .service("/", (ctx, req) -> { + final String commonName = CertificateUtil.getCommonName(ctx.sslSession()); + return HttpResponse.of("nested:" + commonName); + }) + .and() + .virtualHost("*.example.com") + .service("/", (ctx, req) -> { + final String commonName = CertificateUtil.getCommonName(ctx.sslSession()); + return HttpResponse.of("wild:" + commonName); + }); + } + }; + + @RegisterExtension + static final ServerExtension certRenewableServer = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + sb.tlsProvider(settableTlsProvider); + sb.service("/", (ctx, req) -> { + final String commonName = CertificateUtil.getCommonName(ctx.sslSession()); + return HttpResponse.of(commonName); + }); + } + }; + + private static final SettableTlsProvider settableTlsProvider = new SettableTlsProvider(); + + @BeforeEach + void setUp() { + settableTlsProvider.set(null); + } + + @Test + void testDefault() { + final BlockingWebClient client = WebClient.builder(server.uri(SessionProtocol.HTTPS)) + .factory(ClientFactory.insecure()) + .build() + .blocking(); + assertThat(client.get("/").contentUtf8()).isEqualTo("default:default"); + } + + @CsvSource({ + "example.com, wild:example.com", + "api.example.com, nested:api.example.com", + "foo.example.com, wild:*.example.com", + "example.org, default:default", + "api.example.org, default:default", + "foo.example.org, default:default", + "bar.example.org, default:default", + "baz.bar.example.org, default:default" + }) + @ParameterizedTest + void wildcardMatch(String host, String expected) { + try (ClientFactory factory = ClientFactory.builder() + .tlsNoVerify() + .addressResolverGroupFactory( + unused -> MockAddressResolverGroup.localhost()) + .build()) { + assertThat(WebClient.builder("https://" + host + ':' + server.httpsPort()) + .factory(factory) + .build() + .blocking() + .get("/") + .contentUtf8()).isEqualTo(expected); + } + } + + @Test + void shouldUseNewTlsKeyPair() { + for (String host : ImmutableList.of("foo.com", "bar.com")) { + settableTlsProvider.set(TlsKeyPair.ofSelfSigned(host)); + try (ClientFactory factory = ClientFactory.builder() + .tlsNoVerify() + .addressResolverGroupFactory( + unused -> MockAddressResolverGroup.localhost()) + .build()) { + final BlockingWebClient client = WebClient.builder(certRenewableServer.httpsUri()) + .factory(factory) + .build() + .blocking(); + assertThat(client.get("/").contentUtf8()).isEqualTo(host); + } + } + } + + @Test + void disallowTlsProviderWhenTlsSettingsIsSet() { + assertThatThrownBy(() -> { + Server.builder() + .tls(TlsKeyPair.ofSelfSigned()) + .tlsProvider(TlsProvider.of(TlsKeyPair.ofSelfSigned())) + .build(); + }).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Cannot configure TLS settings with a TlsProvider"); + + assertThatThrownBy(() -> { + Server.builder() + .tlsSelfSigned() + .tlsProvider(TlsProvider.of(TlsKeyPair.ofSelfSigned())) + .build(); + }).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Cannot configure TLS settings with a TlsProvider"); + + assertThatThrownBy(() -> { + Server.builder() + .tlsProvider(TlsProvider.of(TlsKeyPair.ofSelfSigned())) + .virtualHost("example.com") + .tls(TlsKeyPair.ofSelfSigned()) + .and() + .build(); + }).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Cannot configure TLS settings with a TlsProvider"); + } + + private static class SettableTlsProvider implements TlsProvider { + + @Nullable + private volatile TlsKeyPair keyPair; + + @Override + public TlsKeyPair keyPair(String hostname) { + return keyPair; + } + + public void set(@Nullable TlsKeyPair keyPair) { + this.keyPair = keyPair; + } + } +} diff --git a/core/src/test/java/com/linecorp/armeria/server/TlsProviderMappingTest.java b/core/src/test/java/com/linecorp/armeria/server/TlsProviderMappingTest.java new file mode 100644 index 00000000000..5c5ceacaefb --- /dev/null +++ b/core/src/test/java/com/linecorp/armeria/server/TlsProviderMappingTest.java @@ -0,0 +1,75 @@ +/* + * Copyright 2024 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.server; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.junit.jupiter.api.Test; + +import com.linecorp.armeria.common.Flags; +import com.linecorp.armeria.common.TlsKeyPair; +import com.linecorp.armeria.common.TlsProvider; +import com.linecorp.armeria.common.util.TlsEngineType; + +class TlsProviderMappingTest { + + @Test + void testNoDefault() { + final TlsProvider tlsProvider = TlsProvider.builder() + .keyPair("example.com", TlsKeyPair.ofSelfSigned()) + .keyPair("api.example.com", TlsKeyPair.ofSelfSigned()) + .keyPair("foo.com", TlsKeyPair.ofSelfSigned()) + .keyPair("*.foo.com", TlsKeyPair.ofSelfSigned()) + .build(); + final TlsProviderMapping mapping = new TlsProviderMapping(tlsProvider, + TlsEngineType.OPENSSL, + ServerTlsConfig.builder().build(), + Flags.meterRegistry()); + assertThat(mapping.map("example.com")).isNotNull(); + assertThat(mapping.map("api.example.com")).isNotNull(); + assertThatThrownBy(() -> mapping.map("web.example.com")) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("No TLS key pair found for web.example.com"); + assertThat(mapping.map("foo.com")).isNotNull(); + assertThat(mapping.map("bar.foo.com")).isNotNull(); + assertThatThrownBy(() -> mapping.map("baz.bar.foo.com")) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("No TLS key pair found for baz.bar.foo.com"); + } + + @Test + void testWithDefault() { + final TlsProvider tlsProvider = TlsProvider.builder() + .keyPair(TlsKeyPair.ofSelfSigned()) + .keyPair("example.com", TlsKeyPair.ofSelfSigned()) + .keyPair("api.example.com", TlsKeyPair.ofSelfSigned()) + .keyPair("foo.com", TlsKeyPair.ofSelfSigned()) + .keyPair("*.foo.com", TlsKeyPair.ofSelfSigned()) + .build(); + final TlsProviderMapping mapping = new TlsProviderMapping(tlsProvider, + TlsEngineType.OPENSSL, + ServerTlsConfig.builder().build(), + Flags.meterRegistry()); + assertThat(mapping.map("example.com")).isNotNull(); + assertThat(mapping.map("api.example.com")).isNotNull(); + assertThat(mapping.map("web.example.com")).isNotNull(); + assertThat(mapping.map("foo.com")).isNotNull(); + assertThat(mapping.map("bar.foo.com")).isNotNull(); + assertThat(mapping.map("baz.bar.foo.com")).isNotNull(); + } +} diff --git a/core/src/test/java/com/linecorp/armeria/server/VirtualHostAnnotatedServiceBindingBuilderTest.java b/core/src/test/java/com/linecorp/armeria/server/VirtualHostAnnotatedServiceBindingBuilderTest.java index 5d8e4a8cfdf..8cc6e3f8b1d 100644 --- a/core/src/test/java/com/linecorp/armeria/server/VirtualHostAnnotatedServiceBindingBuilderTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/VirtualHostAnnotatedServiceBindingBuilderTest.java @@ -107,7 +107,7 @@ void testAllConfigsAreSet() { .multipartUploadsLocation(multipartUploadsLocation) .requestIdGenerator(serviceRequestIdGenerator) .build(new TestService()) - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); assertThat(virtualHost.serviceConfigs()).hasSize(2); final ServiceConfig pathBar = virtualHost.serviceConfigs().get(0); diff --git a/core/src/test/java/com/linecorp/armeria/server/VirtualHostBuilderTest.java b/core/src/test/java/com/linecorp/armeria/server/VirtualHostBuilderTest.java index c211ad6a886..055b51b8bbc 100644 --- a/core/src/test/java/com/linecorp/armeria/server/VirtualHostBuilderTest.java +++ b/core/src/test/java/com/linecorp/armeria/server/VirtualHostBuilderTest.java @@ -168,7 +168,7 @@ void virtualHostWithoutPattern() { Server.builder() .virtualHost("foo.com") .defaultHostname("foo.com") - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); assertThat(h.hostnamePattern()).isEqualTo("foo.com"); assertThat(h.defaultHostname()).isEqualTo("foo.com"); } @@ -178,7 +178,7 @@ void virtualHostWithPattern() { final VirtualHost h = Server.builder().virtualHost("*.foo.com") .defaultHostname("bar.foo.com") - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); assertThat(h.hostnamePattern()).isEqualTo("*.foo.com"); assertThat(h.defaultHostname()).isEqualTo("bar.foo.com"); } @@ -189,14 +189,14 @@ void accessLoggerCustomization() { Server.builder().virtualHost("*.foo.com") .defaultHostname("bar.foo.com") .accessLogger(host -> LoggerFactory.getLogger("customize.test")) - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); assertThat(h1.accessLogger().getName()).isEqualTo("customize.test"); final VirtualHost h2 = Server.builder().virtualHost("*.foo.com") .defaultHostname("bar.foo.com") .accessLogger(LoggerFactory.getLogger("com.foo.test")) - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); assertThat(h2.accessLogger().getName()).isEqualTo("com.foo.test"); } @@ -258,13 +258,13 @@ void tlsAllowUnsafeCiphersCustomization(String templateTlsAllowUnsafeCiphers, switch (expectedOutcome) { case "success": virtualHostBuilder.build(serverBuilder.virtualHostTemplate, noopDependencyInjector, - null, ServerErrorHandler.ofDefault()); + null, ServerErrorHandler.ofDefault(), null); break; case "failure": assertThatThrownBy(() -> virtualHostBuilder.build(serverBuilder.virtualHostTemplate, noopDependencyInjector, null, - ServerErrorHandler.ofDefault())) + ServerErrorHandler.ofDefault(), null)) .isInstanceOf(IllegalStateException.class) .hasMessageContaining("TLS with a bad cipher suite"); break; @@ -304,7 +304,7 @@ void virtualHostWithMismatch() { assertThatThrownBy(() -> { Server.builder().virtualHost("foo.com") .defaultHostname("bar.com") - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); }).isInstanceOf(IllegalArgumentException.class); } @@ -313,7 +313,7 @@ void virtualHostWithMismatch2() { assertThatThrownBy(() -> { Server.builder().virtualHost("*.foo.com") .defaultHostname("bar.com") - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); }).isInstanceOf(IllegalArgumentException.class); } @@ -327,7 +327,7 @@ void precedenceOfDuplicateRoute() throws Exception { final VirtualHost virtualHost = new VirtualHostBuilder(Server.builder(), true) .service(routeA, (ctx, req) -> HttpResponse.of(200)) .service(routeB, (ctx, req) -> HttpResponse.of(201)) - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); assertThat(virtualHost.serviceConfigs().size()).isEqualTo(2); final RoutingContext routingContext = new DefaultRoutingContext(virtualHost(), "example.com", RequestHeaders.of(HttpMethod.GET, "/"), @@ -343,11 +343,11 @@ void multipartUploadsLocationCustomization() { final Path multipartUploadsLocation = FileSystems.getDefault().getPath("logs", "access.log"); final VirtualHost h1 = new VirtualHostBuilder(Server.builder(), false) .multipartUploadsLocation(multipartUploadsLocation) - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); assertThat(h1.multipartUploadsLocation()).isEqualTo(multipartUploadsLocation); final VirtualHost h2 = new VirtualHostBuilder(Server.builder(), false) - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); assertThat(h2.multipartUploadsLocation()).isEqualTo(template.multipartUploadsLocation()); } @@ -356,11 +356,11 @@ void defaultLogNameCustomization() { final String defaultLogName = "test"; final VirtualHost h1 = new VirtualHostBuilder(Server.builder(), false) .defaultLogName(defaultLogName) - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); assertThat(h1.defaultLogName()).isEqualTo(defaultLogName); final VirtualHost h2 = new VirtualHostBuilder(Server.builder(), false) - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); assertThat(h2.defaultLogName()).isEqualTo(template.defaultLogName()); } @@ -369,11 +369,11 @@ void successFunctionCustomization() { final SuccessFunction successFunction = (ctx, log) -> false; final VirtualHost h1 = new VirtualHostBuilder(Server.builder(), false) .successFunction(successFunction) - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); assertThat(h1.successFunction()).isEqualTo(successFunction); final VirtualHost h2 = new VirtualHostBuilder(Server.builder(), false) - .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault()); + .build(template, noopDependencyInjector, null, ServerErrorHandler.ofDefault(), null); assertThat(h2.successFunction()).isEqualTo(template.successFunction()); } } diff --git a/it/builders/src/test/java/com/linecorp/armeria/OverriddenBuilderMethodsReturnTypeTest.java b/it/builders/src/test/java/com/linecorp/armeria/OverriddenBuilderMethodsReturnTypeTest.java index 2c2f76a95e4..5307cf68669 100644 --- a/it/builders/src/test/java/com/linecorp/armeria/OverriddenBuilderMethodsReturnTypeTest.java +++ b/it/builders/src/test/java/com/linecorp/armeria/OverriddenBuilderMethodsReturnTypeTest.java @@ -45,56 +45,59 @@ class OverriddenBuilderMethodsReturnTypeTest { @Test void methodChaining() { - final Set excludedClasses = ImmutableSet.of("JsonLogFormatterBuilder", - "TextLogFormatterBuilder", - "PathStreamMessageBuilder", - "InputStreamStreamMessageBuilder", - "ContextPathAnnotatedServiceConfigSetters", - "ContextPathDecoratingBindingBuilder", - "ContextPathServiceBindingBuilder", - "ContextPathServicesBuilder", - "DecoratingServiceBindingBuilder", - "ServerBuilder", - "ServiceBindingBuilder", - "AnnotatedServiceBindingBuilder", - "VirtualHostAnnotatedServiceBindingBuilder", - "VirtualHostBuilder", - "VirtualHostContextPathDecoratingBindingBuilder", - "VirtualHostContextPathServiceBindingBuilder", - "VirtualHostContextPathServicesBuilder", - "VirtualHostDecoratingServiceBindingBuilder", - "VirtualHostServiceBindingBuilder", - "ChainedCorsPolicyBuilder", - "CorsPolicyBuilder", - "ConsulEndpointGroupBuilde", - "AbstractDnsResolverBuilder", - "AbstractRuleBuilder", - "AbstractRuleWithContentBuilder", - "DnsResolverGroupBuilder", - "AbstractCircuitBreakerMappingBuilder", - "CircuitBreakerMappingBuilder", - "CircuitBreakerRuleBuilder", - "CircuitBreakerRuleWithContentBuilder", - "AbstractDynamicEndpointGroupBuilder", - "DynamicEndpointGroupBuilder", - "DynamicEndpointGroupSetters", - "DnsAddressEndpointGroupBuilder", - "DnsEndpointGroupBuilder", - "DnsServiceEndpointGroupBuilder", - "DnsTextEndpointGroupBuilder", - "AbstractHealthCheckedEndpointGroupBuilder", - "HealthCheckedEndpointGroupBuilder", - "RetryRuleBuilder", - "RetryRuleWithContentBuilder", - "AbstractHeadersSanitizerBuilder", - "JsonHeadersSanitizerBuilder", - "TextHeadersSanitizerBuilder", - "EurekaEndpointGroupBuilder", - "KubernetesEndpointGroupBuilder", - "Resilience4jCircuitBreakerMappingBuilder", - "ZooKeeperEndpointGroupBuilder", - "AbstractCuratorFrameworkBuilder", - "ZooKeeperUpdatingListenerBuilder"); + final Set excludedClasses = ImmutableSet.of( + "AbstractCircuitBreakerMappingBuilder", + "AbstractCuratorFrameworkBuilder", + "AbstractDnsResolverBuilder", + "AbstractDynamicEndpointGroupBuilder", + "AbstractHeadersSanitizerBuilder", + "AbstractHealthCheckedEndpointGroupBuilder", + "AbstractRuleBuilder", + "AbstractRuleWithContentBuilder", + "AnnotatedServiceBindingBuilder", + "ChainedCorsPolicyBuilder", + "CircuitBreakerMappingBuilder", + "CircuitBreakerRuleBuilder", + "CircuitBreakerRuleWithContentBuilder", + "ClientTlsConfigBuilder", + "ConsulEndpointGroupBuilder", + "ContextPathAnnotatedServiceConfigSetters", + "ContextPathDecoratingBindingBuilder", + "ContextPathServiceBindingBuilder", + "ContextPathServicesBuilder", + "CorsPolicyBuilder", + "DecoratingServiceBindingBuilder", + "DnsAddressEndpointGroupBuilder", + "DnsEndpointGroupBuilder", + "DnsResolverGroupBuilder", + "DnsServiceEndpointGroupBuilder", + "DnsTextEndpointGroupBuilder", + "DynamicEndpointGroupBuilder", + "DynamicEndpointGroupSetters", + "EurekaEndpointGroupBuilder", + "HealthCheckedEndpointGroupBuilder", + "InputStreamStreamMessageBuilder", + "JsonHeadersSanitizerBuilder", + "JsonLogFormatterBuilder", + "KubernetesEndpointGroupBuilder", + "PathStreamMessageBuilder", + "Resilience4jCircuitBreakerMappingBuilder", + "RetryRuleBuilder", + "RetryRuleWithContentBuilder", + "ServerBuilder", + "ServerTlsConfigBuilder", + "ServiceBindingBuilder", + "TextHeadersSanitizerBuilder", + "TextLogFormatterBuilder", + "VirtualHostAnnotatedServiceBindingBuilder", + "VirtualHostBuilder", + "VirtualHostContextPathDecoratingBindingBuilder", + "VirtualHostContextPathServiceBindingBuilder", + "VirtualHostContextPathServicesBuilder", + "VirtualHostDecoratingServiceBindingBuilder", + "VirtualHostServiceBindingBuilder", + "ZooKeeperEndpointGroupBuilder", + "ZooKeeperUpdatingListenerBuilder"); final String packageName = "com.linecorp.armeria"; findAllClasses(packageName).stream() .map(ReflectionUtils::forName) diff --git a/junit5/src/main/java/com/linecorp/armeria/internal/testing/SelfSignedCertificateRuleDelegate.java b/junit5/src/main/java/com/linecorp/armeria/internal/testing/SelfSignedCertificateRuleDelegate.java index 56ecf114893..9393d208f18 100644 --- a/junit5/src/main/java/com/linecorp/armeria/internal/testing/SelfSignedCertificateRuleDelegate.java +++ b/junit5/src/main/java/com/linecorp/armeria/internal/testing/SelfSignedCertificateRuleDelegate.java @@ -28,6 +28,7 @@ import java.time.temporal.TemporalAccessor; import java.util.Date; +import com.linecorp.armeria.common.TlsKeyPair; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.internal.common.util.SelfSignedCertificate; @@ -50,6 +51,9 @@ public final class SelfSignedCertificateRuleDelegate { @Nullable private SelfSignedCertificate certificate; + @Nullable + private TlsKeyPair tlsKeyPair; + /** * Creates a new instance. */ @@ -205,6 +209,16 @@ public File privateKeyFile() { return ensureCertificate().privateKey(); } + /** + * Returns the {@link TlsKeyPair} of the self-signed certificate. + */ + public TlsKeyPair tlsKeyPair() { + if (tlsKeyPair == null) { + tlsKeyPair = TlsKeyPair.of(privateKey(), certificate()); + } + return tlsKeyPair; + } + private SelfSignedCertificate ensureCertificate() { checkState(certificate != null, "certificate not created"); return certificate; diff --git a/junit5/src/main/java/com/linecorp/armeria/testing/junit5/server/SelfSignedCertificateExtension.java b/junit5/src/main/java/com/linecorp/armeria/testing/junit5/server/SelfSignedCertificateExtension.java index e155fd62e82..d64c3e94f70 100644 --- a/junit5/src/main/java/com/linecorp/armeria/testing/junit5/server/SelfSignedCertificateExtension.java +++ b/junit5/src/main/java/com/linecorp/armeria/testing/junit5/server/SelfSignedCertificateExtension.java @@ -26,6 +26,8 @@ import org.junit.jupiter.api.extension.Extension; import org.junit.jupiter.api.extension.ExtensionContext; +import com.linecorp.armeria.common.TlsKeyPair; +import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.internal.testing.SelfSignedCertificateRuleDelegate; import com.linecorp.armeria.testing.junit5.common.AbstractAllOrEachExtension; @@ -144,4 +146,12 @@ public PrivateKey privateKey() { public File privateKeyFile() { return delegate.privateKeyFile(); } + + /** + * Returns the {@link TlsKeyPair} of the self-signed certificate. + */ + @UnstableApi + public TlsKeyPair tlsKeyPair() { + return delegate.tlsKeyPair(); + } }