diff --git a/server/expression/array.go b/server/expression/array.go index d443c63c7..4b71b6ffc 100644 --- a/server/expression/array.go +++ b/server/expression/array.go @@ -79,7 +79,7 @@ func (array *Array) Eval(ctx *sql.Context, row sql.Row) (any, error) { } // We always cast the element, as there may be parameter restrictions in place - castFunc := framework.GetImplicitCast(doltgresType, resultTyp) + castFunc := framework.GetImplicitCast(doltgresType.OID, resultTyp.OID) if castFunc == nil { if doltgresType.OID == uint32(oid.T_unknown) { castFunc = framework.UnknownLiteralCast @@ -163,7 +163,7 @@ func (array *Array) WithResolvedChildren(children []any) (any, error) { // getTargetType returns the evaluated type for this expression. // Returns the "anyarray" type if the type combination is invalid. func (array *Array) getTargetType(children ...sql.Expression) (pgtypes.DoltgresType, error) { - var childrenTypes []pgtypes.DoltgresType + var childrenTypeOids []uint32 for _, child := range children { if child != nil { childType, ok := child.Type().(pgtypes.DoltgresType) @@ -171,12 +171,13 @@ func (array *Array) getTargetType(children ...sql.Expression) (pgtypes.DoltgresT // We use "anyarray" as the indeterminate/invalid type return pgtypes.AnyArray, nil } - childrenTypes = append(childrenTypes, childType) + childrenTypeOids = append(childrenTypeOids, childType.OID) } } - targetType, err := framework.FindCommonType(childrenTypes) + targetTypeOid, err := framework.FindCommonType(childrenTypeOids) if err != nil { return pgtypes.DoltgresType{}, fmt.Errorf("ARRAY %s", err.Error()) } + targetType := pgtypes.OidToBuiltInDoltgresType[targetTypeOid] return targetType.ToArrayType(), nil } diff --git a/server/expression/assignment_cast.go b/server/expression/assignment_cast.go index 1f3f22a49..320e664bb 100644 --- a/server/expression/assignment_cast.go +++ b/server/expression/assignment_cast.go @@ -55,7 +55,7 @@ func (ac *AssignmentCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if err != nil || val == nil { return val, err } - castFunc := framework.GetAssignmentCast(ac.fromType, ac.toType) + castFunc := framework.GetAssignmentCast(ac.fromType.OID, ac.toType.OID) if castFunc == nil { if ac.fromType.OID == uint32(oid.T_unknown) { castFunc = framework.UnknownLiteralCast diff --git a/server/expression/explicit_cast.go b/server/expression/explicit_cast.go index 909672772..4bc47f5b7 100644 --- a/server/expression/explicit_cast.go +++ b/server/expression/explicit_cast.go @@ -88,7 +88,7 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { return nil, nil } - castFunction := framework.GetExplicitCast(fromType, c.castToType) + castFunction := framework.GetExplicitCast(fromType.OID, c.castToType.OID) if castFunction == nil { if fromType.OID == uint32(oid.T_unknown) { castFunction = framework.UnknownLiteralCast diff --git a/server/expression/implicit_cast.go b/server/expression/implicit_cast.go index 73957ec75..3be1df097 100644 --- a/server/expression/implicit_cast.go +++ b/server/expression/implicit_cast.go @@ -54,7 +54,7 @@ func (ic *ImplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if err != nil || val == nil { return val, err } - castFunc := framework.GetImplicitCast(ic.fromType, ic.toType) + castFunc := framework.GetImplicitCast(ic.fromType.OID, ic.toType.OID) if castFunc == nil { return nil, fmt.Errorf("target is of type %s but expression is of type %s", ic.toType.String(), ic.fromType.String()) } diff --git a/server/functions/dolt_procedures.go b/server/functions/dolt_procedures.go index 71658a083..bd4a1e973 100755 --- a/server/functions/dolt_procedures.go +++ b/server/functions/dolt_procedures.go @@ -119,7 +119,7 @@ func drainRowIter(ctx *sql.Context, rowIter sql.RowIter) (any, error) { return nil, err } - castFn := framework.GetExplicitCast(fromType, pgtypes.Text) + castFn := framework.GetExplicitCast(fromType.OID, pgtypes.Text.OID) textVal, err := castFn(ctx, row[i], pgtypes.Text) if err != nil { return nil, err diff --git a/server/functions/framework/cast.go b/server/functions/framework/cast.go index 013e8d690..0e06274be 100644 --- a/server/functions/framework/cast.go +++ b/server/functions/framework/cast.go @@ -31,7 +31,7 @@ type TypeCastFunction func(ctx *sql.Context, val any, targetType pgtypes.Doltgre // getCastFunction is used to recursively call the cast function for when the inner logic sees that it has two array // types. This sidesteps providing -type getCastFunction func(fromType pgtypes.DoltgresType, toType pgtypes.DoltgresType) TypeCastFunction +type getCastFunction func(fromType uint32, toType uint32) TypeCastFunction // TypeCast is used to cast from one type to another. type TypeCast struct { @@ -109,9 +109,9 @@ func GetPotentialExplicitCasts(fromType uint32) []pgtypes.DoltgresType { } // GetPotentialAssignmentCasts returns all registered assignment and implicit type casts from the given type. -func GetPotentialAssignmentCasts(fromType uint32) []pgtypes.DoltgresType { - assignment := getPotentialCasts(assignmentTypeCastMutex, assignmentTypeCastsArray, fromType) - implicit := GetPotentialImplicitCasts(fromType) +func GetPotentialAssignmentCasts(fromTypeOid uint32) []pgtypes.DoltgresType { + assignment := getPotentialCasts(assignmentTypeCastMutex, assignmentTypeCastsArray, fromTypeOid) + implicit := GetPotentialImplicitCasts(fromTypeOid) both := make([]pgtypes.DoltgresType, len(assignment)+len(implicit)) copy(both, assignment) copy(both[len(assignment):], implicit) @@ -119,24 +119,28 @@ func GetPotentialAssignmentCasts(fromType uint32) []pgtypes.DoltgresType { } // GetPotentialImplicitCasts returns all registered implicit type casts from the given type. -func GetPotentialImplicitCasts(fromType uint32) []pgtypes.DoltgresType { - return getPotentialCasts(implicitTypeCastMutex, implicitTypeCastsArray, fromType) +func GetPotentialImplicitCasts(toTypeOid uint32) []pgtypes.DoltgresType { + return getPotentialCasts(implicitTypeCastMutex, implicitTypeCastsArray, toTypeOid) } // GetExplicitCast returns the explicit type cast function that will cast the "from" type to the "to" type. Returns nil // if such a cast is not valid. -func GetExplicitCast(fromType pgtypes.DoltgresType, toType pgtypes.DoltgresType) TypeCastFunction { - if tcf := getCast(explicitTypeCastMutex, explicitTypeCastsMap, fromType, toType, GetExplicitCast); tcf != nil { +func GetExplicitCast(fromTypeOid, toTypeOid uint32) TypeCastFunction { + if tcf := getCast(explicitTypeCastMutex, explicitTypeCastsMap, fromTypeOid, toTypeOid, GetExplicitCast); tcf != nil { return tcf - } else if tcf = getCast(assignmentTypeCastMutex, assignmentTypeCastsMap, fromType, toType, GetExplicitCast); tcf != nil { + } else if tcf = getCast(assignmentTypeCastMutex, assignmentTypeCastsMap, fromTypeOid, toTypeOid, GetExplicitCast); tcf != nil { return tcf - } else if tcf = getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetExplicitCast); tcf != nil { + } else if tcf = getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromTypeOid, toTypeOid, GetExplicitCast); tcf != nil { return tcf } + + fromType := pgtypes.OidToBuiltInDoltgresType[fromTypeOid] + toType := pgtypes.OidToBuiltInDoltgresType[toTypeOid] + // We check for the identity after checking the maps, as the identity may be overridden (such as for types that have // parameters). If one of the types are a string type, then we do not use the identity, and use the I/O conversions // below. - if fromType.OID == toType.OID && toType.TypCategory != pgtypes.TypeCategory_StringTypes && fromType.TypCategory != pgtypes.TypeCategory_StringTypes { + if fromTypeOid == toTypeOid && toType.TypCategory != pgtypes.TypeCategory_StringTypes && fromType.TypCategory != pgtypes.TypeCategory_StringTypes { return identityCast } // All types have a built-in explicit cast from string types: https://www.postgresql.org/docs/15/sql-createcast.html @@ -169,12 +173,16 @@ func GetExplicitCast(fromType pgtypes.DoltgresType, toType pgtypes.DoltgresType) // GetAssignmentCast returns the assignment type cast function that will cast the "from" type to the "to" type. Returns // nil if such a cast is not valid. -func GetAssignmentCast(fromType pgtypes.DoltgresType, toType pgtypes.DoltgresType) TypeCastFunction { - if tcf := getCast(assignmentTypeCastMutex, assignmentTypeCastsMap, fromType, toType, GetAssignmentCast); tcf != nil { +func GetAssignmentCast(fromTypeOid, toTypeOid uint32) TypeCastFunction { + if tcf := getCast(assignmentTypeCastMutex, assignmentTypeCastsMap, fromTypeOid, toTypeOid, GetAssignmentCast); tcf != nil { return tcf - } else if tcf = getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetAssignmentCast); tcf != nil { + } else if tcf = getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromTypeOid, toTypeOid, GetAssignmentCast); tcf != nil { return tcf } + + fromType := pgtypes.OidToBuiltInDoltgresType[fromTypeOid] + toType := pgtypes.OidToBuiltInDoltgresType[toTypeOid] + // We check for the identity after checking the maps, as the identity may be overridden (such as for types that have // parameters). If the "to" type is a string type, then we do not use the identity, and use the I/O conversion below. if fromType.OID == toType.OID && fromType.TypCategory != pgtypes.TypeCategory_StringTypes { @@ -198,13 +206,13 @@ func GetAssignmentCast(fromType pgtypes.DoltgresType, toType pgtypes.DoltgresTyp // GetImplicitCast returns the implicit type cast function that will cast the "from" type to the "to" type. Returns nil // if such a cast is not valid. -func GetImplicitCast(fromType pgtypes.DoltgresType, toType pgtypes.DoltgresType) TypeCastFunction { - if tcf := getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetImplicitCast); tcf != nil { +func GetImplicitCast(fromTypeOid, toTypeOid uint32) TypeCastFunction { + if tcf := getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromTypeOid, toTypeOid, GetImplicitCast); tcf != nil { return tcf } // We check for the identity after checking the maps, as the identity may be overridden (such as for types that have // parameters). - if fromType.OID == toType.OID { + if fromTypeOid == toTypeOid { return identityCast } return nil @@ -244,21 +252,25 @@ func getPotentialCasts(mutex *sync.RWMutex, castArray map[uint32][]pgtypes.Doltg // not valid. func getCast(mutex *sync.RWMutex, castMap map[uint32]map[uint32]TypeCastFunction, - fromType pgtypes.DoltgresType, toType pgtypes.DoltgresType, outerFunc getCastFunction) TypeCastFunction { + fromTypeOid, toTypeOid uint32, outerFunc getCastFunction) TypeCastFunction { mutex.RLock() defer mutex.RUnlock() - if toMap, ok := castMap[fromType.OID]; ok { - if f, ok := toMap[toType.OID]; ok { + if toMap, ok := castMap[fromTypeOid]; ok { + if f, ok := toMap[toTypeOid]; ok { return f } } + + fromType := pgtypes.OidToBuiltInDoltgresType[fromTypeOid] + toType := pgtypes.OidToBuiltInDoltgresType[toTypeOid] + // If there isn't a direct mapping, then we need to check if the types are array variants. // As long as the base types are convertable, the array variants are also convertable. if fromType.IsArrayType() && toType.IsArrayType() { fromBaseType := fromType.ArrayBaseType() toBaseType := toType.ArrayBaseType() - if baseCast := outerFunc(fromBaseType, toBaseType); baseCast != nil { + if baseCast := outerFunc(fromBaseType.OID, toBaseType.OID); baseCast != nil { // We use a closure that can unwrap the slice, since conversion functions expect a singular non-nil value return func(ctx *sql.Context, vals any, targetType pgtypes.DoltgresType) (any, error) { var err error diff --git a/server/functions/framework/common_type.go b/server/functions/framework/common_type.go index d93d50618..f3dce15c7 100644 --- a/server/functions/framework/common_type.go +++ b/server/functions/framework/common_type.go @@ -24,50 +24,55 @@ import ( // FindCommonType returns the common type that given types can convert to. // https://www.postgresql.org/docs/15/typeconv-union-case.html -func FindCommonType(types []pgtypes.DoltgresType) (pgtypes.DoltgresType, error) { - var candidateType = pgtypes.Unknown +func FindCommonType(typOids []uint32) (uint32, error) { + var candidateTypeOid = pgtypes.Unknown.OID var fail = false - for _, typ := range types { - if typ.OID == candidateType.OID { + for _, typOid := range typOids { + if typOid == candidateTypeOid { continue - } else if candidateType.OID == uint32(oid.T_unknown) { - candidateType = typ + } else if candidateTypeOid == uint32(oid.T_unknown) { + candidateTypeOid = typOid } else { - candidateType = pgtypes.Unknown + candidateTypeOid = pgtypes.Unknown.OID fail = true } } if !fail { - if candidateType.OID == uint32(oid.T_unknown) { - return pgtypes.Text, nil + if candidateTypeOid == uint32(oid.T_unknown) { + return pgtypes.Text.OID, nil } - return candidateType, nil + return candidateTypeOid, nil } - for _, typ := range types { - if candidateType.OID == uint32(oid.T_unknown) { - candidateType = typ + for _, typOid := range typOids { + if candidateTypeOid == uint32(oid.T_unknown) { + candidateTypeOid = typOid } - if typ.OID != uint32(oid.T_unknown) && candidateType.TypCategory != typ.TypCategory { - return pgtypes.DoltgresType{}, fmt.Errorf("types %s and %s cannot be matched", candidateType.String(), typ.String()) + candidateType := pgtypes.OidToBuiltInDoltgresType[candidateTypeOid] + typ := pgtypes.OidToBuiltInDoltgresType[typOid] + if typOid != uint32(oid.T_unknown) && candidateType.TypCategory != typ.TypCategory { + return 0, fmt.Errorf("types %s and %s cannot be matched", candidateType.String(), typ.String()) } } var preferredTypeFound = false - for _, typ := range types { - if typ.OID == uint32(oid.T_unknown) { + for _, typOid := range typOids { + if typOid == uint32(oid.T_unknown) { continue - } else if GetImplicitCast(typ, candidateType) != nil { + } else if GetImplicitCast(typOid, candidateTypeOid) != nil { continue - } else if GetImplicitCast(candidateType, typ) == nil { - return pgtypes.DoltgresType{}, fmt.Errorf("cannot find implicit cast function from %s to %s", candidateType.String(), typ.String()) + } else if GetImplicitCast(candidateTypeOid, typOid) == nil { + candidateType := pgtypes.OidToBuiltInDoltgresType[candidateTypeOid] + typ := pgtypes.OidToBuiltInDoltgresType[typOid] + return 0, fmt.Errorf("cannot find implicit cast function from %s to %s", candidateType.String(), typ.String()) } else if !preferredTypeFound { + candidateType := pgtypes.OidToBuiltInDoltgresType[candidateTypeOid] if candidateType.IsPreferred { - candidateType = typ + candidateTypeOid = typOid preferredTypeFound = true } } else { - return pgtypes.DoltgresType{}, fmt.Errorf("found another preferred candidate type") + return 0, fmt.Errorf("found another preferred candidate type") } } - return candidateType, nil + return candidateTypeOid, nil } diff --git a/server/functions/framework/compiled_function.go b/server/functions/framework/compiled_function.go index 68a5e30cf..33e3bfade 100644 --- a/server/functions/framework/compiled_function.go +++ b/server/functions/framework/compiled_function.go @@ -80,7 +80,7 @@ func newCompiledFunctionInternal( return c } // Next we'll resolve the overload based on the parameters given. - overload, err := c.resolve(overloads, fnOverloads, originalTypes) + overload, err := c.resolve(overloads, fnOverloads, overloads.oidsForTypes(originalTypes)) if err != nil { c.stashedErr = err return c @@ -356,12 +356,7 @@ func (c *CompiledFunction) GetQuickFunction() QuickFunction { // resolve returns an overloadMatch that either matches the given parameters exactly, or is a viable match after casting. // Returns an invalid overloadMatch if a viable match is not found. -func (c *CompiledFunction) resolve( - overloads *Overloads, - fnOverloads []Overload, - argTypes []pgtypes.DoltgresType, -) (overloadMatch, error) { - +func (c *CompiledFunction) resolve(overloads *Overloads, fnOverloads []Overload, argTypes []uint32) (overloadMatch, error) { // First check for an exact match exactMatch, found := overloads.ExactMatchForTypes(argTypes...) if found { @@ -385,28 +380,28 @@ func (c *CompiledFunction) resolve( // resolveOperator resolves an operator according to the rules defined by Postgres. // https://www.postgresql.org/docs/15/typeconv-oper.html -func (c *CompiledFunction) resolveOperator(argTypes []pgtypes.DoltgresType, overloads *Overloads, fnOverloads []Overload) (overloadMatch, error) { +func (c *CompiledFunction) resolveOperator(argTypeOids []uint32, overloads *Overloads, fnOverloads []Overload) (overloadMatch, error) { // Binary operators treat unknown literals as the other type, so we'll account for that here to see if we can find // an "exact" match. - if len(argTypes) == 2 { - leftUnknownType := argTypes[0].OID == uint32(oid.T_unknown) - rightUnknownType := argTypes[1].OID == uint32(oid.T_unknown) + if len(argTypeOids) == 2 { + leftUnknownType := argTypeOids[0] == uint32(oid.T_unknown) + rightUnknownType := argTypeOids[1] == uint32(oid.T_unknown) if (leftUnknownType && !rightUnknownType) || (!leftUnknownType && rightUnknownType) { - var typ pgtypes.DoltgresType + var typ uint32 casts := []TypeCastFunction{identityCast, identityCast} if leftUnknownType { casts[0] = UnknownLiteralCast - typ = argTypes[1] + typ = argTypeOids[1] } else { casts[1] = UnknownLiteralCast - typ = argTypes[0] + typ = argTypeOids[0] } if exactMatch, ok := overloads.ExactMatchForTypes(typ, typ); ok { return overloadMatch{ params: Overload{ function: exactMatch, - paramTypes: []pgtypes.DoltgresType{typ, typ}, - argTypes: []pgtypes.DoltgresType{typ, typ}, + paramTypes: []uint32{typ, typ}, + argTypes: []uint32{typ, typ}, variadic: -1, }, casts: casts, @@ -415,12 +410,12 @@ func (c *CompiledFunction) resolveOperator(argTypes []pgtypes.DoltgresType, over } } // From this point, the steps appear to be the same for functions and operators - return c.resolveFunction(argTypes, fnOverloads) + return c.resolveFunction(argTypeOids, fnOverloads) } // resolveFunction resolves a function according to the rules defined by Postgres. // https://www.postgresql.org/docs/15/typeconv-func.html -func (c *CompiledFunction) resolveFunction(argTypes []pgtypes.DoltgresType, overloads []Overload) (overloadMatch, error) { +func (c *CompiledFunction) resolveFunction(argTypes []uint32, overloads []Overload) (overloadMatch, error) { // First we'll discard all overloads that do not have implicitly-convertible param types compatibleOverloads := c.typeCompatibleOverloads(overloads, argTypes) @@ -470,23 +465,25 @@ func (c *CompiledFunction) resolveFunction(argTypes []pgtypes.DoltgresType, over // typeCompatibleOverloads returns all overloads that have a matching number of params whose types can be // implicitly converted to the ones provided. This is the set of all possible overloads that could be used with the // param types provided. -func (c *CompiledFunction) typeCompatibleOverloads(fnOverloads []Overload, argTypes []pgtypes.DoltgresType) []overloadMatch { +func (c *CompiledFunction) typeCompatibleOverloads(fnOverloads []Overload, argTypeOids []uint32) []overloadMatch { var compatible []overloadMatch for _, overload := range fnOverloads { isConvertible := true - overloadCasts := make([]TypeCastFunction, len(argTypes)) + overloadCasts := make([]TypeCastFunction, len(argTypeOids)) // Polymorphic parameters must be gathered so that we can later verify that they all have matching base types - var polymorphicParameters []pgtypes.DoltgresType - var polymorphicTargets []pgtypes.DoltgresType - for i := range argTypes { - paramType := overload.argTypes[i] - if paramType.IsValidForPolymorphicType(argTypes[i]) { + var polymorphicParameters []uint32 + var polymorphicTargets []uint32 + for i := range argTypeOids { + paramTypeOid := overload.argTypes[i] + paramType := pgtypes.OidToBuiltInDoltgresType[paramTypeOid] + argType := pgtypes.OidToBuiltInDoltgresType[argTypeOids[i]] + if paramType.IsValidForPolymorphicType(argType) { overloadCasts[i] = identityCast - polymorphicParameters = append(polymorphicParameters, paramType) - polymorphicTargets = append(polymorphicTargets, argTypes[i]) + polymorphicParameters = append(polymorphicParameters, paramTypeOid) + polymorphicTargets = append(polymorphicTargets, argTypeOids[i]) } else { - if overloadCasts[i] = GetImplicitCast(argTypes[i], paramType); overloadCasts[i] == nil { - if argTypes[i].OID == uint32(oid.T_unknown) { + if overloadCasts[i] = GetImplicitCast(argTypeOids[i], paramTypeOid); overloadCasts[i] == nil { + if argTypeOids[i] == uint32(oid.T_unknown) { overloadCasts[i] = UnknownLiteralCast } else { isConvertible = false @@ -505,14 +502,14 @@ func (c *CompiledFunction) typeCompatibleOverloads(fnOverloads []Overload, argTy // closestTypeMatches returns the set of overload candidates that have the most exact type matches for the arg types // provided. -func (*CompiledFunction) closestTypeMatches(argTypes []pgtypes.DoltgresType, candidates []overloadMatch) []overloadMatch { +func (*CompiledFunction) closestTypeMatches(argTypes []uint32, candidates []overloadMatch) []overloadMatch { matchCount := 0 var matches []overloadMatch for _, cand := range candidates { currentMatchCount := 0 for argIdx := range argTypes { argType := cand.params.argTypes[argIdx] - if argTypes[argIdx].OID == argType.OID || argTypes[argIdx].OID == uint32(oid.T_unknown) { + if argTypes[argIdx] == argType || argTypes[argIdx] == uint32(oid.T_unknown) { currentMatchCount++ } } @@ -527,14 +524,14 @@ func (*CompiledFunction) closestTypeMatches(argTypes []pgtypes.DoltgresType, can } // preferredTypeMatches returns the overload candidates that have the most preferred types for args that require casts. -func (*CompiledFunction) preferredTypeMatches(argTypes []pgtypes.DoltgresType, candidates []overloadMatch) []overloadMatch { +func (*CompiledFunction) preferredTypeMatches(argTypeOids []uint32, candidates []overloadMatch) []overloadMatch { preferredCount := 0 var preferredOverloads []overloadMatch for _, cand := range candidates { currentPreferredCount := 0 - for argIdx := range argTypes { - argType := cand.params.argTypes[argIdx] - if argTypes[argIdx].OID != argType.OID && argType.IsPreferred { + for argIdx := range argTypeOids { + paramTypeOid := cand.params.argTypes[argIdx] + if argType := pgtypes.OidToBuiltInDoltgresType[paramTypeOid]; argTypeOids[argIdx] != paramTypeOid && argType.IsPreferred { currentPreferredCount++ } } @@ -551,18 +548,18 @@ func (*CompiledFunction) preferredTypeMatches(argTypes []pgtypes.DoltgresType, c // unknownTypeCategoryMatches checks the type categories of `unknown` types. These types have an inherent bias toward // the string category since an `unknown` literal resembles a string. Returns false if the resolution should fail. -func (c *CompiledFunction) unknownTypeCategoryMatches(argTypes []pgtypes.DoltgresType, candidates []overloadMatch) ([]overloadMatch, bool) { +func (c *CompiledFunction) unknownTypeCategoryMatches(argTypes []uint32, candidates []overloadMatch) ([]overloadMatch, bool) { matches := make([]overloadMatch, len(candidates)) copy(matches, candidates) // For our first loop, we'll filter matches based on whether they accept the string category for argIdx := range argTypes { // We're only concerned with `unknown` types - if argTypes[argIdx].OID != uint32(oid.T_unknown) { + if argTypes[argIdx] != uint32(oid.T_unknown) { continue } var newMatches []overloadMatch for _, match := range matches { - if match.params.argTypes[argIdx].TypCategory == pgtypes.TypeCategory_StringTypes { + if m := pgtypes.OidToBuiltInDoltgresType[match.params.argTypes[argIdx]]; m.TypCategory == pgtypes.TypeCategory_StringTypes { newMatches = append(newMatches, match) } } @@ -587,50 +584,50 @@ func (c *CompiledFunction) unknownTypeCategoryMatches(argTypes []pgtypes.Doltgre } // polymorphicTypesCompatible returns whether any polymorphic types given are compatible with the expression types given -func (*CompiledFunction) polymorphicTypesCompatible(paramTypes []pgtypes.DoltgresType, exprTypes []pgtypes.DoltgresType) bool { - if len(paramTypes) != len(exprTypes) { +func (*CompiledFunction) polymorphicTypesCompatible(paramTypeOids []uint32, exprTypeOids []uint32) bool { + if len(paramTypeOids) != len(exprTypeOids) { return false } // If there are less than two parameters then we don't even need to check - if len(paramTypes) < 2 { + if len(paramTypeOids) < 2 { return true } // If one of the types is anyarray, then anyelement behaves as anynonarray, so we can convert them to anynonarray - for _, paramType := range paramTypes { - if paramType.OID == uint32(oid.T_anyarray) { + for _, paramTypeOid := range paramTypeOids { + if paramTypeOid == uint32(oid.T_anyarray) { // At least one parameter is anyarray, so copy all parameters to a new slice and replace anyelement with anynonarray - newParamTypes := make([]pgtypes.DoltgresType, len(paramTypes)) - copy(newParamTypes, paramTypes) - for i := range newParamTypes { - if paramTypes[i].OID == uint32(oid.T_anyelement) { - newParamTypes[i] = pgtypes.AnyNonArray + newParamTypeOids := make([]uint32, len(paramTypeOids)) + copy(newParamTypeOids, paramTypeOids) + for i := range newParamTypeOids { + if paramTypeOids[i] == uint32(oid.T_anyelement) { + newParamTypeOids[i] = pgtypes.AnyNonArray.OID } } - paramTypes = newParamTypes + paramTypeOids = newParamTypeOids break } } // The base type is the type that must match between all polymorphic types. - var baseType pgtypes.DoltgresType - for i, paramType := range paramTypes { - if paramType.IsPolymorphicType() && exprTypes[i].OID != uint32(oid.T_unknown) { + var baseTypeOid uint32 + for i, paramTypeOid := range paramTypeOids { + if paramType := pgtypes.OidToBuiltInDoltgresType[paramTypeOid]; paramType.IsPolymorphicType() && exprTypeOids[i] != uint32(oid.T_unknown) { // Although we do this check before we ever reach this function, we do it again as we may convert anyelement // to anynonarray, which changes type validity - if !paramType.IsValidForPolymorphicType(exprTypes[i]) { + if exprType := pgtypes.OidToBuiltInDoltgresType[exprTypeOids[i]]; !paramType.IsValidForPolymorphicType(exprType) { return false } // Get the base expression type that we'll compare against - baseExprType := exprTypes[i] - if baseExprType.IsArrayType() { - baseExprType = baseExprType.ArrayBaseType() + baseExprTypeOid := exprTypeOids[i] + if baseExprType := pgtypes.OidToBuiltInDoltgresType[baseExprTypeOid]; baseExprType.IsArrayType() { + baseExprTypeOid = baseExprType.ArrayBaseType().OID } // TODO: handle range types // Check that the base expression type matches the previously-found base type - if baseType.IsEmptyType() { - baseType = baseExprType - } else if baseType.OID != baseExprType.OID { + if baseTypeOid == 0 { + baseTypeOid = baseExprTypeOid + } else if baseTypeOid != baseExprTypeOid { return false } } diff --git a/server/functions/framework/overloads.go b/server/functions/framework/overloads.go index 51b23d3f1..dce6fa306 100644 --- a/server/functions/framework/overloads.go +++ b/server/functions/framework/overloads.go @@ -40,7 +40,7 @@ func NewOverloads() *Overloads { // Add adds the given function to the overload collection. Returns an error if the there's a problem with the // function's declaration. func (o *Overloads) Add(function FunctionInterface) error { - key := keyForParamTypes(function.GetParameters()) + key := keyForParamTypes(o.oidsForTypes(function.GetParameters())) if _, ok := o.ByParamType[key]; ok { return fmt.Errorf("duplicate function overload for `%s`", function.GetName()) } @@ -58,28 +58,38 @@ func (o *Overloads) Add(function FunctionInterface) error { } // keyForParamTypes returns a string key to match an overload with the given parameter types. -func keyForParamTypes(types []pgtypes.DoltgresType) string { +func keyForParamTypes(types []uint32) string { sb := strings.Builder{} for i, typ := range types { if i > 0 { sb.WriteByte(',') } - sb.WriteString(typ.String()) + t := pgtypes.OidToBuiltInDoltgresType[typ] + sb.WriteString(t.String()) } return sb.String() } +// baseIdsForTypes returns the base IDs of the given types. +func (o *Overloads) oidsForTypes(types []pgtypes.DoltgresType) []uint32 { + baseIds := make([]uint32, len(types)) + for i, t := range types { + baseIds[i] = t.OID + } + return baseIds +} + // overloadsForParams returns all overloads matching the number of params given, without regard for types. func (o *Overloads) overloadsForParams(numParams int) []Overload { results := make([]Overload, 0, len(o.AllOverloads)) for _, overload := range o.AllOverloads { - params := overload.GetParameters() + params := o.oidsForTypes(overload.GetParameters()) variadicIndex := overload.VariadicIndex() if variadicIndex >= 0 && len(params) <= numParams { // Variadic functions may only match when the function is declared with parameters that are fewer or equal // to our target length. If our target length is less, then we cannot expand, so we do not treat it as // variadic. - extendedParams := make([]pgtypes.DoltgresType, numParams) + extendedParams := make([]uint32, numParams) copy(extendedParams, params[:variadicIndex]) // This is copying the parameters after the variadic index, so we need to add 1. We subtract the declared // parameter count from the target parameter count to obtain the additional parameter count. @@ -87,7 +97,7 @@ func (o *Overloads) overloadsForParams(numParams int) []Overload { copy(extendedParams[firstValueAfterVariadic:], params[variadicIndex+1:]) // ToArrayType immediately followed by BaseType is a way to get the base type without having to cast. // For array types, ToArrayType causes them to return themselves. - variadicBaseType := overload.GetParameters()[variadicIndex].ToArrayType().ArrayBaseType() + variadicBaseType := overload.GetParameters()[variadicIndex].ToArrayType().ArrayBaseType().OID for variadicParamIdx := 0; variadicParamIdx < 1+(numParams-len(params)); variadicParamIdx++ { extendedParams[variadicParamIdx+variadicIndex] = variadicBaseType } @@ -111,7 +121,7 @@ func (o *Overloads) overloadsForParams(numParams int) []Overload { // ExactMatchForTypes returns the function that exactly matches the given parameter types, or nil if no overload with // those types exists. -func (o *Overloads) ExactMatchForTypes(types ...pgtypes.DoltgresType) (FunctionInterface, bool) { +func (o *Overloads) ExactMatchForTypes(types ...uint32) (FunctionInterface, bool) { key := keyForParamTypes(types) fn, ok := o.ByParamType[key] return fn, ok @@ -123,10 +133,10 @@ type Overload struct { // function is the actual function to call to invoke this overload function FunctionInterface // paramTypes is the base IDs of the parameters that the function expects - paramTypes []pgtypes.DoltgresType + paramTypes []uint32 // argTypes is the base IDs of the parameters that the function expects, extended to match the number of args // provided in the case of a variadic function. - argTypes []pgtypes.DoltgresType + argTypes []uint32 // variadic is the index of the variadic parameter, or -1 if the function is not variadic variadic int }