diff --git a/yaml.go b/yaml.go index 66c5c66..66a3d01 100644 --- a/yaml.go +++ b/yaml.go @@ -22,17 +22,53 @@ import ( // Marshal the object into JSON then converts JSON to YAML and returns the // YAML. func Marshal(o interface{}) ([]byte, error) { - j, err := json.Marshal(o) - if err != nil { - return nil, fmt.Errorf("error marshaling into JSON: %v", err) + var buf bytes.Buffer + err := NewEncoder(&buf).Encode(o) + return buf.Bytes(), err +} + +// An Encoder writes YAML values to an output stream. +type Encoder struct { + encoder *yaml.Encoder +} + +// NewEncoder returns a new encoder that writes to w. The Encoder should be closed after use to flush all data to w. +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{ + encoder: yaml.NewEncoder(w), } +} - y, err := JSONToYAML(j) - if err != nil { - return nil, fmt.Errorf("error converting JSON to YAML: %v", err) +// Encode writes the YAML encoding of obj to the stream. +// If multiple items are encoded to the stream, the second and subsequent document will be preceded with a "---" document separator, +// but the first will not. +// +// See the documentation for Marshal for details about the conversion of Go values to YAML. +func (e *Encoder) Encode(obj interface{}) error { + var buf bytes.Buffer + // Convert an object to the JSON. + if err := json.NewEncoder(&buf).Encode(obj); err != nil { + return fmt.Errorf("error encode into JSON: %w", err) + } + + if err := jsonToYAML(e.encoder, &buf); err != nil { + return fmt.Errorf("error encode into YAML: %w", err) } - return y, nil + return nil +} + +// SetIndent changes the used indentation used when encoding. +func (e *Encoder) SetIndent(spaces int) { + e.encoder.SetIndent(spaces) +} + +// Close closes the encoder by writing any remaining data. It does not write a stream terminating string "...". +func (e *Encoder) Close() (err error) { + if err := e.encoder.Close(); err != nil { + return fmt.Errorf("error closing encoder: %w", err) + } + return nil } // JSONOpt is a decoding option for decoding from JSON format. @@ -41,20 +77,57 @@ type JSONOpt func(*json.Decoder) *json.Decoder // Unmarshal converts YAML to JSON then uses JSON to unmarshal into an object, // optionally configuring the behavior of the JSON unmarshal. func Unmarshal(y []byte, o interface{}, opts ...JSONOpt) error { - dec := yaml.NewDecoder(bytes.NewReader(y)) - return unmarshal(dec, o, opts) + return NewDecoder(bytes.NewReader(y), opts...).Decode(o) +} + +// A Decoder reads and decodes YAML values from an input stream. +type Decoder struct { + opts []JSONOpt + decoder *yaml.Decoder +} + +// NewDecoder returns a new decoder that reads from r. +// +// The decoder introduces its own buffering and may read data from r beyond the YAML values requested. +// +// Options for the standard library json.Decoder can be optionally specified, e.g. to decode untyped numbers into json.Number instead of float64, +// or to disallow unknown fields (but for that purpose, see also UnmarshalStrict) +func NewDecoder(r io.Reader, opts ...JSONOpt) *Decoder { + return &Decoder{ + opts: opts, + decoder: yaml.NewDecoder(r), + } +} + +// Decode reads the next YAML-encoded value from its input and stores it in the value pointed to by o. +// +// See the documentation for Unmarshal for details about the conversion of YAML into a Go value. +func (dec *Decoder) Decode(o interface{}) error { + return unmarshal(dec.decoder, o, dec.opts) +} + +func disallowUnknownFields(d *json.Decoder) *json.Decoder { + d.DisallowUnknownFields() + return d +} + +// KnownFields ensures that the keys in decoded mappings to +// exist as fields in the struct being decoded into. +func (dec *Decoder) KnownFields() { + dec.decoder.KnownFields(true) + dec.opts = append(dec.opts, disallowUnknownFields) } func unmarshal(dec *yaml.Decoder, o interface{}, opts []JSONOpt) error { vo := reflect.ValueOf(o) j, err := yamlToJSON(dec, &vo) if err != nil { - return fmt.Errorf("error converting YAML to JSON: %v", err) + return fmt.Errorf("error converting YAML to JSON: %w", err) } err = jsonUnmarshal(bytes.NewReader(j), o, opts...) if err != nil { - return fmt.Errorf("error unmarshaling JSON: %v", err) + return fmt.Errorf("error unmarshaling JSON: %w", err) } return nil @@ -70,13 +143,19 @@ func jsonUnmarshal(r io.Reader, o interface{}, opts ...JSONOpt) error { d = opt(d) } if err := d.Decode(&o); err != nil { - return fmt.Errorf("while decoding JSON: %v", err) + return fmt.Errorf("while decoding JSON: %w", err) } return nil } // JSONToYAML converts JSON to YAML. func JSONToYAML(j []byte) ([]byte, error) { + var buf bytes.Buffer + err := jsonToYAML(yaml.NewEncoder(&buf), bytes.NewReader(j)) + return buf.Bytes(), err +} + +func jsonToYAML(e *yaml.Encoder, r io.Reader) error { // Convert the JSON to an object. var jsonObj interface{} // We are using yaml.Unmarshal here (instead of json.Unmarshal) because the @@ -84,13 +163,12 @@ func JSONToYAML(j []byte) ([]byte, error) { // etc.) when unmarshalling to interface{}, it just picks float64 // universally. go-yaml does go through the effort of picking the right // number type, so we can preserve number type throughout this process. - err := yaml.Unmarshal(j, &jsonObj) - if err != nil { - return nil, err + if err := yaml.NewDecoder(r).Decode(&jsonObj); err != nil { + return err } // Marshal this object into YAML. - return yaml.Marshal(jsonObj) + return e.Encode(jsonObj) } // YAMLToJSON converts YAML to JSON. Since JSON is a subset of YAML, diff --git a/yaml_test.go b/yaml_test.go index 749aca4..d7e1cdd 100644 --- a/yaml_test.go +++ b/yaml_test.go @@ -1,10 +1,12 @@ package yaml import ( + "bytes" + "errors" "fmt" + "io" "math" "reflect" - "runtime" "strconv" "strings" "testing" @@ -23,15 +25,29 @@ func TestMarshal(t *testing.T) { s := MarshalTest{"a", math.MaxInt64, math.MaxFloat32, math.MaxFloat64} e := []byte(fmt.Sprintf("A: a\nB: %d\nC: %s\nD: %s\n", int64(math.MaxInt64), f32String, f64String)) - y, err := Marshal(s) - if err != nil { - t.Errorf("error marshaling YAML: %v", err) - } + t.Run("Marshal", func(t *testing.T) { + y, err := Marshal(s) + if err != nil { + t.Errorf("error marshaling YAML: %v", err) + } + + if !reflect.DeepEqual(y, e) { + t.Errorf("marshal YAML was unsuccessful, expected: %#v, got: %#v", + string(e), string(y)) + } + }) + t.Run("Encode", func(t *testing.T) { + var buf bytes.Buffer + if err := NewEncoder(&buf).Encode(s); err != nil { + t.Errorf("error encoding YAML: %v", err) + } + + if y := buf.Bytes(); !reflect.DeepEqual(y, e) { + t.Errorf("encode YAML was unsuccessful, expected: %#v, got: %#v", + string(e), string(y)) + } + }) - if !reflect.DeepEqual(y, e) { - t.Errorf("marshal YAML was unsuccessful, expected: %#v, got: %#v", - string(e), string(y)) - } } type UnmarshalString struct { @@ -60,41 +76,41 @@ type NestedSlice struct { C *string } -func TestUnmarshal(t *testing.T) { +func TestUnmarshalStrict(t *testing.T) { y := []byte("a: 1") s1 := UnmarshalString{} e1 := UnmarshalString{A: "1"} - unmarshalEqual(t, y, &s1, &e1) + unmarshalEqual(t, y, &s1, &e1, true) y = []byte(`a: "1"`) s1 = UnmarshalString{} e1 = UnmarshalString{A: "1"} - unmarshalEqual(t, y, &s1, &e1) + unmarshalEqual(t, y, &s1, &e1, true) y = []byte("a: true") s1 = UnmarshalString{} e1 = UnmarshalString{A: "true"} - unmarshalEqual(t, y, &s1, &e1) + unmarshalEqual(t, y, &s1, &e1, true) y = []byte("a: 1") s1 = UnmarshalString{} e1 = UnmarshalString{A: "1"} - unmarshalEqual(t, y, &s1, &e1) + unmarshalEqual(t, y, &s1, &e1, true) y = []byte("a:\n a: 1") s2 := UnmarshalNestedString{} e2 := UnmarshalNestedString{NestedString{"1"}} - unmarshalEqual(t, y, &s2, &e2) + unmarshalEqual(t, y, &s2, &e2, true) y = []byte("a:\n - b: abc\n c: def\n - b: 123\n c: 456\n") s3 := UnmarshalSlice{} e3 := UnmarshalSlice{[]NestedSlice{{"abc", strPtr("def")}, {"123", strPtr("456")}}} - unmarshalEqual(t, y, &s3, &e3) + unmarshalEqual(t, y, &s3, &e3, true) y = []byte("a:\n b: 1") s4 := UnmarshalStringMap{} e4 := UnmarshalStringMap{map[string]string{"b": "1"}} - unmarshalEqual(t, y, &s4, &e4) + unmarshalEqual(t, y, &s4, &e4, true) y = []byte(` a: @@ -110,7 +126,7 @@ b: "a": {Name: "TestA"}, "b": {Name: "TestB"}, } - unmarshalEqual(t, y, &s5, &e5) + unmarshalEqual(t, y, &s5, &e5, true) } // TestUnmarshalNonStrict tests that we parse ambiguous YAML without error. @@ -145,69 +161,107 @@ func TestUnmarshalNonStrict(t *testing.T) { }, } { s := UnmarshalString{} - unmarshalEqual(t, tc.yaml, &s, &tc.want) + unmarshalEqual(t, tc.yaml, &s, &tc.want, false) } } -// prettyFunctionName converts a slice of JSONOpt function pointers to a human -// readable string representation. -func prettyFunctionName(opts []JSONOpt) []string { - var r []string - for _, o := range opts { - r = append(r, runtime.FuncForPC(reflect.ValueOf(o).Pointer()).Name()) - } - return r -} +func unmarshalEqual(t *testing.T, y []byte, s, e interface{}, knowFields bool) { + t.Run("Unmarshal", func(t *testing.T) { + if err := Unmarshal(y, s); err != nil { + t.Errorf("Unmarshal(%#q, s) = %v", string(y), err) + return + } -func unmarshalEqual(t *testing.T, y []byte, s, e interface{}, opts ...JSONOpt) { //nolint:unparam - t.Helper() - err := Unmarshal(y, s, opts...) - if err != nil { - t.Errorf("Unmarshal(%#q, s, %v) = %v", string(y), prettyFunctionName(opts), err) - return - } + if !reflect.DeepEqual(s, e) { + t.Errorf("Unmarshal(%#q, s) = %+#v; want %+#v", string(y), s, e) + } + }) + t.Run("Decode", func(t *testing.T) { + d := NewDecoder(bytes.NewReader(y)) + if knowFields { + d.KnownFields() + } + if err := d.Decode(s); err != nil { + t.Errorf("Decode(%#q, s) = %v", string(y), err) + return + } - if !reflect.DeepEqual(s, e) { - t.Errorf("Unmarshal(%#q, s, %v) = %+#v; want %+#v", string(y), prettyFunctionName(opts), s, e) - } + if !reflect.DeepEqual(s, e) { + t.Errorf("Decode(%#q, s) = %+#v; want %+#v", string(y), s, e) + } + }) } // TestUnmarshalErrors tests that we return an error on ambiguous YAML. func TestUnmarshalErrors(t *testing.T) { for _, tc := range []struct { - yaml []byte - wantErr string + name string + yaml []byte + knowFields bool + wantErr string }{ { - // Declaring `a` twice produces an error. + name: "Declaring `a` twice produces an error", yaml: []byte("a: 1\na: 2"), wantErr: `key "a" already defined`, }, { - // Not ignoring first declaration of A with wrong type. + name: "Not ignoring first declaration of A with wrong type", yaml: []byte("a: [1,2,3]\na: value-of-a"), wantErr: `key "a" already defined`, }, { - // Declaring field `true` twice. + name: "Declaring field `true` twice", yaml: []byte("true: string-value-of-yes\ntrue: 1"), wantErr: `key "true" already defined`, }, + { + name: "Declaring unknown C field", + yaml: []byte("C: 1"), + knowFields: true, + wantErr: `json: unknown field "C"`, + }, } { - s := UnmarshalString{} - err := Unmarshal(tc.yaml, &s) - if tc.wantErr != "" && err == nil { - t.Errorf("Unmarshal(%#q, &s) = nil; want error", string(tc.yaml)) - continue - } - if tc.wantErr == "" && err != nil { - t.Errorf("Unmarshal(%#q, &s) = %v; want no error", string(tc.yaml), err) - continue - } - // We only expect errors during unmarshalling YAML. - if tc.wantErr != "" && !strings.Contains(err.Error(), tc.wantErr) { - t.Errorf("Unmarshal(%#q, &s) = %v; want err contains %#q", string(tc.yaml), err, tc.wantErr) - } + t.Run(tc.name+": unmarshal", func(t *testing.T) { + if tc.knowFields { + t.Skip("decoder only tc") + return + } + s := UnmarshalString{} + err := Unmarshal(tc.yaml, &s) + if tc.wantErr != "" && err == nil { + t.Errorf("Unmarshal(%#q, &s) = nil; want error", string(tc.yaml)) + return + } + if tc.wantErr == "" && err != nil { + t.Errorf("Unmarshal(%#q, &s) = %v; want no error", string(tc.yaml), err) + return + } + // We only expect errors during unmarshalling YAML. + if tc.wantErr != "" && !strings.Contains(err.Error(), tc.wantErr) { + t.Errorf("Unmarshal(%#q, &s) = %v; want err contains %#q", string(tc.yaml), err, tc.wantErr) + } + }) + t.Run(tc.name+": decode", func(t *testing.T) { + s := UnmarshalString{} + d := NewDecoder(bytes.NewReader(tc.yaml)) + if tc.knowFields { + d.KnownFields() + } + err := d.Decode(&s) + if tc.wantErr != "" && err == nil { + t.Errorf("Unmarshal(%#q, &s) = nil; want error", string(tc.yaml)) + return + } + if tc.wantErr == "" && err != nil { + t.Errorf("Unmarshal(%#q, &s) = %v; want no error", string(tc.yaml), err) + return + } + // We only expect errors during unmarshalling YAML. + if tc.wantErr != "" && !strings.Contains(err.Error(), tc.wantErr) { + t.Errorf("Unmarshal(%#q, &s) = %v; want err contains %#q", string(tc.yaml), err, tc.wantErr) + } + }) } } @@ -396,3 +450,64 @@ func TestYAMLToJSONDuplicateFields(t *testing.T) { t.Error("expected YAMLtoJSON to fail on duplicate field names") } } + +type MultiDoc struct { + Test int `json:"test"` +} + +func TestMultiDocDecode(t *testing.T) { + data := `--- +test: 1 +--- +test: 2 +--- +test: 3 +` + decoder := NewDecoder(strings.NewReader(data)) + + for i := 1; i < 4; i++ { + var obj MultiDoc + if err := decoder.Decode(&obj); err != nil { + t.Errorf("decode #%d failed: %s", i, err) + } + if obj.Test != i { + t.Errorf("decoded #%d has incorrect value %#v", i, obj) + } + } + var obj MultiDoc + if err := decoder.Decode(&obj); !errors.Is(err, io.EOF) { + t.Errorf("decode should return EOF but got: %s", err) + } +} + +func TestMultiDocEncode(t *testing.T) { + docs := []MultiDoc{ + { + Test: 1, + }, + { + Test: 2, + }, + { + Test: 3, + }, + } + expected := `test: 1 +--- +test: 2 +--- +test: 3 +` + + var buf bytes.Buffer + encoder := NewEncoder(&buf) + for _, obj := range docs { + if err := encoder.Encode(obj); err != nil { + t.Errorf("encode object %#v failed: %s", obj, err) + } + } + if encoded := buf.String(); encoded != expected { + t.Errorf("expected %s, but got %s", expected, encoded) + + } +}