From bd1e279069881a5f3e8527573ffec22b7562c551 Mon Sep 17 00:00:00 2001 From: Sergey Vilgelm Date: Tue, 17 Dec 2024 20:42:23 -0800 Subject: [PATCH] parse objects --- README.md | 2 + bool_or_schema.go | 2 + components.go | 84 +++++---- parser.go | 281 ++++++++++++++++++++++++++++++ parser_test.go | 433 ++++++++++++++++++++++++++++++++++++++++++++++ schema.go | 273 ++++++++++++++++++++++++++++- types.go | 8 +- validation.go | 21 ++- 8 files changed, 1063 insertions(+), 41 deletions(-) create mode 100644 parser.go create mode 100644 parser_test.go diff --git a/README.md b/README.md index 0af16d9..a21b309 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,8 @@ The implementation of OpenAPI v3.1 Specification for Go using generics. * `Validator.ValidateData()` method validates the data. * `Validator.ValidateDataAsJSON()` method validates the data by converting it into `map[string]any` type first using `json.Marshal` and `json.Unmarshal`. **WARNING**: the function is slow due to double conversion. + * Added `ParseObject` function to create `SchemaBuilder` by parsing an object. + The function supports `json`, `yaml` and `openapi` field tags for the structs. * Use OpenAPI `v3.1.1` by default. ## Features diff --git a/bool_or_schema.go b/bool_or_schema.go index b18e231..9866d7e 100644 --- a/bool_or_schema.go +++ b/bool_or_schema.go @@ -78,6 +78,8 @@ func NewBoolOrSchema(v any) *BoolOrSchema { return &BoolOrSchema{Allowed: v} case *RefOrSpec[Schema]: return &BoolOrSchema{Schema: v} + case *SchemaBulder: + return &BoolOrSchema{Schema: v.Build()} default: return nil } diff --git a/components.go b/components.go index 68b5085..3f157df 100644 --- a/components.go +++ b/components.go @@ -1,5 +1,9 @@ package openapi +import ( + "regexp" +) + // Components holds a set of reusable objects for different aspects of the OAS. // All objects defined within the components object will have no effect on the API unless they are explicitly referenced // from properties outside the components object. @@ -160,57 +164,77 @@ func (o *Components) Add(name string, v any) *Components { return o } +var namePattern = regexp.MustCompile(`^[a-zA-Z0-9\.\-_]+$`) + func (o *Components) validateSpec(location string, validator *Validator) []*validationError { var errs []*validationError - if o.Schemas != nil { - for k, v := range o.Schemas { - errs = append(errs, v.validateSpec(joinLoc(location, "schemas", k), validator)...) + for k, v := range o.Schemas { + if !namePattern.MatchString(k) { + errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String())) } + errs = append(errs, v.validateSpec(joinLoc(location, "schemas", k), validator)...) } - if o.Responses != nil { - for k, v := range o.Responses { - errs = append(errs, v.validateSpec(joinLoc(location, "responses", k), validator)...) + + for k, v := range o.Responses { + if !namePattern.MatchString(k) { + errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String())) } + errs = append(errs, v.validateSpec(joinLoc(location, "responses", k), validator)...) } - if o.Parameters != nil { - for k, v := range o.Parameters { - errs = append(errs, v.validateSpec(joinLoc(location, "parameters", k), validator)...) + for k, v := range o.Parameters { + if !namePattern.MatchString(k) { + errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String())) } + errs = append(errs, v.validateSpec(joinLoc(location, "parameters", k), validator)...) } - if o.Examples != nil { - for k, v := range o.Examples { - errs = append(errs, v.validateSpec(joinLoc(location, "examples", k), validator)...) + + for k, v := range o.Examples { + if !namePattern.MatchString(k) { + errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String())) } + errs = append(errs, v.validateSpec(joinLoc(location, "examples", k), validator)...) } - if o.RequestBodies != nil { - for k, v := range o.RequestBodies { - errs = append(errs, v.validateSpec(joinLoc(location, "requestBodies", k), validator)...) + + for k, v := range o.RequestBodies { + if !namePattern.MatchString(k) { + errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String())) } + errs = append(errs, v.validateSpec(joinLoc(location, "requestBodies", k), validator)...) } - if o.Headers != nil { - for k, v := range o.Headers { - errs = append(errs, v.validateSpec(joinLoc(location, "headers", k), validator)...) + + for k, v := range o.Headers { + if !namePattern.MatchString(k) { + errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String())) } + errs = append(errs, v.validateSpec(joinLoc(location, "headers", k), validator)...) } - if o.SecuritySchemes != nil { - for k, v := range o.SecuritySchemes { - errs = append(errs, v.validateSpec(joinLoc(location, "securitySchemes", k), validator)...) + + for k, v := range o.SecuritySchemes { + if !namePattern.MatchString(k) { + errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String())) } + errs = append(errs, v.validateSpec(joinLoc(location, "securitySchemes", k), validator)...) } - if o.Links != nil { - for k, v := range o.Links { - errs = append(errs, v.validateSpec(joinLoc(location, "links", k), validator)...) + + for k, v := range o.Links { + if !namePattern.MatchString(k) { + errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String())) } + errs = append(errs, v.validateSpec(joinLoc(location, "links", k), validator)...) } - if o.Callbacks != nil { - for k, v := range o.Callbacks { - errs = append(errs, v.validateSpec(joinLoc(location, "callbacks", k), validator)...) + + for k, v := range o.Callbacks { + if !namePattern.MatchString(k) { + errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String())) } + errs = append(errs, v.validateSpec(joinLoc(location, "callbacks", k), validator)...) } - if o.Paths != nil { - for k, v := range o.Paths { - errs = append(errs, v.validateSpec(joinLoc(location, "paths", k), validator)...) + + for k, v := range o.Paths { + if !namePattern.MatchString(k) { + errs = append(errs, newValidationError(joinLoc(location, "schemas", k), "invalid name %q, must match %q", k, namePattern.String())) } + errs = append(errs, v.validateSpec(joinLoc(location, "paths", k), validator)...) } return errs diff --git a/parser.go b/parser.go new file mode 100644 index 0000000..6971200 --- /dev/null +++ b/parser.go @@ -0,0 +1,281 @@ +package openapi + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" +) + +const is64Bit = uint64(^uintptr(0)) == ^uint64(0) + +// ParseObject parses the object and returns the schema or the reference to the schema. +// +// The object can be a struct, pointer to struct, map, slice, pointer to map or slice, or any other type. +// The object can contain fields with `json`, `yaml` or `openapi` tags. +// +// `opanapi:"[,ref: || any other tags]"` tag: +// - is the name of the field in the schema, can be "-" to skip the field or empty to use the name from json, yaml tags or original field name. +// json schema fields: +// - ref: is a reference to the schema, can not be used with jsonschema fields. +// - required, marks the field as required by adding it to the required list of the parent schema. +// - deprecated, marks the field as deprecated. +// - title:, sets the title of the field or summary for the fereference. +// - summary:<summary>, sets the summary of the reference. +// - description:<description>, sets the description of the field. +// - type:<type> (boolean, integer, number, string, array, object), may be used multiple times. +// The first usage overrides the default type, all other types are added. +// - addtype:<type>, adds additional type, may be used multiple times. +// - format:<format>, sets the format of the type. +// +// The `components` parameter is needed to store the schemas of the structs, and to avoid the circular references. +// In case of the given object is struct, the function will return a reference to the schema stored in the components +// Otherwise, the function will return the schema itself. +func ParseObject(obj any, components *Extendable[Components]) (*SchemaBulder, error) { + t := reflect.TypeOf(obj) + if t == nil { + return NewSchemaBuilder().Type(NullType).GoType("nil"), nil + } + value := reflect.ValueOf(obj) + return parseObject(joinLoc("", t.String()), value, components) +} + +func parseObject(location string, obj reflect.Value, components *Extendable[Components]) (*SchemaBulder, error) { + t := obj.Type() + if t == nil { + return NewSchemaBuilder().Type(NullType).GoType("nil"), nil + } + kind := t.Kind() + if kind == reflect.Ptr { + builder, err := parseObject(location, obj.Elem(), components) + if err != nil { + return nil, err + } + if builder.IsRef() { + builder = NewSchemaBuilder().OneOf( + builder.Build(), + NewSchemaBuilder().Type(NullType).Build(), + ) + } else { + builder.AddType(NullType) + } + return builder, nil + } + if kind == reflect.Interface { + return NewSchemaBuilder().GoType("any"), nil + } + builder := NewSchemaBuilder().GoType(fmt.Sprintf("%T", obj.Interface())) + switch obj.Interface().(type) { + case bool: + builder.Type(BooleanType) + case int, uint: + if is64Bit { + builder.Type(IntegerType).Format(Int64Format) + } else { + builder.Type(IntegerType).Format(Int32Format) + } + case int8, int16, int32, uint8, uint16, uint32: + builder.Type(IntegerType).Format(Int32Format) + case int64, uint64: + builder.Type(IntegerType).Format(Int64Format) + case float32: + builder.Type(NumberType).Format(FloatFormat) + case float64: + builder.Type(NumberType).Format(DoubleFormat) + case string: + builder.Type(StringType) + case []byte: + builder.Type(StringType).ContentEncoding(Base64Encoding).GoType("[]byte") // TODO: create an option for default ContentEncoding + case json.Number: + builder.Type(NumberType).GoPackage(t.PkgPath()) + case json.RawMessage: + builder.Type(StringType).ContentMediaType("application/json").GoPackage(t.PkgPath()) + default: + switch kind { + case reflect.Array, reflect.Slice: + var elemSchema any + if t.Elem().Kind() == reflect.Interface { + elemSchema = true + } else { + var ( + err error + newElem reflect.Value + ) + if t.Elem().Kind() == reflect.Ptr { + newElem = reflect.New(t.Elem()) + } else { + newElem = reflect.New(t.Elem()).Elem() + } + elemSchema, err = parseObject(location, newElem, components) + if err != nil { + return nil, err + } + } + builder.Type(ArrayType).Items(NewBoolOrSchema(elemSchema)).GoType("") + case reflect.Map: + if k := t.Key().Kind(); k != reflect.String { + return nil, fmt.Errorf("%s: unsupported map key type %s, expected string", location, k) + } + var elemSchema any + if t.Elem().Kind() == reflect.Interface { + elemSchema = true + } else { + var ( + err error + newElem reflect.Value + ) + if t.Elem().Kind() == reflect.Ptr { + newElem = reflect.New(t.Elem().Elem()) + } else { + newElem = reflect.New(t.Elem()).Elem() + } + elemSchema, err = parseObject(location, newElem, components) + if err != nil { + return nil, err + } + } + builder.Type(ObjectType).AdditionalProperties(NewBoolOrSchema(elemSchema)).GoType("") + case reflect.Struct: + objName := strings.ReplaceAll(t.PkgPath()+"."+t.Name(), "/", ".") + if components.Spec.Schemas[objName] != nil { + return NewSchemaBuilder().Ref("#/components/schemas/" + objName), nil + } + // add a temporary schema to avoid circular references + if components.Spec.Schemas == nil { + components.Spec.Schemas = make(map[string]*RefOrSpec[Schema], 1) + } + // reserve the name of the schema + components.Spec.Schemas[objName] = NewSchemaBuilder().Ref("to be deleted").Build() + var allOf []*RefOrSpec[Schema] + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + // skip unexported fields + if !field.IsExported() { + continue + } + fieldSchema, err := parseObject(joinLoc(location, field.Name), obj.Field(i), components) + if err != nil { + // remove the temporary schema + delete(components.Spec.Schemas, objName) + return nil, err + } + if field.Anonymous { + allOf = append(allOf, fieldSchema.Build()) + continue + } + name := applyTag(field, fieldSchema, builder) + // skip the field if it's marked as "-" + if name == "-" { + continue + } + builder.AddProperty(name, fieldSchema.Build()) + } + if len(allOf) > 0 { + allOf = append(allOf, builder.Type(ObjectType).GoType("").Build()) + builder = NewSchemaBuilder().AllOf(allOf...).GoType(t.String()) + } else { + builder.Type(ObjectType) + } + builder.GoPackage(t.PkgPath()) + components.Spec.Schemas[objName] = builder.Build() + builder = NewSchemaBuilder().Ref("#/components/schemas/" + objName) + } + } + + return builder, nil +} + +func applyTag(field reflect.StructField, schema *SchemaBulder, parent *SchemaBulder) (name string) { + name = field.Name + + for _, tagName := range []string{"json", "yaml"} { + if tag, ok := field.Tag.Lookup(tagName); ok { + parts := strings.SplitN(tag, ",", 2) + if len(parts) > 0 { + part := strings.TrimSpace(parts[0]) + if part != "" { + name = part + break + } + } + } + } + + tag, ok := field.Tag.Lookup("openapi") + if !ok { + return + } + parts := strings.Split(tag, ",") + if len(parts) == 0 { + return + } + + if parts[0] != "" { + name = parts[0] + } + if name == "-" { + return parts[0] + } + parts = parts[1:] + if len(parts) == 0 { + return + } + + if strings.HasPrefix("ref:", parts[0]) { + schema.Ref(parts[0][4:]) + } + + var isTypeOverriden bool + + for _, part := range parts { + prefixIndex := strings.Index(part, ":") + var prefix string + if prefixIndex == -1 { + prefix = part + } else { + prefix = part[:prefixIndex] + if prefixIndex == len(part)-1 { + part = "" + } + part = part[prefixIndex+1:] + } + + // the tags for the references only + if schema.IsRef() { + switch prefix { + case "required": + parent.AddRequired(name) + case "description": + schema.Description(part) + case "title", "summary": + schema.Title(part) + } + continue + } + + switch prefix { + case "required": + parent.AddRequired(name) + case "deprecated": + schema.Deprecated(true) + case "title": + schema.Title(part) + case "description": + schema.Description(part) + case "type": + // first type overrides the default type, all other types are added + if !isTypeOverriden { + schema.Type(part) + isTypeOverriden = true + } else { + schema.AddType(part) + } + case "addtype": + schema.AddType(part) + case "format": + schema.Format(part) + } + } + + return +} diff --git a/parser_test.go b/parser_test.go new file mode 100644 index 0000000..c32ddf8 --- /dev/null +++ b/parser_test.go @@ -0,0 +1,433 @@ +package openapi_test + +import ( + "encoding/json" + "github.com/stretchr/testify/require" + "github.com/sv-tools/openapi" + "testing" +) + +type Simple struct { + Fs string `json:"fs,omitempty" yaml:"FS,omitempty" openapi:",format:password"` // json name should be used + Fi int `yaml:"FI,omitempty"` // yaml name shuld be used + Fb *bool + Fbs []byte `json:"fbs,omitempty" yaml:"FS,omitempty" openapi:"fBS"` // openapi name should be used + Fm map[string]string `openapi:",required,title:Map of strings,addtype:null"` // default field name should be used + Excluded1 map[string]string `openapi:"-"` + Excluded2 map[string]string `json:"-"` + Excluded3 map[string]string `yaml:"-"` + Fa any `openapi:",deprecated"` + + fp string +} + +type Complex struct { + Simple // anonymous field + Next *Complex `json:"Next"` // circular references +} + +func TestParseObject(t *testing.T) { + trueVar := true + strVar := "foo" + + for _, tt := range []struct { + name string + obj any + expected *openapi.RefOrSpec[openapi.Schema] + expectedComponents *openapi.Components + err string + }{ + { + name: "nil", + obj: nil, + expected: openapi.NewSchemaBuilder(). + Type(openapi.NullType). + GoType("nil").Build(), + }, + { + name: "bool true", + obj: true, + expected: openapi.NewSchemaBuilder(). + Type(openapi.BooleanType). + GoType("bool").Build(), + }, + { + name: "bool false", + obj: false, + expected: openapi.NewSchemaBuilder(). + Type(openapi.BooleanType). + GoType("bool").Build(), + }, + { + name: "ptr to bool true", + obj: &trueVar, + expected: openapi.NewSchemaBuilder(). + Type(openapi.BooleanType, openapi.NullType). + GoType("bool").Build(), + }, + { + name: "int", + obj: 42, + expected: openapi.NewSchemaBuilder(). + Type(openapi.IntegerType). + Format(openapi.Int64Format). + GoType("int").Build(), + }, + { + name: "int8", + obj: int8(42), + expected: openapi.NewSchemaBuilder(). + Type(openapi.IntegerType). + Format(openapi.Int32Format). + GoType("int8").Build(), + }, + { + name: "int32", + obj: int32(42), + expected: openapi.NewSchemaBuilder(). + Type(openapi.IntegerType). + Format(openapi.Int32Format). + GoType("int32").Build(), + }, + { + name: "int64", + obj: int64(42), + expected: openapi.NewSchemaBuilder(). + Type(openapi.IntegerType). + Format(openapi.Int64Format). + GoType("int64").Build(), + }, + { + name: "uint", + obj: uint(42), + expected: openapi.NewSchemaBuilder(). + Type(openapi.IntegerType). + Format(openapi.Int64Format). + GoType("uint").Build(), + }, + { + name: "uint8", + obj: uint8(42), + expected: openapi.NewSchemaBuilder(). + Type(openapi.IntegerType). + Format(openapi.Int32Format). + GoType("uint8").Build(), + }, + { + name: "uint32", + obj: uint32(42), + expected: openapi.NewSchemaBuilder(). + Type(openapi.IntegerType). + Format(openapi.Int32Format). + GoType("uint32").Build(), + }, + { + name: "uint64", + obj: uint64(42), + expected: openapi.NewSchemaBuilder(). + Type(openapi.IntegerType). + Format(openapi.Int64Format). + GoType("uint64").Build(), + }, + { + name: "float32", + obj: float32(42), + expected: openapi.NewSchemaBuilder(). + Type(openapi.NumberType). + Format(openapi.FloatFormat). + GoType("float32").Build(), + }, + { + name: "float64", + obj: float64(42), + expected: openapi.NewSchemaBuilder(). + Type(openapi.NumberType). + Format(openapi.DoubleFormat). + GoType("float64").Build(), + }, + { + name: "string", + obj: "foo", + expected: openapi.NewSchemaBuilder(). + Type(openapi.StringType). + GoType("string").Build(), + }, + { + name: "bytes", + obj: []byte("foo"), + expected: openapi.NewSchemaBuilder(). + Type(openapi.StringType). + ContentEncoding(openapi.Base64Encoding). + GoType("[]byte").Build(), + }, + { + name: "map string string", + obj: map[string]string{"foo": "bar"}, + expected: openapi.NewSchemaBuilder().Type(openapi.ObjectType). + AdditionalProperties(openapi.NewBoolOrSchema(openapi.NewSchemaBuilder(). + Type(openapi.StringType). + GoType("string").Build(), + )).Build(), + }, + { + name: "map string ref string", + obj: map[string]*string{"foo": &strVar}, + expected: openapi.NewSchemaBuilder().Type(openapi.ObjectType). + AdditionalProperties(openapi.NewBoolOrSchema(openapi.NewSchemaBuilder(). + Type(openapi.StringType, openapi.NullType). + GoType("string").Build(), + )).Build(), + }, + { + name: "map string int", + obj: map[string]int{"foo": 42}, + expected: openapi.NewSchemaBuilder(). + Type(openapi.ObjectType). + AdditionalProperties(openapi.NewBoolOrSchema(openapi.NewSchemaBuilder(). + Type(openapi.IntegerType). + Format(openapi.Int64Format). + GoType("int").Build(), + )).Build(), + }, + { + name: "map string any", + obj: map[string]any{"foo": 42, "bar": "baz"}, + expected: openapi.NewSchemaBuilder(). + Type(openapi.ObjectType). + AdditionalProperties(openapi.NewBoolOrSchema(true)). + Build(), + }, + { + name: "slice int", + obj: []int{42}, + expected: openapi.NewSchemaBuilder(). + Type(openapi.ArrayType). + Items(openapi.NewBoolOrSchema(openapi.NewSchemaBuilder(). + Type(openapi.IntegerType). + Format(openapi.Int64Format). + GoType("int").Build(), + )).Build(), + }, + { + name: "slice string", + obj: []string{"foo"}, + expected: openapi.NewSchemaBuilder(). + Type(openapi.ArrayType). + Items(openapi.NewBoolOrSchema(openapi.NewSchemaBuilder(). + Type(openapi.StringType). + GoType("string").Build(), + )).Build(), + }, + { + name: "slice any", + obj: []any{"foo", 42}, + expected: openapi.NewSchemaBuilder().Type(openapi.ArrayType). + Items(openapi.NewBoolOrSchema(true)). + Build(), + }, + { + name: "double slice any", + obj: [][]any{{"foo", 42}}, + expected: openapi.NewSchemaBuilder(). + Type(openapi.ArrayType). + Items(openapi.NewBoolOrSchema(openapi.NewSchemaBuilder(). + Type(openapi.ArrayType). + Items(openapi.NewBoolOrSchema(true)). + Build(), + )).Build(), + }, + { + name: "triple slice any", + obj: [][][]any{{{"foo", 42}}}, + expected: openapi.NewSchemaBuilder(). + Type(openapi.ArrayType). + Items(openapi.NewBoolOrSchema(openapi.NewSchemaBuilder(). + Type(openapi.ArrayType). + Items(openapi.NewBoolOrSchema( + openapi.NewSchemaBuilder(). + Type(openapi.ArrayType). + Items(openapi.NewBoolOrSchema(true)). + Build(), + )).Build(), + )).Build(), + }, + { + name: "map map string any", + obj: map[string]map[string]any{"xyz": {"foo": 42, "bar": "baz"}}, + expected: openapi.NewSchemaBuilder(). + Type(openapi.ObjectType). + AdditionalProperties(openapi.NewBoolOrSchema(openapi.NewSchemaBuilder(). + Type(openapi.ObjectType). + AdditionalProperties(openapi.NewBoolOrSchema(true)). + Build(), + )).Build(), + }, + { + name: "slice map string any", + obj: []map[string]any{{"foo": 42, "bar": "baz"}}, + expected: openapi.NewSchemaBuilder(). + Type(openapi.ArrayType). + Items(openapi.NewBoolOrSchema(openapi.NewSchemaBuilder(). + Type(openapi.ObjectType). + AdditionalProperties(openapi.NewBoolOrSchema(true)). + Build(), + )).Build(), + }, + { + name: "json number", + obj: json.Number("42"), + expected: openapi.NewSchemaBuilder(). + Type(openapi.NumberType). + GoPackage("encoding/json").GoType("json.Number").Build(), + }, + { + name: "json raw", + obj: json.RawMessage(`"foo"`), + expected: openapi.NewSchemaBuilder(). + Type(openapi.StringType). + ContentMediaType("application/json"). + GoPackage("encoding/json").GoType("json.RawMessage").Build(), + }, + { + name: "simple struct", + obj: Simple{ + Fs: "foo", + Fi: 42, + Fb: &trueVar, + Fbs: []byte("bar"), + Fm: map[string]string{"baz": "qux"}, + Fa: []any{"435", 42, false}, + fp: "baz", + }, + expected: openapi.NewSchemaBuilder().Ref("#/components/schemas/github.com.sv-tools.openapi_test.Simple").Build(), + expectedComponents: openapi.NewComponents().Spec.Add( + "github.com.sv-tools.openapi_test.Simple", + openapi.NewSchemaBuilder(). + Type(openapi.ObjectType). + AddProperty("fs", openapi.NewSchemaBuilder().Type(openapi.StringType).GoType("string").Format("password").Build()). + AddProperty("FI", openapi.NewSchemaBuilder().Type(openapi.IntegerType).Format(openapi.Int64Format).GoType("int").Build()). + AddProperty("Fb", openapi.NewSchemaBuilder().Type(openapi.BooleanType, openapi.NullType).GoType("bool").Build()). + AddProperty("fBS", openapi.NewSchemaBuilder().Type(openapi.StringType).ContentEncoding(openapi.Base64Encoding).GoType("[]byte").Build()). + AddProperty("Fm", openapi.NewSchemaBuilder(). + Type(openapi.ObjectType, openapi.NullType). + Title("Map of strings"). + AdditionalProperties(openapi.NewBoolOrSchema(openapi.NewSchemaBuilder(). + Type(openapi.StringType). + GoType("string").Build(), + )).Build(), + ). + AddProperty("Fa", openapi.NewSchemaBuilder(). + Deprecated(true). + GoType("any").Build(), + ). + AddRequired("Fm"). + GoPackage("github.com/sv-tools/openapi_test").GoType("openapi_test.Simple").Build(), + ), + }, + { + name: "complex struct", + obj: Complex{ + Simple: Simple{ + Fs: "foo", + Fi: 42, + Fb: &trueVar, + Fbs: []byte("bar"), + Fm: map[string]string{"baz": "qux"}, + Fa: []any{"435", 42, false}, + fp: "baz", + }, + Next: &Complex{}, + }, + expected: openapi.NewSchemaBuilder().Ref("#/components/schemas/github.com.sv-tools.openapi_test.Complex").Build(), + expectedComponents: openapi.NewComponents().Spec.Add( + "github.com.sv-tools.openapi_test.Complex", + openapi.NewSchemaBuilder(). + AllOf( + openapi.NewSchemaBuilder().Ref("#/components/schemas/github.com.sv-tools.openapi_test.Simple").Build(), + openapi.NewSchemaBuilder(). + Type(openapi.ObjectType). + AddProperty("Next", openapi.NewSchemaBuilder(). + OneOf( + openapi.NewSchemaBuilder().Ref("#/components/schemas/github.com.sv-tools.openapi_test.Complex").Build(), + openapi.NewSchemaBuilder().Type(openapi.NullType).Build(), + ). + Build(), + ). + Build(), + ). + GoPackage("github.com/sv-tools/openapi_test").GoType("openapi_test.Complex").Build(), + ).Add( + "github.com.sv-tools.openapi_test.Simple", + openapi.NewSchemaBuilder(). + Type(openapi.ObjectType). + AddProperty("fs", openapi.NewSchemaBuilder().Type(openapi.StringType).GoType("string").Format("password").Build()). + AddProperty("FI", openapi.NewSchemaBuilder().Type(openapi.IntegerType).Format(openapi.Int64Format).GoType("int").Build()). + AddProperty("Fb", openapi.NewSchemaBuilder().Type(openapi.BooleanType, openapi.NullType).GoType("bool").Build()). + AddProperty("fBS", openapi.NewSchemaBuilder().Type(openapi.StringType).ContentEncoding(openapi.Base64Encoding).GoType("[]byte").Build()). + AddProperty("Fm", openapi.NewSchemaBuilder(). + Type(openapi.ObjectType, openapi.NullType). + Title("Map of strings"). + AdditionalProperties(openapi.NewBoolOrSchema(openapi.NewSchemaBuilder(). + Type(openapi.StringType). + GoType("string").Build(), + )).Build(), + ). + AddProperty("Fa", openapi.NewSchemaBuilder(). + Deprecated(true). + GoType("any").Build(), + ). + AddRequired("Fm"). + GoPackage("github.com/sv-tools/openapi_test").GoType("openapi_test.Simple").Build(), + ), + }, + } { + t.Run(tt.name, func(t *testing.T) { + spec := openapi.NewOpenAPIBuilder(). + Info(openapi.NewInfoBuilder().Title("Test").Version("1.0").Build()). + Components(openapi.NewComponents()). + Build() + schema, err := openapi.ParseObject(tt.obj, spec.Spec.Components) + if tt.err != "" { + require.ErrorContains(t, err, tt.err) + return + } + require.NoError(t, err) + require.NotNil(t, schema) + + actual, err := schema.Build().MarshalJSON() + require.NoError(t, err) + + expected, err := tt.expected.MarshalJSON() + require.NoError(t, err) + + require.JSONEq(t, string(expected), string(actual)) + + if tt.expectedComponents != nil { + actualComponents, err := spec.Spec.Components.MarshalJSON() + require.NoError(t, err) + + expectedComponents, err := json.Marshal(tt.expectedComponents) + require.NoError(t, err) + + require.JSONEq(t, string(expectedComponents), string(actualComponents)) + } + + spec.Spec.Components.Spec.Add("test", schema.Build()) + validator, err := openapi.NewValidator( + spec, + openapi.AllowUnusedComponents(), + ) + require.NoError(t, err) + + require.NoError(t, validator.ValidateSpec()) + + value, err := openapi.ConvertToJSON(tt.obj) + require.NoError(t, err) + + pretty, _ := json.MarshalIndent(tt.obj, "", " ") + t.Logf("obj: %s", pretty) + + require.NoError(t, validator.ValidateData("#/components/schemas/test", value)) + }) + } +} diff --git a/schema.go b/schema.go index ffd9ce1..130ce0c 100644 --- a/schema.go +++ b/schema.go @@ -313,6 +313,13 @@ type Schema struct { Example any `json:"example,omitempty" yaml:"example,omitempty"` Extensions map[string]any `json:"-" yaml:"-"` + + // *** Go Fields *** + + // GoPackage is a custom field to store the Go package of the schema. + GoPackage string `json:"x-go-package,omitempty" yaml:"x-go-package,omitempty"` + // GoType is a custom field to store the Go type of the schema. + GoType string `json:"x-go-type,omitempty" yaml:"x-go-type,omitempty"` } // AddExt sets the extension and returns the current object (self|this). @@ -719,42 +726,77 @@ type SchemaBulder struct { func NewSchemaBuilder() *SchemaBulder { return &SchemaBulder{ - spec: NewRefOrSpec[Schema](&Schema{ - Schema: Draft202012, - }), + spec: NewRefOrSpec[Schema](&Schema{}), } } func (b *SchemaBulder) Build() *RefOrSpec[Schema] { + if b.spec.Ref != nil { + b.spec.Spec = nil + } return b.spec } func (b *SchemaBulder) Extensions(v map[string]any) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Extensions = v return b } func (b *SchemaBulder) AddExt(name string, value any) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.AddExt(name, value) return b } +func (b *SchemaBulder) Ref(v string) *SchemaBulder { + if b.spec.Ref == nil { + b.spec.Ref = &Ref{ + Summary: b.spec.Spec.Title, + Description: b.spec.Spec.Description, + } + b.spec.Spec = nil + } + b.spec.Ref.Ref = v + return b +} + +func (b *SchemaBulder) IsRef() bool { + return b.spec.Ref != nil +} + func (b *SchemaBulder) Schema(v string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Schema = v return b } func (b *SchemaBulder) ID(v string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.ID = v return b } func (b *SchemaBulder) Defs(v map[string]*RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Defs = v return b } func (b *SchemaBulder) AddDef(name string, value *RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } if b.spec.Spec.Defs == nil { b.spec.Spec.Defs = make(map[string]*RefOrSpec[Schema], 1) } @@ -763,16 +805,25 @@ func (b *SchemaBulder) AddDef(name string, value *RefOrSpec[Schema]) *SchemaBuld } func (b *SchemaBulder) DynamicRef(v string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.DynamicRef = v return b } func (b *SchemaBulder) Vocabulary(v map[string]bool) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Vocabulary = v return b } func (b *SchemaBulder) AddVocabulary(name string, value bool) *SchemaBulder { + if b.spec.Ref != nil { + return b + } if b.spec.Spec.Vocabulary == nil { b.spec.Spec.Vocabulary = make(map[string]bool, 1) } @@ -781,16 +832,25 @@ func (b *SchemaBulder) AddVocabulary(name string, value bool) *SchemaBulder { } func (b *SchemaBulder) DynamicAnchor(v string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.DynamicAnchor = v return b } func (b *SchemaBulder) Type(v ...string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Type = NewSingleOrArray[string](v...) return b } func (b *SchemaBulder) AddType(v ...string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } if b.spec.Spec.Type == nil { b.spec.Spec.Type = NewSingleOrArray[string](v...) } else { @@ -800,121 +860,195 @@ func (b *SchemaBulder) AddType(v ...string) *SchemaBulder { } func (b *SchemaBulder) Default(v any) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Default = v return b } func (b *SchemaBulder) Title(v string) *SchemaBulder { + if b.spec.Ref != nil { + b.spec.Ref.Summary = v + return b + } b.spec.Spec.Title = v return b } func (b *SchemaBulder) Description(v string) *SchemaBulder { + if b.spec.Ref != nil { + b.spec.Ref.Description = v + return b + } b.spec.Spec.Description = v return b } func (b *SchemaBulder) Const(v string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Const = v return b } func (b *SchemaBulder) Comment(v string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Comment = v return b } func (b *SchemaBulder) Enum(v ...any) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Enum = v return b } func (b *SchemaBulder) AddEnum(v ...any) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Enum = append(b.spec.Spec.Enum, v...) return b } func (b *SchemaBulder) Examples(v ...any) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Examples = v return b } func (b *SchemaBulder) AddExamples(v ...any) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Examples = append(b.spec.Spec.Examples, v...) return b } func (b *SchemaBulder) ReadOnly(v bool) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.ReadOnly = v return b } func (b *SchemaBulder) WriteOnly(v bool) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.WriteOnly = v return b } func (b *SchemaBulder) Deprecated(v bool) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Deprecated = v return b } func (b *SchemaBulder) ContentSchema(v *RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.ContentSchema = v return b } func (b *SchemaBulder) ContentMediaType(v string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.ContentMediaType = v return b } func (b *SchemaBulder) ContentEncoding(v string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.ContentEncoding = v return b } func (b *SchemaBulder) Not(v *RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Not = v return b } func (b *SchemaBulder) AllOf(v ...*RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.AllOf = v return b } func (b *SchemaBulder) AddAllOf(v ...*RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.AllOf = append(b.spec.Spec.AllOf, v...) return b } func (b *SchemaBulder) AnyOf(v ...*RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.AnyOf = v return b } func (b *SchemaBulder) AddAnyOf(v ...*RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.AnyOf = append(b.spec.Spec.AnyOf, v...) return b } func (b *SchemaBulder) OneOf(v ...*RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.OneOf = v return b } func (b *SchemaBulder) AddOneOf(v ...*RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.OneOf = append(b.spec.Spec.OneOf, v...) return b } func (b *SchemaBulder) DependentRequired(v map[string][]string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.DependentRequired = v return b } func (b *SchemaBulder) AddDependentRequired(name string, value ...string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } if b.spec.Spec.DependentRequired == nil { b.spec.Spec.DependentRequired = make(map[string][]string, 1) } @@ -923,11 +1057,17 @@ func (b *SchemaBulder) AddDependentRequired(name string, value ...string) *Schem } func (b *SchemaBulder) DependentSchemas(v map[string]*RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.DependentSchemas = v return b } func (b *SchemaBulder) AddDependentSchema(name string, value *RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } if b.spec.Spec.DependentSchemas == nil { b.spec.Spec.DependentSchemas = make(map[string]*RefOrSpec[Schema], 1) } @@ -936,121 +1076,193 @@ func (b *SchemaBulder) AddDependentSchema(name string, value *RefOrSpec[Schema]) } func (b *SchemaBulder) If(v *RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.If = v return b } func (b *SchemaBulder) Then(v *RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Then = v return b } func (b *SchemaBulder) Else(v *RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Else = v return b } func (b *SchemaBulder) MultipleOf(v int) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.MultipleOf = &v return b } func (b *SchemaBulder) Minimum(v int) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Minimum = &v return b } func (b *SchemaBulder) ExclusiveMinimum(v int) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.ExclusiveMinimum = &v return b } func (b *SchemaBulder) Maximum(v int) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Maximum = &v return b } func (b *SchemaBulder) ExclusiveMaximum(v int) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.ExclusiveMaximum = &v return b } func (b *SchemaBulder) MinLength(v int) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.MinLength = &v return b } func (b *SchemaBulder) MaxLength(v int) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.MaxLength = &v return b } func (b *SchemaBulder) Pattern(v string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Pattern = v return b } func (b *SchemaBulder) Format(v string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Format = v return b } func (b *SchemaBulder) Items(v *BoolOrSchema) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Items = v return b } func (b *SchemaBulder) MaxItems(v int) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.MaxItems = &v return b } func (b *SchemaBulder) UnevaluatedItems(v *BoolOrSchema) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.UnevaluatedItems = v return b } func (b *SchemaBulder) Contains(v *RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Contains = v return b } func (b *SchemaBulder) MinContains(v int) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.MinContains = &v return b } func (b *SchemaBulder) MaxContains(v int) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.MaxContains = &v return b } func (b *SchemaBulder) MinItems(v int) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.MinItems = &v return b } func (b *SchemaBulder) UniqueItems(v bool) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.UniqueItems = &v return b } func (b *SchemaBulder) PrefixItems(v ...*RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.PrefixItems = v return b } func (b *SchemaBulder) AddPrefixItems(v ...*RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.PrefixItems = append(b.spec.Spec.PrefixItems, v...) return b } func (b *SchemaBulder) Properties(v map[string]*RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Properties = v return b } func (b *SchemaBulder) AddProperty(name string, value *RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } if b.spec.Spec.Properties == nil { b.spec.Spec.Properties = make(map[string]*RefOrSpec[Schema], 1) } @@ -1059,11 +1271,17 @@ func (b *SchemaBulder) AddProperty(name string, value *RefOrSpec[Schema]) *Schem } func (b *SchemaBulder) PatternProperties(v map[string]*RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.PatternProperties = v return b } func (b *SchemaBulder) AddPatternProperty(name string, value *RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } if b.spec.Spec.PatternProperties == nil { b.spec.Spec.PatternProperties = make(map[string]*RefOrSpec[Schema], 1) } @@ -1072,56 +1290,105 @@ func (b *SchemaBulder) AddPatternProperty(name string, value *RefOrSpec[Schema]) } func (b *SchemaBulder) AdditionalProperties(v *BoolOrSchema) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.AdditionalProperties = v return b } func (b *SchemaBulder) UnevaluatedProperties(v *BoolOrSchema) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.UnevaluatedProperties = v return b } func (b *SchemaBulder) PropertyNames(v *RefOrSpec[Schema]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.PropertyNames = v return b } func (b *SchemaBulder) MinProperties(v int) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.MinProperties = &v return b } func (b *SchemaBulder) MaxProperties(v int) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.MaxProperties = &v return b } func (b *SchemaBulder) Required(v ...string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Required = v return b } func (b *SchemaBulder) AddRequired(v ...string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Required = append(b.spec.Spec.Required, v...) return b } func (b *SchemaBulder) Discriminator(v *Discriminator) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Discriminator = v return b } func (b *SchemaBulder) XML(v *Extendable[XML]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.XML = v return b } func (b *SchemaBulder) ExternalDocs(v *Extendable[ExternalDocs]) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.ExternalDocs = v return b } func (b *SchemaBulder) Example(v any) *SchemaBulder { + if b.spec.Ref != nil { + return b + } b.spec.Spec.Example = v return b } + +func (b *SchemaBulder) GoType(v string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } + b.spec.Spec.GoType = v + return b +} + +func (b *SchemaBulder) GoPackage(v string) *SchemaBulder { + if b.spec.Ref != nil { + return b + } + b.spec.Spec.GoPackage = v + return b +} diff --git a/types.go b/types.go index 8fabcc4..67e99e9 100644 --- a/types.go +++ b/types.go @@ -70,9 +70,13 @@ func GetType(v any) (string, error) { } func getKind(v any) reflect.Kind { - k := reflect.TypeOf(v).Kind() + t := reflect.TypeOf(v) + if t == nil { + return reflect.Invalid + } + k := t.Kind() if k == reflect.Ptr { - k = reflect.TypeOf(v).Elem().Kind() + k = t.Elem().Kind() } return k } diff --git a/validation.go b/validation.go index 92f0c0d..42eaafe 100644 --- a/validation.go +++ b/validation.go @@ -224,13 +224,10 @@ func (v *Validator) ValidateDataAsJSON(location string, value any) error { switch getKind(value) { // marshal and unmarshal the value to JSON representation (map[any]struct). case reflect.Struct: - data, err := json.Marshal(value) - if err != nil { - return fmt.Errorf("marshaling value failed: %w", err) - } - value, err = jsonschema.UnmarshalJSON(bytes.NewReader(data)) + var err error + value, err = ConvertToJSON(value) if err != nil { - return fmt.Errorf("unmarshaling value failed: %w", err) + return err } // check if the value is already a JSON, if not keep it as is. case reflect.String: @@ -241,3 +238,15 @@ func (v *Validator) ValidateDataAsJSON(location string, value any) error { } return v.ValidateData(location, value) } + +func ConvertToJSON(value any) (any, error) { + data, err := json.Marshal(value) + if err != nil { + return nil, fmt.Errorf("marshaling value failed: %w", err) + } + value, err = jsonschema.UnmarshalJSON(bytes.NewReader(data)) + if err != nil { + return nil, fmt.Errorf("unmarshaling value failed: %w", err) + } + return value, nil +}