Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use OID instead of DoltgresType struct for functions and casts #1025

Open
wants to merge 1 commit into
base: jennifer/type-changes
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions server/expression/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (array *Array) Eval(ctx *sql.Context, row sql.Row) (any, error) {
}

// We always cast the element, as there may be parameter restrictions in place
castFunc := framework.GetImplicitCast(doltgresType, resultTyp)
castFunc := framework.GetImplicitCast(doltgresType.OID, resultTyp.OID)
if castFunc == nil {
if doltgresType.OID == uint32(oid.T_unknown) {
castFunc = framework.UnknownLiteralCast
Expand Down Expand Up @@ -163,20 +163,21 @@ func (array *Array) WithResolvedChildren(children []any) (any, error) {
// getTargetType returns the evaluated type for this expression.
// Returns the "anyarray" type if the type combination is invalid.
func (array *Array) getTargetType(children ...sql.Expression) (pgtypes.DoltgresType, error) {
var childrenTypes []pgtypes.DoltgresType
var childrenTypeOids []uint32
for _, child := range children {
if child != nil {
childType, ok := child.Type().(pgtypes.DoltgresType)
if !ok {
// We use "anyarray" as the indeterminate/invalid type
return pgtypes.AnyArray, nil
}
childrenTypes = append(childrenTypes, childType)
childrenTypeOids = append(childrenTypeOids, childType.OID)
}
}
targetType, err := framework.FindCommonType(childrenTypes)
targetTypeOid, err := framework.FindCommonType(childrenTypeOids)
if err != nil {
return pgtypes.DoltgresType{}, fmt.Errorf("ARRAY %s", err.Error())
}
targetType := pgtypes.OidToBuiltInDoltgresType[targetTypeOid]
return targetType.ToArrayType(), nil
}
2 changes: 1 addition & 1 deletion server/expression/assignment_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (ac *AssignmentCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
if err != nil || val == nil {
return val, err
}
castFunc := framework.GetAssignmentCast(ac.fromType, ac.toType)
castFunc := framework.GetAssignmentCast(ac.fromType.OID, ac.toType.OID)
if castFunc == nil {
if ac.fromType.OID == uint32(oid.T_unknown) {
castFunc = framework.UnknownLiteralCast
Expand Down
2 changes: 1 addition & 1 deletion server/expression/explicit_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
return nil, nil
}

castFunction := framework.GetExplicitCast(fromType, c.castToType)
castFunction := framework.GetExplicitCast(fromType.OID, c.castToType.OID)
if castFunction == nil {
if fromType.OID == uint32(oid.T_unknown) {
castFunction = framework.UnknownLiteralCast
Expand Down
2 changes: 1 addition & 1 deletion server/expression/implicit_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (ic *ImplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) {
if err != nil || val == nil {
return val, err
}
castFunc := framework.GetImplicitCast(ic.fromType, ic.toType)
castFunc := framework.GetImplicitCast(ic.fromType.OID, ic.toType.OID)
if castFunc == nil {
return nil, fmt.Errorf("target is of type %s but expression is of type %s", ic.toType.String(), ic.fromType.String())
}
Expand Down
2 changes: 1 addition & 1 deletion server/functions/dolt_procedures.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func drainRowIter(ctx *sql.Context, rowIter sql.RowIter) (any, error) {
return nil, err
}

castFn := framework.GetExplicitCast(fromType, pgtypes.Text)
castFn := framework.GetExplicitCast(fromType.OID, pgtypes.Text.OID)
textVal, err := castFn(ctx, row[i], pgtypes.Text)
if err != nil {
return nil, err
Expand Down
54 changes: 33 additions & 21 deletions server/functions/framework/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type TypeCastFunction func(ctx *sql.Context, val any, targetType pgtypes.Doltgre

// getCastFunction is used to recursively call the cast function for when the inner logic sees that it has two array
// types. This sidesteps providing
type getCastFunction func(fromType pgtypes.DoltgresType, toType pgtypes.DoltgresType) TypeCastFunction
type getCastFunction func(fromType uint32, toType uint32) TypeCastFunction

// TypeCast is used to cast from one type to another.
type TypeCast struct {
Expand Down Expand Up @@ -109,34 +109,38 @@ func GetPotentialExplicitCasts(fromType uint32) []pgtypes.DoltgresType {
}

// GetPotentialAssignmentCasts returns all registered assignment and implicit type casts from the given type.
func GetPotentialAssignmentCasts(fromType uint32) []pgtypes.DoltgresType {
assignment := getPotentialCasts(assignmentTypeCastMutex, assignmentTypeCastsArray, fromType)
implicit := GetPotentialImplicitCasts(fromType)
func GetPotentialAssignmentCasts(fromTypeOid uint32) []pgtypes.DoltgresType {
assignment := getPotentialCasts(assignmentTypeCastMutex, assignmentTypeCastsArray, fromTypeOid)
implicit := GetPotentialImplicitCasts(fromTypeOid)
both := make([]pgtypes.DoltgresType, len(assignment)+len(implicit))
copy(both, assignment)
copy(both[len(assignment):], implicit)
return both
}

// GetPotentialImplicitCasts returns all registered implicit type casts from the given type.
func GetPotentialImplicitCasts(fromType uint32) []pgtypes.DoltgresType {
return getPotentialCasts(implicitTypeCastMutex, implicitTypeCastsArray, fromType)
func GetPotentialImplicitCasts(toTypeOid uint32) []pgtypes.DoltgresType {
return getPotentialCasts(implicitTypeCastMutex, implicitTypeCastsArray, toTypeOid)
}

// GetExplicitCast returns the explicit type cast function that will cast the "from" type to the "to" type. Returns nil
// if such a cast is not valid.
func GetExplicitCast(fromType pgtypes.DoltgresType, toType pgtypes.DoltgresType) TypeCastFunction {
if tcf := getCast(explicitTypeCastMutex, explicitTypeCastsMap, fromType, toType, GetExplicitCast); tcf != nil {
func GetExplicitCast(fromTypeOid, toTypeOid uint32) TypeCastFunction {
if tcf := getCast(explicitTypeCastMutex, explicitTypeCastsMap, fromTypeOid, toTypeOid, GetExplicitCast); tcf != nil {
return tcf
} else if tcf = getCast(assignmentTypeCastMutex, assignmentTypeCastsMap, fromType, toType, GetExplicitCast); tcf != nil {
} else if tcf = getCast(assignmentTypeCastMutex, assignmentTypeCastsMap, fromTypeOid, toTypeOid, GetExplicitCast); tcf != nil {
return tcf
} else if tcf = getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetExplicitCast); tcf != nil {
} else if tcf = getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromTypeOid, toTypeOid, GetExplicitCast); tcf != nil {
return tcf
}

fromType := pgtypes.OidToBuiltInDoltgresType[fromTypeOid]
toType := pgtypes.OidToBuiltInDoltgresType[toTypeOid]

// We check for the identity after checking the maps, as the identity may be overridden (such as for types that have
// parameters). If one of the types are a string type, then we do not use the identity, and use the I/O conversions
// below.
if fromType.OID == toType.OID && toType.TypCategory != pgtypes.TypeCategory_StringTypes && fromType.TypCategory != pgtypes.TypeCategory_StringTypes {
if fromTypeOid == toTypeOid && toType.TypCategory != pgtypes.TypeCategory_StringTypes && fromType.TypCategory != pgtypes.TypeCategory_StringTypes {
return identityCast
}
// All types have a built-in explicit cast from string types: https://www.postgresql.org/docs/15/sql-createcast.html
Expand Down Expand Up @@ -169,12 +173,16 @@ func GetExplicitCast(fromType pgtypes.DoltgresType, toType pgtypes.DoltgresType)

// GetAssignmentCast returns the assignment type cast function that will cast the "from" type to the "to" type. Returns
// nil if such a cast is not valid.
func GetAssignmentCast(fromType pgtypes.DoltgresType, toType pgtypes.DoltgresType) TypeCastFunction {
if tcf := getCast(assignmentTypeCastMutex, assignmentTypeCastsMap, fromType, toType, GetAssignmentCast); tcf != nil {
func GetAssignmentCast(fromTypeOid, toTypeOid uint32) TypeCastFunction {
if tcf := getCast(assignmentTypeCastMutex, assignmentTypeCastsMap, fromTypeOid, toTypeOid, GetAssignmentCast); tcf != nil {
return tcf
} else if tcf = getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetAssignmentCast); tcf != nil {
} else if tcf = getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromTypeOid, toTypeOid, GetAssignmentCast); tcf != nil {
return tcf
}

fromType := pgtypes.OidToBuiltInDoltgresType[fromTypeOid]
toType := pgtypes.OidToBuiltInDoltgresType[toTypeOid]

// We check for the identity after checking the maps, as the identity may be overridden (such as for types that have
// parameters). If the "to" type is a string type, then we do not use the identity, and use the I/O conversion below.
if fromType.OID == toType.OID && fromType.TypCategory != pgtypes.TypeCategory_StringTypes {
Expand All @@ -198,13 +206,13 @@ func GetAssignmentCast(fromType pgtypes.DoltgresType, toType pgtypes.DoltgresTyp

// GetImplicitCast returns the implicit type cast function that will cast the "from" type to the "to" type. Returns nil
// if such a cast is not valid.
func GetImplicitCast(fromType pgtypes.DoltgresType, toType pgtypes.DoltgresType) TypeCastFunction {
if tcf := getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromType, toType, GetImplicitCast); tcf != nil {
func GetImplicitCast(fromTypeOid, toTypeOid uint32) TypeCastFunction {
if tcf := getCast(implicitTypeCastMutex, implicitTypeCastsMap, fromTypeOid, toTypeOid, GetImplicitCast); tcf != nil {
return tcf
}
// We check for the identity after checking the maps, as the identity may be overridden (such as for types that have
// parameters).
if fromType.OID == toType.OID {
if fromTypeOid == toTypeOid {
return identityCast
}
return nil
Expand Down Expand Up @@ -244,21 +252,25 @@ func getPotentialCasts(mutex *sync.RWMutex, castArray map[uint32][]pgtypes.Doltg
// not valid.
func getCast(mutex *sync.RWMutex,
castMap map[uint32]map[uint32]TypeCastFunction,
fromType pgtypes.DoltgresType, toType pgtypes.DoltgresType, outerFunc getCastFunction) TypeCastFunction {
fromTypeOid, toTypeOid uint32, outerFunc getCastFunction) TypeCastFunction {
mutex.RLock()
defer mutex.RUnlock()

if toMap, ok := castMap[fromType.OID]; ok {
if f, ok := toMap[toType.OID]; ok {
if toMap, ok := castMap[fromTypeOid]; ok {
if f, ok := toMap[toTypeOid]; ok {
return f
}
}

fromType := pgtypes.OidToBuiltInDoltgresType[fromTypeOid]
toType := pgtypes.OidToBuiltInDoltgresType[toTypeOid]

// If there isn't a direct mapping, then we need to check if the types are array variants.
// As long as the base types are convertable, the array variants are also convertable.
if fromType.IsArrayType() && toType.IsArrayType() {
fromBaseType := fromType.ArrayBaseType()
toBaseType := toType.ArrayBaseType()
if baseCast := outerFunc(fromBaseType, toBaseType); baseCast != nil {
if baseCast := outerFunc(fromBaseType.OID, toBaseType.OID); baseCast != nil {
// We use a closure that can unwrap the slice, since conversion functions expect a singular non-nil value
return func(ctx *sql.Context, vals any, targetType pgtypes.DoltgresType) (any, error) {
var err error
Expand Down
51 changes: 28 additions & 23 deletions server/functions/framework/common_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,50 +24,55 @@ import (

// FindCommonType returns the common type that given types can convert to.
// https://www.postgresql.org/docs/15/typeconv-union-case.html
func FindCommonType(types []pgtypes.DoltgresType) (pgtypes.DoltgresType, error) {
var candidateType = pgtypes.Unknown
func FindCommonType(typOids []uint32) (uint32, error) {
var candidateTypeOid = pgtypes.Unknown.OID
var fail = false
for _, typ := range types {
if typ.OID == candidateType.OID {
for _, typOid := range typOids {
if typOid == candidateTypeOid {
continue
} else if candidateType.OID == uint32(oid.T_unknown) {
candidateType = typ
} else if candidateTypeOid == uint32(oid.T_unknown) {
candidateTypeOid = typOid
} else {
candidateType = pgtypes.Unknown
candidateTypeOid = pgtypes.Unknown.OID
fail = true
}
}
if !fail {
if candidateType.OID == uint32(oid.T_unknown) {
return pgtypes.Text, nil
if candidateTypeOid == uint32(oid.T_unknown) {
return pgtypes.Text.OID, nil
}
return candidateType, nil
return candidateTypeOid, nil
}
for _, typ := range types {
if candidateType.OID == uint32(oid.T_unknown) {
candidateType = typ
for _, typOid := range typOids {
if candidateTypeOid == uint32(oid.T_unknown) {
candidateTypeOid = typOid
}
if typ.OID != uint32(oid.T_unknown) && candidateType.TypCategory != typ.TypCategory {
return pgtypes.DoltgresType{}, fmt.Errorf("types %s and %s cannot be matched", candidateType.String(), typ.String())
candidateType := pgtypes.OidToBuiltInDoltgresType[candidateTypeOid]
typ := pgtypes.OidToBuiltInDoltgresType[typOid]
if typOid != uint32(oid.T_unknown) && candidateType.TypCategory != typ.TypCategory {
return 0, fmt.Errorf("types %s and %s cannot be matched", candidateType.String(), typ.String())
}
}

var preferredTypeFound = false
for _, typ := range types {
if typ.OID == uint32(oid.T_unknown) {
for _, typOid := range typOids {
if typOid == uint32(oid.T_unknown) {
continue
} else if GetImplicitCast(typ, candidateType) != nil {
} else if GetImplicitCast(typOid, candidateTypeOid) != nil {
continue
} else if GetImplicitCast(candidateType, typ) == nil {
return pgtypes.DoltgresType{}, fmt.Errorf("cannot find implicit cast function from %s to %s", candidateType.String(), typ.String())
} else if GetImplicitCast(candidateTypeOid, typOid) == nil {
candidateType := pgtypes.OidToBuiltInDoltgresType[candidateTypeOid]
typ := pgtypes.OidToBuiltInDoltgresType[typOid]
return 0, fmt.Errorf("cannot find implicit cast function from %s to %s", candidateType.String(), typ.String())
} else if !preferredTypeFound {
candidateType := pgtypes.OidToBuiltInDoltgresType[candidateTypeOid]
if candidateType.IsPreferred {
candidateType = typ
candidateTypeOid = typOid
preferredTypeFound = true
}
} else {
return pgtypes.DoltgresType{}, fmt.Errorf("found another preferred candidate type")
return 0, fmt.Errorf("found another preferred candidate type")
}
}
return candidateType, nil
return candidateTypeOid, nil
}
Loading
Loading