Skip to content

Commit

Permalink
Support for mTLS in client
Browse files Browse the repository at this point in the history
Fixes #10
  • Loading branch information
dsyer committed Nov 5, 2024
1 parent 148e1d1 commit 05e6240
Show file tree
Hide file tree
Showing 15 changed files with 257 additions and 248 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import static org.assertj.core.api.Assertions.assertThat;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledOnOs;
Expand Down Expand Up @@ -122,12 +121,13 @@ void clientChannelWithSsl(@Autowired GrpcChannelFactory channels) {

@Nested
@SpringBootTest(properties = { "spring.grpc.server.port=0", "spring.grpc.server.ssl.client-auth=REQUIRE",
"spring.grpc.server.ssl.secure=false",
"spring.grpc.client.channels.test-channel.address=static://0.0.0.0:${local.grpc.port}",
"spring.grpc.client.channels.test-channel.ssl.bundle=ssltest",
"spring.grpc.client.channels.test-channel.negotiation-type=TLS",
"spring.grpc.client.channels.test-channel.secure=false" })
@ActiveProfiles("ssl")
@DirtiesContext
@Disabled("Requires client certificate")
class ServerWithClientAuth {

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.grpc.client;

import io.grpc.ChannelCredentials;
import io.grpc.InsecureChannelCredentials;

/**
* A provider for obtaining channel credentials for gRPC client.
*/
public interface ChannelCredentialsProvider {

static final ChannelCredentialsProvider INSECURE = path -> InsecureChannelCredentials.create();

ChannelCredentials getChannelCredentials(String path);

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ public class DefaultGrpcChannelFactory implements GrpcChannelFactory, Disposable

private final List<GrpcChannelConfigurer> configurers = new ArrayList<>();

private ChannelCredentialsProvider credentials = ChannelCredentialsProvider.INSECURE;

private VirtualTargets targets = VirtualTargets.DEFAULT;

public DefaultGrpcChannelFactory() {
Expand All @@ -51,10 +53,15 @@ public void setVirtualTargets(VirtualTargets targets) {
this.targets = targets;
}

public void setCredentialsProvider(ChannelCredentialsProvider credentials) {
this.credentials = credentials;
}

@Override
public ManagedChannelBuilder<?> createChannel(String authority) {
ManagedChannelBuilder<?> target = builders.computeIfAbsent(authority, path -> {
ManagedChannelBuilder<?> builder = newChannel(targets.getTarget(path));
ManagedChannelBuilder<?> builder = newChannel(targets.getTarget(path),
credentials.getChannelCredentials(path));
for (GrpcChannelConfigurer configurer : configurers) {
configurer.configure(path, builder);
}
Expand All @@ -64,12 +71,8 @@ public ManagedChannelBuilder<?> createChannel(String authority) {

}

protected ChannelCredentials channelCredentials(String path) {
return InsecureChannelCredentials.create();
}

protected ManagedChannelBuilder<?> newChannel(String path) {
return Grpc.newChannelBuilder(path, channelCredentials(path));
protected ManagedChannelBuilder<?> newChannel(String path, ChannelCredentials creds) {
return Grpc.newChannelBuilder(path, creds);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,10 @@

import java.util.List;

import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.NettyChannelBuilder;

public class NettyGrpcChannelFactory extends DefaultGrpcChannelFactory {

public NettyGrpcChannelFactory(List<GrpcChannelConfigurer> configurers) {
super(configurers);
}

protected ManagedChannelBuilder<?> newChannel(String path) {
if (path.startsWith("unix:")) {
return super.newChannel(path);
}
return NettyChannelBuilder.forTarget(path);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,10 @@

import java.util.List;

import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;

public class ShadedNettyGrpcChannelFactory extends DefaultGrpcChannelFactory {

public ShadedNettyGrpcChannelFactory(List<GrpcChannelConfigurer> configurers) {
super(configurers);
}

protected ManagedChannelBuilder<?> newChannel(String path) {
if (path.startsWith("unix:")) {
return super.newChannel(path);
}
return NettyChannelBuilder.forTarget(path);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
|spring.grpc.server.port | `+++9090+++` | Server port to listen on. When the value is 0, a random available port is selected. The default is 9090.
|spring.grpc.server.shutdown-grace-period | `+++30s+++` | Maximum time to wait for the server to gracefully shutdown. When the value is negative, the server waits forever. When the value is 0, the server will force shutdown immediately. The default is 30 seconds.
|spring.grpc.server.ssl.bundle | | SSL bundle name.
|spring.grpc.server.ssl.client-auth | | Client authentication mode.
|spring.grpc.server.ssl.enabled | | Whether to enable SSL support. Enabled automatically if "bundle" is provided unless specified otherwise.
|spring.grpc.server.ssl.secure | `+++true+++` | Flag to indicate that client authentication is secure (i.e. certificates are checked). Do not set this to false in production.

|===
Original file line number Diff line number Diff line change
Expand Up @@ -15,133 +15,39 @@
*/
package org.springframework.grpc.autoconfigure.client;

import java.util.List;

import javax.net.ssl.SSLException;

import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.ssl.SslBundles;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.grpc.autoconfigure.client.GrpcClientProperties.NamedChannel;
import org.springframework.grpc.client.DefaultGrpcChannelFactory;
import org.springframework.grpc.client.GrpcChannelConfigurer;
import org.springframework.grpc.client.GrpcChannelFactory;
import org.springframework.grpc.client.NegotiationType;
import org.springframework.grpc.client.NettyGrpcChannelFactory;
import org.springframework.grpc.client.ShadedNettyGrpcChannelFactory;
import org.springframework.grpc.client.VirtualTargets;
import org.springframework.grpc.client.ChannelCredentialsProvider;

import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;

public class GrpcChannelFactoryConfigurations {

@Configuration(proxyBeanMethods = false)
@ConditionalOnClass(io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder.class)
@ConditionalOnMissingBean(GrpcChannelFactory.class)
@ConditionalOnMissingBean(ChannelCredentialsProvider.class)
public static class ShadedNettyChannelFactoryConfiguration {

@Bean
public DefaultGrpcChannelFactory defaultGrpcChannelFactory(final List<GrpcChannelConfigurer> configurers,
GrpcClientProperties channels) {
DefaultGrpcChannelFactory factory = new ShadedNettyGrpcChannelFactory(configurers);
factory.setVirtualTargets(new NamedChannelVirtualTargets(channels));
return factory;
}

@Bean
public GrpcChannelConfigurer secureChannelConfigurer(GrpcClientProperties channels) {

return (authority, input) -> {
NamedChannel channel = channels.getChannel(authority);
if (!authority.startsWith("unix:")
&& input instanceof io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder builder) {
builder.negotiationType(of(channel.getNegotiationType()));
try {
if (!channel.isSecure()) {
builder.sslContext(io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts.forClient()
.trustManager(
io.grpc.netty.shaded.io.netty.handler.ssl.util.InsecureTrustManagerFactory.INSTANCE)
.build());
}
}
catch (SSLException e) {
throw new IllegalStateException("Failed to create SSL context", e);
}
}
};

}

private static io.grpc.netty.shaded.io.grpc.netty.NegotiationType of(final NegotiationType negotiationType) {
return switch (negotiationType) {
case PLAINTEXT -> io.grpc.netty.shaded.io.grpc.netty.NegotiationType.PLAINTEXT;
case PLAINTEXT_UPGRADE -> io.grpc.netty.shaded.io.grpc.netty.NegotiationType.PLAINTEXT_UPGRADE;
case TLS -> io.grpc.netty.shaded.io.grpc.netty.NegotiationType.TLS;
};
public ChannelCredentialsProvider channelCredentialsProvider(GrpcClientProperties channels,
SslBundles bundles) {
return new ShadedNettyChannelCredentialsProvider(bundles, channels);
}

}

@Configuration(proxyBeanMethods = false)
@ConditionalOnClass(NettyChannelBuilder.class)
@ConditionalOnMissingBean(GrpcChannelFactory.class)
@ConditionalOnMissingBean(ChannelCredentialsProvider.class)
public static class NettyChannelFactoryConfiguration {

@Bean
public DefaultGrpcChannelFactory defaultGrpcChannelFactory(final List<GrpcChannelConfigurer> configurers,
GrpcClientProperties channels) {
DefaultGrpcChannelFactory factory = new NettyGrpcChannelFactory(configurers);
factory.setVirtualTargets(new NamedChannelVirtualTargets(channels));
return factory;
}

@Bean
public GrpcChannelConfigurer secureChannelConfigurer(GrpcClientProperties channels) {

return (authority, input) -> {
NamedChannel channel = channels.getChannel(authority);
if (!authority.startsWith("unix:") && input instanceof NettyChannelBuilder builder) {
builder.negotiationType(of(channel.getNegotiationType()));
try {
if (!channel.isSecure()) {
builder.sslContext(GrpcSslContexts.forClient()
.trustManager(InsecureTrustManagerFactory.INSTANCE)
.build());
}
}
catch (SSLException e) {
throw new IllegalStateException("Failed to create SSL context", e);
}
}
};

}

private static io.grpc.netty.NegotiationType of(final NegotiationType negotiationType) {
return switch (negotiationType) {
case PLAINTEXT -> io.grpc.netty.NegotiationType.PLAINTEXT;
case PLAINTEXT_UPGRADE -> io.grpc.netty.NegotiationType.PLAINTEXT_UPGRADE;
case TLS -> io.grpc.netty.NegotiationType.TLS;
};
}

}

static class NamedChannelVirtualTargets implements VirtualTargets {

private final GrpcClientProperties channels;

NamedChannelVirtualTargets(GrpcClientProperties channels) {
this.channels = channels;
}

@Override
public String getTarget(String authority) {
NamedChannel channel = this.channels.getChannel(authority);
return channels.getTarget(channel.getAddress());
public ChannelCredentialsProvider channelCredentialsProvider(GrpcClientProperties channels,
SslBundles bundles) {
return new NettyChannelCredentialsProvider(bundles, channels);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,26 @@
*/
package org.springframework.grpc.autoconfigure.client;

import java.util.List;
import java.util.concurrent.TimeUnit;

import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.ssl.SslBundle;
import org.springframework.boot.ssl.SslBundles;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.grpc.autoconfigure.client.GrpcClientProperties.NamedChannel;
import org.springframework.grpc.autoconfigure.common.codec.GrpcCodecConfiguration;
import org.springframework.grpc.client.ChannelCredentialsProvider;
import org.springframework.grpc.client.DefaultGrpcChannelFactory;
import org.springframework.grpc.client.GrpcChannelConfigurer;
import org.springframework.grpc.client.GrpcChannelFactory;
import org.springframework.grpc.client.VirtualTargets;

import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;

@Configuration(proxyBeanMethods = false)
@EnableConfigurationProperties(GrpcClientProperties.class)
Expand All @@ -37,26 +43,21 @@
public class GrpcClientAutoConfiguration {

@Bean
public GrpcChannelConfigurer sslGrpcChannelConfigurer(GrpcClientProperties channels, SslBundles bundles) {
@ConditionalOnMissingBean(GrpcChannelFactory.class)
public DefaultGrpcChannelFactory defaultGrpcChannelFactory(final List<GrpcChannelConfigurer> configurers,
ChannelCredentialsProvider credentials, GrpcClientProperties channels, SslBundles bundles) {
DefaultGrpcChannelFactory factory = new DefaultGrpcChannelFactory(configurers);
factory.setCredentialsProvider(credentials);
factory.setVirtualTargets(new NamedChannelVirtualTargets(channels));
return factory;
}

@Bean
public GrpcChannelConfigurer sslGrpcChannelConfigurer(GrpcClientProperties channels) {
return (authority, builder) -> {
for (String name : channels.getChannels().keySet()) {
if (authority.equals(name)) {
NamedChannel channel = channels.getChannels().get(name);
if (channel.getSsl().isEnabled() && channel.getSsl().getBundle() != null) {
SslBundle bundle = bundles.getBundle(channel.getSsl().getBundle());
if (NettyChannelFactoryHelper.isAvailable()) {
NettyChannelFactoryHelper.sslContext(builder, bundle);
}
else if (ShadedNettyChannelFactoryHelper.isAvailable()) {
ShadedNettyChannelFactoryHelper.sslContext(builder, bundle);
}
else {
throw new IllegalStateException("Netty is not available");
}
}
else {
// builder.usePlaintext();
}
if (channel.getUserAgent() != null) {
builder.userAgent(channel.getUserAgent());
}
Expand Down Expand Up @@ -96,4 +97,20 @@ GrpcChannelConfigurer decompressionClientConfigurer(DecompressorRegistry registr
return (name, builder) -> builder.decompressorRegistry(registry);
}

static class NamedChannelVirtualTargets implements VirtualTargets {

private final GrpcClientProperties channels;

NamedChannelVirtualTargets(GrpcClientProperties channels) {
this.channels = channels;
}

@Override
public String getTarget(String authority) {
NamedChannel channel = this.channels.getChannel(authority);
return channels.getTarget(channel.getAddress());
}

}

}
Loading

0 comments on commit 05e6240

Please sign in to comment.