diff --git a/grpc-gcp/build.gradle b/grpc-gcp/build.gradle index 367379e4..32e37965 100644 --- a/grpc-gcp/build.gradle +++ b/grpc-gcp/build.gradle @@ -34,6 +34,7 @@ dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" implementation "io.grpc:grpc-stub:${grpcVersion}" implementation "io.opencensus:opencensus-api:${opencensusVersion}" + implementation "com.google.api:api-common:2.1.5" compileOnly "org.apache.tomcat:annotations-api:6.0.53" // necessary for Java 9+ diff --git a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java index 64c7cca2..4354cce7 100644 --- a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java +++ b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannel.java @@ -25,6 +25,7 @@ import com.google.cloud.grpc.proto.ApiConfig; import com.google.cloud.grpc.proto.MethodConfig; import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.ThreadFactoryBuilder; import com.google.common.base.Joiner; import com.google.errorprone.annotations.concurrent.GuardedBy; import com.google.protobuf.Descriptors.FieldDescriptor; @@ -48,11 +49,13 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedList; import java.util.List; import java.util.LongSummaryStatistics; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -92,6 +95,10 @@ public class GcpManagedChannel extends ManagedChannel { @VisibleForTesting final List channelRefs = new CopyOnWriteArrayList<>(); + private final ExecutorService stateNotificationExecutor = Executors.newCachedThreadPool( + new ThreadFactoryBuilder().setNameFormat("gcp-mc-state-notifications-%d").build()); + private List stateChangeCallbacks = Collections.synchronizedList(new LinkedList<>()); + // Metrics configuration. private MetricRegistry metricRegistry; private final List labelKeys = new ArrayList<>(); @@ -872,6 +879,15 @@ private void recordUnresponsiveDetection(long nanos, long dropCount) { } } + @Override + public void notifyWhenStateChanged(ConnectivityState source, Runnable callback) { + if (!getState(false).equals(source)) { + stateNotificationExecutor.execute(callback); + return; + } + stateChangeCallbacks.add(callback); + } + /** * ChannelStateMonitor subscribes to channel's state changes and informs {@link GcpManagedChannel} * on any new state. This monitor allows to detect when a channel is not ready and temporarily @@ -919,7 +935,14 @@ public void run() { } } + private synchronized void executeStateChangeCallbacks() { + List callbacksToTrigger = stateChangeCallbacks; + stateChangeCallbacks = new LinkedList<>(); + callbacksToTrigger.forEach(stateNotificationExecutor::execute); + } + void processChannelStateChange(int channelId, ConnectivityState state) { + executeStateChangeCallbacks(); if (!fallbackEnabled) { return; } @@ -967,10 +990,12 @@ protected ChannelRef getChannelRefForBind() { ChannelRef channelRef; if (options.getChannelPoolOptions() != null && options.getChannelPoolOptions().isUseRoundRobinOnBind()) { channelRef = getChannelRefRoundRobin(); + logger.finest(log( + "Channel %d picked for bind operation using round-robin.", channelRef.getId())); } else { channelRef = getChannelRef(null); + logger.finest(log("Channel %d picked for bind operation.", channelRef.getId())); } - logger.finest(log("Channel %d picked for bind operation.", channelRef.getId())); return channelRef; } @@ -1061,6 +1086,35 @@ private synchronized ChannelRef createNewChannel() { return channelRef; } + // Returns first newly created channel or null if there are already some channels in the pool. + @Nullable + private ChannelRef createFirstChannel() { + if (!channelRefs.isEmpty()) { + return null; + } + synchronized (this) { + if (channelRefs.isEmpty()) { + return createNewChannel(); + } + } + return null; + } + + // Creates new channel if maxSize is not reached. + // Returns new channel or null. + @Nullable + private ChannelRef tryCreateNewChannel() { + if (channelRefs.size() >= maxSize) { + return null; + } + synchronized (this) { + if (channelRefs.size() < maxSize) { + return createNewChannel(); + } + } + return null; + } + /** * Pick a {@link ChannelRef} (and create a new one if necessary). If notReadyFallbackEnabled is * true in the {@link GcpResiliencyOptions} then instead of a channel in a non-READY state another @@ -1068,8 +1122,9 @@ private synchronized ChannelRef createNewChannel() { * be provided if available. */ private ChannelRef pickLeastBusyChannel(boolean forFallback) { - if (channelRefs.isEmpty()) { - return createNewChannel(); + ChannelRef first = createFirstChannel(); + if (first != null) { + return first; } // Pick the least busy channel and the least busy ready and not overloaded channel (this could @@ -1095,17 +1150,23 @@ private ChannelRef pickLeastBusyChannel(boolean forFallback) { if (!fallbackEnabled) { if (channelRefs.size() < maxSize && minStreams >= maxConcurrentStreamsLowWatermark) { - return createNewChannel(); + ChannelRef newChannel = tryCreateNewChannel(); + if (newChannel != null) { + return newChannel; + } } return channelCandidate; } if (channelRefs.size() < maxSize && readyMinStreams >= maxConcurrentStreamsLowWatermark) { - if (!forFallback && readyCandidate == null) { - logger.finest(log("Fallback to newly created channel")); - fallbacksSucceeded.incrementAndGet(); + ChannelRef newChannel = tryCreateNewChannel(); + if (newChannel != null) { + if (!forFallback && readyCandidate == null) { + logger.finest(log("Fallback to newly created channel %d", newChannel.getId())); + fallbacksSucceeded.incrementAndGet(); + } + return newChannel; } - return createNewChannel(); } if (readyCandidate != null) { @@ -1164,6 +1225,9 @@ public ManagedChannel shutdownNow() { if (logMetricService != null && !logMetricService.isTerminated()) { logMetricService.shutdownNow(); } + if (!stateNotificationExecutor.isTerminated()) { + stateNotificationExecutor.shutdownNow(); + } return this; } @@ -1176,6 +1240,7 @@ public ManagedChannel shutdown() { if (logMetricService != null) { logMetricService.shutdown(); } + stateNotificationExecutor.shutdown(); return this; } @@ -1197,6 +1262,11 @@ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedE //noinspection ResultOfMethodCallIgnored logMetricService.awaitTermination(awaitTimeNanos, NANOSECONDS); } + awaitTimeNanos = endTimeNanos - System.nanoTime(); + if (awaitTimeNanos > 0) { + //noinspection ResultOfMethodCallIgnored + stateNotificationExecutor.awaitTermination(awaitTimeNanos, NANOSECONDS); + } return isTerminated(); } @@ -1210,7 +1280,7 @@ public boolean isShutdown() { if (logMetricService != null) { return logMetricService.isShutdown(); } - return true; + return stateNotificationExecutor.isShutdown(); } @Override @@ -1223,12 +1293,15 @@ public boolean isTerminated() { if (logMetricService != null) { return logMetricService.isTerminated(); } - return true; + return stateNotificationExecutor.isTerminated(); } /** Get the current connectivity state of the channel pool. */ @Override public ConnectivityState getState(boolean requestConnection) { + if (requestConnection && getNumberOfChannels() == 0) { + createFirstChannel(); + } int ready = 0; int idle = 0; int connecting = 0; diff --git a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannelOptions.java b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannelOptions.java index 9f66d721..f6c4c557 100644 --- a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannelOptions.java +++ b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpManagedChannelOptions.java @@ -20,6 +20,7 @@ import io.opencensus.metrics.LabelKey; import io.opencensus.metrics.LabelValue; import io.opencensus.metrics.MetricRegistry; + import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -64,10 +65,10 @@ public GcpResiliencyOptions getResiliencyOptions() { @Override public String toString() { return String.format( - "{channelPoolOptions: %s, metricsOptions: %s, resiliencyOptions: %s}", + "{channelPoolOptions: %s, resiliencyOptions: %s, metricsOptions: %s}", getChannelPoolOptions(), - getMetricsOptions(), - getResiliencyOptions() + getResiliencyOptions(), + getMetricsOptions() ); } @@ -208,8 +209,9 @@ public boolean isUseRoundRobinOnBind() { @Override public String toString() { return String.format( - "{maxSize: %d, concurrentStreamsLowWatermark: %d, useRoundRobinOnBind: %s}", + "{maxSize: %d, minSize: %d, concurrentStreamsLowWatermark: %d, useRoundRobinOnBind: %s}", getMaxSize(), + getMinSize(), getConcurrentStreamsLowWatermark(), isUseRoundRobinOnBind() ); diff --git a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpMultiEndpointChannel.java b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpMultiEndpointChannel.java new file mode 100644 index 00000000..9d9de7e7 --- /dev/null +++ b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpMultiEndpointChannel.java @@ -0,0 +1,427 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.grpc; + +import static java.util.concurrent.TimeUnit.NANOSECONDS; + +import com.google.cloud.grpc.GcpManagedChannelOptions.GcpChannelPoolOptions; +import com.google.cloud.grpc.GcpManagedChannelOptions.GcpMetricsOptions; +import com.google.cloud.grpc.multiendpoint.MultiEndpoint; +import com.google.cloud.grpc.proto.ApiConfig; +import com.google.common.base.Preconditions; +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ClientCall.Listener; +import io.grpc.ConnectivityState; +import io.grpc.Grpc; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.opencensus.metrics.LabelKey; +import io.opencensus.metrics.LabelValue; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; + +/** + * The purpose of GcpMultiEndpointChannel is twofold: + * + *
    + *
  1. Fallback to an alternative endpoint (host:port) of a gRPC service when the original + * endpoint is completely unavailable. + *
  2. Be able to route an RPC call to a specific group of endpoints. + *
+ * + *

A group of endpoints is called a {@link MultiEndpoint} and is essentially a list of endpoints + * where priority is defined by the position in the list with the first endpoint having top + * priority. A MultiEndpoint tracks endpoints' availability. When a MultiEndpoint is picked for an + * RPC call, it picks the top priority endpoint that is currently available. More information on + * the {@link MultiEndpoint} class. + * + *

GcpMultiEndpointChannel can have one or more MultiEndpoint identified by its name -- arbitrary + * string provided in the {@link GcpMultiEndpointOptions} when configuring MultiEndpoints. This name + * can be used to route an RPC call to this MultiEndpoint by setting the {@link #ME_KEY} key value + * of the RPC {@link CallOptions}. + * + *

GcpMultiEndpointChannel receives a list of GcpMultiEndpointOptions for initial configuration. + * An updated configuration can be provided at any time later using + * {@link GcpMultiEndpointChannel#setMultiEndpoints(List)}. The first item in the + * GcpMultiEndpointOptions list defines the default MultiEndpoint that will be used when no + * MultiEndpoint name is provided with an RPC call. + * + *

Example: + * + *

Let's assume we have a service with read and write operations and the following backends: + *

    + *
  • service.example.com -- the main set of backends supporting all operations
  • + *
  • service-fallback.example.com -- read-write replica supporting all operations
  • + *
  • ro-service.example.com -- read-only replica supporting only read operations
  • + *
+ * + *

Example configuration: + *

    + *
  • + * MultiEndpoint named "default" with endpoints: + *
      + *
    1. service.example.com:443
    2. + *
    3. service-fallback.example.com:443
    4. + *
    + *
  • + *
  • + * MultiEndpoint named "read" with endpoints: + *
      + *
    1. ro-service.example.com:443
    2. + *
    3. service-fallback.example.com:443
    4. + *
    5. service.example.com:443
    6. + *
    + *
  • + *
+ * + *

With the configuration above GcpMultiEndpointChannel will use the "default" MultiEndpoint by + * default. It means that RPC calls by default will use the main endpoint and if it is not available + * then the read-write replica. + * + *

To offload some read calls to the read-only replica we can specify "read" MultiEndpoint in + * the CallOptions. Then these calls will use the read-only replica endpoint and if it is not + * available then the read-write replica and if it is also not available then the main endpoint. + * + *

GcpMultiEndpointChannel creates a {@link GcpManagedChannel} channel pool for every unique + * endpoint. For the example above three channel pools will be created. + */ +public class GcpMultiEndpointChannel extends ManagedChannel { + + public static final CallOptions.Key ME_KEY = CallOptions.Key.create("MultiEndpoint"); + private final LabelKey endpointKey = + LabelKey.create("endpoint", "Endpoint address."); + private final Map multiEndpoints = new ConcurrentHashMap<>(); + private MultiEndpoint defaultMultiEndpoint; + private final ApiConfig apiConfig; + private final GcpManagedChannelOptions gcpManagedChannelOptions; + + private final Map pools = new ConcurrentHashMap<>(); + + /** + * Constructor for {@link GcpMultiEndpointChannel}. + * + * @param meOptions list of MultiEndpoint configurations. + * @param apiConfig the ApiConfig object for configuring GcpManagedChannel. + * @param gcpManagedChannelOptions the options for GcpManagedChannel. + */ + public GcpMultiEndpointChannel( + List meOptions, + ApiConfig apiConfig, + GcpManagedChannelOptions gcpManagedChannelOptions) { + this.apiConfig = apiConfig; + this.gcpManagedChannelOptions = gcpManagedChannelOptions; + setMultiEndpoints(meOptions); + } + + private class EndpointStateMonitor implements Runnable { + + private final ManagedChannel channel; + private final String endpoint; + + private EndpointStateMonitor(ManagedChannel channel, String endpoint) { + this.endpoint = endpoint; + this.channel = channel; + run(); + } + + @Override + public void run() { + if (channel == null) { + return; + } + ConnectivityState newState = checkPoolState(channel, endpoint); + if (newState != ConnectivityState.SHUTDOWN) { + channel.notifyWhenStateChanged(newState, this); + } + } + } + + // Checks and returns channel pool state. Also notifies all MultiEndpoints of the pool state. + private ConnectivityState checkPoolState(ManagedChannel channel, String endpoint) { + ConnectivityState state = channel.getState(false); + // Update endpoint state in all multiendpoints. + for (MultiEndpoint me : multiEndpoints.values()) { + me.setEndpointAvailable(endpoint, state.equals(ConnectivityState.READY)); + } + return state; + } + + private GcpManagedChannelOptions prepareGcpManagedChannelConfig( + GcpManagedChannelOptions gcpOptions, String endpoint) { + final GcpMetricsOptions.Builder metricsOptions = GcpMetricsOptions.newBuilder( + gcpOptions.getMetricsOptions() + ); + + final List labelKeys = new ArrayList<>(metricsOptions.build().getLabelKeys()); + final List labelValues = new ArrayList<>(metricsOptions.build().getLabelValues()); + + labelKeys.add(endpointKey); + labelValues.add(LabelValue.create(endpoint)); + + // Make sure the pool will have at least 1 channel always connected. If maximum size > 1 then we + // want at least 2 channels or square root of maximum channels whichever is larger. + // Do not override if minSize is already specified as > 0. + final GcpChannelPoolOptions.Builder poolOptions = GcpChannelPoolOptions.newBuilder( + gcpOptions.getChannelPoolOptions() + ); + if (poolOptions.build().getMinSize() < 1) { + int minSize = Math.min(2, poolOptions.build().getMaxSize()); + minSize = Math.max(minSize, ((int) Math.sqrt(poolOptions.build().getMaxSize()))); + poolOptions.setMinSize(minSize); + } + + return GcpManagedChannelOptions.newBuilder(gcpOptions) + .withChannelPoolOptions(poolOptions.build()) + .withMetricsOptions(metricsOptions.withLabels(labelKeys, labelValues).build()) + .build(); + } + + /** + * Update the list of MultiEndpoint configurations. + * + *

MultiEndpoints are matched with the current ones by name. + *

    + *
  • If a current MultiEndpoint is missing in the updated list, the MultiEndpoint will be + * removed. + *
  • A new MultiEndpoint will be created for every new name in the list. + *
  • For an existing MultiEndpoint only its endpoints will be updated (no recovery timeout + * change). + *
+ * + *

Endpoints are matched by the endpoint address (usually in the form of address:port). + *

    + *
  • If an existing endpoint is not used by any MultiEndpoint in the updated list, then the + * channel poll for this endpoint will be shutdown. + *
  • A channel pool will be created for every new endpoint. + *
  • For an existing endpoint nothing will change (the channel pool will not be re-created, thus + * no channel credentials change, nor channel configurator change). + *
+ */ + public void setMultiEndpoints(List meOptions) { + Preconditions.checkNotNull(meOptions); + Preconditions.checkArgument(!meOptions.isEmpty(), "MultiEndpoints list is empty"); + Set currentMultiEndpoints = new HashSet<>(); + Set currentEndpoints = new HashSet<>(); + + // Must have all multiendpoints before initializing the pools so that all multiendpoints + // can get status update of every pool. + meOptions.forEach(options -> { + currentMultiEndpoints.add(options.getName()); + // Create or update MultiEndpoint + if (multiEndpoints.containsKey(options.getName())) { + multiEndpoints.get(options.getName()).setEndpoints(options.getEndpoints()); + } else { + multiEndpoints.put(options.getName(), + (new MultiEndpoint.Builder(options.getEndpoints())) + .withRecoveryTimeout(options.getRecoveryTimeout()) + .build()); + } + }); + + // TODO: Support the same endpoint in different MultiEndpoint to use different channel + // credentials. + // TODO: Support different endpoints in the same MultiEndpoint to use different channel + // credentials. + meOptions.forEach(options -> { + // Create missing pools + options.getEndpoints().forEach(endpoint -> { + currentEndpoints.add(endpoint); + pools.computeIfAbsent(endpoint, e -> { + ManagedChannelBuilder managedChannelBuilder; + if (options.getChannelCredentials() != null) { + managedChannelBuilder = Grpc.newChannelBuilder(e, options.getChannelCredentials()); + } else { + String serviceAddress; + int port; + int colon = e.lastIndexOf(':'); + if (colon < 0) { + serviceAddress = e; + // Assume https by default. + port = 443; + } else { + serviceAddress = e.substring(0, colon); + port = Integer.parseInt(e.substring(colon + 1)); + } + managedChannelBuilder = ManagedChannelBuilder.forAddress(serviceAddress, port); + } + if (options.getChannelConfigurator() != null) { + managedChannelBuilder = options.getChannelConfigurator().apply(managedChannelBuilder); + } + + GcpManagedChannel channel = new GcpManagedChannel( + managedChannelBuilder, + apiConfig, + // Add endpoint to metric labels. + prepareGcpManagedChannelConfig(gcpManagedChannelOptions, e)); + // Start monitoring the pool state. + new EndpointStateMonitor(channel, e); + return channel; + }); + // Communicate current state to MultiEndpoints. + checkPoolState(pools.get(endpoint), endpoint); + }); + }); + defaultMultiEndpoint = multiEndpoints.get(meOptions.get(0).getName()); + + // Remove obsolete multiendpoints. + multiEndpoints.keySet().removeIf(name -> !currentMultiEndpoints.contains(name)); + + // Shutdown and remove the pools not present in options. + for (String endpoint : pools.keySet()) { + if (!currentEndpoints.contains(endpoint)) { + pools.get(endpoint).shutdown(); + pools.remove(endpoint); + } + } + } + + /** + * Initiates an orderly shutdown in which preexisting calls continue but new calls are immediately + * cancelled. + * + * @return this + * @since 1.0.0 + */ + @Override + public ManagedChannel shutdown() { + pools.values().forEach(GcpManagedChannel::shutdown); + return this; + } + + /** + * Returns whether the channel is shutdown. Shutdown channels immediately cancel any new calls, + * but may still have some calls being processed. + * + * @see #shutdown() + * @see #isTerminated() + * @since 1.0.0 + */ + @Override + public boolean isShutdown() { + return pools.values().stream().allMatch(GcpManagedChannel::isShutdown); + } + + /** + * Returns whether the channel is terminated. Terminated channels have no running calls and + * relevant resources released (like TCP connections). + * + * @see #isShutdown() + * @since 1.0.0 + */ + @Override + public boolean isTerminated() { + return pools.values().stream().allMatch(GcpManagedChannel::isTerminated); + } + + /** + * Initiates a forceful shutdown in which preexisting and new calls are cancelled. Although + * forceful, the shutdown process is still not instantaneous; {@link #isTerminated()} will likely + * return {@code false} immediately after this method returns. + * + * @return this + * @since 1.0.0 + */ + @Override + public ManagedChannel shutdownNow() { + pools.values().forEach(GcpManagedChannel::shutdownNow); + return this; + } + + /** + * Waits for the channel to become terminated, giving up if the timeout is reached. + * + * @return whether the channel is terminated, as would be done by {@link #isTerminated()}. + * @since 1.0.0 + */ + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + long endTimeNanos = System.nanoTime() + unit.toNanos(timeout); + for (GcpManagedChannel gcpManagedChannel : pools.values()) { + if (gcpManagedChannel.isTerminated()) { + continue; + } + long awaitTimeNanos = endTimeNanos - System.nanoTime(); + if (awaitTimeNanos <= 0) { + break; + } + gcpManagedChannel.awaitTermination(awaitTimeNanos, NANOSECONDS); + } + return isTerminated(); + } + + /** + * Check the value of {@link #ME_KEY} key in the {@link CallOptions} and if found use + * the MultiEndpoint with the same name for this call. + * + *

Create a {@link ClientCall} to the remote operation specified by the given {@link + * MethodDescriptor}. The returned {@link ClientCall} does not trigger any remote behavior until + * {@link ClientCall#start(Listener, Metadata)} is invoked. + * + * @param methodDescriptor describes the name and parameter types of the operation to call. + * @param callOptions runtime options to be applied to this call. + * @return a {@link ClientCall} bound to the specified method. + * @since 1.0.0 + */ + @Override + public ClientCall newCall( + MethodDescriptor methodDescriptor, CallOptions callOptions) { + final String multiEndpointKey = callOptions.getOption(ME_KEY); + MultiEndpoint me = defaultMultiEndpoint; + if (multiEndpointKey != null) { + me = multiEndpoints.getOrDefault(multiEndpointKey, defaultMultiEndpoint); + } + return pools.get(me.getCurrentId()).newCall(methodDescriptor, callOptions); + } + + /** + * The authority of the current endpoint of the default MultiEndpoint. Typically, this is in the + * format {@code host:port}. + * + * To get the authority of the current endpoint of another MultiEndpoint use {@link + * #authorityFor(String)} method. + * + * This may return different values over time because MultiEndpoint may switch between endpoints. + * + * @since 1.0.0 + */ + @Override + public String authority() { + return pools.get(defaultMultiEndpoint.getCurrentId()).authority(); + } + + /** + * The authority of the current endpoint of the specified MultiEndpoint. Typically, this is in the + * format {@code host:port}. + * + * This may return different values over time because MultiEndpoint may switch between endpoints. + */ + public String authorityFor(String multiEndpointName) { + MultiEndpoint multiEndpoint = multiEndpoints.get(multiEndpointName); + if (multiEndpoint == null) { + return null; + } + return pools.get(multiEndpoint.getCurrentId()).authority(); + } +} diff --git a/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpMultiEndpointOptions.java b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpMultiEndpointOptions.java new file mode 100644 index 00000000..c1c30bac --- /dev/null +++ b/grpc-gcp/src/main/java/com/google/cloud/grpc/GcpMultiEndpointOptions.java @@ -0,0 +1,170 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.grpc; + +import com.google.api.core.ApiFunction; +import com.google.cloud.grpc.multiendpoint.MultiEndpoint; +import com.google.common.base.Preconditions; +import io.grpc.ChannelCredentials; +import io.grpc.ManagedChannelBuilder; +import java.time.Duration; +import java.util.List; + +/** + * {@link MultiEndpoint} configuration for the {@link GcpMultiEndpointChannel}. + */ +public class GcpMultiEndpointOptions { + + private final String name; + private final List endpoints; + private final ApiFunction, ManagedChannelBuilder> channelConfigurator; + private final ChannelCredentials channelCredentials; + private final Duration recoveryTimeout; + + public static String DEFAULT_NAME = "default"; + + public GcpMultiEndpointOptions(Builder builder) { + this.name = builder.name; + this.endpoints = builder.endpoints; + this.channelConfigurator = builder.channelConfigurator; + this.channelCredentials = builder.channelCredentials; + this.recoveryTimeout = builder.recoveryTimeout; + } + + /** + * Creates a new GcpMultiEndpointOptions.Builder. + * + * @param endpoints list of endpoints for the MultiEndpoint. + */ + public static Builder newBuilder(List endpoints) { + return new Builder(endpoints); + } + + /** + * Creates a new GcpMultiEndpointOptions.Builder from GcpMultiEndpointOptions. + */ + public static Builder newBuilder(GcpMultiEndpointOptions options) { + return new Builder(options); + } + + public String getName() { + return name; + } + + public List getEndpoints() { + return endpoints; + } + + public ApiFunction, ManagedChannelBuilder> getChannelConfigurator() { + return channelConfigurator; + } + + public ChannelCredentials getChannelCredentials() { + return channelCredentials; + } + + public Duration getRecoveryTimeout() { + return recoveryTimeout; + } + + public static class Builder { + + private String name = GcpMultiEndpointOptions.DEFAULT_NAME; + private List endpoints; + private ApiFunction, ManagedChannelBuilder> channelConfigurator; + private ChannelCredentials channelCredentials; + private Duration recoveryTimeout = Duration.ZERO; + + public Builder(List endpoints) { + setEndpoints(endpoints); + } + + public Builder(GcpMultiEndpointOptions options) { + this.name = options.getName(); + this.endpoints = options.getEndpoints(); + this.channelConfigurator = options.getChannelConfigurator(); + this.channelCredentials = options.getChannelCredentials(); + this.recoveryTimeout = options.getRecoveryTimeout(); + } + + public GcpMultiEndpointOptions build() { + return new GcpMultiEndpointOptions(this); + } + + private void setEndpoints(List endpoints) { + Preconditions.checkNotNull(endpoints); + Preconditions.checkArgument( + !endpoints.isEmpty(), "At least one endpoint must be specified."); + Preconditions.checkArgument( + endpoints.stream().noneMatch(s -> s.trim().isEmpty()), "No empty endpoints allowed."); + this.endpoints = endpoints; + } + + /** + * Sets the name of the MultiEndpoint. + * + * @param name MultiEndpoint name. + */ + public GcpMultiEndpointOptions.Builder withName(String name) { + this.name = name; + return this; + } + + /** + * Sets the endpoints of the MultiEndpoint. + * + * @param endpoints List of endpoints in the form of host:port in descending priority order. + */ + public GcpMultiEndpointOptions.Builder withEndpoints(List endpoints) { + this.setEndpoints(endpoints); + return this; + } + + /** + * Sets the channel configurator for the MultiEndpoint channel pool. + * + * @param channelConfigurator function to perform on the ManagedChannelBuilder in the channel + * pool. + */ + public GcpMultiEndpointOptions.Builder withChannelConfigurator( + ApiFunction, ManagedChannelBuilder> channelConfigurator) { + this.channelConfigurator = channelConfigurator; + return this; + } + + /** + * Sets the channel credentials to use in the MultiEndpoint channel pool. + * + * @param channelCredentials channel credentials. + */ + public GcpMultiEndpointOptions.Builder withChannelCredentials( + ChannelCredentials channelCredentials) { + this.channelCredentials = channelCredentials; + return this; + } + + /** + * Sets the recovery timeout for the MultiEndpoint. See more info in the {@link MultiEndpoint}. + * + * @param recoveryTimeout recovery timeout. + */ + public GcpMultiEndpointOptions.Builder withRecoveryTimeout(Duration recoveryTimeout) { + this.recoveryTimeout = recoveryTimeout; + return this; + } + } +} diff --git a/grpc-gcp/src/main/java/com/google/cloud/grpc/multiendpoint/Endpoint.java b/grpc-gcp/src/main/java/com/google/cloud/grpc/multiendpoint/Endpoint.java new file mode 100644 index 00000000..fc6d08eb --- /dev/null +++ b/grpc-gcp/src/main/java/com/google/cloud/grpc/multiendpoint/Endpoint.java @@ -0,0 +1,78 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.grpc.multiendpoint; + +import com.google.errorprone.annotations.CheckReturnValue; +import java.util.concurrent.ScheduledFuture; + +/** + * Endpoint holds an endpoint's state, priority and a future of upcoming state change. + */ +@CheckReturnValue +final class Endpoint { + + /** + * Holds a state of an endpoint. + */ + public enum EndpointState { + UNAVAILABLE, + AVAILABLE, + RECOVERING, + } + + private final String id; + private EndpointState state; + private int priority; + private ScheduledFuture changeStateFuture; + + public Endpoint(String id, EndpointState state, int priority) { + this.id = id; + this.priority = priority; + this.state = state; + } + + public String getId() { + return id; + } + + public EndpointState getState() { + return state; + } + + public int getPriority() { + return priority; + } + + public void setState(EndpointState state) { + this.state = state; + } + + public void setPriority(int priority) { + this.priority = priority; + } + + public synchronized void setChangeStateFuture(ScheduledFuture future) { + resetStateChangeFuture(); + changeStateFuture = future; + } + + public synchronized void resetStateChangeFuture() { + if (changeStateFuture != null) { + changeStateFuture.cancel(true); + } + } +} diff --git a/grpc-gcp/src/main/java/com/google/cloud/grpc/multiendpoint/MultiEndpoint.java b/grpc-gcp/src/main/java/com/google/cloud/grpc/multiendpoint/MultiEndpoint.java new file mode 100644 index 00000000..18b9abfd --- /dev/null +++ b/grpc-gcp/src/main/java/com/google/cloud/grpc/multiendpoint/MultiEndpoint.java @@ -0,0 +1,204 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.grpc.multiendpoint; + +import static java.util.Comparator.comparingInt; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import com.google.cloud.grpc.multiendpoint.Endpoint.EndpointState; +import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.CheckReturnValue; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.ScheduledThreadPoolExecutor; + +/** + * MultiEndpoint holds a list of endpoints, tracks their availability and defines the current + * endpoint. An endpoint has a priority defined by its position in the list (first item has top + * priority). MultiEndpoint returns top priority endpoint that is available as current. If no + * endpoint is available, MultiEndpoint returns the top priority endpoint. + * + *

Sometimes switching between endpoints can be costly, and it is worth waiting for some time + * after current endpoint becomes unavailable. For this case, use {@link + * Builder#withRecoveryTimeout} to set the recovery timeout. MultiEndpoint will keep the current + * endpoint for up to recovery timeout after it became unavailable to give it some time to recover. + * + *

The list of endpoints can be changed at any time with {@link #setEndpoints} method. + * MultiEndpoint will preserve endpoints' state and update their priority according to their new + * positions. + * + *

The initial state of endpoint is "unavailable" or "recovering" if using recovery timeout. + */ +@CheckReturnValue +public final class MultiEndpoint { + @GuardedBy("this") + private final Map endpointsMap = new HashMap<>(); + + @GuardedBy("this") + private String currentId; + + private final Duration recoveryTimeout; + + private final ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(1); + + private MultiEndpoint(Builder builder) { + this.recoveryTimeout = builder.recoveryTimeout; + this.setEndpoints(builder.endpoints); + } + + /** Builder for MultiEndpoint. */ + public static final class Builder { + private final List endpoints; + private Duration recoveryTimeout = Duration.ZERO; + + public Builder(List endpoints) { + Preconditions.checkNotNull(endpoints); + Preconditions.checkArgument(!endpoints.isEmpty(), "Endpoints list must not be empty."); + this.endpoints = endpoints; + } + + /** + * MultiEndpoint will keep the current endpoint for up to recovery timeout after it became + * unavailable to give it some time to recover. + */ + public Builder withRecoveryTimeout(Duration timeout) { + Preconditions.checkNotNull(timeout); + this.recoveryTimeout = timeout; + return this; + } + + public MultiEndpoint build() { + return new MultiEndpoint(this); + } + } + + /** + * Returns current endpoint id. + * + *

Note that the read is not synchronized and in case of a race condition there is a chance of + * getting an outdated current id. + */ + @SuppressWarnings("GuardedBy") + public String getCurrentId() { + return currentId; + } + + private synchronized void setEndpointStateInternal(String endpointId, EndpointState state) { + Endpoint endpoint = endpointsMap.get(endpointId); + if (endpoint != null) { + endpoint.setState(state); + maybeUpdateCurrentEndpoint(); + } + } + + private boolean isRecoveryEnabled() { + return !recoveryTimeout.isNegative() && !recoveryTimeout.isZero(); + } + + /** Inform MultiEndpoint when an endpoint becomes available or unavailable. */ + public synchronized void setEndpointAvailable(String endpointId, boolean available) { + setEndpointState(endpointId, available ? EndpointState.AVAILABLE : EndpointState.UNAVAILABLE); + } + + private synchronized void setEndpointState(String endpointId, EndpointState state) { + Preconditions.checkNotNull(state); + Endpoint endpoint = endpointsMap.get(endpointId); + if (endpoint == null) { + return; + } + // If we allow some recovery time. + if (EndpointState.UNAVAILABLE.equals(state) && isRecoveryEnabled()) { + endpoint.setState(EndpointState.RECOVERING); + ScheduledFuture future = + executor.schedule( + () -> setEndpointStateInternal(endpointId, EndpointState.UNAVAILABLE), + recoveryTimeout.toMillis(), + MILLISECONDS); + endpoint.setChangeStateFuture(future); + return; + } + endpoint.resetStateChangeFuture(); + endpoint.setState(state); + maybeUpdateCurrentEndpoint(); + } + + /** + * Provide an updated list of endpoints to MultiEndpoint. + * + *

MultiEndpoint will preserve current endpoints' state and update their priority according to + * their new positions. + */ + public synchronized void setEndpoints(List endpoints) { + Preconditions.checkNotNull(endpoints); + Preconditions.checkArgument(!endpoints.isEmpty(), "Endpoints list must not be empty."); + + // Remove obsolete endpoints. + endpointsMap.keySet().retainAll(endpoints); + + // Add new endpoints and update priority. + int priority = 0; + for (String endpointId : endpoints) { + Endpoint existingEndpoint = endpointsMap.get(endpointId); + if (existingEndpoint != null) { + existingEndpoint.setPriority(priority++); + continue; + } + EndpointState newState = + isRecoveryEnabled() ? EndpointState.RECOVERING : EndpointState.UNAVAILABLE; + Endpoint newEndpoint = new Endpoint(endpointId, newState, priority++); + if (isRecoveryEnabled()) { + ScheduledFuture future = + executor.schedule( + () -> setEndpointStateInternal(endpointId, EndpointState.UNAVAILABLE), + recoveryTimeout.toMillis(), + MILLISECONDS); + newEndpoint.setChangeStateFuture(future); + } + endpointsMap.put(endpointId, newEndpoint); + } + + maybeUpdateCurrentEndpoint(); + } + + // Updates currentId to the top-priority available endpoint unless the current endpoint is + // recovering. + private synchronized void maybeUpdateCurrentEndpoint() { + Optional topEndpoint = + endpointsMap.values().stream() + .filter((c) -> c.getState().equals(EndpointState.AVAILABLE)) + .min(comparingInt(Endpoint::getPriority)); + + Endpoint current = endpointsMap.get(currentId); + if (current != null && current.getState().equals(EndpointState.RECOVERING)) { + // Keep recovering endpoint as current unless a higher priority endpoint became available. + if (!topEndpoint.isPresent() || topEndpoint.get().getPriority() >= current.getPriority()) { + return; + } + } + + if (!topEndpoint.isPresent() && current == null) { + topEndpoint = endpointsMap.values().stream().min(comparingInt(Endpoint::getPriority)); + } + + topEndpoint.ifPresent(endpoint -> currentId = endpoint.getId()); + } +} diff --git a/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpManagedChannelTest.java b/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpManagedChannelTest.java index 417ced1e..9af0e88e 100644 --- a/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpManagedChannelTest.java +++ b/grpc-gcp/src/test/java/com/google/cloud/grpc/GcpManagedChannelTest.java @@ -55,7 +55,9 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Handler; import java.util.logging.Level; import java.util.logging.LogRecord; @@ -100,7 +102,7 @@ private Level lastLogLevel(int nthFromLast) { private final Handler testLogHandler = new Handler() { @Override - public void publish(LogRecord record) { + public synchronized void publish(LogRecord record) { logRecords.add(record); } @@ -377,9 +379,10 @@ public void testGetChannelRefWithFallback() { assertEquals(2, pool.getNumberOfChannels()); // This was a fallback from non-ready channel 0 to the newly created channel 1. assertThat(logRecords.size()).isEqualTo(logCount + 3); - assertThat(lastLogMessage(3)).isEqualTo( - poolIndex + ": Fallback to newly created channel"); - assertThat(lastLogLevel(3)).isEqualTo(Level.FINEST); + logRecords.forEach(logRecord -> System.out.println(logRecord.getMessage())); + assertThat(lastLogMessage()).isEqualTo( + poolIndex + ": Fallback to newly created channel 1"); + assertThat(lastLogLevel()).isEqualTo(Level.FINEST); assertFallbacksMetric(fakeRegistry, 1, 0); // Adding one active stream to channel 1. @@ -1225,6 +1228,114 @@ public void testUnresponsiveDetection() throws InterruptedException { poolIndex + ": stat: " + GcpMetricsConstants.METRIC_NUM_UNRESPONSIVE_DETECTIONS + " = 0"); } + @Test + public void testStateNotifications() throws InterruptedException { + final AtomicBoolean immediateCallbackCalled = new AtomicBoolean(); + // Test callback is called when state doesn't match. + gcpChannel.notifyWhenStateChanged(ConnectivityState.SHUTDOWN, () -> + immediateCallbackCalled.set(true)); + + TimeUnit.MILLISECONDS.sleep(1); + + assertThat(immediateCallbackCalled.get()).isTrue(); + + // Subscribe for notification when leaving IDLE state. + final AtomicReference newState = new AtomicReference<>(); + + final Runnable callback = new Runnable() { + @Override + public void run() { + ConnectivityState state = gcpChannel.getState(false); + newState.set(state); + if (state.equals(ConnectivityState.IDLE)) { + gcpChannel.notifyWhenStateChanged(ConnectivityState.IDLE, this); + } + } + }; + + gcpChannel.notifyWhenStateChanged(ConnectivityState.IDLE, callback); + + // Init connection to move out of the IDLE state. + ConnectivityState currentState = gcpChannel.getState(true); + // Make sure it was IDLE; + assertThat(currentState).isEqualTo(ConnectivityState.IDLE); + + TimeUnit.MILLISECONDS.sleep(5); + + assertThat(newState.get()) + .isAnyOf(ConnectivityState.CONNECTING, ConnectivityState.TRANSIENT_FAILURE); + } + + @Test + public void testParallelGetChannelRefWontExceedMaxSize() throws InterruptedException { + resetGcpChannel(); + GcpChannelPoolOptions poolOptions = GcpChannelPoolOptions.newBuilder() + .setMaxSize(2) + .setConcurrentStreamsLowWatermark(0) + .build(); + GcpManagedChannelOptions options = GcpManagedChannelOptions.newBuilder() + .withChannelPoolOptions(poolOptions) + .build(); + gcpChannel = + (GcpManagedChannel) + GcpManagedChannelBuilder.forDelegateBuilder(builder) + .withOptions(options) + .build(); + + assertThat(gcpChannel.getNumberOfChannels()).isEqualTo(0); + assertThat(gcpChannel.getStreamsLowWatermark()).isEqualTo(0); + + for (int i = 0; i < gcpChannel.getMaxSize() - 1; i++) { + gcpChannel.getChannelRef(null); + } + + assertThat(gcpChannel.getNumberOfChannels()).isEqualTo(gcpChannel.getMaxSize() - 1); + + Runnable requestChannel = () -> gcpChannel.getChannelRef(null); + + int requestCount = gcpChannel.getMaxSize() * 3; + ExecutorService exec = Executors.newFixedThreadPool(requestCount); + for (int i = 0; i < requestCount; i++) { + exec.execute(requestChannel); + } + exec.shutdown(); + exec.awaitTermination(100, TimeUnit.MILLISECONDS); + + assertThat(gcpChannel.getNumberOfChannels()).isEqualTo(gcpChannel.getMaxSize()); + } + + @Test + public void testParallelGetChannelRefWontExceedMaxSizeFromTheStart() throws InterruptedException { + resetGcpChannel(); + GcpChannelPoolOptions poolOptions = GcpChannelPoolOptions.newBuilder() + .setMaxSize(2) + .setConcurrentStreamsLowWatermark(0) + .build(); + GcpManagedChannelOptions options = GcpManagedChannelOptions.newBuilder() + .withChannelPoolOptions(poolOptions) + .build(); + gcpChannel = + (GcpManagedChannel) + GcpManagedChannelBuilder.forDelegateBuilder(builder) + .withOptions(options) + .build(); + + assertThat(gcpChannel.getNumberOfChannels()).isEqualTo(0); + assertThat(gcpChannel.getStreamsLowWatermark()).isEqualTo(0); + + Runnable requestChannel = () -> gcpChannel.getChannelRef(null); + + int requestCount = gcpChannel.getMaxSize() * 3; + ExecutorService exec = Executors.newFixedThreadPool(requestCount); + for (int i = 0; i < requestCount; i++) { + exec.execute(requestChannel); + } + exec.shutdown(); + exec.awaitTermination(100, TimeUnit.MILLISECONDS); + + assertThat(gcpChannel.getNumberOfChannels()).isEqualTo(gcpChannel.getMaxSize()); + } + static class FakeManagedChannel extends ManagedChannel { private ConnectivityState state = ConnectivityState.IDLE; private Runnable stateCallback; diff --git a/grpc-gcp/src/test/java/com/google/cloud/grpc/SpannerIntegrationTest.java b/grpc-gcp/src/test/java/com/google/cloud/grpc/SpannerIntegrationTest.java index 0d3d3414..780c6f58 100644 --- a/grpc-gcp/src/test/java/com/google/cloud/grpc/SpannerIntegrationTest.java +++ b/grpc-gcp/src/test/java/com/google/cloud/grpc/SpannerIntegrationTest.java @@ -16,12 +16,29 @@ package com.google.cloud.grpc; +import static com.google.cloud.grpc.GcpMultiEndpointChannel.ME_KEY; +import static com.google.cloud.spanner.SpannerOptions.CALL_CONTEXT_CONFIGURATOR_KEY; import static com.google.common.base.Preconditions.checkState; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import com.google.api.core.ApiFunction; +import com.google.api.gax.grpc.GrpcCallContext; +import com.google.api.gax.grpc.GrpcTransportChannel; import com.google.api.gax.longrunning.OperationFuture; +import com.google.api.gax.rpc.ApiCallContext; +import com.google.api.gax.rpc.FixedTransportChannelProvider; +import com.google.api.gax.rpc.TransportChannelProvider; import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.grpc.GcpManagedChannel.ChannelRef; +import com.google.cloud.grpc.GcpManagedChannelOptions.GcpChannelPoolOptions; +import com.google.cloud.grpc.GcpManagedChannelOptions.GcpMetricsOptions; +import com.google.cloud.grpc.MetricRegistryTestUtils.FakeMetricRegistry; +import com.google.cloud.grpc.MetricRegistryTestUtils.MetricsRecord; +import com.google.cloud.grpc.MetricRegistryTestUtils.PointWithFunction; +import com.google.cloud.grpc.proto.ApiConfig; import com.google.cloud.spanner.Database; import com.google.cloud.spanner.DatabaseAdminClient; import com.google.cloud.spanner.DatabaseClient; @@ -32,13 +49,17 @@ import com.google.cloud.spanner.InstanceConfigId; import com.google.cloud.spanner.InstanceId; import com.google.cloud.spanner.InstanceInfo; +import com.google.cloud.spanner.KeySet; import com.google.cloud.spanner.Mutation; import com.google.cloud.spanner.Spanner; import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.SpannerOptions; +import com.google.cloud.spanner.SpannerOptions.CallContextConfigurator; import com.google.common.collect.Iterators; import com.google.common.util.concurrent.ListenableFuture; import com.google.protobuf.Empty; +import com.google.protobuf.util.JsonFormat; +import com.google.protobuf.util.JsonFormat.Parser; import com.google.spanner.admin.database.v1.CreateDatabaseMetadata; import com.google.spanner.admin.instance.v1.CreateInstanceMetadata; import com.google.spanner.v1.BatchCreateSessionsRequest; @@ -62,14 +83,24 @@ import com.google.spanner.v1.SpannerGrpc.SpannerFutureStub; import com.google.spanner.v1.SpannerGrpc.SpannerStub; import com.google.spanner.v1.TransactionOptions; +import com.google.spanner.v1.TransactionOptions.ReadOnly; +import com.google.spanner.v1.TransactionOptions.ReadWrite; import com.google.spanner.v1.TransactionSelector; +import io.grpc.CallOptions; import io.grpc.ConnectivityState; +import io.grpc.Context; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; +import io.grpc.MethodDescriptor; import io.grpc.StatusRuntimeException; import io.grpc.auth.MoreCallCredentials; import io.grpc.stub.StreamObserver; +import io.opencensus.metrics.LabelValue; import java.io.File; +import java.io.IOException; +import java.io.Reader; +import java.nio.file.Files; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -77,13 +108,14 @@ import java.util.List; import java.util.Set; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import java.util.function.Function; import java.util.logging.Handler; import java.util.logging.Level; import java.util.logging.LogRecord; import java.util.logging.Logger; +import java.util.stream.Collectors; +import javax.annotation.Nullable; import org.junit.After; import org.junit.AfterClass; import org.junit.Assume; @@ -114,7 +146,7 @@ public final class SpannerIntegrationTest { private static final int MAX_CHANNEL = 3; private static final int MAX_STREAM = 2; - private static final ManagedChannelBuilder builder = + private static final ManagedChannelBuilder builder = ManagedChannelBuilder.forAddress(SPANNER_TARGET, 443); private GcpManagedChannel gcpChannel; private GcpManagedChannel gcpChannelBRR; @@ -149,7 +181,7 @@ private Level lastLogLevel() { private final Handler testLogHandler = new Handler() { @Override - public void publish(LogRecord record) { + public synchronized void publish(LogRecord record) { logRecords.add(record); } @@ -259,7 +291,7 @@ private SpannerBlockingStub getSpannerBlockingStub() { return stub; } - private static void deleteSession(SpannerGrpc.SpannerBlockingStub stub, Session session) { + private static void deleteSession(SpannerBlockingStub stub, Session session) { if (session != null) { stub.deleteSession(DeleteSessionRequest.newBuilder().setName(session.getName()).build()); } @@ -316,7 +348,7 @@ private List createAsyncSessions(SpannerStub stub) throws Exception { // Check CreateSession with multiple channels and streams, CreateSessionRequest req = CreateSessionRequest.newBuilder().setDatabase(DATABASE_PATH).build(); for (int i = 0; i < MAX_CHANNEL * MAX_STREAM; i++) { - AsyncResponseObserver resp = new AsyncResponseObserver(); + AsyncResponseObserver resp = new AsyncResponseObserver<>(); stub.createSession(req, resp); resps.add(resp); } @@ -337,7 +369,7 @@ private void deleteAsyncSessions(SpannerStub stub, List respNames) throw AsyncResponseObserver resp = new AsyncResponseObserver<>(); stub.deleteSession(DeleteSessionRequest.newBuilder().setName(respName).build(), resp); // The ChannelRef which is bound with the current affinity key. - GcpManagedChannel.ChannelRef currentChannel = + ChannelRef currentChannel = gcpChannel.affinityKeyToChannelRef.get(respName); // Verify the channel is in use. assertEquals(1, currentChannel.getActiveStreamsCount()); @@ -387,7 +419,7 @@ private void deleteFutureSessions(SpannerFutureStub stub, List futureNam ListenableFuture future = stub.deleteSession(DeleteSessionRequest.newBuilder().setName(futureName).build()); // The ChannelRef which is bound with the current affinity key. - GcpManagedChannel.ChannelRef currentChannel = + ChannelRef currentChannel = gcpChannel.affinityKeyToChannelRef.get(futureName); // Verify the channel is in use. assertEquals(1, currentChannel.getActiveStreamsCount()); @@ -416,7 +448,7 @@ public void setupChannels() { .withApiConfigJsonFile(configFile) .withOptions(GcpManagedChannelOptions.newBuilder() .withChannelPoolOptions( - GcpManagedChannelOptions.GcpChannelPoolOptions.newBuilder() + GcpChannelPoolOptions.newBuilder() .setMaxSize(MAX_CHANNEL) .setConcurrentStreamsLowWatermark(MAX_STREAM) .setUseRoundRobinOnBind(true) @@ -429,12 +461,268 @@ public void setupChannels() { public void shutdownChannels() { testLogger.removeHandler(testLogHandler); testLogger.setLevel(Level.INFO); + logRecords.clear(); gcpChannel.shutdownNow(); gcpChannelBRR.shutdownNow(); } + private long getOkCallsCount( + FakeMetricRegistry fakeRegistry, String endpoint) { + MetricsRecord record = fakeRegistry.pollRecord(); + List> metric = + record.getMetrics().get(GcpMetricsConstants.METRIC_NUM_CALLS_COMPLETED); + for (PointWithFunction m : metric) { + assertThat(m.keys().get(0).getKey()).isEqualTo("result"); + assertThat(m.keys().get(1).getKey()).isEqualTo("endpoint"); + if (!m.values().get(0).equals(LabelValue.create(GcpMetricsConstants.RESULT_SUCCESS))) { + continue; + } + if (!m.values().get(1).equals(LabelValue.create(endpoint))) { + continue; + } + return m.value(); + } + fail("Success calls metric is not found for endpoint: " + endpoint); + return 0; + } + + // For this test we'll create a Spanner client with gRPC-GCP MultiEndpoint feature. + // + // Imagine we have a multi-region Spanner instance with leader in the us-east4 and follower in the + // us-east1 regions. + // + // We will provide two MultiEndpoint configs: "leader" (having leader region endpoint first and + // follower second) and "follower" (having follower region endpoint first and leader second). + // + // Then we'll make sure the Spanner client uses leader MultiEndpoint as a default one and creates + // its sessions there. Then we'll make sure a read request will also use the leader MultiEndpoint + // by default. + // + // Then we'll verify we can use the follower MultiEndpoint when needed by specifying that in + // the Spanner context. + // + // Then we'll update MultiEndpoints configuration by replacing the leader endpoint and renaming + // the follower MultiEndpoint. And make sure the new leader endpoint and the previous follower + // endpoint are still working as expected when using different MultiEndpoints. + @Test + public void testSpannerMultiEndpointClient() throws IOException, InterruptedException { + // Watch debug messages. + testLogger.setLevel(Level.FINEST); + + final FakeMetricRegistry fakeRegistry = new FakeMetricRegistry(); + + File configFile = + new File(SpannerIntegrationTest.class.getClassLoader().getResource(API_FILE).getFile()); + + // Leader-first multi-endpoint endpoints. + final List leaderEndpoints = new ArrayList<>(); + // Follower-first multi-endpoint endpoints. + final List followerEndpoints = new ArrayList<>(); + final String leaderEndpoint = "us-east4.googleapis.com:443"; + final String followerEndpoint = "us-east1.googleapis.com:443"; + leaderEndpoints.add(leaderEndpoint); + leaderEndpoints.add(followerEndpoint); + followerEndpoints.add(followerEndpoint); + followerEndpoints.add(leaderEndpoint); + + ApiFunction, ManagedChannelBuilder> configurator = input -> input.overrideAuthority( + SPANNER_TARGET); + + GcpMultiEndpointOptions leaderOpts = GcpMultiEndpointOptions.newBuilder(leaderEndpoints) + .withName("leader") + .withChannelConfigurator(configurator) + .withRecoveryTimeout(Duration.ofSeconds(3)) + .build(); + + GcpMultiEndpointOptions followerOpts = GcpMultiEndpointOptions.newBuilder(followerEndpoints) + .withName("follower") + .withChannelConfigurator(configurator) + .withRecoveryTimeout(Duration.ofSeconds(3)) + .build(); + + List opts = new ArrayList<>(); + opts.add(leaderOpts); + opts.add(followerOpts); + + Parser parser = JsonFormat.parser(); + ApiConfig.Builder apiConfig = ApiConfig.newBuilder(); + Reader reader = Files.newBufferedReader(configFile.toPath(), UTF_8); + parser.merge(reader, apiConfig); + + GcpMultiEndpointChannel gcpMultiEndpointChannel = new GcpMultiEndpointChannel( + opts, + apiConfig.build(), + GcpManagedChannelOptions.newBuilder() + .withChannelPoolOptions(GcpChannelPoolOptions.newBuilder() + .setConcurrentStreamsLowWatermark(0) + .setMaxSize(3) + .build()) + .withMetricsOptions(GcpMetricsOptions.newBuilder() + .withMetricRegistry(fakeRegistry) + .build()) + .build()); + + final int currentIndex = GcpManagedChannel.channelPoolIndex.get(); + final String followerPoolIndex = String.format("pool-%d", currentIndex); + final String leaderPoolIndex = String.format("pool-%d", currentIndex - 1); + + // Make sure authorities are overridden by channel configurator. + assertThat(gcpMultiEndpointChannel.authority()).isEqualTo(SPANNER_TARGET); + assertThat(gcpMultiEndpointChannel.authorityFor("leader")) + .isEqualTo(SPANNER_TARGET); + assertThat(gcpMultiEndpointChannel.authorityFor("follower")) + .isEqualTo(SPANNER_TARGET); + assertThat(gcpMultiEndpointChannel.authorityFor("no-such-name")).isNull(); + + TimeUnit.MILLISECONDS.sleep(200); + + List logMessages = logRecords.stream() + .map(LogRecord::getMessage).collect(Collectors.toList()); + + // Make sure min channels are created and connections are established right away in both pools. + for (String poolIndex : Arrays.asList(leaderPoolIndex, followerPoolIndex)) { + for (int i = 0; i < 2; i++) { + assertThat(logMessages).contains( + poolIndex + ": Channel " + i + " state change detected: null -> IDLE"); + + assertThat(logMessages).contains( + poolIndex + ": Channel " + i + " state change detected: IDLE -> CONNECTING"); + } + } + + // Make sure endpoint is set as a metric label for each pool. + assertThat(logRecords.stream().filter(logRecord -> + logRecord.getMessage().matches( + leaderPoolIndex + ": Metrics options: \\{namePrefix: \"\", labels: \\[endpoint: " + + "\"" + leaderEndpoint + "\"], metricRegistry: .*" + )).count()).isEqualTo(1); + + assertThat(logRecords.stream().filter(logRecord -> + logRecord.getMessage().matches( + followerPoolIndex + ": Metrics options: \\{namePrefix: \"\", labels: \\[endpoint: " + + "\"" + followerEndpoint + "\"], metricRegistry: .*" + )).count()).isEqualTo(1); + + logRecords.clear(); + + TransportChannelProvider channelProvider = FixedTransportChannelProvider.create( + GrpcTransportChannel.create(gcpMultiEndpointChannel)); + + SpannerOptions.Builder options = SpannerOptions.newBuilder().setProjectId(GCP_PROJECT_ID); + + options.setChannelProvider(channelProvider); + // Match channel pool size. + options.setNumChannels(3); + + Spanner spanner = options.build().getService(); + InstanceId instanceId = InstanceId.of(GCP_PROJECT_ID, INSTANCE_ID); + DatabaseId databaseId = DatabaseId.of(instanceId, DB_NAME); + DatabaseClient databaseClient = spanner.getDatabaseClient(databaseId); + + Runnable readQuery = () -> { + try (com.google.cloud.spanner.ResultSet read = databaseClient.singleUse() + .read("Users", KeySet.all(), Arrays.asList("UserId", "UserName"))) { + int readRows = 0; + while (read.next()) { + readRows++; + assertEquals(USERNAME, read.getCurrentRowAsStruct().getString("UserName")); + } + assertEquals(1, readRows); + } + }; + + // Make sure leader endpoint is used by default. + assertThat(getOkCallsCount(fakeRegistry, leaderEndpoint)).isEqualTo(0); + readQuery.run(); + + // Wait for sessions creation requests to be completed but no more than 10 seconds. + for (int i = 0; i < 20; i++) { + TimeUnit.MILLISECONDS.sleep(500); + if (getOkCallsCount(fakeRegistry, leaderEndpoint) == 4) { + break; + } + } + + // 3 session creation requests + 1 our read request to the leader endpoint. + assertThat(getOkCallsCount(fakeRegistry, leaderEndpoint)).isEqualTo(4); + + // Make sure there were 3 session creation requests in the leader pool only. + assertThat(logRecords.stream().filter(logRecord -> + logRecord.getMessage().matches( + leaderPoolIndex + ": Binding \\d+ key\\(s\\) to channel \\d:.*" + )).count()).isEqualTo(3); + + assertThat(logRecords.stream().filter(logRecord -> + logRecord.getMessage().matches( + followerPoolIndex + ": Binding \\d+ key\\(s\\) to channel \\d:.*" + )).count()).isEqualTo(0); + + // Create context for using follower-first multi-endpoint. + Function contextFor = meName -> Context.current() + .withValue(CALL_CONTEXT_CONFIGURATOR_KEY, + new CallContextConfigurator() { + @Nullable + @Override + public ApiCallContext configure(ApiCallContext context, ReqT request, + MethodDescriptor method) { + return context.merge(GrpcCallContext.createDefault().withCallOptions( + CallOptions.DEFAULT.withOption(ME_KEY, meName))); + } + }); + + assertThat(getOkCallsCount(fakeRegistry, followerEndpoint)).isEqualTo(0); + // Use follower, make sure it is used. + contextFor.apply("follower").run(readQuery); + assertThat(getOkCallsCount(fakeRegistry, followerEndpoint)).isEqualTo(1); + + // Replace leader endpoints. + final String newLeaderEndpoint = "us-west1.googleapis.com:443"; + leaderEndpoints.clear(); + leaderEndpoints.add(newLeaderEndpoint); + leaderEndpoints.add(followerEndpoint); + leaderOpts = GcpMultiEndpointOptions.newBuilder(leaderEndpoints) + .withName("leader") + .withChannelConfigurator(configurator) + .build(); + + followerEndpoints.clear(); + followerEndpoints.add(followerEndpoint); + followerEndpoints.add(newLeaderEndpoint); + + // Rename follower MultiEndpoint. + followerOpts = GcpMultiEndpointOptions.newBuilder(followerEndpoints) + .withName("follower-2") + .withChannelConfigurator(configurator) + .build(); + + opts.clear(); + opts.add(leaderOpts); + opts.add(followerOpts); + + gcpMultiEndpointChannel.setMultiEndpoints(opts); + + // As it takes some time to connect to the new leader endpoint, RPC will fall back to the + // follower until we connect to leader. + assertThat(getOkCallsCount(fakeRegistry, followerEndpoint)).isEqualTo(1); + readQuery.run(); + assertThat(getOkCallsCount(fakeRegistry, followerEndpoint)).isEqualTo(2); + + TimeUnit.MILLISECONDS.sleep(500); + + // Make sure the new leader endpoint is used by default after it is connected. + assertThat(getOkCallsCount(fakeRegistry, newLeaderEndpoint)).isEqualTo(0); + readQuery.run(); + assertThat(getOkCallsCount(fakeRegistry, newLeaderEndpoint)).isEqualTo(1); + + // Make sure that the follower endpoint still works if specified. + assertThat(getOkCallsCount(fakeRegistry, followerEndpoint)).isEqualTo(2); + // Use follower, make sure it is used. + contextFor.apply("follower-2").run(readQuery); + assertThat(getOkCallsCount(fakeRegistry, followerEndpoint)).isEqualTo(3); + } + @Test - public void testCreateAndGetSessionBlocking() throws Exception { + public void testCreateAndGetSessionBlocking() { SpannerBlockingStub stub = getSpannerBlockingStub(); CreateSessionRequest req = CreateSessionRequest.newBuilder().setDatabase(DATABASE_PATH).build(); // The first MAX_CHANNEL requests (without affinity) should be distributed 1 per channel. @@ -457,7 +745,7 @@ public void testCreateAndGetSessionBlocking() throws Exception { } @Test - public void testBatchCreateSessionsBlocking() throws Exception { + public void testBatchCreateSessionsBlocking() { int sessionCount = 10; SpannerBlockingStub stub = getSpannerBlockingStub(); BatchCreateSessionsRequest req = @@ -513,7 +801,7 @@ public void testSessionsCreatedUsingRoundRobin() throws Exception { .setSql("select * FROM Users") .build()); // The ChannelRef which is bound with the lastSession. - GcpManagedChannel.ChannelRef currentChannel = + ChannelRef currentChannel = gcpChannelBRR.affinityKeyToChannelRef.get(lastSession); // Verify the channel is in use. assertEquals(1, currentChannel.getActiveStreamsCount()); @@ -581,7 +869,7 @@ public void testSessionsCreatedWithoutRoundRobin() throws Exception { .setSql("select * FROM Users") .build()); // The ChannelRef which is bound with the lastSession. - GcpManagedChannel.ChannelRef currentChannel = + ChannelRef currentChannel = gcpChannel.affinityKeyToChannelRef.get(lastSession); // Verify the channel is in use. assertEquals(1, currentChannel.getActiveStreamsCount()); @@ -638,7 +926,7 @@ public void testExecuteSqlFuture() throws Exception { .setSql("select * FROM Users") .build()); // The ChannelRef which is bound with the current affinity key. - GcpManagedChannel.ChannelRef currentChannel = + ChannelRef currentChannel = gcpChannel.affinityKeyToChannelRef.get(futureName); // Verify the channel is in use. assertEquals(1, currentChannel.getActiveStreamsCount()); @@ -660,7 +948,7 @@ public void testExecuteStreamingSqlAsync() throws Exception { ExecuteSqlRequest.newBuilder().setSession(respName).setSql("select * FROM Users").build(), resp); // The ChannelRef which is bound with the current affinity key. - GcpManagedChannel.ChannelRef currentChannel = + ChannelRef currentChannel = gcpChannel.affinityKeyToChannelRef.get(respName); // Verify the channel is in use. assertEquals(1, currentChannel.getActiveStreamsCount()); @@ -678,7 +966,7 @@ public void testPartitionQueryAsync() throws Exception { for (String respName : respNames) { TransactionOptions options = TransactionOptions.newBuilder() - .setReadOnly(TransactionOptions.ReadOnly.getDefaultInstance()) + .setReadOnly(ReadOnly.getDefaultInstance()) .build(); TransactionSelector selector = TransactionSelector.newBuilder().setBegin(options).build(); AsyncResponseObserver resp = new AsyncResponseObserver<>(); @@ -690,7 +978,7 @@ public void testPartitionQueryAsync() throws Exception { .build(), resp); // The ChannelRef which is bound with the current affinity key. - GcpManagedChannel.ChannelRef currentChannel = + ChannelRef currentChannel = gcpChannel.affinityKeyToChannelRef.get(respName); // Verify the channel is in use. assertEquals(1, currentChannel.getActiveStreamsCount()); @@ -707,7 +995,7 @@ public void testExecuteBatchDmlFuture() throws Exception { for (String futureName : futureNames) { TransactionOptions options = TransactionOptions.newBuilder() - .setReadWrite(TransactionOptions.ReadWrite.getDefaultInstance()) + .setReadWrite(ReadWrite.getDefaultInstance()) .build(); TransactionSelector selector = TransactionSelector.newBuilder().setBegin(options).build(); // Will use only one session for the whole batch. @@ -719,7 +1007,7 @@ public void testExecuteBatchDmlFuture() throws Exception { .addStatements(Statement.newBuilder().setSql("select * FROM Users").build()) .build()); // The ChannelRef which is bound with the current affinity key. - GcpManagedChannel.ChannelRef currentChannel = + ChannelRef currentChannel = gcpChannel.affinityKeyToChannelRef.get(futureName); // Verify the channel is in use. assertEquals(1, currentChannel.getActiveStreamsCount()); @@ -769,7 +1057,7 @@ private static class AsyncResponseObserver implements StreamObserver threeEndpoints = + new ArrayList<>(ImmutableList.of("first", "second", "third")); + + private final List fourEndpoints = + new ArrayList<>(ImmutableList.of("four", "first", "third", "second")); + + private static final long RECOVERY_MS = 1000; + + @Rule + public ExpectedException expectedEx = ExpectedException.none(); + + private MultiEndpoint initPlain(List endpoints) { + return new MultiEndpoint.Builder(endpoints).build(); + } + + private MultiEndpoint initWithRecovery(List endpoints, long recoveryTimeOut) { + return new MultiEndpoint.Builder(endpoints) + .withRecoveryTimeout(Duration.ofMillis(recoveryTimeOut)) + .build(); + } + + @Test + public void initPlain_raisesErrorWhenEmptyEndpoints() { + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage("Endpoints list must not be empty."); + initPlain(ImmutableList.of()); + } + + @Test + public void initWithRecovery_raisesErrorWhenEmptyEndpoints() { + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage("Endpoints list must not be empty."); + initWithRecovery(ImmutableList.of(), RECOVERY_MS); + } + + @Test + public void getCurrent_returnsTopPriorityAvailableEndpointWithoutRecovery() { + MultiEndpoint multiEndpoint = initPlain(threeEndpoints); + + // Returns first after creation. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(0)); + + // Second becomes available. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(1), true); + + // Second is the current as the only available. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(1)); + + // Third becomes available. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(2), true); + + // Second is still the current because it has higher priority. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(1)); + + // First becomes available. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(0), true); + + // First becomes the current because it has higher priority. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(0)); + + // Second becomes unavailable. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(1), false); + + // Second becoming unavailable should not affect the current first. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(0)); + + // First becomes unavailable. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(0), false); + + // Third becomes the current as the only remaining available. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(2)); + + // Third becomes unavailable. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(2), false); + + // After all endpoints became unavailable the multiEndpoint sticks to the last used endpoint. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(2)); + } + + @Test + public void getCurrent_returnsTopPriorityAvailableEndpointWithRecovery() + throws InterruptedException { + MultiEndpoint multiEndpoint = initWithRecovery(threeEndpoints, RECOVERY_MS); + + // Returns first after creation. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(0)); + + // Second becomes available. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(1), true); + + // First is still the current to allow it to become available within recovery timeout. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(0)); + + // After recovery timeout has passed. + Sleeper.DEFAULT.sleep(RECOVERY_MS + 100); + + // Second becomes current as an available endpoint with top priority. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(1)); + + // Third becomes available. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(2), true); + + // Second is still the current because it has higher priority. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(1)); + + // Second becomes unavailable. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(1), false); + + // Second is still current, allowing upto recoveryTimeout to recover. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(1)); + + // Halfway through recovery timeout the second recovers. + Sleeper.DEFAULT.sleep(RECOVERY_MS / 2); + multiEndpoint.setEndpointAvailable(threeEndpoints.get(1), true); + + // Second is the current. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(1)); + + // After the initial recovery timeout, the second is still current. + Sleeper.DEFAULT.sleep(RECOVERY_MS / 2 + 100); + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(1)); + + // Second becomes unavailable. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(1), false); + + // After recovery timeout has passed. + Sleeper.DEFAULT.sleep(RECOVERY_MS + 100); + + // Changes to an available endpoint -- third. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(2)); + + // First becomes available. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(0), true); + + // First becomes current immediately. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(0)); + + // First becomes unavailable. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(0), false); + + // First is still current, allowing upto recoveryTimeout to recover. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(0)); + + // After recovery timeout has passed. + Sleeper.DEFAULT.sleep(RECOVERY_MS + 100); + + // Changes to an available endpoint -- third. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(2)); + + // Third becomes unavailable + multiEndpoint.setEndpointAvailable(threeEndpoints.get(2), false); + + // Third is still current, allowing upto recoveryTimeout to recover. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(2)); + + // Halfway through recovery timeout the second becomes available. + // Sleeper.defaultSleeper().sleep(Duration.ofMillis(RECOVERY_MS - 100)); + multiEndpoint.setEndpointAvailable(threeEndpoints.get(1), true); + + // Second becomes current immediately. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(1)); + + // Second becomes unavailable. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(1), false); + + // Second is still current, allowing upto recoveryTimeout to recover. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(1)); + + // After recovery timeout has passed. + Sleeper.DEFAULT.sleep(RECOVERY_MS + 100); + + // After all endpoints became unavailable the multiEndpoint sticks to the last used endpoint. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(threeEndpoints.get(1)); + } + + @Test + public void setEndpoints_raisesErrorWhenEmptyEndpoints() { + MultiEndpoint multiEndpoint = initPlain(threeEndpoints); + expectedEx.expect(IllegalArgumentException.class); + multiEndpoint.setEndpoints(ImmutableList.of()); + } + + @Test + public void setEndpoints_updatesEndpoints() { + MultiEndpoint multiEndpoint = initPlain(threeEndpoints); + multiEndpoint.setEndpoints(fourEndpoints); + + // "first" which is now under index 1 still current because no other available. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(fourEndpoints.get(1)); + } + + @Test + public void setEndpoints_updatesEndpointsWithRecovery() { + MultiEndpoint multiEndpoint = initWithRecovery(threeEndpoints, RECOVERY_MS); + multiEndpoint.setEndpoints(fourEndpoints); + + // "first" which is now under index 1 still current because no other available. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(fourEndpoints.get(1)); + } + + @Test + public void setEndpoints_updatesEndpointsPreservingStates() { + MultiEndpoint multiEndpoint = initPlain(threeEndpoints); + + // Second is available. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(1), true); + multiEndpoint.setEndpoints(fourEndpoints); + + // "second" which is now under index 3 still must remain available. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(fourEndpoints.get(3)); + } + + @Test + public void setEndpoints_updatesEndpointsPreservingStatesWithRecovery() + throws InterruptedException { + MultiEndpoint multiEndpoint = initWithRecovery(threeEndpoints, RECOVERY_MS); + + // After recovery timeout has passed. + Sleeper.DEFAULT.sleep(RECOVERY_MS + 100); + + // Second is available. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(1), true); + multiEndpoint.setEndpoints(fourEndpoints); + + // "second" which is now under index 3 still must remain available. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(fourEndpoints.get(3)); + } + + @Test + public void setEndpoints_updatesEndpointsSwitchToTopPriorityAvailable() { + MultiEndpoint multiEndpoint = initPlain(threeEndpoints); + + // Second and third is available. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(1), true); + multiEndpoint.setEndpointAvailable(threeEndpoints.get(2), true); + + multiEndpoint.setEndpoints(fourEndpoints); + + // "third" which is now under index 2 must become current, because "second" has lower priority. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(fourEndpoints.get(2)); + } + + @Test + public void setEndpoints_updatesEndpointsSwitchToTopPriorityAvailableWithRecovery() + throws InterruptedException { + MultiEndpoint multiEndpoint = initWithRecovery(threeEndpoints, RECOVERY_MS); + + // After recovery timeout has passed. + Sleeper.DEFAULT.sleep(RECOVERY_MS + 100); + + // Second and third is available. + multiEndpoint.setEndpointAvailable(threeEndpoints.get(1), true); + multiEndpoint.setEndpointAvailable(threeEndpoints.get(2), true); + + multiEndpoint.setEndpoints(fourEndpoints); + + // "third" which is now under index 2 must become current, because "second" has lower priority. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(fourEndpoints.get(2)); + } + + @Test + public void setEndpoints_updatesEndpointsRemovesOnlyActiveEndpoint() { + List extraEndpoints = new ArrayList<>(threeEndpoints); + extraEndpoints.add("extra"); + MultiEndpoint multiEndpoint = initPlain(extraEndpoints); + + // Extra is available. + multiEndpoint.setEndpointAvailable("extra", true); + + // Extra is removed. + multiEndpoint.setEndpoints(fourEndpoints); + + // "four" which is under index 0 must become current, because no endpoints available. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(fourEndpoints.get(0)); + } + + @Test + public void setEndpoints_updatesEndpointsRemovesOnlyActiveEndpointWithRecovery() + throws InterruptedException { + List extraEndpoints = new ArrayList<>(threeEndpoints); + extraEndpoints.add("extra"); + MultiEndpoint multiEndpoint = initWithRecovery(extraEndpoints, RECOVERY_MS); + + // After recovery timeout has passed. + Sleeper.DEFAULT.sleep(RECOVERY_MS + 100); + + // Extra is available. + multiEndpoint.setEndpointAvailable("extra", true); + + // Extra is removed. + multiEndpoint.setEndpoints(fourEndpoints); + + // "four" which is under index 0 must become current, because no endpoints available. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(fourEndpoints.get(0)); + } + + @Test + public void setEndpoints_recoveringEndpointGetsRemoved() throws InterruptedException { + List extraEndpoints = new ArrayList<>(threeEndpoints); + extraEndpoints.add("extra"); + MultiEndpoint multiEndpoint = initWithRecovery(extraEndpoints, RECOVERY_MS); + + // After recovery timeout has passed. + Sleeper.DEFAULT.sleep(RECOVERY_MS + 100); + + // Extra is available. + multiEndpoint.setEndpointAvailable("extra", true); + + // Extra is recovering. + multiEndpoint.setEndpointAvailable("extra", false); + + // Extra is removed. + multiEndpoint.setEndpoints(fourEndpoints); + + // "four" which is under index 0 must become current, because no endpoints available. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(fourEndpoints.get(0)); + + // After recovery timeout has passed. + Sleeper.DEFAULT.sleep(RECOVERY_MS + 100); + + // "four" is still current. + assertThat(multiEndpoint.getCurrentId()).isEqualTo(fourEndpoints.get(0)); + } +}