Skip to content

Commit

Permalink
Add client auth enum to ssl configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
dsyer committed Nov 5, 2024
1 parent d24e80a commit 148e1d1
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import static org.assertj.core.api.Assertions.assertThat;

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledOnOs;
Expand Down Expand Up @@ -119,6 +120,23 @@ void clientChannelWithSsl(@Autowired GrpcChannelFactory channels) {

}

@Nested
@SpringBootTest(properties = { "spring.grpc.server.port=0", "spring.grpc.server.ssl.client-auth=REQUIRE",
"spring.grpc.client.channels.test-channel.address=static://0.0.0.0:${local.grpc.port}",
"spring.grpc.client.channels.test-channel.negotiation-type=TLS",
"spring.grpc.client.channels.test-channel.secure=false" })
@ActiveProfiles("ssl")
@DirtiesContext
@Disabled("Requires client certificate")
class ServerWithClientAuth {

@Test
void clientChannelWithSsl(@Autowired GrpcChannelFactory channels) {
assertThatResponseIsServedToChannel(channels.createChannel("test-channel").build());
}

}

private void assertThatResponseIsServedToChannel(ManagedChannel clientChannel) {
SimpleGrpc.SimpleBlockingStub client = SimpleGrpc.newBlockingStub(clientChannel);
HelloReply response = client.sayHello(HelloRequest.newBuilder().setName("Alien").build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
import java.util.Objects;
import java.util.Set;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.TrustManagerFactory;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.grpc.internal.GrpcUtils;

import com.google.common.collect.Lists;

import io.grpc.Grpc;
Expand All @@ -30,10 +37,9 @@
import io.grpc.ServerCredentials;
import io.grpc.ServerProvider;
import io.grpc.ServerServiceDefinition;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.grpc.internal.GrpcUtils;
import io.grpc.TlsServerCredentials;
import io.grpc.TlsServerCredentials.Builder;
import io.grpc.TlsServerCredentials.ClientAuth;

/**
* Default implementation for {@link GrpcServerFactory gRPC service factories}.
Expand All @@ -56,11 +62,25 @@ public class DefaultGrpcServerFactory<T extends ServerBuilder<T>> implements Grp

private final List<ServerBuilderCustomizer<T>> serverBuilderCustomizers;

private KeyManagerFactory keyManager;

private TrustManagerFactory trustManager;

private ClientAuth clientAuth;

public DefaultGrpcServerFactory(String address, List<ServerBuilderCustomizer<T>> serverBuilderCustomizers) {
this.address = address;
this.serverBuilderCustomizers = Objects.requireNonNull(serverBuilderCustomizers, "serverBuilderCustomizers");
}

public DefaultGrpcServerFactory(String address, List<ServerBuilderCustomizer<T>> serverBuilderCustomizers,
KeyManagerFactory keyManager, TrustManagerFactory trustManager, ClientAuth clientAuth) {
this(address, serverBuilderCustomizers);
this.keyManager = keyManager;
this.trustManager = trustManager;
this.clientAuth = clientAuth;
}

protected String address() {
return this.address;
}
Expand Down Expand Up @@ -99,7 +119,17 @@ protected int port() {
* @return some server credentials (default is insecure)
*/
protected ServerCredentials credentials() {
return InsecureServerCredentials.create();
if (this.keyManager == null || port() == -1) {
return InsecureServerCredentials.create();
}
Builder builder = TlsServerCredentials.newBuilder().keyManager(this.keyManager.getKeyManagers());
if (this.trustManager != null) {
builder.trustManager(this.trustManager.getTrustManagers());
}
if (this.clientAuth != null) {
builder.clientAuth(this.clientAuth);
}
return builder.build();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
import java.util.List;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.TrustManagerFactory;

import io.grpc.ServerCredentials;
import io.grpc.TlsServerCredentials;
import io.grpc.TlsServerCredentials.Builder;
import io.grpc.TlsServerCredentials.ClientAuth;
import io.grpc.netty.NettyServerBuilder;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerDomainSocketChannel;
Expand All @@ -35,12 +38,10 @@
*/
public class NettyGrpcServerFactory extends DefaultGrpcServerFactory<NettyServerBuilder> {

private KeyManagerFactory keyManager;

public NettyGrpcServerFactory(String address, KeyManagerFactory keyManager,
List<ServerBuilderCustomizer<NettyServerBuilder>> serverBuilderCustomizers) {
super(address, serverBuilderCustomizers);
this.keyManager = keyManager;
public NettyGrpcServerFactory(String address,
List<ServerBuilderCustomizer<NettyServerBuilder>> serverBuilderCustomizers, KeyManagerFactory keyManager,
TrustManagerFactory trustManager, ClientAuth clientAuth) {
super(address, serverBuilderCustomizers, keyManager, trustManager, clientAuth);
}

@Override
Expand All @@ -56,12 +57,4 @@ protected NettyServerBuilder newServerBuilder() {
return super.newServerBuilder();
}

@Override
protected ServerCredentials credentials() {
if (this.keyManager == null || port() == -1) {
return super.credentials();
}
return TlsServerCredentials.newBuilder().keyManager(this.keyManager.getKeyManagers()).build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import java.util.List;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.TrustManagerFactory;

import io.grpc.ServerCredentials;
import io.grpc.TlsServerCredentials;
import io.grpc.TlsServerCredentials.ClientAuth;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.grpc.netty.shaded.io.netty.channel.epoll.EpollEventLoopGroup;
import io.grpc.netty.shaded.io.netty.channel.epoll.EpollServerDomainSocketChannel;
Expand All @@ -35,12 +35,10 @@
*/
public class ShadedNettyGrpcServerFactory extends DefaultGrpcServerFactory<NettyServerBuilder> {

private KeyManagerFactory keyManager;

public ShadedNettyGrpcServerFactory(String address, KeyManagerFactory keyManager,
List<ServerBuilderCustomizer<NettyServerBuilder>> serverBuilderCustomizers) {
super(address, serverBuilderCustomizers);
this.keyManager = keyManager;
public ShadedNettyGrpcServerFactory(String address,
List<ServerBuilderCustomizer<NettyServerBuilder>> serverBuilderCustomizers, KeyManagerFactory keyManager,
TrustManagerFactory trustManager, ClientAuth clientAuth) {
super(address, serverBuilderCustomizers, keyManager, trustManager, clientAuth);
}

@Override
Expand All @@ -56,12 +54,4 @@ protected NettyServerBuilder newServerBuilder() {
return super.newServerBuilder();
}

@Override
protected ServerCredentials credentials() {
if (this.keyManager == null || port() == -1) {
return super.credentials();
}
return TlsServerCredentials.newBuilder().keyManager(this.keyManager.getKeyManagers()).build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.List;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.TrustManagerFactory;

import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
Expand Down Expand Up @@ -56,12 +57,14 @@ ShadedNettyGrpcServerFactory shadedNettyGrpcServerFactory(GrpcServerProperties p
List<ServerBuilderCustomizer<io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder>> builderCustomizers = List
.of(mapper::customizeServerBuilder, serverBuilderCustomizers::customize);
KeyManagerFactory keyManager = null;
TrustManagerFactory trustManager = null;
if (properties.getSsl().isEnabled()) {
SslBundle bundle = bundles.getBundle(properties.getSsl().getBundle());
keyManager = bundle.getManagers().getKeyManagerFactory();
trustManager = bundle.getManagers().getTrustManagerFactory();
}
ShadedNettyGrpcServerFactory factory = new ShadedNettyGrpcServerFactory(properties.getAddress(), keyManager,
builderCustomizers);
ShadedNettyGrpcServerFactory factory = new ShadedNettyGrpcServerFactory(properties.getAddress(),
builderCustomizers, keyManager, trustManager, properties.getSsl().getClientAuth());
grpcServicesDiscoverer.findServices().forEach(factory::addService);
return factory;
}
Expand All @@ -82,12 +85,14 @@ NettyGrpcServerFactory nettyGrpcServerFactory(GrpcServerProperties properties,
List<ServerBuilderCustomizer<NettyServerBuilder>> builderCustomizers = List
.of(mapper::customizeServerBuilder, serverBuilderCustomizers::customize);
KeyManagerFactory keyManager = null;
TrustManagerFactory trustManager = null;
if (properties.getSsl().isEnabled()) {
SslBundle bundle = bundles.getBundle(properties.getSsl().getBundle());
keyManager = bundle.getManagers().getKeyManagerFactory();
trustManager = bundle.getManagers().getTrustManagerFactory();
}
NettyGrpcServerFactory factory = new NettyGrpcServerFactory(properties.getAddress(), keyManager,
builderCustomizers);
NettyGrpcServerFactory factory = new NettyGrpcServerFactory(properties.getAddress(), builderCustomizers,
keyManager, trustManager, properties.getSsl().getClientAuth());
grpcServicesDiscoverer.findServices().forEach(factory::addService);
return factory;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.springframework.util.unit.DataSize;
import org.springframework.util.unit.DataUnit;

import io.grpc.TlsServerCredentials.ClientAuth;

@ConfigurationProperties(prefix = "spring.grpc.server")
public class GrpcServerProperties {

Expand Down Expand Up @@ -253,6 +255,11 @@ public static class Ssl {
*/
private Boolean enabled;

/**
* Client authentication mode.
*/
private ClientAuth clientAuth = ClientAuth.NONE;

/**
* SSL bundle name.
*/
Expand Down Expand Up @@ -284,6 +291,14 @@ public void setBundle(String bundle) {
this.bundle = bundle;
}

public void setClientAuth(ClientAuth clientAuth) {
this.clientAuth = clientAuth;
}

public ClientAuth getClientAuth() {
return clientAuth;
}

}

}

0 comments on commit 148e1d1

Please sign in to comment.