diff --git a/core/pkg/evaluator/fractional_test.go b/core/pkg/evaluator/fractional_test.go index fc431cbd5..debdf16d0 100644 --- a/core/pkg/evaluator/fractional_test.go +++ b/core/pkg/evaluator/fractional_test.go @@ -461,7 +461,7 @@ func TestFractionalEvaluation(t *testing.T) { je := NewJSON(log, store.NewFlags()) je.store.Flags = tt.flags.Flags - value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant) + value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant, ResolverConfiguration{}) if value != tt.expectedValue { t.Errorf("expected value '%s', got '%s'", tt.expectedValue, value) @@ -590,7 +590,7 @@ func BenchmarkFractionalEvaluation(b *testing.B) { je := NewJSON(log, &store.Flags{Flags: tt.flags.Flags}) for i := 0; i < b.N; i++ { value, variant, reason, _, err := resolve[string]( - ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant) + ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant, ResolverConfiguration{}) if value != tt.expectedValue { b.Errorf("expected value '%s', got '%s'", tt.expectedValue, value) diff --git a/core/pkg/evaluator/json.go b/core/pkg/evaluator/json.go index 655eaf75c..9680d6b29 100644 --- a/core/pkg/evaluator/json.go +++ b/core/pkg/evaluator/json.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "reflect" "regexp" "strconv" "strings" @@ -33,6 +34,7 @@ const ( // evaluation if the user did not supply the optional bucketing property. targetingKeyKey = "targetingKey" Disabled = "DISABLED" + TypeMetadataKey = "type" ) var regBrace *regexp.Regexp @@ -138,11 +140,25 @@ func (je *JSON) SetState(payload sync.DataSync) (map[string]interface{}, bool, e return events, reSync, nil } +type TypeCheckingBehavior string + +const ( + ErrorOnTypeMismatch TypeCheckingBehavior = "ErrorOnTypeMismatch" + PanicOnTypeMismatch TypeCheckingBehavior = "PanicOnTypeMismatch" + VerifyAllVariantsAndErrorOnTypeMismatch TypeCheckingBehavior = "VerifyAllVariantsAndErrorOnTypeMismatch" + VerifyAllVariantsAndPanicOnTypeMismatch TypeCheckingBehavior = "VerifyAllVariantsAndPanicOnTypeMismatch" +) + +type ResolverConfiguration struct { + TypeCheckingBehavior TypeCheckingBehavior +} + // Resolver implementation for flagd flags. This resolver should be kept reusable, hence must interact with interfaces. type Resolver struct { store store.IStore Logger *logger.Logger tracer trace.Tracer + ResolverConfiguration } func NewResolver(store store.IStore, logger *logger.Logger, jsonEvalTracer trace.Tracer) Resolver { @@ -153,7 +169,7 @@ func NewResolver(store store.IStore, logger *logger.Logger, jsonEvalTracer trace jsonlogic.AddOperator(SemVerEvaluationName, NewSemVerComparison(logger).SemVerEvaluation) jsonlogic.AddOperator(LegacyFractionEvaluationName, NewLegacyFractional(logger).LegacyFractionalEvaluation) - return Resolver{store: store, Logger: logger, tracer: jsonEvalTracer} + return Resolver{store: store, Logger: logger, tracer: jsonEvalTracer, ResolverConfiguration: ResolverConfiguration{}} } func (je *Resolver) ResolveAllValues(ctx context.Context, reqID string, context map[string]any) ([]AnyValue, error) { @@ -181,13 +197,13 @@ func (je *Resolver) ResolveAllValues(ctx context.Context, reqID string, context defaultValue := flag.Variants[flag.DefaultVariant] switch defaultValue.(type) { case bool: - value, variant, reason, metadata, err = resolve[bool](ctx, reqID, flagKey, context, je.evaluateVariant) + value, variant, reason, metadata, err = resolve[bool](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration) case string: - value, variant, reason, metadata, err = resolve[string](ctx, reqID, flagKey, context, je.evaluateVariant) + value, variant, reason, metadata, err = resolve[string](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration) case float64: - value, variant, reason, metadata, err = resolve[float64](ctx, reqID, flagKey, context, je.evaluateVariant) + value, variant, reason, metadata, err = resolve[float64](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration) case map[string]any: - value, variant, reason, metadata, err = resolve[map[string]any](ctx, reqID, flagKey, context, je.evaluateVariant) + value, variant, reason, metadata, err = resolve[map[string]any](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration) } if err != nil { je.Logger.ErrorWithID(reqID, fmt.Sprintf("bulk evaluation: key: %s returned error: %s", flagKey, err.Error())) @@ -210,7 +226,7 @@ func (je *Resolver) ResolveBooleanValue( defer span.End() je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating boolean flag: %s", flagKey)) - return resolve[bool](ctx, reqID, flagKey, context, je.evaluateVariant) + return resolve[bool](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration) } func (je *Resolver) ResolveStringValue( @@ -225,7 +241,7 @@ func (je *Resolver) ResolveStringValue( defer span.End() je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating string flag: %s", flagKey)) - return resolve[string](ctx, reqID, flagKey, context, je.evaluateVariant) + return resolve[string](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration) } func (je *Resolver) ResolveFloatValue( @@ -240,7 +256,7 @@ func (je *Resolver) ResolveFloatValue( defer span.End() je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating float flag: %s", flagKey)) - value, variant, reason, metadata, err = resolve[float64](ctx, reqID, flagKey, context, je.evaluateVariant) + value, variant, reason, metadata, err = resolve[float64](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration) return } @@ -256,7 +272,7 @@ func (je *Resolver) ResolveIntValue(ctx context.Context, reqID string, flagKey s je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating int flag: %s", flagKey)) var val float64 - val, variant, reason, metadata, err = resolve[float64](ctx, reqID, flagKey, context, je.evaluateVariant) + val, variant, reason, metadata, err = resolve[float64](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration) value = int64(val) return } @@ -273,7 +289,7 @@ func (je *Resolver) ResolveObjectValue( defer span.End() je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating object flag: %s", flagKey)) - return resolve[map[string]any](ctx, reqID, flagKey, context, je.evaluateVariant) + return resolve[map[string]any](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration) } func (je *Resolver) ResolveAsAnyValue( @@ -286,19 +302,55 @@ func (je *Resolver) ResolveAsAnyValue( defer span.End() je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating flag `%s` as a generic flag", flagKey)) - value, variant, reason, meta, err := resolve[interface{}](ctx, reqID, flagKey, context, je.evaluateVariant) + value, variant, reason, meta, err := resolve[interface{}](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration) return NewAnyValue(value, variant, reason, flagKey, meta, err) } +func verifyAllVariants(expectedType reflect.Type, variants map[string]interface{}) (mismatchedVariants map[string]string, errmsg string) { + mismatchedVariants = map[string]string{} + errmsg = "" + + for name, value := range variants { + if expectedType != reflect.TypeOf(value) { + mismatchedVariants[name] = reflect.TypeOf(value).String() + } + } + if len(mismatchedVariants) > 0 { + errmsg = fmt.Sprintf("%s: Mismatched variants exist! Expected flag type: %s. Mismatched variants' types: %v", model.TypeCheckingError, expectedType, mismatchedVariants) + } + return mismatchedVariants, errmsg +} + // resolve is a helper for generic flag resolving func resolve[T constraints](ctx context.Context, reqID string, key string, context map[string]any, - variantEval variantEvaluator) (value T, variant string, reason string, metadata map[string]interface{}, err error, + variantEval variantEvaluator, resolverConfig ResolverConfiguration) (value T, variant string, reason string, metadata map[string]interface{}, err error, ) { variant, variants, reason, metadata, err := variantEval(ctx, reqID, key, context) if err != nil { return value, variant, reason, metadata, err } + switch resolverConfig.TypeCheckingBehavior { + case ErrorOnTypeMismatch: + if metadata[TypeMetadataKey] != reflect.TypeOf(*new(T)) { + return value, variant, reason, metadata, errors.New(model.TypeMismatchErrorCode) + } + case PanicOnTypeMismatch: + if metadata[TypeMetadataKey] != reflect.TypeOf(*new(T)) { + panic(model.TypeMismatchErrorCode) + } + case VerifyAllVariantsAndErrorOnTypeMismatch: + mismatchedVariants, errmsg := verifyAllVariants(metadata[TypeMetadataKey].(reflect.Type), variants) + if len(mismatchedVariants) > 0 { + return value, variant, reason, metadata, errors.New(errmsg) + } + case VerifyAllVariantsAndPanicOnTypeMismatch: + mismatchedVariants, errmsg := verifyAllVariants(metadata[TypeMetadataKey].(reflect.Type), variants) + if len(mismatchedVariants) > 0 { + panic(errmsg) + } + } + var ok bool value, ok = variants[variant].(T) if !ok { @@ -321,6 +373,13 @@ func (je *Resolver) evaluateVariant(ctx context.Context, reqID string, flagKey s return "", map[string]interface{}{}, model.ErrorReason, metadata, errors.New(model.FlagNotFoundErrorCode) } + if value, ok := flag.Variants[flag.DefaultVariant]; ok { + metadata[TypeMetadataKey] = reflect.TypeOf(value) + } else { + je.Logger.ErrorWithID(reqID, fmt.Sprintf("Error inferring type for flag %s. defaultVariant (%s) not found in variants (%v)", flagKey, flag.DefaultVariant, flag.Variants)) + return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.ErrorReason) + } + // add selector to evaluation metadata selector := je.store.SelectorForFlag(ctx, flag) if selector != "" { diff --git a/core/pkg/evaluator/legacy_fractional_test.go b/core/pkg/evaluator/legacy_fractional_test.go index 04c4792bc..59e0953e0 100644 --- a/core/pkg/evaluator/legacy_fractional_test.go +++ b/core/pkg/evaluator/legacy_fractional_test.go @@ -275,7 +275,7 @@ func TestLegacyFractionalEvaluation(t *testing.T) { je := NewJSON(log, store.NewFlags()) je.store.Flags = tt.flags.Flags - value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant) + value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant, ResolverConfiguration{}) if value != tt.expectedValue { t.Errorf("expected value '%s', got '%s'", tt.expectedValue, value) diff --git a/core/pkg/evaluator/semver_test.go b/core/pkg/evaluator/semver_test.go index 557072b98..444d6900f 100644 --- a/core/pkg/evaluator/semver_test.go +++ b/core/pkg/evaluator/semver_test.go @@ -704,7 +704,7 @@ func TestJSONEvaluator_semVerEvaluation(t *testing.T) { je := NewJSON(log, store.NewFlags()) je.store.Flags = tt.flags.Flags - value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant) + value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant, ResolverConfiguration{}) if value != tt.expectedValue { t.Errorf("expected value '%s', got '%s'", tt.expectedValue, value) diff --git a/core/pkg/evaluator/string_comparison_test.go b/core/pkg/evaluator/string_comparison_test.go index 75a678658..b9095658b 100644 --- a/core/pkg/evaluator/string_comparison_test.go +++ b/core/pkg/evaluator/string_comparison_test.go @@ -188,7 +188,7 @@ func TestJSONEvaluator_startsWithEvaluation(t *testing.T) { je := NewJSON(log, store.NewFlags()) je.store.Flags = tt.flags.Flags - value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant) + value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant, ResolverConfiguration{}) if value != tt.expectedValue { t.Errorf("expected value '%s', got '%s'", tt.expectedValue, value) @@ -386,7 +386,7 @@ func TestJSONEvaluator_endsWithEvaluation(t *testing.T) { je.store.Flags = tt.flags.Flags - value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant) + value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant, ResolverConfiguration{}) if value != tt.expectedValue { t.Errorf("expected value '%s', got '%s'", tt.expectedValue, value) diff --git a/core/pkg/model/error.go b/core/pkg/model/error.go index 9656f4f3c..4c0b99089 100644 --- a/core/pkg/model/error.go +++ b/core/pkg/model/error.go @@ -9,6 +9,7 @@ const ( GeneralErrorCode = "GENERAL" FlagDisabledErrorCode = "FLAG_DISABLED" InvalidContextCode = "INVALID_CONTEXT" + TypeCheckingError = "TYPE_CHECKING_ERROR" ) var ReadableErrorMessage = map[string]string{ @@ -18,6 +19,7 @@ var ReadableErrorMessage = map[string]string{ GeneralErrorCode: "General error", FlagDisabledErrorCode: "Flag is disabled", InvalidContextCode: "Invalid context provided", + TypeCheckingError: "Type checking error due to selected type checking behavior", } func GetErrorMessage(code string) string {