diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java index 4c3fba4c1..027d1d75d 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/Config.java @@ -15,6 +15,7 @@ public final class Config { static final int DEFAULT_DEADLINE = 500; static final int DEFAULT_STREAM_DEADLINE_MS = 10 * 60 * 1000; + static final int DEFAULT_STREAM_RETRY_GRACE_PERIOD = 5; static final int DEFAULT_MAX_CACHE_SIZE = 1000; static final long DEFAULT_KEEP_ALIVE = 0; @@ -35,6 +36,7 @@ public final class Config { static final String KEEP_ALIVE_MS_ENV_VAR_NAME_OLD = "FLAGD_KEEP_ALIVE_TIME"; static final String KEEP_ALIVE_MS_ENV_VAR_NAME = "FLAGD_KEEP_ALIVE_TIME_MS"; static final String TARGET_URI_ENV_VAR_NAME = "FLAGD_TARGET_URI"; + static final String STREAM_RETRY_GRACE_PERIOD = "FLAGD_RETRY_GRACE_PERIOD"; static final String RESOLVER_RPC = "rpc"; static final String RESOLVER_IN_PROCESS = "in-process"; @@ -52,7 +54,6 @@ public final class Config { public static final String LRU_CACHE = CacheType.LRU.getValue(); static final String DEFAULT_CACHE = LRU_CACHE; - static final int DEFAULT_MAX_EVENT_STREAM_RETRIES = 5; static final int BASE_EVENT_STREAM_RETRY_BACKOFF_MS = 1000; static String fallBackToEnvOrDefault(String key, String defaultValue) { diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java index 4ba459e4d..c98effac2 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java @@ -13,82 +13,111 @@ import lombok.Builder; import lombok.Getter; -/** FlagdOptions is a builder to build flagd provider options. */ +/** + * FlagdOptions is a builder to build flagd provider options. + */ @Builder @Getter @SuppressWarnings("PMD.TooManyStaticImports") public class FlagdOptions { - /** flagd resolving type. */ + /** + * flagd resolving type. + */ private Config.EvaluatorType resolverType; - /** flagd connection host. */ + /** + * flagd connection host. + */ @Builder.Default private String host = fallBackToEnvOrDefault(Config.HOST_ENV_VAR_NAME, Config.DEFAULT_HOST); - /** flagd connection port. */ + /** + * flagd connection port. + */ private int port; - /** Use TLS connectivity. */ + /** + * Use TLS connectivity. + */ @Builder.Default private boolean tls = Boolean.parseBoolean(fallBackToEnvOrDefault(Config.TLS_ENV_VAR_NAME, Config.DEFAULT_TLS)); - /** TLS certificate overriding if TLS connectivity is used. */ + /** + * TLS certificate overriding if TLS connectivity is used. + */ @Builder.Default private String certPath = fallBackToEnvOrDefault(Config.SERVER_CERT_PATH_ENV_VAR_NAME, null); - /** Unix socket path to flagd. */ + /** + * Unix socket path to flagd. + */ @Builder.Default private String socketPath = fallBackToEnvOrDefault(Config.SOCKET_PATH_ENV_VAR_NAME, null); - /** Cache type to use. Supports - lru, disabled. */ + /** + * Cache type to use. Supports - lru, disabled. + */ @Builder.Default private String cacheType = fallBackToEnvOrDefault(Config.CACHE_ENV_VAR_NAME, Config.DEFAULT_CACHE); - /** Max cache size. */ + /** + * Max cache size. + */ @Builder.Default private int maxCacheSize = fallBackToEnvOrDefault(Config.MAX_CACHE_SIZE_ENV_VAR_NAME, Config.DEFAULT_MAX_CACHE_SIZE); - /** Max event stream connection retries. */ - @Builder.Default - private int maxEventStreamRetries = fallBackToEnvOrDefault( - Config.MAX_EVENT_STREAM_RETRIES_ENV_VAR_NAME, Config.DEFAULT_MAX_EVENT_STREAM_RETRIES); - - /** Backoff interval in milliseconds. */ + /** + * Backoff interval in milliseconds. + */ @Builder.Default private int retryBackoffMs = fallBackToEnvOrDefault( Config.BASE_EVENT_STREAM_RETRY_BACKOFF_MS_ENV_VAR_NAME, Config.BASE_EVENT_STREAM_RETRY_BACKOFF_MS); /** - * Connection deadline in milliseconds. For RPC resolving, this is the deadline to connect to - * flagd for flag evaluation. For in-process resolving, this is the deadline for sync stream - * termination. + * Connection deadline in milliseconds. + * For RPC resolving, this is the deadline to connect to flagd for flag + * evaluation. + * For in-process resolving, this is the deadline for sync stream termination. */ @Builder.Default private int deadline = fallBackToEnvOrDefault(Config.DEADLINE_MS_ENV_VAR_NAME, Config.DEFAULT_DEADLINE); /** - * Streaming connection deadline in milliseconds. Set to 0 to disable the deadline. Defaults to - * 600000 (10 minutes); recommended to prevent infrastructure from killing idle connections. + * Streaming connection deadline in milliseconds. + * Set to 0 to disable the deadline. + * Defaults to 600000 (10 minutes); recommended to prevent infrastructure from killing idle connections. */ @Builder.Default private int streamDeadlineMs = fallBackToEnvOrDefault(Config.STREAM_DEADLINE_MS_ENV_VAR_NAME, Config.DEFAULT_STREAM_DEADLINE_MS); - /** Selector to be used with flag sync gRPC contract. */ + /** + * Grace time period in seconds before provider moves from STALE to ERROR. + * Defaults to 5 + */ + @Builder.Default + private int retryGracePeriod = + fallBackToEnvOrDefault(Config.STREAM_RETRY_GRACE_PERIOD, Config.DEFAULT_STREAM_RETRY_GRACE_PERIOD); + /** + * Selector to be used with flag sync gRPC contract. + **/ @Builder.Default private String selector = fallBackToEnvOrDefault(Config.SOURCE_SELECTOR_ENV_VAR_NAME, null); - /** gRPC client KeepAlive in milliseconds. Disabled with 0. Defaults to 0 (disabled). */ + /** + * gRPC client KeepAlive in milliseconds. Disabled with 0. + * Defaults to 0 (disabled). + **/ @Builder.Default private long keepAlive = fallBackToEnvOrDefault( Config.KEEP_ALIVE_MS_ENV_VAR_NAME, fallBackToEnvOrDefault(Config.KEEP_ALIVE_MS_ENV_VAR_NAME_OLD, Config.DEFAULT_KEEP_ALIVE)); /** - * File source of flags to be used by offline mode. Setting this enables the offline mode of the - * in-process provider. + * File source of flags to be used by offline mode. + * Setting this enables the offline mode of the in-process provider. */ @Builder.Default private String offlineFlagSourcePath = fallBackToEnvOrDefault(Config.OFFLINE_SOURCE_PATH, null); @@ -96,31 +125,35 @@ public class FlagdOptions { /** * gRPC custom target string. * - *

Setting this will allow user to use custom gRPC name resolver at present we are supporting - * all core resolver along with a custom resolver for envoy proxy resolution. For more visit - * (https://grpc.io/docs/guides/custom-name-resolution/) + *

Setting this will allow user to use custom gRPC name resolver at present + * we are supporting all core resolver along with a custom resolver for envoy proxy + * resolution. For more visit (https://grpc.io/docs/guides/custom-name-resolution/) */ @Builder.Default private String targetUri = fallBackToEnvOrDefault(Config.TARGET_URI_ENV_VAR_NAME, null); /** - * Function providing an EvaluationContext to mix into every evaluations. The sync-metadata - * response + * Function providing an EvaluationContext to mix into every evaluations. + * The sync-metadata response * (https://buf.build/open-feature/flagd/docs/main:flagd.sync.v1#flagd.sync.v1.GetMetadataResponse), - * represented as a {@link dev.openfeature.sdk.Structure}, is passed as an argument. This function - * runs every time the provider (re)connects, and its result is cached and used in every - * evaluation. By default, the entire sync response (converted to a Structure) is used. + * represented as a {@link dev.openfeature.sdk.Structure}, is passed as an + * argument. + * This function runs every time the provider (re)connects, and its result is cached and used in every evaluation. + * By default, the entire sync response (converted to a Structure) is used. */ @Builder.Default private Function contextEnricher = (syncMetadata) -> new ImmutableContext(syncMetadata.asMap()); - /** Inject a Custom Connector for fetching flags. */ + /** + * Inject a Custom Connector for fetching flags. + */ private Connector customConnector; /** - * Inject OpenTelemetry for the library runtime. Providing sdk will initiate distributed tracing - * for flagd grpc connectivity. + * Inject OpenTelemetry for the library runtime. Providing sdk will initiate + * distributed tracing for flagd grpc + * connectivity. */ private OpenTelemetry openTelemetry; @@ -139,11 +172,14 @@ public FlagdOptions build() { }; } - /** Overload default lombok builder. */ + /** + * Overload default lombok builder. + */ public static class FlagdOptionsBuilder { /** - * Enable OpenTelemetry instance extraction from GlobalOpenTelemetry. Note that, this is only - * useful if global configurations are registered. + * Enable OpenTelemetry instance extraction from GlobalOpenTelemetry. Note that, + * this is only useful if global + * configurations are registered. */ public FlagdOptionsBuilder withGlobalTelemetry(final boolean b) { if (b) { diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java index 5f3d8a361..1e9c30882 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdProvider.java @@ -21,7 +21,9 @@ import java.util.function.Function; import lombok.extern.slf4j.Slf4j; -/** OpenFeature provider for flagd. */ +/** + * OpenFeature provider for flagd. + */ @Slf4j @SuppressWarnings({"PMD.TooManyStaticImports", "checkstyle:NoFinalizer"}) public class FlagdProvider extends EventProvider { @@ -38,7 +40,9 @@ protected final void finalize() { // DO NOT REMOVE, spotbugs: CT_CONSTRUCTOR_THROW } - /** Create a new FlagdProvider instance with default options. */ + /** + * Create a new FlagdProvider instance with default options. + */ public FlagdProvider() { this(FlagdOptions.builder().build()); } @@ -55,10 +59,7 @@ public FlagdProvider(final FlagdOptions options) { break; case Config.RESOLVER_RPC: this.flagResolver = new GrpcResolver( - options, - new Cache(options.getCacheType(), options.getMaxCacheSize()), - this::isConnected, - this::onConnectionEvent); + options, new Cache(options.getCacheType(), options.getMaxCacheSize()), this::onConnectionEvent); break; default: throw new IllegalStateException( @@ -80,7 +81,7 @@ public synchronized void initialize(EvaluationContext evaluationContext) throws } this.flagResolver.init(); - this.initialized = true; + this.initialized = this.connected = true; } @Override @@ -129,8 +130,10 @@ public ProviderEvaluation getObjectEvaluation(String key, Value defaultVa } /** - * An unmodifiable view of a Structure representing the latest result of the SyncMetadata. Set on - * initial connection and updated with every reconnection. see: + * An unmodifiable view of a Structure representing the latest result of the + * SyncMetadata. + * Set on initial connection and updated with every reconnection. + * see: * https://buf.build/open-feature/flagd/docs/main:flagd.sync.v1#flagd.sync.v1.FlagSyncService.GetMetadata * * @return Object map representing sync metadata @@ -153,38 +156,42 @@ private boolean isConnected() { } private void onConnectionEvent(ConnectionEvent connectionEvent) { - boolean previous = connected; - boolean current = connected = connectionEvent.isConnected(); + final boolean wasConnected = connected; + final boolean isConnected = connected = connectionEvent.isConnected(); + syncMetadata = connectionEvent.getSyncMetadata(); enrichedContext = contextEnricher.apply(connectionEvent.getSyncMetadata()); - // configuration changed - if (initialized && previous && current) { - log.debug("Configuration changed"); - ProviderEventDetails details = ProviderEventDetails.builder() - .flagsChanged(connectionEvent.getFlagsChanged()) - .message("configuration changed") - .build(); - this.emitProviderConfigurationChanged(details); + if (!initialized) { return; } - // there was an error - if (initialized && previous && !current) { - log.debug("There has been an error"); + + if (!wasConnected && isConnected) { ProviderEventDetails details = ProviderEventDetails.builder() - .message("there has been an error") + .flagsChanged(connectionEvent.getFlagsChanged()) + .message("connected to flagd") .build(); - this.emitProviderError(details); + this.emitProviderReady(details); return; } - // we recovered from an error - if (initialized && !previous && current) { - log.debug("Recovered from error"); + + if (wasConnected && isConnected) { ProviderEventDetails details = ProviderEventDetails.builder() - .message("recovered from error") + .flagsChanged(connectionEvent.getFlagsChanged()) + .message("configuration changed") .build(); - this.emitProviderReady(details); this.emitProviderConfigurationChanged(details); + return; + } + + if (connectionEvent.isStale()) { + this.emitProviderStale(ProviderEventDetails.builder() + .message("there has been an error") + .build()); + } else { + this.emitProviderError(ProviderEventDetails.builder() + .message("there has been an error") + .build()); } } } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitor.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitor.java new file mode 100644 index 000000000..0878ce910 --- /dev/null +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelMonitor.java @@ -0,0 +1,98 @@ +package dev.openfeature.contrib.providers.flagd.resolver.common; + +import dev.openfeature.sdk.exceptions.GeneralError; +import io.grpc.ConnectivityState; +import io.grpc.ManagedChannel; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; + +/** + * A utility class to monitor and manage the connectivity state of a gRPC ManagedChannel. + */ +@Slf4j +public class ChannelMonitor { + + private ChannelMonitor() {} + + /** + * Monitors the state of a gRPC channel and triggers the specified callbacks based on state changes. + * + * @param expectedState the initial state to monitor. + * @param channel the ManagedChannel to monitor. + * @param onConnectionReady callback invoked when the channel transitions to a READY state. + * @param onConnectionLost callback invoked when the channel transitions to a FAILURE or SHUTDOWN state. + */ + public static void monitorChannelState( + ConnectivityState expectedState, + ManagedChannel channel, + Runnable onConnectionReady, + Runnable onConnectionLost) { + channel.notifyWhenStateChanged(expectedState, () -> { + ConnectivityState currentState = channel.getState(true); + log.info("Channel state changed to: {}", currentState); + if (currentState == ConnectivityState.READY) { + onConnectionReady.run(); + } else if (currentState == ConnectivityState.TRANSIENT_FAILURE + || currentState == ConnectivityState.SHUTDOWN) { + onConnectionLost.run(); + } + // Re-register the state monitor to watch for the next state transition. + monitorChannelState(currentState, channel, onConnectionReady, onConnectionLost); + }); + } + + /** + * Waits for the channel to reach a desired state within a specified timeout period. + * + * @param channel the ManagedChannel to monitor. + * @param desiredState the ConnectivityState to wait for. + * @param connectCallback callback invoked when the desired state is reached. + * @param timeout the maximum amount of time to wait. + * @param unit the time unit of the timeout. + * @throws InterruptedException if the current thread is interrupted while waiting. + */ + public static void waitForDesiredState( + ManagedChannel channel, + ConnectivityState desiredState, + Runnable connectCallback, + long timeout, + TimeUnit unit) + throws InterruptedException { + waitForDesiredState(channel, desiredState, connectCallback, new CountDownLatch(1), timeout, unit); + } + + private static void waitForDesiredState( + ManagedChannel channel, + ConnectivityState desiredState, + Runnable connectCallback, + CountDownLatch latch, + long timeout, + TimeUnit unit) + throws InterruptedException { + channel.notifyWhenStateChanged(ConnectivityState.SHUTDOWN, () -> { + try { + ConnectivityState state = channel.getState(true); + log.debug("Channel state changed to: {}", state); + + if (state == desiredState) { + connectCallback.run(); + latch.countDown(); + return; + } + waitForDesiredState(channel, desiredState, connectCallback, latch, timeout, unit); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + log.error("Thread interrupted while waiting for desired state", e); + } catch (Exception e) { + log.error("Error occurred while waiting for desired state", e); + } + }); + + // Await the latch or timeout for the state change + if (!latch.await(timeout, unit)) { + throw new GeneralError(String.format( + "Deadline exceeded. Condition did not complete within the %d " + "deadline", timeout)); + } + } +} diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionEvent.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionEvent.java index 994ccdc9c..0e8ff4c6b 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionEvent.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionEvent.java @@ -4,65 +4,121 @@ import dev.openfeature.sdk.Structure; import java.util.Collections; import java.util.List; -import lombok.AllArgsConstructor; -import lombok.Getter; /** - * Event payload for a {@link dev.openfeature.contrib.providers.flagd.resolver.Resolver} connection - * state change event. + * Represents an event payload for a connection state change in a + * {@link dev.openfeature.contrib.providers.flagd.resolver.Resolver}. + * The event includes information about the connection status, any flags that have changed, + * and metadata associated with the synchronization process. */ -@AllArgsConstructor public class ConnectionEvent { - @Getter - private final boolean connected; + /** + * The current state of the connection. + */ + private final ConnectionState connected; + + /** + * A list of flags that have changed due to this connection event. + */ private final List flagsChanged; + + /** + * Metadata associated with synchronization in this connection event. + */ private final Structure syncMetadata; /** - * Construct a new ConnectionEvent. + * Constructs a new {@code ConnectionEvent} with the connection status only. * - * @param connected status of the connection + * @param connected {@code true} if the connection is established, otherwise {@code false}. */ public ConnectionEvent(boolean connected) { + this( + connected ? ConnectionState.CONNECTED : ConnectionState.DISCONNECTED, + Collections.emptyList(), + new ImmutableStructure()); + } + + /** + * Constructs a new {@code ConnectionEvent} with the specified connection state. + * + * @param connected the connection state indicating if the connection is established or not. + */ + public ConnectionEvent(ConnectionState connected) { this(connected, Collections.emptyList(), new ImmutableStructure()); } /** - * Construct a new ConnectionEvent. + * Constructs a new {@code ConnectionEvent} with the specified connection state and changed flags. * - * @param connected status of the connection - * @param flagsChanged list of flags changed + * @param connected the connection state indicating if the connection is established or not. + * @param flagsChanged a list of flags that have changed due to this connection event. */ - public ConnectionEvent(boolean connected, List flagsChanged) { + public ConnectionEvent(ConnectionState connected, List flagsChanged) { this(connected, flagsChanged, new ImmutableStructure()); } /** - * Construct a new ConnectionEvent. + * Constructs a new {@code ConnectionEvent} with the specified connection state and synchronization metadata. * - * @param connected status of the connection - * @param syncMetadata sync.getMetadata + * @param connected the connection state indicating if the connection is established or not. + * @param syncMetadata metadata related to the synchronization process of this event. */ - public ConnectionEvent(boolean connected, Structure syncMetadata) { + public ConnectionEvent(ConnectionState connected, Structure syncMetadata) { this(connected, Collections.emptyList(), new ImmutableStructure(syncMetadata.asMap())); } /** - * Get changed flags. + * Constructs a new {@code ConnectionEvent} with the specified connection state, changed flags, and + * synchronization metadata. + * + * @param connectionState the state of the connection. + * @param flagsChanged a list of flags that have changed due to this connection event. + * @param syncMetadata metadata related to the synchronization process of this event. + */ + public ConnectionEvent(ConnectionState connectionState, List flagsChanged, Structure syncMetadata) { + this.connected = connectionState; + this.flagsChanged = flagsChanged != null ? flagsChanged : Collections.emptyList(); // Ensure non-null list + this.syncMetadata = syncMetadata != null + ? new ImmutableStructure(syncMetadata.asMap()) + : new ImmutableStructure(); // Ensure valid syncMetadata + } + + /** + * Retrieves an unmodifiable view of the list of changed flags. * - * @return an unmodifiable view of the changed flags + * @return an unmodifiable list of changed flags. */ public List getFlagsChanged() { return Collections.unmodifiableList(flagsChanged); } /** - * Get changed sync metadata represented as SDK structure type. + * Retrieves the synchronization metadata represented as an immutable SDK structure type. * - * @return an unmodifiable view of the sync metadata + * @return an immutable structure containing the synchronization metadata. */ public Structure getSyncMetadata() { return new ImmutableStructure(syncMetadata.asMap()); } + + /** + * Indicates whether the current connection state is connected. + * + * @return {@code true} if connected, otherwise {@code false}. + */ + public boolean isConnected() { + return this.connected == ConnectionState.CONNECTED; + } + + /** + * Indicates + * whether the current connection state is stale. + * + * @return {@code true} if stale, otherwise {@code false}. + */ + public boolean isStale() { + return this.connected == ConnectionState.STALE; + } } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionState.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionState.java new file mode 100644 index 000000000..6dbd388a0 --- /dev/null +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ConnectionState.java @@ -0,0 +1,27 @@ +package dev.openfeature.contrib.providers.flagd.resolver.common; + +/** + * Represents the possible states of a connection. + */ +public enum ConnectionState { + + /** + * The connection is active and functioning as expected. + */ + CONNECTED, + + /** + * The connection is not active and has been fully disconnected. + */ + DISCONNECTED, + + /** + * The connection is inactive or degraded but may still recover. + */ + STALE, + + /** + * The connection has encountered an error and cannot function correctly. + */ + ERROR, +} diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/Util.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/Util.java index d634f0745..394716415 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/Util.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/Util.java @@ -2,18 +2,26 @@ import dev.openfeature.sdk.exceptions.GeneralError; import java.util.function.Supplier; +import lombok.extern.slf4j.Slf4j; -/** Utils for flagd resolvers. */ +/** + * Utility class for managing gRPC connection states and handling synchronization operations. + */ +@Slf4j public class Util { + /** + * Private constructor to prevent instantiation of utility class. + */ private Util() {} /** - * A helper to block the caller for given conditions. + * A helper method to block the caller until a condition is met or a timeout occurs. * - * @param deadline number of milliseconds to block - * @param connectedSupplier func to check for status true - * @throws InterruptedException if interrupted + * @param deadline the maximum number of milliseconds to block + * @param connectedSupplier a function that evaluates to {@code true} when the desired condition is met + * @throws InterruptedException if the thread is interrupted during the waiting process + * @throws GeneralError if the deadline is exceeded before the condition is met */ public static void busyWaitAndCheck(final Long deadline, final Supplier connectedSupplier) throws InterruptedException { @@ -22,7 +30,7 @@ public static void busyWaitAndCheck(final Long deadline, final Supplier do { if (deadline <= System.currentTimeMillis() - start) { throw new GeneralError(String.format( - "Deadline exceeded. Condition did not complete within the %d deadline", deadline)); + "Deadline exceeded. Condition did not complete within the %d " + "deadline", deadline)); } Thread.sleep(50L); diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserver.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserver.java index 7d60a704a..db89931e5 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserver.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserver.java @@ -10,38 +10,42 @@ import java.util.List; import java.util.Map; import java.util.function.BiConsumer; -import java.util.function.Supplier; import lombok.extern.slf4j.Slf4j; -/** EventStreamObserver handles events emitted by flagd. */ +/** + * Observer for a gRPC event stream that handles notifications about flag changes and provider readiness events. + * This class updates a cache and notifies listeners via a lambda callback when events occur. + */ @Slf4j @SuppressFBWarnings(justification = "cache needs to be read and write by multiple objects") class EventStreamObserver implements StreamObserver { + + /** + * A consumer to handle connection events with a flag indicating success and a list of changed flags. + */ private final BiConsumer> onConnectionEvent; - private final Supplier shouldRetrySilently; - private final Object sync; + + /** + * The cache to update based on received events. + */ private final Cache cache; /** - * Create a gRPC stream that get notified about flag changes. + * Constructs a new {@code EventStreamObserver} instance. * - * @param sync synchronization object from caller - * @param cache cache to update - * @param onConnectionEvent lambda to call to handle the response - * @param shouldRetrySilently Boolean supplier indicating if the GRPC connector will try to - * recover silently + * @param cache the cache to update based on received events + * @param onConnectionEvent a consumer to handle connection events with a boolean and a list of changed flags */ - EventStreamObserver( - Object sync, - Cache cache, - BiConsumer> onConnectionEvent, - Supplier shouldRetrySilently) { - this.sync = sync; + EventStreamObserver(Cache cache, BiConsumer> onConnectionEvent) { this.cache = cache; this.onConnectionEvent = onConnectionEvent; - this.shouldRetrySilently = shouldRetrySilently; } + /** + * Called when a new event is received from the stream. + * + * @param value the event stream response containing event data + */ @Override public void onNext(EventStreamResponse value) { switch (value.getType()) { @@ -52,37 +56,38 @@ public void onNext(EventStreamResponse value) { this.handleProviderReadyEvent(); break; default: - log.debug("unhandled event type {}", value.getType()); + log.debug("Unhandled event type {}", value.getType()); } } + /** + * Called when an error occurs in the stream. + * + * @param throwable the error that occurred + */ @Override public void onError(Throwable throwable) { - if (Boolean.TRUE.equals(shouldRetrySilently.get())) { - log.debug("Event stream error, trying to recover", throwable); - } else { - log.error("Event stream error", throwable); - if (this.cache.getEnabled()) { - this.cache.clear(); - } - this.onConnectionEvent.accept(false, Collections.emptyList()); + if (this.cache.getEnabled().equals(Boolean.TRUE)) { + this.cache.clear(); } - - // handle last call of this stream - handleEndOfStream(); } + /** + * Called when the stream is completed. + */ @Override public void onCompleted() { - if (this.cache.getEnabled()) { + if (this.cache.getEnabled().equals(Boolean.TRUE)) { this.cache.clear(); } this.onConnectionEvent.accept(false, Collections.emptyList()); - - // handle last call of this stream - handleEndOfStream(); } + /** + * Handles configuration change events by updating the cache and notifying listeners about changed flags. + * + * @param value the event stream response containing configuration change data + */ private void handleConfigurationChangeEvent(EventStreamResponse value) { List changedFlags = new ArrayList<>(); boolean cachingEnabled = this.cache.getEnabled(); @@ -95,7 +100,6 @@ private void handleConfigurationChangeEvent(EventStreamResponse value) { } } else { Map flags = flagsValue.getStructValue().getFieldsMap(); - this.cache.getEnabled(); for (String flagKey : flags.keySet()) { changedFlags.add(flagKey); if (cachingEnabled) { @@ -107,16 +111,12 @@ private void handleConfigurationChangeEvent(EventStreamResponse value) { this.onConnectionEvent.accept(true, changedFlags); } + /** + * Handles provider readiness events by clearing the cache (if enabled) and notifying listeners of readiness. + */ private void handleProviderReadyEvent() { - this.onConnectionEvent.accept(true, Collections.emptyList()); - if (this.cache.getEnabled()) { + if (this.cache.getEnabled().equals(Boolean.TRUE)) { this.cache.clear(); } } - - private void handleEndOfStream() { - synchronized (this.sync) { - this.sync.notifyAll(); - } - } } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnector.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnector.java index 8fabb5d8b..9508c521b 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnector.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnector.java @@ -1,169 +1,246 @@ package dev.openfeature.contrib.providers.flagd.resolver.grpc; -import static dev.openfeature.contrib.providers.flagd.resolver.common.backoff.BackoffStrategies.maxRetriesWithExponentialTimeBackoffStrategy; - +import com.google.common.annotations.VisibleForTesting; import dev.openfeature.contrib.providers.flagd.FlagdOptions; import dev.openfeature.contrib.providers.flagd.resolver.common.ChannelBuilder; +import dev.openfeature.contrib.providers.flagd.resolver.common.ChannelMonitor; import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionEvent; -import dev.openfeature.contrib.providers.flagd.resolver.common.Util; -import dev.openfeature.contrib.providers.flagd.resolver.common.backoff.GrpcStreamConnectorBackoffService; -import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; -import dev.openfeature.flagd.grpc.evaluation.Evaluation.EventStreamRequest; -import dev.openfeature.flagd.grpc.evaluation.Evaluation.EventStreamResponse; -import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc; -import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionState; +import dev.openfeature.sdk.ImmutableStructure; +import io.grpc.ConnectivityState; import io.grpc.ManagedChannel; -import io.grpc.stub.StreamObserver; +import io.grpc.stub.AbstractBlockingStub; +import io.grpc.stub.AbstractStub; import java.util.Collections; -import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; -import java.util.function.Supplier; +import java.util.function.Function; +import lombok.Getter; import lombok.extern.slf4j.Slf4j; -/** Class that abstracts the gRPC communication with flagd. */ +/** + * A generic GRPC connector that manages connection states, reconnection logic, and event streaming for + * GRPC services. + * + * @param the type of the asynchronous stub for the GRPC service + * @param the type of the blocking stub for the GRPC service + */ @Slf4j -@SuppressFBWarnings(justification = "cache needs to be read and write by multiple objects") -public class GrpcConnector { - private final Object sync = new Object(); +public class GrpcConnector, K extends AbstractBlockingStub> { - private final ServiceGrpc.ServiceBlockingStub serviceBlockingStub; - private final ServiceGrpc.ServiceStub serviceStub; + /** + * The asynchronous service stub for making non-blocking GRPC calls. + */ + private final T serviceStub; + + /** + * The blocking service stub for making blocking GRPC calls. + */ + private final K blockingStub; + + /** + * The GRPC managed channel for managing the underlying GRPC connection. + */ private final ManagedChannel channel; + /** + * The deadline in milliseconds for GRPC operations. + */ private final long deadline; + + /** + * The deadline in milliseconds for event streaming operations. + */ private final long streamDeadlineMs; - private final Cache cache; + /** + * A consumer that handles connection events such as connection loss or reconnection. + */ private final Consumer onConnectionEvent; - private final Supplier connectedSupplier; - private final GrpcStreamConnectorBackoffService backoff; - // Thread responsible for event observation - private Thread eventObserverThread; + /** + * A consumer that handles GRPC service stubs for event stream handling. + */ + private final Consumer streamObserver; + + /** + * An executor service responsible for scheduling reconnection attempts. + */ + private final ScheduledExecutorService reconnectExecutor; + + /** + * The grace period in milliseconds to wait for reconnection before emitting an error event. + */ + private final long gracePeriod; + + /** + * Indicates whether the connector is currently connected to the GRPC service. + */ + @Getter + private boolean connected = false; + + /** + * A scheduled task for managing reconnection attempts. + */ + private ScheduledFuture reconnectTask; /** - * GrpcConnector creates an abstraction over gRPC communication. + * Constructs a new {@code GrpcConnector} instance with the specified options and parameters. * - * @param options flagd options - * @param cache cache to use - * @param connectedSupplier lambda providing current connection status from caller - * @param onConnectionEvent lambda which handles changes in the connection/stream + * @param options the configuration options for the GRPC connection + * @param stub a function to create the asynchronous service stub from a {@link ManagedChannel} + * @param blockingStub a function to create the blocking service stub from a {@link ManagedChannel} + * @param onConnectionEvent a consumer to handle connection events + * @param eventStreamObserver a consumer to handle the event stream + * @param channel the managed channel for the GRPC connection */ public GrpcConnector( final FlagdOptions options, - final Cache cache, - final Supplier connectedSupplier, - Consumer onConnectionEvent) { - this.channel = ChannelBuilder.nettyChannel(options); - this.serviceStub = ServiceGrpc.newStub(channel); - this.serviceBlockingStub = ServiceGrpc.newBlockingStub(channel); + final Function stub, + final Function blockingStub, + final Consumer onConnectionEvent, + final Consumer eventStreamObserver, + ManagedChannel channel) { + + this.channel = channel; + this.serviceStub = stub.apply(channel); + this.blockingStub = blockingStub.apply(channel); this.deadline = options.getDeadline(); this.streamDeadlineMs = options.getStreamDeadlineMs(); - this.cache = cache; this.onConnectionEvent = onConnectionEvent; - this.connectedSupplier = connectedSupplier; - this.backoff = new GrpcStreamConnectorBackoffService(maxRetriesWithExponentialTimeBackoffStrategy( - options.getMaxEventStreamRetries(), options.getRetryBackoffMs())); + this.streamObserver = eventStreamObserver; + this.gracePeriod = options.getRetryGracePeriod(); + this.reconnectExecutor = Executors.newSingleThreadScheduledExecutor(); } - /** Initialize the gRPC stream. */ + /** + * Constructs a {@code GrpcConnector} instance for testing purposes. + * + * @param options the configuration options for the GRPC connection + * @param stub a function to create the asynchronous service stub from a {@link ManagedChannel} + * @param blockingStub a function to create the blocking service stub from a {@link ManagedChannel} + * @param onConnectionEvent a consumer to handle connection events + * @param eventStreamObserver a consumer to handle the event stream + */ + @VisibleForTesting + GrpcConnector( + final FlagdOptions options, + final Function stub, + final Function blockingStub, + final Consumer onConnectionEvent, + final Consumer eventStreamObserver) { + this(options, stub, blockingStub, onConnectionEvent, eventStreamObserver, ChannelBuilder.nettyChannel(options)); + } + + /** + * Initializes the GRPC connection by waiting for the channel to be ready and monitoring its state. + * + * @throws Exception if the channel does not reach the desired state within the deadline + */ public void initialize() throws Exception { - eventObserverThread = new Thread(this::observeEventStream); - eventObserverThread.setDaemon(true); - eventObserverThread.start(); + log.info("Initializing GRPC connection..."); + ChannelMonitor.waitForDesiredState( + channel, ConnectivityState.READY, this::onInitialConnect, deadline, TimeUnit.MILLISECONDS); + ChannelMonitor.monitorChannelState(ConnectivityState.READY, channel, this::onReady, this::onConnectionLost); + } - // block till ready - Util.busyWaitAndCheck(this.deadline, this.connectedSupplier); + /** + * Returns the blocking service stub for making blocking GRPC calls. + * + * @return the blocking service stub + */ + public K getResolver() { + return blockingStub; } /** - * Shuts down all gRPC resources. + * Shuts down the GRPC connection and cleans up associated resources. * - * @throws Exception is something goes wrong while terminating the communication. + * @throws InterruptedException if interrupted while waiting for termination */ - public void shutdown() throws Exception { - // first shutdown the event listener - if (this.eventObserverThread != null) { - this.eventObserverThread.interrupt(); + public void shutdown() throws InterruptedException { + log.info("Shutting down GRPC connection..."); + if (reconnectExecutor != null) { + reconnectExecutor.shutdownNow(); + reconnectExecutor.awaitTermination(deadline, TimeUnit.MILLISECONDS); } - try { - if (this.channel != null && !this.channel.isShutdown()) { - this.channel.shutdown(); - this.channel.awaitTermination(this.deadline, TimeUnit.MILLISECONDS); - } - } finally { - this.cache.clear(); - if (this.channel != null && !this.channel.isShutdown()) { - this.channel.shutdownNow(); - this.channel.awaitTermination(this.deadline, TimeUnit.MILLISECONDS); - log.warn(String.format("Unable to shut down channel by %d deadline", this.deadline)); - } + if (!channel.isShutdown()) { + channel.shutdownNow(); + channel.awaitTermination(deadline, TimeUnit.MILLISECONDS); + } + + if (connected) { this.onConnectionEvent.accept(new ConnectionEvent(false)); + connected = false; } } - /** - * Provide the object that can be used to resolve Feature Flag values. - * - * @return a {@link ServiceGrpc.ServiceBlockingStub} for running FF resolution. - */ - public ServiceGrpc.ServiceBlockingStub getResolver() { - return serviceBlockingStub.withDeadlineAfter(this.deadline, TimeUnit.MILLISECONDS); + private synchronized void onInitialConnect() { + connected = true; + restartStream(); } /** - * Event stream observer logic. This contains blocking mechanisms, hence must be run in a - * dedicated thread. + * Handles the event when the GRPC channel becomes ready, marking the connection as established. + * Cancels any pending reconnection task and restarts the event stream. */ - private void observeEventStream() { - while (backoff.shouldRetry()) { - final StreamObserver responseObserver = - new EventStreamObserver(sync, this.cache, this::onConnectionEvent, backoff::shouldRetrySilently); + private synchronized void onReady() { + connected = true; - ServiceGrpc.ServiceStub localServiceStub = this.serviceStub; + if (reconnectTask != null && !reconnectTask.isCancelled()) { + reconnectTask.cancel(false); + log.debug("Reconnection task cancelled as connection became READY."); + } + restartStream(); + this.onConnectionEvent.accept(new ConnectionEvent(true)); + } - if (this.streamDeadlineMs > 0) { - localServiceStub = localServiceStub.withDeadlineAfter(this.streamDeadlineMs, TimeUnit.MILLISECONDS); - } + /** + * Handles the event when the GRPC channel loses its connection, marking the connection as lost. + * Schedules a reconnection task after a grace period and emits a stale connection event. + */ + private synchronized void onConnectionLost() { + log.debug("Connection lost. Emit STALE event..."); + log.debug("Waiting {}s for connection to become available...", gracePeriod); + connected = false; - localServiceStub.eventStream(EventStreamRequest.getDefaultInstance(), responseObserver); - - try { - synchronized (sync) { - sync.wait(); - } - } catch (InterruptedException e) { - // Interruptions are considered end calls for this observer, hence log and - // return - // Note - this is the most common interruption when shutdown, hence the log - // level debug - log.debug("interruption while waiting for condition", e); - Thread.currentThread().interrupt(); - } + this.onConnectionEvent.accept( + new ConnectionEvent(ConnectionState.STALE, Collections.emptyList(), new ImmutableStructure())); - try { - backoff.waitUntilNextAttempt(); - } catch (InterruptedException e) { - // Interruptions are considered end calls for this observer, hence log and - // return - log.warn("interrupted while restarting gRPC Event Stream"); - Thread.currentThread().interrupt(); - } + if (reconnectTask != null && !reconnectTask.isCancelled()) { + reconnectTask.cancel(false); } - log.error("failed to connect to event stream, exhausted retries"); - this.onConnectionEvent(false, Collections.emptyList()); + if (!reconnectExecutor.isShutdown()) { + reconnectTask = reconnectExecutor.schedule( + () -> { + log.debug( + "Provider did not reconnect successfully within {}s. Emit ERROR event...", gracePeriod); + this.onConnectionEvent.accept(new ConnectionEvent(false)); + }, + gracePeriod, + TimeUnit.SECONDS); + } } - private void onConnectionEvent(final boolean connected, final List changedFlags) { - // reset reconnection states + /** + * Restarts the event stream using the asynchronous service stub, applying a deadline if configured. + * Emits a connection event if the restart is successful. + */ + private synchronized void restartStream() { if (connected) { - backoff.reset(); + log.debug("(Re)initializing event stream."); + T localServiceStub = this.serviceStub; + if (streamDeadlineMs > 0) { + localServiceStub = localServiceStub.withDeadlineAfter(this.streamDeadlineMs, TimeUnit.MILLISECONDS); + } + streamObserver.accept(localServiceStub); + return; } - - // chain to initiator - this.onConnectionEvent.accept(new ConnectionEvent(connected, changedFlags)); + log.debug("Stream restart skipped. Not connected."); } } diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcResolver.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcResolver.java index 9d8c3a9f2..a64275c2b 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcResolver.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcResolver.java @@ -11,14 +11,17 @@ import dev.openfeature.contrib.providers.flagd.FlagdOptions; import dev.openfeature.contrib.providers.flagd.resolver.Resolver; import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionEvent; +import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionState; import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; import dev.openfeature.contrib.providers.flagd.resolver.grpc.strategy.ResolveFactory; import dev.openfeature.contrib.providers.flagd.resolver.grpc.strategy.ResolveStrategy; +import dev.openfeature.flagd.grpc.evaluation.Evaluation; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveBooleanRequest; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveFloatRequest; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveIntRequest; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveObjectRequest; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveStringRequest; +import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc; import dev.openfeature.sdk.EvaluationContext; import dev.openfeature.sdk.ImmutableMetadata; import dev.openfeature.sdk.ProviderEvaluation; @@ -34,7 +37,6 @@ import java.util.Map; import java.util.function.Consumer; import java.util.function.Function; -import java.util.function.Supplier; /** * Resolves flag values using https://buf.build/open-feature/flagd/docs/main:flagd.evaluation.v1. @@ -44,72 +46,89 @@ @SuppressFBWarnings(justification = "cache needs to be read and write by multiple objects") public final class GrpcResolver implements Resolver { - private final GrpcConnector connector; + private final GrpcConnector connector; private final Cache cache; private final ResolveStrategy strategy; - private final Supplier connectedSupplier; /** * Resolves flag values using https://buf.build/open-feature/flagd/docs/main:flagd.evaluation.v1. * Flags are evaluated remotely. * - * @param options flagd options - * @param cache cache to use - * @param connectedSupplier lambda providing current connection status from caller + * @param options flagd options + * @param cache cache to use * @param onConnectionEvent lambda which handles changes in the connection/stream */ public GrpcResolver( - final FlagdOptions options, - final Cache cache, - final Supplier connectedSupplier, - final Consumer onConnectionEvent) { + final FlagdOptions options, final Cache cache, final Consumer onConnectionEvent) { this.cache = cache; - this.connectedSupplier = connectedSupplier; this.strategy = ResolveFactory.getStrategy(options); - this.connector = new GrpcConnector(options, cache, connectedSupplier, onConnectionEvent); + this.connector = new GrpcConnector<>( + options, + ServiceGrpc::newStub, + ServiceGrpc::newBlockingStub, + onConnectionEvent, + stub -> stub.eventStream( + Evaluation.EventStreamRequest.getDefaultInstance(), + new EventStreamObserver( + cache, + (k, e) -> + onConnectionEvent.accept(new ConnectionEvent(ConnectionState.CONNECTED, e))))); } - /** Initialize Grpc resolver. */ + /** + * Initialize Grpc resolver. + */ public void init() throws Exception { this.connector.initialize(); } - /** Shutdown Grpc resolver. */ + /** + * Shutdown Grpc resolver. + */ public void shutdown() throws Exception { this.connector.shutdown(); } - /** Boolean evaluation from grpc resolver. */ + /** + * Boolean evaluation from grpc resolver. + */ public ProviderEvaluation booleanEvaluation(String key, Boolean defaultValue, EvaluationContext ctx) { ResolveBooleanRequest request = ResolveBooleanRequest.newBuilder().buildPartial(); - return resolve(key, ctx, request, this.connector.getResolver()::resolveBoolean, null); + return resolve(key, ctx, request, connector.getResolver()::resolveBoolean, null); } - /** String evaluation from grpc resolver. */ + /** + * String evaluation from grpc resolver. + */ public ProviderEvaluation stringEvaluation(String key, String defaultValue, EvaluationContext ctx) { ResolveStringRequest request = ResolveStringRequest.newBuilder().buildPartial(); - - return resolve(key, ctx, request, this.connector.getResolver()::resolveString, null); + return resolve(key, ctx, request, connector.getResolver()::resolveString, null); } - /** Double evaluation from grpc resolver. */ + /** + * Double evaluation from grpc resolver. + */ public ProviderEvaluation doubleEvaluation(String key, Double defaultValue, EvaluationContext ctx) { ResolveFloatRequest request = ResolveFloatRequest.newBuilder().buildPartial(); - return resolve(key, ctx, request, this.connector.getResolver()::resolveFloat, null); + return resolve(key, ctx, request, connector.getResolver()::resolveFloat, null); } - /** Integer evaluation from grpc resolver. */ + /** + * Integer evaluation from grpc resolver. + */ public ProviderEvaluation integerEvaluation(String key, Integer defaultValue, EvaluationContext ctx) { ResolveIntRequest request = ResolveIntRequest.newBuilder().buildPartial(); - return resolve(key, ctx, request, this.connector.getResolver()::resolveInt, (Object value) -> ((Long) value) - .intValue()); + return resolve( + key, ctx, request, connector.getResolver()::resolveInt, (Object value) -> ((Long) value).intValue()); } - /** Object evaluation from grpc resolver. */ + /** + * Object evaluation from grpc resolver. + */ public ProviderEvaluation objectEvaluation(String key, Value defaultValue, EvaluationContext ctx) { ResolveObjectRequest request = ResolveObjectRequest.newBuilder().buildPartial(); @@ -118,13 +137,13 @@ public ProviderEvaluation objectEvaluation(String key, Value defaultValue key, ctx, request, - this.connector.getResolver()::resolveObject, + connector.getResolver()::resolveObject, (Object value) -> convertObjectResponse((Struct) value)); } /** - * A generic resolve method that takes a resolverRef and an optional converter lambda to transform - * the result. + * A generic resolve method that takes a resolverRef and an optional converter + * lambda to transform the result. */ private ProviderEvaluation resolve( String key, @@ -187,7 +206,7 @@ private Boolean isEvaluationCacheable(ProviderEvaluation evaluation) { } private Boolean cacheAvailable() { - return this.cache.getEnabled() && this.connectedSupplier.get(); + return this.cache.getEnabled() && this.connector.isConnected(); } private static ImmutableMetadata metadataFromResponse(Message response) { diff --git a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java index 663d4bb0d..fd617af1f 100644 --- a/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java +++ b/providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/process/InProcessResolver.java @@ -5,6 +5,7 @@ import dev.openfeature.contrib.providers.flagd.FlagdOptions; import dev.openfeature.contrib.providers.flagd.resolver.Resolver; import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionEvent; +import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionState; import dev.openfeature.contrib.providers.flagd.resolver.common.Util; import dev.openfeature.contrib.providers.flagd.resolver.process.model.FeatureFlag; import dev.openfeature.contrib.providers.flagd.resolver.process.storage.FlagStore; @@ -28,8 +29,9 @@ import lombok.extern.slf4j.Slf4j; /** - * Resolves flag values using https://buf.build/open-feature/flagd/docs/main:flagd.sync.v1. Flags - * are evaluated locally. + * Resolves flag values using + * https://buf.build/open-feature/flagd/docs/main:flagd.sync.v1. + * Flags are evaluated locally. */ @Slf4j public class InProcessResolver implements Resolver { @@ -41,12 +43,15 @@ public class InProcessResolver implements Resolver { private final Supplier connectedSupplier; /** - * Resolves flag values using https://buf.build/open-feature/flagd/docs/main:flagd.sync.v1. Flags - * are evaluated locally. + * Resolves flag values using + * https://buf.build/open-feature/flagd/docs/main:flagd.sync.v1. + * Flags are evaluated locally. * - * @param options flagd options - * @param connectedSupplier lambda providing current connection status from caller - * @param onConnectionEvent lambda which handles changes in the connection/stream + * @param options flagd options + * @param connectedSupplier lambda providing current connection status from + * caller + * @param onConnectionEvent lambda which handles changes in the + * connection/stream */ public InProcessResolver( FlagdOptions options, @@ -64,7 +69,9 @@ public InProcessResolver( .build(); } - /** Initialize in-process resolver. */ + /** + * Initialize in-process resolver. + */ public void init() throws Exception { flagStore.init(); final Thread stateWatcher = new Thread(() -> { @@ -75,7 +82,7 @@ public void init() throws Exception { switch (storageStateChange.getStorageState()) { case OK: onConnectionEvent.accept(new ConnectionEvent( - true, + ConnectionState.CONNECTED, storageStateChange.getChangedFlagsKeys(), storageStateChange.getSyncMetadata())); break; @@ -109,27 +116,37 @@ public void shutdown() throws InterruptedException { onConnectionEvent.accept(new ConnectionEvent(false)); } - /** Resolve a boolean flag. */ + /** + * Resolve a boolean flag. + */ public ProviderEvaluation booleanEvaluation(String key, Boolean defaultValue, EvaluationContext ctx) { return resolve(Boolean.class, key, ctx); } - /** Resolve a string flag. */ + /** + * Resolve a string flag. + */ public ProviderEvaluation stringEvaluation(String key, String defaultValue, EvaluationContext ctx) { return resolve(String.class, key, ctx); } - /** Resolve a double flag. */ + /** + * Resolve a double flag. + */ public ProviderEvaluation doubleEvaluation(String key, Double defaultValue, EvaluationContext ctx) { return resolve(Double.class, key, ctx); } - /** Resolve an integer flag. */ + /** + * Resolve an integer flag. + */ public ProviderEvaluation integerEvaluation(String key, Integer defaultValue, EvaluationContext ctx) { return resolve(Integer.class, key, ctx); } - /** Resolve an object flag. */ + /** + * Resolve an object flag. + */ public ProviderEvaluation objectEvaluation(String key, Value defaultValue, EvaluationContext ctx) { final ProviderEvaluation evaluation = resolve(Object.class, key, ctx); diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java index c8e8aba1c..c85765223 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java @@ -1,8 +1,22 @@ package dev.openfeature.contrib.providers.flagd; -import static dev.openfeature.contrib.providers.flagd.Config.*; +import static dev.openfeature.contrib.providers.flagd.Config.DEFAULT_CACHE; +import static dev.openfeature.contrib.providers.flagd.Config.DEFAULT_HOST; +import static dev.openfeature.contrib.providers.flagd.Config.DEFAULT_IN_PROCESS_PORT; +import static dev.openfeature.contrib.providers.flagd.Config.DEFAULT_MAX_CACHE_SIZE; +import static dev.openfeature.contrib.providers.flagd.Config.DEFAULT_RPC_PORT; +import static dev.openfeature.contrib.providers.flagd.Config.KEEP_ALIVE_MS_ENV_VAR_NAME; +import static dev.openfeature.contrib.providers.flagd.Config.KEEP_ALIVE_MS_ENV_VAR_NAME_OLD; +import static dev.openfeature.contrib.providers.flagd.Config.RESOLVER_ENV_VAR; +import static dev.openfeature.contrib.providers.flagd.Config.RESOLVER_IN_PROCESS; +import static dev.openfeature.contrib.providers.flagd.Config.RESOLVER_RPC; +import static dev.openfeature.contrib.providers.flagd.Config.Resolver; +import static dev.openfeature.contrib.providers.flagd.Config.TARGET_URI_ENV_VAR_NAME; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import dev.openfeature.contrib.providers.flagd.resolver.process.storage.MockConnector; import dev.openfeature.contrib.providers.flagd.resolver.process.storage.connector.Connector; @@ -26,7 +40,6 @@ void TestDefaults() { assertNull(builder.getSocketPath()); assertEquals(DEFAULT_CACHE, builder.getCacheType()); assertEquals(DEFAULT_MAX_CACHE_SIZE, builder.getMaxCacheSize()); - assertEquals(DEFAULT_MAX_EVENT_STREAM_RETRIES, builder.getMaxEventStreamRetries()); assertNull(builder.getSelector()); assertNull(builder.getOpenTelemetry()); assertNull(builder.getCustomConnector()); @@ -47,7 +60,6 @@ void TestBuilderOptions() { .certPath("etc/cert/ca.crt") .cacheType("lru") .maxCacheSize(100) - .maxEventStreamRetries(1) .selector("app=weatherApp") .offlineFlagSourcePath("some-path") .openTelemetry(openTelemetry) @@ -63,7 +75,6 @@ void TestBuilderOptions() { assertEquals("etc/cert/ca.crt", flagdOptions.getCertPath()); assertEquals("lru", flagdOptions.getCacheType()); assertEquals(100, flagdOptions.getMaxCacheSize()); - assertEquals(1, flagdOptions.getMaxEventStreamRetries()); assertEquals("app=weatherApp", flagdOptions.getSelector()); assertEquals("some-path", flagdOptions.getOfflineFlagSourcePath()); assertEquals(openTelemetry, flagdOptions.getOpenTelemetry()); diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java index 0186ff38d..b01cead09 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdProviderTest.java @@ -9,10 +9,8 @@ import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockConstruction; -import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -20,6 +18,7 @@ import com.google.protobuf.Struct; import dev.openfeature.contrib.providers.flagd.resolver.Resolver; import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionEvent; +import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionState; import dev.openfeature.contrib.providers.flagd.resolver.grpc.GrpcConnector; import dev.openfeature.contrib.providers.flagd.resolver.grpc.GrpcResolver; import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; @@ -34,9 +33,7 @@ import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveIntResponse; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveObjectResponse; import dev.openfeature.flagd.grpc.evaluation.Evaluation.ResolveStringResponse; -import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc; import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceBlockingStub; -import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceStub; import dev.openfeature.sdk.EvaluationContext; import dev.openfeature.sdk.FlagEvaluationDetails; import dev.openfeature.sdk.FlagValueType; @@ -50,8 +47,6 @@ import dev.openfeature.sdk.Structure; import dev.openfeature.sdk.Value; import io.cucumber.java.AfterAll; -import io.grpc.Channel; -import io.grpc.Deadline; import java.lang.reflect.Field; import java.util.ArrayList; import java.util.Arrays; @@ -64,11 +59,9 @@ import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.function.Function; -import java.util.function.Supplier; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.mockito.MockedConstruction; -import org.mockito.MockedStatic; class FlagdProviderTest { private static final String FLAG_KEY = "some-key"; @@ -87,7 +80,7 @@ class FlagdProviderTest { private static final Double DOUBLE_VALUE = .5d; private static final String INNER_STRUCT_KEY = "inner_key"; private static final String INNER_STRUCT_VALUE = "inner_value"; - private static final com.google.protobuf.Struct PROTOBUF_STRUCTURE_VALUE = Struct.newBuilder() + private static final Struct PROTOBUF_STRUCTURE_VALUE = Struct.newBuilder() .putFields( INNER_STRUCT_KEY, com.google.protobuf.Value.newBuilder() @@ -466,153 +459,6 @@ void reason_mapped_correctly_if_unknown() { // UNKNOWN } - @Test - void invalidate_cache() { - ResolveBooleanResponse booleanResponse = ResolveBooleanResponse.newBuilder() - .setValue(true) - .setVariant(BOOL_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveStringResponse stringResponse = ResolveStringResponse.newBuilder() - .setValue(STRING_VALUE) - .setVariant(STRING_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveIntResponse intResponse = ResolveIntResponse.newBuilder() - .setValue(INT_VALUE) - .setVariant(INT_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveFloatResponse floatResponse = ResolveFloatResponse.newBuilder() - .setValue(DOUBLE_VALUE) - .setVariant(DOUBLE_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveObjectResponse objectResponse = ResolveObjectResponse.newBuilder() - .setValue(PROTOBUF_STRUCTURE_VALUE) - .setVariant(OBJECT_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ServiceBlockingStub serviceBlockingStubMock = mock(ServiceBlockingStub.class); - ServiceStub serviceStubMock = mock(ServiceStub.class); - when(serviceStubMock.withWaitForReady()).thenReturn(serviceStubMock); - doNothing().when(serviceStubMock).eventStream(any(), any()); - when(serviceStubMock.withDeadline(any(Deadline.class))).thenReturn(serviceStubMock); - when(serviceBlockingStubMock.withWaitForReady()).thenReturn(serviceBlockingStubMock); - when(serviceBlockingStubMock.withDeadline(any(Deadline.class))).thenReturn(serviceBlockingStubMock); - when(serviceBlockingStubMock.withDeadlineAfter(anyLong(), any(TimeUnit.class))) - .thenReturn(serviceBlockingStubMock); - when(serviceBlockingStubMock.resolveBoolean(argThat(x -> FLAG_KEY_BOOLEAN.equals(x.getFlagKey())))) - .thenReturn(booleanResponse); - when(serviceBlockingStubMock.resolveFloat(argThat(x -> FLAG_KEY_DOUBLE.equals(x.getFlagKey())))) - .thenReturn(floatResponse); - when(serviceBlockingStubMock.resolveInt(argThat(x -> FLAG_KEY_INTEGER.equals(x.getFlagKey())))) - .thenReturn(intResponse); - when(serviceBlockingStubMock.resolveString(argThat(x -> FLAG_KEY_STRING.equals(x.getFlagKey())))) - .thenReturn(stringResponse); - when(serviceBlockingStubMock.resolveObject(argThat(x -> FLAG_KEY_OBJECT.equals(x.getFlagKey())))) - .thenReturn(objectResponse); - - GrpcConnector grpc; - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService - .when(() -> ServiceGrpc.newBlockingStub(any(Channel.class))) - .thenReturn(serviceBlockingStubMock); - mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(serviceStubMock); - - final Cache cache = new Cache("lru", 5); - - class NoopInitGrpcConnector extends GrpcConnector { - public NoopInitGrpcConnector( - FlagdOptions options, - Cache cache, - Supplier connectedSupplier, - Consumer onConnectionEvent) { - super(options, cache, connectedSupplier, onConnectionEvent); - } - - public void initialize() throws Exception {} - ; - } - - grpc = new NoopInitGrpcConnector( - FlagdOptions.builder().build(), cache, () -> true, (connectionEvent) -> {}); - } - - FlagdProvider provider = createProvider(grpc); - OpenFeatureAPI.getInstance().setProviderAndWait(provider); - - HashMap flagsMap = new HashMap(); - HashMap structMap = new HashMap(); - - flagsMap.put( - FLAG_KEY_BOOLEAN, - com.google.protobuf.Value.newBuilder().setStringValue("foo").build()); - flagsMap.put( - FLAG_KEY_STRING, - com.google.protobuf.Value.newBuilder().setStringValue("foo").build()); - flagsMap.put( - FLAG_KEY_INTEGER, - com.google.protobuf.Value.newBuilder().setStringValue("foo").build()); - flagsMap.put( - FLAG_KEY_DOUBLE, - com.google.protobuf.Value.newBuilder().setStringValue("foo").build()); - flagsMap.put( - FLAG_KEY_OBJECT, - com.google.protobuf.Value.newBuilder().setStringValue("foo").build()); - - structMap.put( - "flags", - com.google.protobuf.Value.newBuilder() - .setStructValue(Struct.newBuilder().putAllFields(flagsMap)) - .build()); - - // should cache results - FlagEvaluationDetails booleanDetails; - FlagEvaluationDetails stringDetails; - FlagEvaluationDetails intDetails; - FlagEvaluationDetails floatDetails; - FlagEvaluationDetails objectDetails; - - // assert cache has been invalidated - booleanDetails = api.getClient().getBooleanDetails(FLAG_KEY_BOOLEAN, false); - assertTrue(booleanDetails.getValue()); - assertEquals(BOOL_VARIANT, booleanDetails.getVariant()); - assertEquals(STATIC_REASON, booleanDetails.getReason()); - - stringDetails = api.getClient().getStringDetails(FLAG_KEY_STRING, "wrong"); - assertEquals(STRING_VALUE, stringDetails.getValue()); - assertEquals(STRING_VARIANT, stringDetails.getVariant()); - assertEquals(STATIC_REASON, stringDetails.getReason()); - - intDetails = api.getClient().getIntegerDetails(FLAG_KEY_INTEGER, 0); - assertEquals(INT_VALUE, intDetails.getValue()); - assertEquals(INT_VARIANT, intDetails.getVariant()); - assertEquals(STATIC_REASON, intDetails.getReason()); - - floatDetails = api.getClient().getDoubleDetails(FLAG_KEY_DOUBLE, 0.1); - assertEquals(DOUBLE_VALUE, floatDetails.getValue()); - assertEquals(DOUBLE_VARIANT, floatDetails.getVariant()); - assertEquals(STATIC_REASON, floatDetails.getReason()); - - objectDetails = api.getClient().getObjectDetails(FLAG_KEY_OBJECT, new Value()); - assertEquals( - INNER_STRUCT_VALUE, - objectDetails - .getValue() - .asStructure() - .asMap() - .get(INNER_STRUCT_KEY) - .asString()); - assertEquals(OBJECT_VARIANT, objectDetails.getVariant()); - assertEquals(STATIC_REASON, objectDetails.getReason()); - } - private void do_resolvers_cache_responses(String reason, Boolean eventStreamAlive, Boolean shouldCache) { String expectedReason = CACHED_REASON; if (!shouldCache) { @@ -665,7 +511,9 @@ private void do_resolvers_cache_responses(String reason, Boolean eventStreamAliv GrpcConnector grpc = mock(GrpcConnector.class); when(grpc.getResolver()).thenReturn(serviceBlockingStubMock); - FlagdProvider provider = createProvider(grpc, () -> eventStreamAlive); + when(grpc.isConnected()).thenReturn(eventStreamAlive); + FlagdProvider provider = createProvider(grpc); + // provider.setState(eventStreamAlive); // caching only available when event // stream is alive OpenFeatureAPI.getInstance().setProviderAndWait(provider); @@ -710,150 +558,6 @@ private void do_resolvers_cache_responses(String reason, Boolean eventStreamAliv assertEquals(expectedReason, objectDetails.getReason()); } - @Test - void disabled_cache() { - ResolveBooleanResponse booleanResponse = ResolveBooleanResponse.newBuilder() - .setValue(true) - .setVariant(BOOL_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveStringResponse stringResponse = ResolveStringResponse.newBuilder() - .setValue(STRING_VALUE) - .setVariant(STRING_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveIntResponse intResponse = ResolveIntResponse.newBuilder() - .setValue(INT_VALUE) - .setVariant(INT_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveFloatResponse floatResponse = ResolveFloatResponse.newBuilder() - .setValue(DOUBLE_VALUE) - .setVariant(DOUBLE_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ResolveObjectResponse objectResponse = ResolveObjectResponse.newBuilder() - .setValue(PROTOBUF_STRUCTURE_VALUE) - .setVariant(OBJECT_VARIANT) - .setReason(STATIC_REASON) - .build(); - - ServiceBlockingStub serviceBlockingStubMock = mock(ServiceBlockingStub.class); - ServiceStub serviceStubMock = mock(ServiceStub.class); - when(serviceStubMock.withWaitForReady()).thenReturn(serviceStubMock); - when(serviceStubMock.withDeadline(any(Deadline.class))).thenReturn(serviceStubMock); - when(serviceBlockingStubMock.withWaitForReady()).thenReturn(serviceBlockingStubMock); - when(serviceBlockingStubMock.withDeadline(any(Deadline.class))).thenReturn(serviceBlockingStubMock); - when(serviceBlockingStubMock.withDeadlineAfter(anyLong(), any(TimeUnit.class))) - .thenReturn(serviceBlockingStubMock); - when(serviceBlockingStubMock.resolveBoolean(argThat(x -> FLAG_KEY_BOOLEAN.equals(x.getFlagKey())))) - .thenReturn(booleanResponse); - when(serviceBlockingStubMock.resolveFloat(argThat(x -> FLAG_KEY_DOUBLE.equals(x.getFlagKey())))) - .thenReturn(floatResponse); - when(serviceBlockingStubMock.resolveInt(argThat(x -> FLAG_KEY_INTEGER.equals(x.getFlagKey())))) - .thenReturn(intResponse); - when(serviceBlockingStubMock.resolveString(argThat(x -> FLAG_KEY_STRING.equals(x.getFlagKey())))) - .thenReturn(stringResponse); - when(serviceBlockingStubMock.resolveObject(argThat(x -> FLAG_KEY_OBJECT.equals(x.getFlagKey())))) - .thenReturn(objectResponse); - - // disabled cache - final Cache cache = new Cache("disabled", 0); - - GrpcConnector grpc; - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService - .when(() -> ServiceGrpc.newBlockingStub(any(Channel.class))) - .thenReturn(serviceBlockingStubMock); - mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(serviceStubMock); - - class NoopInitGrpcConnector extends GrpcConnector { - public NoopInitGrpcConnector( - FlagdOptions options, - Cache cache, - Supplier connectedSupplier, - Consumer onConnectionEvent) { - super(options, cache, connectedSupplier, onConnectionEvent); - } - - public void initialize() throws Exception {} - ; - } - - grpc = new NoopInitGrpcConnector( - FlagdOptions.builder().build(), cache, () -> true, (connectionEvent) -> {}); - } - - FlagdProvider provider = createProvider(grpc, cache, () -> true); - - try { - provider.initialize(null); - } catch (Exception e) { - // ignore exception if any - } - - OpenFeatureAPI.getInstance().setProviderAndWait(provider); - - HashMap flagsMap = new HashMap<>(); - HashMap structMap = new HashMap<>(); - - flagsMap.put( - "foo", - com.google.protobuf.Value.newBuilder() - .setStringValue("foo") - .build()); // assert that a configuration_change event works - - structMap.put( - "flags", - com.google.protobuf.Value.newBuilder() - .setStructValue(Struct.newBuilder().putAllFields(flagsMap)) - .build()); - - // should not cache results - FlagEvaluationDetails booleanDetails = api.getClient().getBooleanDetails(FLAG_KEY_BOOLEAN, false); - FlagEvaluationDetails stringDetails = api.getClient().getStringDetails(FLAG_KEY_STRING, "wrong"); - FlagEvaluationDetails intDetails = api.getClient().getIntegerDetails(FLAG_KEY_INTEGER, 0); - FlagEvaluationDetails floatDetails = api.getClient().getDoubleDetails(FLAG_KEY_DOUBLE, 0.1); - FlagEvaluationDetails objectDetails = api.getClient().getObjectDetails(FLAG_KEY_OBJECT, new Value()); - - // assert values are not cached - booleanDetails = api.getClient().getBooleanDetails(FLAG_KEY_BOOLEAN, false); - assertTrue(booleanDetails.getValue()); - assertEquals(BOOL_VARIANT, booleanDetails.getVariant()); - assertEquals(STATIC_REASON, booleanDetails.getReason()); - - stringDetails = api.getClient().getStringDetails(FLAG_KEY_STRING, "wrong"); - assertEquals(STRING_VALUE, stringDetails.getValue()); - assertEquals(STRING_VARIANT, stringDetails.getVariant()); - assertEquals(STATIC_REASON, stringDetails.getReason()); - - intDetails = api.getClient().getIntegerDetails(FLAG_KEY_INTEGER, 0); - assertEquals(INT_VALUE, intDetails.getValue()); - assertEquals(INT_VARIANT, intDetails.getVariant()); - assertEquals(STATIC_REASON, intDetails.getReason()); - - floatDetails = api.getClient().getDoubleDetails(FLAG_KEY_DOUBLE, 0.1); - assertEquals(DOUBLE_VALUE, floatDetails.getValue()); - assertEquals(DOUBLE_VARIANT, floatDetails.getVariant()); - assertEquals(STATIC_REASON, floatDetails.getReason()); - - objectDetails = api.getClient().getObjectDetails(FLAG_KEY_OBJECT, new Value()); - assertEquals( - INNER_STRUCT_VALUE, - objectDetails - .getValue() - .asStructure() - .asMap() - .get(INNER_STRUCT_KEY) - .asString()); - assertEquals(OBJECT_VARIANT, objectDetails.getVariant()); - assertEquals(STATIC_REASON, objectDetails.getReason()); - } - @Test void initializationAndShutdown() throws Exception { // given @@ -904,7 +608,7 @@ void contextEnrichment() throws Exception { // when our mock resolver initializes, it runs the passed onConnectionEvent // callback doAnswer(invocation -> { - onConnectionEvent.accept(new ConnectionEvent(true, metadata)); + onConnectionEvent.accept(new ConnectionEvent(ConnectionState.CONNECTED, metadata)); return null; }) .when(mock) @@ -944,7 +648,7 @@ void updatesSyncMetadataWithCallback() throws Exception { // when our mock resolver initializes, it runs the passed onConnectionEvent // callback doAnswer(invocation -> { - onConnectionEvent.accept(new ConnectionEvent(true, metadata)); + onConnectionEvent.accept(new ConnectionEvent(ConnectionState.CONNECTED, metadata)); return null; }) .when(mock) @@ -976,23 +680,17 @@ void updatesSyncMetadataWithCallback() throws Exception { } // test helper - - // create provider with given grpc connector - private FlagdProvider createProvider(GrpcConnector grpc) { - return createProvider(grpc, () -> true); - } - // create provider with given grpc provider and state supplier - private FlagdProvider createProvider(GrpcConnector grpc, Supplier getConnected) { + private FlagdProvider createProvider(GrpcConnector grpc) { final Cache cache = new Cache("lru", 5); - return createProvider(grpc, cache, getConnected); + return createProvider(grpc, cache); } // create provider with given grpc provider, cache and state supplier - private FlagdProvider createProvider(GrpcConnector grpc, Cache cache, Supplier getConnected) { + private FlagdProvider createProvider(GrpcConnector grpc, Cache cache) { final FlagdOptions flagdOptions = FlagdOptions.builder().build(); - final GrpcResolver grpcResolver = new GrpcResolver(flagdOptions, cache, getConnected, (connectionEvent) -> {}); + final GrpcResolver grpcResolver = new GrpcResolver(flagdOptions, cache, (connectionEvent) -> {}); final FlagdProvider provider = new FlagdProvider(); diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/reconnect/rpc/FlagdRpcSetup.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/reconnect/rpc/FlagdRpcSetup.java index 6601a5dd3..e913265d5 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/reconnect/rpc/FlagdRpcSetup.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/reconnect/rpc/FlagdRpcSetup.java @@ -31,6 +31,7 @@ public static void setupTest() throws InterruptedException { .resolverType(Config.Resolver.RPC) .port(flagdContainer.getFirstMappedPort()) .deadline(1000) + .retryGracePeriod(1) .streamDeadlineMs(0) // this makes reconnect tests more predictable .cacheType(CacheType.DISABLED.getValue()) .build()); diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/steps/config/ConfigSteps.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/steps/config/ConfigSteps.java index bee803a11..a680a9947 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/steps/config/ConfigSteps.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/e2e/steps/config/ConfigSteps.java @@ -27,7 +27,6 @@ public class ConfigSteps { public static final List IGNORED_FOR_NOW = new ArrayList() { { add("offlinePollIntervalMs"); - add("retryGraceAttempts"); add("retryBackoffMaxMs"); } }; diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilderTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilderTest.java new file mode 100644 index 000000000..554310f76 --- /dev/null +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilderTest.java @@ -0,0 +1,159 @@ +package dev.openfeature.contrib.providers.flagd.resolver.common; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import dev.openfeature.contrib.providers.flagd.FlagdOptions; +import io.grpc.ManagedChannel; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.NettyChannelBuilder; +import io.netty.channel.epoll.Epoll; +import io.netty.channel.epoll.EpollDomainSocketChannel; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.unix.DomainSocketAddress; +import io.netty.handler.ssl.SslContextBuilder; +import java.io.File; +import java.util.concurrent.TimeUnit; +import javax.net.ssl.SSLKeyException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnOs; +import org.junit.jupiter.api.condition.OS; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.MockedStatic; + +class ChannelBuilderTest { + + @Test + @EnabledOnOs(OS.LINUX) + void testNettyChannel_withSocketPath() { + try (MockedStatic epollMock = mockStatic(Epoll.class); + MockedStatic nettyMock = mockStatic(NettyChannelBuilder.class)) { + + // Mocks + epollMock.when(Epoll::isAvailable).thenReturn(true); + NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class); + ManagedChannel mockChannel = mock(ManagedChannel.class); + + nettyMock + .when(() -> NettyChannelBuilder.forAddress(any(DomainSocketAddress.class))) + .thenReturn(mockBuilder); + + when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder); + when(mockBuilder.eventLoopGroup(any(EpollEventLoopGroup.class))).thenReturn(mockBuilder); + when(mockBuilder.channelType(EpollDomainSocketChannel.class)).thenReturn(mockBuilder); + when(mockBuilder.usePlaintext()).thenReturn(mockBuilder); + when(mockBuilder.build()).thenReturn(mockChannel); + + // Input options + FlagdOptions options = FlagdOptions.builder() + .socketPath("/path/to/socket") + .keepAlive(1000) + .build(); + + // Call method under test + ManagedChannel channel = ChannelBuilder.nettyChannel(options); + + // Assertions + assertThat(channel).isEqualTo(mockChannel); + nettyMock.verify(() -> NettyChannelBuilder.forAddress(new DomainSocketAddress("/path/to/socket"))); + verify(mockBuilder).keepAliveTime(1000, TimeUnit.MILLISECONDS); + verify(mockBuilder).eventLoopGroup(any(EpollEventLoopGroup.class)); + verify(mockBuilder).channelType(EpollDomainSocketChannel.class); + verify(mockBuilder).usePlaintext(); + verify(mockBuilder).build(); + } + } + + @Test + void testNettyChannel_withTlsAndCert() { + try (MockedStatic nettyMock = mockStatic(NettyChannelBuilder.class)) { + // Mocks + NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class); + ManagedChannel mockChannel = mock(ManagedChannel.class); + nettyMock + .when(() -> NettyChannelBuilder.forTarget("localhost:8080")) + .thenReturn(mockBuilder); + + when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder); + when(mockBuilder.sslContext(any())).thenReturn(mockBuilder); + when(mockBuilder.build()).thenReturn(mockChannel); + + File mockCert = mock(File.class); + when(mockCert.exists()).thenReturn(true); + String path = "test-harness/ssl/custom-root-cert.crt"; + + File file = new File(path); + String absolutePath = file.getAbsolutePath(); + // Input options + FlagdOptions options = FlagdOptions.builder() + .host("localhost") + .port(8080) + .keepAlive(5000) + .tls(true) + .certPath(absolutePath) + .build(); + + // Call method under test + ManagedChannel channel = ChannelBuilder.nettyChannel(options); + + // Assertions + assertThat(channel).isEqualTo(mockChannel); + nettyMock.verify(() -> NettyChannelBuilder.forTarget("localhost:8080")); + verify(mockBuilder).keepAliveTime(5000, TimeUnit.MILLISECONDS); + verify(mockBuilder).sslContext(any()); + verify(mockBuilder).build(); + } + } + + @ParameterizedTest + @ValueSource(strings = {"/incorrect/{uri}/;)"}) + void testNettyChannel_withInvalidTargetUri(String uri) { + FlagdOptions options = FlagdOptions.builder().targetUri(uri).build(); + + assertThatThrownBy(() -> ChannelBuilder.nettyChannel(options)) + .isInstanceOf(GenericConfigException.class) + .hasMessageContaining("Error with gRPC target string configuration"); + } + + @Test + void testNettyChannel_epollNotAvailable() { + try (MockedStatic epollMock = mockStatic(Epoll.class)) { + epollMock.when(Epoll::isAvailable).thenReturn(false); + + FlagdOptions options = + FlagdOptions.builder().socketPath("/path/to/socket").build(); + + assertThatThrownBy(() -> ChannelBuilder.nettyChannel(options)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("unix socket cannot be used"); + } + } + + @Test + void testNettyChannel_sslException() throws Exception { + try (MockedStatic nettyMock = mockStatic(NettyChannelBuilder.class)) { + NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class); + nettyMock.when(() -> NettyChannelBuilder.forTarget(anyString())).thenReturn(mockBuilder); + try (MockedStatic sslmock = mockStatic(GrpcSslContexts.class)) { + SslContextBuilder sslMockBuilder = mock(SslContextBuilder.class); + sslmock.when(GrpcSslContexts::forClient).thenReturn(sslMockBuilder); + when(sslMockBuilder.build()).thenThrow(new SSLKeyException("Test SSL error")); + when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder); + + FlagdOptions options = FlagdOptions.builder().tls(true).build(); + + assertThatThrownBy(() -> ChannelBuilder.nettyChannel(options)) + .isInstanceOf(SslConfigException.class) + .hasMessageContaining("Error with SSL configuration"); + } + } + } +} diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserverTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserverTest.java index 061910af4..9370f821a 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserverTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/EventStreamObserverTest.java @@ -1,13 +1,11 @@ package dev.openfeature.contrib.providers.flagd.resolver.grpc; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -16,12 +14,9 @@ import com.google.protobuf.Value; import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; import dev.openfeature.flagd.grpc.evaluation.Evaluation.EventStreamResponse; -import io.grpc.Status; -import io.grpc.StatusRuntimeException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; -import java.util.function.Supplier; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -36,19 +31,14 @@ class StateChange { EventStreamObserver stream; Runnable reconnect; Object sync; - Supplier shouldRetrySilently; @BeforeEach void setUp() { states = new ArrayList<>(); - sync = new Object(); cache = mock(Cache.class); reconnect = mock(Runnable.class); when(cache.getEnabled()).thenReturn(true); - shouldRetrySilently = mock(Supplier.class); - when(shouldRetrySilently.get()) - .thenReturn(true, false); // 1st time we should retry silently, subsequent calls should not - stream = new EventStreamObserver(sync, cache, (state, changed) -> states.add(state), shouldRetrySilently); + stream = new EventStreamObserver(cache, (state, changed) -> states.add(state)); } @Test @@ -66,47 +56,6 @@ public void change() { verify(cache, atLeast(1)).clear(); } - @Test - public void ready() { - EventStreamResponse resp = mock(EventStreamResponse.class); - when(resp.getType()).thenReturn("provider_ready"); - stream.onNext(resp); - // we notify that we are ready - assertEquals(1, states.size()); - assertTrue(states.get(0)); - // cache was cleaned - verify(cache, atLeast(1)).clear(); - } - - @Test - public void noReconnectionOnFirstError() { - stream.onError(new Throwable("error")); - // we flush the cache - verify(cache, never()).clear(); - // we notify the error - assertEquals(0, states.size()); - } - - @Test - public void reconnections() { - stream.onError(new Throwable("error 1")); - stream.onError(new Throwable("error 2")); - // we flush the cache - verify(cache, atLeast(1)).clear(); - // we notify the error - assertEquals(1, states.size()); - assertFalse(states.get(0)); - } - - @Test - public void deadlineExceeded() { - stream.onError(new StatusRuntimeException(Status.DEADLINE_EXCEEDED)); - // we flush the cache - verify(cache, never()).clear(); - // we notify the error - assertEquals(0, states.size()); - } - @Test public void cacheBustingForKnownKeys() { final String key1 = "myKey1"; diff --git a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnectorTest.java b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnectorTest.java index f202b591c..c76963ad4 100644 --- a/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnectorTest.java +++ b/providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/grpc/GrpcConnectorTest.java @@ -1,497 +1,132 @@ package dev.openfeature.contrib.providers.flagd.resolver.grpc; -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.argThat; -import static org.mockito.Mockito.*; - +import com.google.common.collect.Lists; import dev.openfeature.contrib.providers.flagd.FlagdOptions; import dev.openfeature.contrib.providers.flagd.resolver.common.ConnectionEvent; -import dev.openfeature.contrib.providers.flagd.resolver.grpc.cache.Cache; -import dev.openfeature.flagd.grpc.evaluation.Evaluation.EventStreamResponse; +import dev.openfeature.flagd.grpc.evaluation.Evaluation; import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc; -import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceBlockingStub; -import dev.openfeature.flagd.grpc.evaluation.ServiceGrpc.ServiceStub; -import io.grpc.Channel; -import io.grpc.Status; -import io.grpc.StatusRuntimeException; -import io.grpc.netty.NettyChannelBuilder; -import io.netty.channel.EventLoopGroup; -import io.netty.channel.epoll.EpollEventLoopGroup; -import io.netty.channel.unix.DomainSocketAddress; -import java.lang.reflect.Field; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Server; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.util.ArrayList; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.condition.EnabledOnOs; -import org.junit.jupiter.api.condition.OS; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import org.junitpioneer.jupiter.SetEnvironmentVariable; -import org.mockito.MockedConstruction; -import org.mockito.MockedStatic; -import org.mockito.invocation.InvocationOnMock; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; class GrpcConnectorTest { - public static final String HOST = "server.com"; - public static final int PORT = 4321; - public static final String SOCKET_PATH = "/some/other/path"; - - @ParameterizedTest - @ValueSource(ints = {1, 2, 3}) - void validate_retry_calls(int retries) throws Exception { - final int backoffMs = 100; - - final FlagdOptions options = FlagdOptions.builder() - // shorter backoff for testing - .retryBackoffMs(backoffMs) - .maxEventStreamRetries(retries) - .build(); - - final Cache cache = new Cache("disabled", 0); - - final ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - doAnswer(invocation -> null).when(mockStub).eventStream(any(), any()); - - final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, (connectionEvent) -> {}); - - Field serviceStubField = GrpcConnector.class.getDeclaredField("serviceStub"); - serviceStubField.setAccessible(true); - serviceStubField.set(connector, mockStub); + private ManagedChannel testChannel; + private Server testServer; + private static final boolean CONNECTED = true; + private static final boolean DISCONNECTED = false; - final Object syncObject = new Object(); + @Mock + private EventStreamObserver mockEventStreamObserver; - Field syncField = GrpcConnector.class.getDeclaredField("sync"); - syncField.setAccessible(true); - syncField.set(connector, syncObject); - - assertDoesNotThrow(connector::initialize); - - for (int i = 1; i < retries; i++) { - // verify invocation with enough timeout value - verify(mockStub, timeout(2L * i * backoffMs).times(i)).eventStream(any(), any()); - - synchronized (syncObject) { - syncObject.notify(); - } + private final ServiceGrpc.ServiceImplBase testServiceImpl = new ServiceGrpc.ServiceImplBase() { + @Override + public void eventStream( + Evaluation.EventStreamRequest request, + StreamObserver responseObserver) { + // noop } - } - - @Test - void initialization_succeed_with_connected_status() { - final Cache cache = new Cache("disabled", 0); - final ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - Consumer onConnectionEvent = mock(Consumer.class); - doAnswer((InvocationOnMock invocation) -> { - EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1); - eventStreamObserver.onNext(EventStreamResponse.newBuilder() - .setType(Constants.PROVIDER_READY) - .build()); - return null; - }) - .when(mockStub) - .eventStream(any(), any()); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub); + }; - // pass true in connected lambda - final GrpcConnector connector = new GrpcConnector( - FlagdOptions.builder().build(), - cache, - () -> { - try { - Thread.sleep(100); - return true; - } catch (Exception e) { - } - return false; - }, - onConnectionEvent); - - assertDoesNotThrow(connector::initialize); - // assert that onConnectionEvent is connected - verify(onConnectionEvent).accept(argThat(arg -> arg.isConnected())); - } + @BeforeEach + void setUp() throws Exception { + MockitoAnnotations.openMocks(this); + setupTestGrpcServer(); } - @Test - void stream_does_not_fail_on_first_error() { - final Cache cache = new Cache("disabled", 0); - final ServiceStub mockStub = createServiceStubMock(); - Consumer onConnectionEvent = mock(Consumer.class); - doAnswer((InvocationOnMock invocation) -> { - EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1); - eventStreamObserver.onError(new Exception("fake")); - return null; - }) - .when(mockStub) - .eventStream(any(), any()); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub); + private void setupTestGrpcServer() throws IOException { + testServer = + NettyServerBuilder.forPort(8080).addService(testServiceImpl).build(); + testServer.start(); - // pass true in connected lambda - final GrpcConnector connector = new GrpcConnector( - FlagdOptions.builder().build(), - cache, - () -> { - try { - Thread.sleep(100); - return true; - } catch (Exception e) { - } - return false; - }, - onConnectionEvent); - - assertDoesNotThrow(connector::initialize); - // assert that onConnectionEvent is connected gets not called - verify(onConnectionEvent, timeout(300).times(0)).accept(any()); + if (testChannel == null) { + testChannel = ManagedChannelBuilder.forAddress("localhost", 8080) + .usePlaintext() + .build(); } } - @Test - void stream_fails_on_second_error_in_a_row() throws Exception { - final FlagdOptions options = FlagdOptions.builder() - // shorter backoff for testing - .retryBackoffMs(0) - .build(); - - final Cache cache = new Cache("disabled", 0); - Consumer onConnectionEvent = mock(Consumer.class); - - final ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - doAnswer((InvocationOnMock invocation) -> { - EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1); - eventStreamObserver.onError(new Exception("fake")); - return null; - }) - .when(mockStub) - .eventStream(any(), any()); - - final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, onConnectionEvent); - - Field serviceStubField = GrpcConnector.class.getDeclaredField("serviceStub"); - serviceStubField.setAccessible(true); - serviceStubField.set(connector, mockStub); - - final Object syncObject = new Object(); - - Field syncField = GrpcConnector.class.getDeclaredField("sync"); - syncField.setAccessible(true); - syncField.set(connector, syncObject); - - assertDoesNotThrow(connector::initialize); - - // 1st try - verify(mockStub, timeout(300).times(1)).eventStream(any(), any()); - verify(onConnectionEvent, timeout(300).times(0)).accept(any()); - synchronized (syncObject) { - syncObject.notify(); - } - - // 2nd try - verify(mockStub, timeout(300).times(2)).eventStream(any(), any()); - verify(onConnectionEvent, timeout(300).times(1)).accept(argThat(arg -> !arg.isConnected())); + @AfterEach + void tearDown() throws Exception { + tearDownGrpcServer(); } - @Test - void stream_does_not_fail_when_message_between_errors() throws Exception { - final FlagdOptions options = FlagdOptions.builder() - // shorter backoff for testing - .retryBackoffMs(0) - .build(); - - final Cache cache = new Cache("disabled", 0); - Consumer onConnectionEvent = mock(Consumer.class); - - final AtomicBoolean successMessage = new AtomicBoolean(false); - final ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - doAnswer((InvocationOnMock invocation) -> { - EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1); - - if (successMessage.get()) { - eventStreamObserver.onNext(EventStreamResponse.newBuilder() - .setType(Constants.PROVIDER_READY) - .build()); - } else { - eventStreamObserver.onError(new Exception("fake")); - } - return null; - }) - .when(mockStub) - .eventStream(any(), any()); - - final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, onConnectionEvent); - - Field serviceStubField = GrpcConnector.class.getDeclaredField("serviceStub"); - serviceStubField.setAccessible(true); - serviceStubField.set(connector, mockStub); - - final Object syncObject = new Object(); - - Field syncField = GrpcConnector.class.getDeclaredField("sync"); - syncField.setAccessible(true); - syncField.set(connector, syncObject); - - assertDoesNotThrow(connector::initialize); - - // 1st message with error - verify(mockStub, timeout(300).times(1)).eventStream(any(), any()); - verify(onConnectionEvent, timeout(300).times(0)).accept(any()); - - synchronized (syncObject) { - successMessage.set(true); - syncObject.notify(); - } - - // 2nd message with provider ready - verify(mockStub, timeout(300).times(2)).eventStream(any(), any()); - verify(onConnectionEvent, timeout(300).times(1)).accept(argThat(arg -> arg.isConnected())); - synchronized (syncObject) { - successMessage.set(false); - syncObject.notify(); + private void tearDownGrpcServer() throws InterruptedException { + if (testServer != null) { + testServer.shutdownNow(); + testServer.awaitTermination(); } - - // 3nd message with error - verify(mockStub, timeout(300).times(2)).eventStream(any(), any()); - verify(onConnectionEvent, timeout(300).times(0)).accept(argThat(arg -> !arg.isConnected())); } @Test - void stream_does_not_fail_with_deadline_error() throws Exception { - final Cache cache = new Cache("disabled", 0); - final ServiceStub mockStub = createServiceStubMock(); - Consumer onConnectionEvent = mock(Consumer.class); - doAnswer((InvocationOnMock invocation) -> { - EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1); - eventStreamObserver.onError(new StatusRuntimeException(Status.DEADLINE_EXCEEDED)); - return null; - }) - .when(mockStub) - .eventStream(any(), any()); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub); - - // pass true in connected lambda - final GrpcConnector connector = new GrpcConnector( - FlagdOptions.builder().build(), - cache, - () -> { - try { - Thread.sleep(100); - return true; - } catch (Exception e) { - } - return false; - }, - onConnectionEvent); - - assertDoesNotThrow(connector::initialize); - // this should not call the connection event - verify(onConnectionEvent, never()).accept(any()); - } + void whenShuttingDownAndRestartingGrpcServer_ConsumerReceivesDisconnectedAndConnectedEvent() throws Exception { + CountDownLatch sync = new CountDownLatch(2); + ArrayList connectionStateChanges = Lists.newArrayList(); + Consumer testConsumer = event -> { + connectionStateChanges.add(event.isConnected()); + sync.countDown(); + }; + + GrpcConnector instance = new GrpcConnector<>( + FlagdOptions.builder().build(), + ServiceGrpc::newStub, + ServiceGrpc::newBlockingStub, + testConsumer, + stub -> stub.eventStream(Evaluation.EventStreamRequest.getDefaultInstance(), mockEventStreamObserver), + testChannel); + + instance.initialize(); + + // when shutting down server + testServer.shutdown(); + testServer.awaitTermination(1, TimeUnit.SECONDS); + + // when restarting server + setupTestGrpcServer(); + + // then consumer received DISCONNECTED and CONNECTED event + boolean finished = sync.await(10, TimeUnit.SECONDS); + Assertions.assertTrue(finished); + Assertions.assertEquals(Lists.newArrayList(DISCONNECTED, CONNECTED), connectionStateChanges); } @Test - void host_and_port_arg_should_build_tcp_socket() { - final String host = "host.com"; - final int port = 1234; - final String targetUri = String.format("%s:%s", host, port); - - ServiceGrpc.ServiceBlockingStub mockBlockingStub = mock(ServiceGrpc.ServiceBlockingStub.class); - ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket(); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService - .when(() -> ServiceGrpc.newBlockingStub(any(Channel.class))) - .thenReturn(mockBlockingStub); - mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub); - - try (MockedStatic mockStaticChannelBuilder = mockStatic(NettyChannelBuilder.class)) { - - mockStaticChannelBuilder - .when(() -> NettyChannelBuilder.forTarget(anyString())) - .thenReturn(mockChannelBuilder); - - final FlagdOptions flagdOptions = - FlagdOptions.builder().host(host).port(port).tls(false).build(); - new GrpcConnector(flagdOptions, null, null, null); - - // verify host/port matches - mockStaticChannelBuilder.verify( - () -> NettyChannelBuilder.forTarget(String.format(targetUri)), times(1)); - } - } - } - - @Test - @SetEnvironmentVariable(key = "FLAGD_HOST", value = HOST) - @SetEnvironmentVariable(key = "FLAGD_PORT", value = "" + PORT) - void no_args_host_and_port_env_set_should_build_tcp_socket() throws Exception { - final String targetUri = String.format("%s:%s", HOST, PORT); - - ServiceGrpc.ServiceBlockingStub mockBlockingStub = mock(ServiceGrpc.ServiceBlockingStub.class); - ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket(); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService - .when(() -> ServiceGrpc.newBlockingStub(any(Channel.class))) - .thenReturn(mockBlockingStub); - mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub); - - try (MockedStatic mockStaticChannelBuilder = mockStatic(NettyChannelBuilder.class)) { - - mockStaticChannelBuilder - .when(() -> NettyChannelBuilder.forTarget(anyString())) - .thenReturn(mockChannelBuilder); - - new GrpcConnector(FlagdOptions.builder().build(), null, null, null); - - // verify host/port matches & called times(= 1 as we rely on reusable channel) - mockStaticChannelBuilder.verify(() -> NettyChannelBuilder.forTarget(targetUri), times(1)); - } - } - } - - /** - * OS Specific test - This test is valid only on Linux system as it rely on - * epoll availability - */ - @Test - @EnabledOnOs(OS.LINUX) - void path_arg_should_build_domain_socket_with_correct_path() { - final String path = "/some/path"; - - ServiceGrpc.ServiceBlockingStub mockBlockingStub = mock(ServiceGrpc.ServiceBlockingStub.class); - ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket(); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService - .when(() -> ServiceGrpc.newBlockingStub(any(Channel.class))) - .thenReturn(mockBlockingStub); - mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub); - - try (MockedStatic mockStaticChannelBuilder = mockStatic(NettyChannelBuilder.class)) { - - try (MockedConstruction mockEpollEventLoopGroup = - mockConstruction(EpollEventLoopGroup.class, (mock, context) -> {})) { - when(NettyChannelBuilder.forAddress(any(DomainSocketAddress.class))) - .thenReturn(mockChannelBuilder); - - new GrpcConnector(FlagdOptions.builder().socketPath(path).build(), null, null, null); - - // verify path matches - mockStaticChannelBuilder.verify( - () -> NettyChannelBuilder.forAddress(argThat((DomainSocketAddress d) -> { - assertEquals(path, d.path()); // path should match - return true; - })), - times(1)); - } - } - } - } - - /** - * OS Specific test - This test is valid only on Linux system as it rely on - * epoll availability - */ - @Test - @EnabledOnOs(OS.LINUX) - @SetEnvironmentVariable(key = "FLAGD_SOCKET_PATH", value = SOCKET_PATH) - void no_args_socket_env_should_build_domain_socket_with_correct_path() throws Exception { - - ServiceBlockingStub mockBlockingStub = mock(ServiceBlockingStub.class); - ServiceStub mockStub = mock(ServiceStub.class); - NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket(); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService - .when(() -> ServiceGrpc.newBlockingStub(any(Channel.class))) - .thenReturn(mockBlockingStub); - mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub); - - try (MockedStatic mockStaticChannelBuilder = mockStatic(NettyChannelBuilder.class)) { - - try (MockedConstruction mockEpollEventLoopGroup = - mockConstruction(EpollEventLoopGroup.class, (mock, context) -> {})) { - mockStaticChannelBuilder - .when(() -> NettyChannelBuilder.forAddress(any(DomainSocketAddress.class))) - .thenReturn(mockChannelBuilder); - - new GrpcConnector(FlagdOptions.builder().build(), null, null, null); - - // verify path matches & called times(= 1 as we rely on reusable channel) - mockStaticChannelBuilder.verify( - () -> NettyChannelBuilder.forAddress(argThat((DomainSocketAddress d) -> { - assertEquals(SOCKET_PATH, d.path()); // path should match - return true; - })), - times(1)); - } - } - } - } - - @Test - void initialization_with_stream_deadline() throws NoSuchFieldException, IllegalAccessException { - final FlagdOptions options = - FlagdOptions.builder().streamDeadlineMs(16983).build(); - - final Cache cache = new Cache("disabled", 0); - final ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub); - - final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, null); - - assertDoesNotThrow(connector::initialize); - verify(mockStub).withDeadlineAfter(16983, TimeUnit.MILLISECONDS); - } - } - - @Test - void initialization_without_stream_deadline() throws NoSuchFieldException, IllegalAccessException { - final FlagdOptions options = FlagdOptions.builder().streamDeadlineMs(0).build(); - - final Cache cache = new Cache("disabled", 0); - final ServiceGrpc.ServiceStub mockStub = createServiceStubMock(); - - try (MockedStatic mockStaticService = mockStatic(ServiceGrpc.class)) { - mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub); - - final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, null); - - assertDoesNotThrow(connector::initialize); - verify(mockStub, never()).withDeadlineAfter(16983, TimeUnit.MILLISECONDS); - } - } - - private static ServiceStub createServiceStubMock() { - final ServiceStub mockStub = mock(ServiceStub.class); - when(mockStub.withDeadlineAfter(anyLong(), any())).thenReturn(mockStub); - return mockStub; - } - - private NettyChannelBuilder getMockChannelBuilderSocket() { - NettyChannelBuilder mockChannelBuilder = mock(NettyChannelBuilder.class); - when(mockChannelBuilder.eventLoopGroup(any(EventLoopGroup.class))).thenReturn(mockChannelBuilder); - when(mockChannelBuilder.channelType(any(Class.class))).thenReturn(mockChannelBuilder); - when(mockChannelBuilder.usePlaintext()).thenReturn(mockChannelBuilder); - when(mockChannelBuilder.keepAliveTime(anyLong(), any())).thenReturn(mockChannelBuilder); - when(mockChannelBuilder.build()).thenReturn(null); - return mockChannelBuilder; + void whenShuttingDownGrpcConnector_ConsumerReceivesDisconnectedEvent() throws Exception { + CountDownLatch sync = new CountDownLatch(1); + ArrayList connectionStateChanges = Lists.newArrayList(); + Consumer testConsumer = event -> { + connectionStateChanges.add(event.isConnected()); + sync.countDown(); + }; + + GrpcConnector instance = new GrpcConnector<>( + FlagdOptions.builder().build(), + ServiceGrpc::newStub, + ServiceGrpc::newBlockingStub, + testConsumer, + stub -> stub.eventStream(Evaluation.EventStreamRequest.getDefaultInstance(), mockEventStreamObserver), + testChannel); + + instance.initialize(); + // when shutting grpc connector + instance.shutdown(); + + // then consumer received DISCONNECTED and CONNECTED event + boolean finished = sync.await(10, TimeUnit.SECONDS); + Assertions.assertTrue(finished); + Assertions.assertEquals(Lists.newArrayList(DISCONNECTED), connectionStateChanges); } } diff --git a/providers/flagd/test-harness b/providers/flagd/test-harness index fd66a39e1..8931c8645 160000 --- a/providers/flagd/test-harness +++ b/providers/flagd/test-harness @@ -1 +1 @@ -Subproject commit fd66a39e1192409f60cb388d443f821364ee9af4 +Subproject commit 8931c8645b8600e251d5e3ebbad42dff8ce4c78e