Skip to content

Commit

Permalink
feat: support stricter type checking behavior as a Resolver configura…
Browse files Browse the repository at this point in the history
…tion

Signed-off-by: Maks Osowski <[email protected]>
  • Loading branch information
cupofcat committed Dec 18, 2024
1 parent dbc1da4 commit 3980d02
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 18 deletions.
4 changes: 2 additions & 2 deletions core/pkg/evaluator/fractional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ func TestFractionalEvaluation(t *testing.T) {
je := NewJSON(log, store.NewFlags())
je.store.Flags = tt.flags.Flags

value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant)
value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant, ResolverConfiguration{})

if value != tt.expectedValue {
t.Errorf("expected value '%s', got '%s'", tt.expectedValue, value)
Expand Down Expand Up @@ -590,7 +590,7 @@ func BenchmarkFractionalEvaluation(b *testing.B) {
je := NewJSON(log, &store.Flags{Flags: tt.flags.Flags})
for i := 0; i < b.N; i++ {
value, variant, reason, _, err := resolve[string](
ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant)
ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant, ResolverConfiguration{})

if value != tt.expectedValue {
b.Errorf("expected value '%s', got '%s'", tt.expectedValue, value)
Expand Down
83 changes: 71 additions & 12 deletions core/pkg/evaluator/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/json"
"errors"
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
Expand Down Expand Up @@ -33,6 +34,7 @@ const (
// evaluation if the user did not supply the optional bucketing property.
targetingKeyKey = "targetingKey"
Disabled = "DISABLED"
TypeMetadataKey = "type"
)

var regBrace *regexp.Regexp
Expand Down Expand Up @@ -138,11 +140,25 @@ func (je *JSON) SetState(payload sync.DataSync) (map[string]interface{}, bool, e
return events, reSync, nil
}

type TypeCheckingBehavior string

const (
ErrorOnTypeMismatch TypeCheckingBehavior = "ErrorOnTypeMismatch"
PanicOnTypeMismatch TypeCheckingBehavior = "PanicOnTypeMismatch"
VerifyAllVariantsAndErrorOnTypeMismatch TypeCheckingBehavior = "VerifyAllVariantsAndErrorOnTypeMismatch"
VerifyAllVariantsAndPanicOnTypeMismatch TypeCheckingBehavior = "VerifyAllVariantsAndPanicOnTypeMismatch"
)

type ResolverConfiguration struct {
TypeCheckingBehavior TypeCheckingBehavior
}

// Resolver implementation for flagd flags. This resolver should be kept reusable, hence must interact with interfaces.
type Resolver struct {
store store.IStore
Logger *logger.Logger
tracer trace.Tracer
ResolverConfiguration
}

func NewResolver(store store.IStore, logger *logger.Logger, jsonEvalTracer trace.Tracer) Resolver {
Expand All @@ -153,7 +169,7 @@ func NewResolver(store store.IStore, logger *logger.Logger, jsonEvalTracer trace
jsonlogic.AddOperator(SemVerEvaluationName, NewSemVerComparison(logger).SemVerEvaluation)
jsonlogic.AddOperator(LegacyFractionEvaluationName, NewLegacyFractional(logger).LegacyFractionalEvaluation)

return Resolver{store: store, Logger: logger, tracer: jsonEvalTracer}
return Resolver{store: store, Logger: logger, tracer: jsonEvalTracer, ResolverConfiguration: ResolverConfiguration{}}
}

func (je *Resolver) ResolveAllValues(ctx context.Context, reqID string, context map[string]any) ([]AnyValue, error) {
Expand Down Expand Up @@ -181,13 +197,13 @@ func (je *Resolver) ResolveAllValues(ctx context.Context, reqID string, context
defaultValue := flag.Variants[flag.DefaultVariant]
switch defaultValue.(type) {
case bool:
value, variant, reason, metadata, err = resolve[bool](ctx, reqID, flagKey, context, je.evaluateVariant)
value, variant, reason, metadata, err = resolve[bool](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration)
case string:
value, variant, reason, metadata, err = resolve[string](ctx, reqID, flagKey, context, je.evaluateVariant)
value, variant, reason, metadata, err = resolve[string](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration)
case float64:
value, variant, reason, metadata, err = resolve[float64](ctx, reqID, flagKey, context, je.evaluateVariant)
value, variant, reason, metadata, err = resolve[float64](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration)
case map[string]any:
value, variant, reason, metadata, err = resolve[map[string]any](ctx, reqID, flagKey, context, je.evaluateVariant)
value, variant, reason, metadata, err = resolve[map[string]any](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration)
}
if err != nil {
je.Logger.ErrorWithID(reqID, fmt.Sprintf("bulk evaluation: key: %s returned error: %s", flagKey, err.Error()))
Expand All @@ -210,7 +226,7 @@ func (je *Resolver) ResolveBooleanValue(
defer span.End()

je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating boolean flag: %s", flagKey))
return resolve[bool](ctx, reqID, flagKey, context, je.evaluateVariant)
return resolve[bool](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration)
}

func (je *Resolver) ResolveStringValue(
Expand All @@ -225,7 +241,7 @@ func (je *Resolver) ResolveStringValue(
defer span.End()

je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating string flag: %s", flagKey))
return resolve[string](ctx, reqID, flagKey, context, je.evaluateVariant)
return resolve[string](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration)
}

func (je *Resolver) ResolveFloatValue(
Expand All @@ -240,7 +256,7 @@ func (je *Resolver) ResolveFloatValue(
defer span.End()

je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating float flag: %s", flagKey))
value, variant, reason, metadata, err = resolve[float64](ctx, reqID, flagKey, context, je.evaluateVariant)
value, variant, reason, metadata, err = resolve[float64](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration)
return
}

Expand All @@ -256,7 +272,7 @@ func (je *Resolver) ResolveIntValue(ctx context.Context, reqID string, flagKey s

je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating int flag: %s", flagKey))
var val float64
val, variant, reason, metadata, err = resolve[float64](ctx, reqID, flagKey, context, je.evaluateVariant)
val, variant, reason, metadata, err = resolve[float64](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration)
value = int64(val)
return
}
Expand All @@ -273,7 +289,7 @@ func (je *Resolver) ResolveObjectValue(
defer span.End()

je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating object flag: %s", flagKey))
return resolve[map[string]any](ctx, reqID, flagKey, context, je.evaluateVariant)
return resolve[map[string]any](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration)
}

func (je *Resolver) ResolveAsAnyValue(
Expand All @@ -286,19 +302,55 @@ func (je *Resolver) ResolveAsAnyValue(
defer span.End()

je.Logger.DebugWithID(reqID, fmt.Sprintf("evaluating flag `%s` as a generic flag", flagKey))
value, variant, reason, meta, err := resolve[interface{}](ctx, reqID, flagKey, context, je.evaluateVariant)
value, variant, reason, meta, err := resolve[interface{}](ctx, reqID, flagKey, context, je.evaluateVariant, je.ResolverConfiguration)
return NewAnyValue(value, variant, reason, flagKey, meta, err)
}

func verifyAllVariants(expectedType reflect.Type, variants map[string]interface{}) (mismatchedVariants map[string]string, errmsg string) {
mismatchedVariants = map[string]string{}
errmsg = ""

for name, value := range variants {
if expectedType != reflect.TypeOf(value) {
mismatchedVariants[name] = reflect.TypeOf(value).String()
}
}
if len(mismatchedVariants) > 0 {
errmsg = fmt.Sprintf("%s: Mismatched variants exist! Expected flag type: %s. Mismatched variants' types: %v", model.TypeCheckingError, expectedType, mismatchedVariants)
}
return mismatchedVariants, errmsg
}

// resolve is a helper for generic flag resolving
func resolve[T constraints](ctx context.Context, reqID string, key string, context map[string]any,
variantEval variantEvaluator) (value T, variant string, reason string, metadata map[string]interface{}, err error,
variantEval variantEvaluator, resolverConfig ResolverConfiguration) (value T, variant string, reason string, metadata map[string]interface{}, err error,
) {
variant, variants, reason, metadata, err := variantEval(ctx, reqID, key, context)
if err != nil {
return value, variant, reason, metadata, err
}

switch resolverConfig.TypeCheckingBehavior {
case ErrorOnTypeMismatch:
if metadata[TypeMetadataKey] != reflect.TypeOf(*new(T)) {
return value, variant, reason, metadata, errors.New(model.TypeMismatchErrorCode)
}
case PanicOnTypeMismatch:
if metadata[TypeMetadataKey] != reflect.TypeOf(*new(T)) {
panic(model.TypeMismatchErrorCode)
}
case VerifyAllVariantsAndErrorOnTypeMismatch:
mismatchedVariants, errmsg := verifyAllVariants(metadata[TypeMetadataKey].(reflect.Type), variants)
if len(mismatchedVariants) > 0 {
return value, variant, reason, metadata, errors.New(errmsg)
}
case VerifyAllVariantsAndPanicOnTypeMismatch:
mismatchedVariants, errmsg := verifyAllVariants(metadata[TypeMetadataKey].(reflect.Type), variants)
if len(mismatchedVariants) > 0 {
panic(errmsg)
}
}

var ok bool
value, ok = variants[variant].(T)
if !ok {
Expand All @@ -321,6 +373,13 @@ func (je *Resolver) evaluateVariant(ctx context.Context, reqID string, flagKey s
return "", map[string]interface{}{}, model.ErrorReason, metadata, errors.New(model.FlagNotFoundErrorCode)
}

if value, ok := flag.Variants[flag.DefaultVariant]; ok {
metadata[TypeMetadataKey] = reflect.TypeOf(value)
} else {
je.Logger.ErrorWithID(reqID, fmt.Sprintf("Error inferring type for flag %s. defaultVariant (%s) not found in variants (%v)", flagKey, flag.DefaultVariant, flag.Variants))
return "", flag.Variants, model.ErrorReason, metadata, errors.New(model.ErrorReason)
}

// add selector to evaluation metadata
selector := je.store.SelectorForFlag(ctx, flag)
if selector != "" {
Expand Down
2 changes: 1 addition & 1 deletion core/pkg/evaluator/legacy_fractional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ func TestLegacyFractionalEvaluation(t *testing.T) {
je := NewJSON(log, store.NewFlags())
je.store.Flags = tt.flags.Flags

value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant)
value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant, ResolverConfiguration{})

if value != tt.expectedValue {
t.Errorf("expected value '%s', got '%s'", tt.expectedValue, value)
Expand Down
2 changes: 1 addition & 1 deletion core/pkg/evaluator/semver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ func TestJSONEvaluator_semVerEvaluation(t *testing.T) {
je := NewJSON(log, store.NewFlags())
je.store.Flags = tt.flags.Flags

value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant)
value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant, ResolverConfiguration{})

if value != tt.expectedValue {
t.Errorf("expected value '%s', got '%s'", tt.expectedValue, value)
Expand Down
4 changes: 2 additions & 2 deletions core/pkg/evaluator/string_comparison_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func TestJSONEvaluator_startsWithEvaluation(t *testing.T) {
je := NewJSON(log, store.NewFlags())
je.store.Flags = tt.flags.Flags

value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant)
value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant, ResolverConfiguration{})

if value != tt.expectedValue {
t.Errorf("expected value '%s', got '%s'", tt.expectedValue, value)
Expand Down Expand Up @@ -386,7 +386,7 @@ func TestJSONEvaluator_endsWithEvaluation(t *testing.T) {

je.store.Flags = tt.flags.Flags

value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant)
value, variant, reason, _, err := resolve[string](ctx, reqID, tt.flagKey, tt.context, je.evaluateVariant, ResolverConfiguration{})

if value != tt.expectedValue {
t.Errorf("expected value '%s', got '%s'", tt.expectedValue, value)
Expand Down
2 changes: 2 additions & 0 deletions core/pkg/model/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ const (
GeneralErrorCode = "GENERAL"
FlagDisabledErrorCode = "FLAG_DISABLED"
InvalidContextCode = "INVALID_CONTEXT"
TypeCheckingError = "TYPE_CHECKING_ERROR"
)

var ReadableErrorMessage = map[string]string{
Expand All @@ -18,6 +19,7 @@ var ReadableErrorMessage = map[string]string{
GeneralErrorCode: "General error",
FlagDisabledErrorCode: "Flag is disabled",
InvalidContextCode: "Invalid context provided",
TypeCheckingError: "Type checking error due to selected type checking behavior",
}

func GetErrorMessage(code string) string {
Expand Down

0 comments on commit 3980d02

Please sign in to comment.