diff --git a/internal/luai/decode.go b/internal/luai/decode.go index 907bdc2e..5f61ea08 100644 --- a/internal/luai/decode.go +++ b/internal/luai/decode.go @@ -88,13 +88,18 @@ func indirect(v reflect.Value) reflect.Value { func storeLiteral(value reflect.Value, lvalue lua.LValue) { value = indirect(value) - switch lvalue.Type() { - case lua.LTString: + + switch value.Kind() { + case reflect.String: value.SetString(lvalue.String()) - case lua.LTNumber: - value.SetInt(int64(lvalue.(lua.LNumber))) - case lua.LTBool: + case reflect.Bool: value.SetBool(bool(lvalue.(lua.LBool))) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + value.SetInt(int64(lvalue.(lua.LNumber))) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + value.SetUint(uint64(lvalue.(lua.LNumber))) + case reflect.Float32, reflect.Float64: + value.SetFloat(float64(lvalue.(lua.LNumber))) } } @@ -106,6 +111,15 @@ func objectInterface(lvalue *lua.LTable) any { return v } +// To unmarshal lua obj into an interface value, +// Unmarshal stores one of these in the interface value: +// +// - bool, for LTBool +// - float64, for LTNumber +// - string, for LTString +// - []interface{}, for LTTable arrays +// - map[string]interface{}, for LTTable objects +// - nil for LTNil func valueInterface(lvalue lua.LValue) any { switch lvalue.Type() { case lua.LTTable: @@ -117,7 +131,7 @@ func valueInterface(lvalue lua.LValue) any { case lua.LTString: return lvalue.String() case lua.LTNumber: - return int(lvalue.(lua.LNumber)) + return float64(lvalue.(lua.LNumber)) case lua.LTBool: return bool(lvalue.(lua.LBool)) } @@ -134,11 +148,10 @@ func arrayInterface(lvalue *lua.LTable) any { } func unmarshalWorker(value lua.LValue, reflected reflect.Value) error { + reflected = indirect(reflected) switch value.Type() { case lua.LTTable: - reflected = indirect(reflected) - tagMap := make(map[string]int) switch reflected.Kind() { case reflect.Interface: @@ -147,7 +160,7 @@ func unmarshalWorker(value lua.LValue, reflected reflect.Value) error { result := valueInterface(value) reflected.Set(reflect.ValueOf(result)) } - // map[T1]T2 where T1 is string, an integer type + // map[T1]T2 where T1 is string or an integer type case reflect.Map: t := reflected.Type() keyType := t.Key() @@ -239,6 +252,8 @@ func unmarshalWorker(value lua.LValue, reflected reflect.Value) error { reflected.Set(reflect.MakeSlice(reflected.Type(), 0, 0)) } case reflect.Struct: + tagMap := make(map[string]int) + for i := 0; i < reflected.NumField(); i++ { fieldTypeField := reflected.Type().Field(i) tag := fieldTypeField.Tag.Get("luai") diff --git a/internal/luai/encode.go b/internal/luai/encode.go index f85eaf3d..1cff0f64 100644 --- a/internal/luai/encode.go +++ b/internal/luai/encode.go @@ -98,7 +98,17 @@ func Marshal(state *lua.LState, v any) (lua.LValue, error) { return nil, err } - table.RawSetString(key.String(), value) + switch key.Kind() { + case reflect.String: + table.RawSetString(key.String(), value) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + table.RawSetInt(int(key.Int()), value) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + table.RawSetInt(int(key.Uint()), value) + default: + return nil, errors.New("marshal: unsupported type " + key.Kind().String() + " for key") + } + } return table, nil default: diff --git a/internal/luai/encoding_test.go b/internal/luai/encoding_test.go deleted file mode 100644 index ae304a5d..00000000 --- a/internal/luai/encoding_test.go +++ /dev/null @@ -1,313 +0,0 @@ -/* - * Copyright 2024 Han Li and contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package luai - -import ( - "fmt" - "reflect" - "testing" - - "github.com/version-fox/vfox/internal/logger" - lua "github.com/yuin/gopher-lua" -) - -func setupSuite(tb testing.TB) func(tb testing.TB) { - logger.SetLevel(logger.DebugLevel) - - return func(tb testing.TB) { - logger.SetLevel(logger.InfoLevel) - } -} - -type testStruct struct { - Field1 string - Field2 int - Field3 bool -} - -type testStructTag struct { - Field1 string `luai:"field1"` - Field2 int `luai:"field2"` - Field3 bool `luai:"field3"` -} - -type complexStruct struct { - Field1 string - Field2 int - Field3 bool - SimpleStruct *testStruct - Struct testStructTag - Map map[string]interface{} - Slice []any -} - -func TestEncoding(t *testing.T) { - teardownSuite := setupSuite(t) - defer teardownSuite(t) - - m := map[string]interface{}{ - "key1": "value1", - "key2": 2, - "key3": true, - } - - s := []any{"value1", 2, true} - - t.Run("Struct", func(t *testing.T) { - luaVm := lua.NewState() - defer luaVm.Close() - - test := testStruct{ - Field1: "test", - Field2: 1, - Field3: true, - } - - _table, err := Marshal(luaVm, &test) - if err != nil { - t.Fatal(err) - } - - luaVm.SetGlobal("table", _table) - - if err := luaVm.DoString(` - assert(table.Field1 == "test") - assert(table.Field2 == 1) - assert(table.Field3 == true) - print("lua Struct done") - `); err != nil { - t.Fatal(err) - } - - struct2 := testStruct{} - err = Unmarshal(_table, &struct2) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(test, struct2) { - t.Errorf("expected %+v, got %+v", test, struct2) - } - }) - - t.Run("Struct with Tag", func(t *testing.T) { - luaVm := lua.NewState() - defer luaVm.Close() - - test := testStructTag{ - Field1: "test", - Field2: 1, - Field3: true, - } - - _table, err := Marshal(luaVm, &test) - if err != nil { - t.Fatal(err) - } - - table := _table.(*lua.LTable) - - luaVm.SetGlobal("table", table) - if err := luaVm.DoString(` - assert(table.field1 == "test") - assert(table.field2 == 1) - assert(table.field3 == true) - print("lua Struct with Tag done") - `); err != nil { - t.Fatal(err) - } - - struct2 := testStructTag{} - err = Unmarshal(table, &struct2) - if err != nil { - t.Fatal(err) - } - - if !reflect.DeepEqual(test, struct2) { - t.Errorf("expected %+v, got %+v", test, struct2) - } - }) - - t.Run("Support Map, Slice and Any", func(t *testing.T) { - L := lua.NewState() - defer L.Close() - table, err := Marshal(L, m) - if err != nil { - t.Fatalf("marshal map failed: %v", err) - } - L.SetGlobal("m", table) - if err := L.DoString(` - assert(m.key1 == "value1") - assert(m.key2 == 2) - assert(m.key3 == true) - print("lua Map done") - `); err != nil { - t.Errorf("map test failed: %v", err) - } - - slice, err := Marshal(L, s) - if err != nil { - t.Fatalf("marshal slice failed: %v", err) - } - - L.SetGlobal("s", slice) - if err := L.DoString(` - assert(s[1] == "value1") - assert(s[2] == 2) - assert(s[3] == true) - print("lua Slice done") - `); err != nil { - t.Errorf("slice test failed: %v", err) - } - - // Unmarshal - - // Test case for map - m2 := map[string]any{} - - fmt.Println("==== start unmarshal ====") - - err = Unmarshal(table, &m2) - if err != nil { - t.Fatalf("unmarshal map failed: %v", err) - } - - fmt.Printf("m2: %+v\n", m2) - - if !reflect.DeepEqual(m, m2) { - t.Errorf("expected %+v, got %+v", m, m2) - } - - // Test case for slice - s2 := []any{} - - err = Unmarshal(slice, &s2) - if err != nil { - t.Fatalf("unmarshal slice failed: %v", err) - } - - fmt.Printf("s2: %+v\n", s2) - - if !reflect.DeepEqual(s, s2) { - t.Errorf("expected %+v, got %+v", s, s2) - } - - var s3 any - err = Unmarshal(slice, &s3) - if err != nil { - t.Fatalf("unmarshal slice failed: %v", err) - } - - if !reflect.DeepEqual(s, s3) { - t.Errorf("expected %+v, got %+v", s, s3) - } - }) - - t.Run("MapSliceStructUnified", func(t *testing.T) { - L := lua.NewState() - defer L.Close() - - input := complexStruct{ - Field1: "value1", - Field2: 123, - Field3: true, - Struct: testStructTag{ - Field1: "value1", - Field2: 2, - Field3: true, - }, - Map: m, - Slice: s, - } - - table, err := Marshal(L, input) - if err != nil { - t.Fatalf("marshal map failed: %v", err) - } - - L.SetGlobal("m", table) - - if err := L.DoString(` - assert(m.Field1 == "value1") - assert(m.Field2 == 123) - assert(m.Field3 == true) - assert(m.Struct.field1 == "value1") - assert(m.Struct.field2 == 2) - assert(m.Struct.field3 == true) - assert(m.Map.key1 == "value1") - assert(m.Map.key2 == 2) - assert(m.Map.key3 == true) - assert(m.Slice[1] == "value1") - assert(m.Slice[2] == 2) - assert(m.Slice[3] == true) - print("lua MapSliceStructUnified done") - `); err != nil { - t.Errorf("map test failed: %v", err) - } - - // Unmarshal - output := complexStruct{} - err = Unmarshal(table, &output) - if err != nil { - t.Fatalf("unmarshal map failed: %v", err) - } - - isEqual := reflect.DeepEqual(input, output) - if !isEqual { - t.Fatalf("expected %+v, got %+v", input, output) - } - - fmt.Printf("output: %+v\n", output) - - if !reflect.DeepEqual(input, output) { - t.Errorf("expected %+v, got %+v", input, output) - } - }) - - t.Run("TableWithEmptyField", func(t *testing.T) { - L := lua.NewState() - defer L.Close() - - output := struct { - Field1 string `luai:"field1"` - Field2 *string `luai:"field2"` - }{} - - if err := L.DoString(` - return { - field1 = "value1", - } - `); err != nil { - t.Errorf("map test failed: %v", err) - } - - table := L.ToTable(-1) // returned value - L.Pop(1) - // Unmarshal - err := Unmarshal(table, &output) - if err != nil { - t.Fatalf("unmarshal map failed: %v", err) - } - fmt.Printf("output: %+v\n", output) - if output.Field1 != "value1" { - t.Errorf("expected %+v, got %+v", "value1", output.Field1) - } - if output.Field2 != nil { - t.Errorf("expected %+v, got %+v", nil, output.Field2) - } - }) -} diff --git a/internal/luai/example_test.go b/internal/luai/example_test.go new file mode 100644 index 00000000..2ecb4b97 --- /dev/null +++ b/internal/luai/example_test.go @@ -0,0 +1,318 @@ +/* + * Copyright 2024 Han Li and contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package luai + +import ( + "fmt" + "reflect" + "testing" + + "github.com/version-fox/vfox/internal/logger" + lua "github.com/yuin/gopher-lua" +) + +func setupSuite(tb testing.TB) func(tb testing.TB) { + logger.SetLevel(logger.DebugLevel) + + return func(tb testing.TB) { + logger.SetLevel(logger.InfoLevel) + } +} + +type testStruct struct { + Field1 string + Field2 int + Field3 bool +} + +type testStructTag struct { + Field1 string `luai:"field1"` + Field2 int `luai:"field2"` + Field3 bool `luai:"field3"` +} + +type complexStruct struct { + Field1 string + Field2 int + Field3 bool + SimpleStruct *testStruct + Struct testStructTag + Map map[string]interface{} + Slice []any +} + +func TestExample(t *testing.T) { + teardownSuite := setupSuite(t) + defer teardownSuite(t) + + t.Run("TableWithEmptyFieldAndIncompatibleType", func(t *testing.T) { + L := NewLuaVM() + defer L.Close() + + output := struct { + Field1 string `luai:"field1"` + Field2 *string `luai:"field2"` + AString string `luai:"a_string"` + }{} + + if err := L.Instance.DoString(` + return { + field1 = "value1", + --- notice: here we return a number + a_string = 8, + } + `); err != nil { + t.Errorf("map test failed: %v", err) + } + + table := L.ReturnedValue() + err := Unmarshal(table, &output) + if err != nil { + t.Fatalf("unmarshal map failed: %v", err) + } + fmt.Printf("output: %+v\n", output) + if output.Field1 != "value1" { + t.Errorf("expected %+v, got %+v", "value1", output.Field1) + } + if output.Field2 != nil { + t.Errorf("expected %+v, got %+v", nil, output.Field2) + } + if output.AString != "8" { + t.Errorf("expected %+v, got %+v", "", output.AString) + } + }) +} + +func TestCases(t *testing.T) { + teardownSuite := setupSuite(t) + defer teardownSuite(t) + + var unmarshalTests = []struct { + CaseName string + in any + ptr any // new(type) + out any + luaValidationScript string + err error + }{ + { + CaseName: "Struct", + in: testStruct{ + Field1: "test", + Field2: 1, + Field3: true, + }, + ptr: new(testStruct), + out: testStruct{ + Field1: "test", + Field2: 1, + Field3: true, + }, + luaValidationScript: ` + assert(m.Field1 == "test") + assert(m.Field2 == 1) + assert(m.Field3 == true) + print("lua Struct done") + `, + }, + { + CaseName: "Struct with Tag", + in: testStructTag{ + Field1: "test", + Field2: 1, + Field3: true, + }, + ptr: new(testStructTag), + out: testStructTag{ + Field1: "test", + Field2: 1, + Field3: true, + }, + luaValidationScript: ` + assert(m.field1 == "test") + assert(m.field2 == 1) + assert(m.field3 == true) + print("lua Struct with Tag done") + `, + }, + { + CaseName: "Map", + in: map[string]interface{}{ + "key1": "value1", + "key2": 2, + "key3": true, + }, + ptr: new(map[string]any), + out: map[string]interface{}{ + "key1": "value1", + "key2": float64(2), + "key3": true, + }, + }, + { + CaseName: "Slice", + in: []any{"value1", 2, true}, + ptr: new([]any), + out: []any{"value1", float64(2), true}, + }, + { + CaseName: "Any", + in: map[string]interface{}{ + "key1": "value1", + "key2": 2, + "key3": true, + }, + ptr: new(any), + out: map[string]interface{}{ + "key1": "value1", + "key2": float64(2), + "key3": true, + }, + luaValidationScript: ` + assert(m.key1 == "value1") + assert(m.key2 == 2) + assert(m.key3 == true) + print("Any Done") + `, + }, + { + CaseName: "Map[Int]", + in: map[int]int{ + 1: 1, + 2: 2, + }, + ptr: new(map[int]int), + out: map[int]int{ + 1: 1, + 2: 2, + }, + luaValidationScript: ` + assert(m[1] == 1) + assert(m[2] == 2) + print("lua Map[Int] done") + `, + }, + { + CaseName: "MapSliceStructUnified", + in: complexStruct{ + Field1: "value1", + Field2: 123, + Field3: true, + Struct: testStructTag{ + Field1: "value1", + Field2: 2, + Field3: true, + }, + Map: map[string]interface{}{ + "key1": "value1", + "key2": float64(2), + "key3": true, + }, + Slice: []any{"value1", 2, true}, + }, + ptr: new(complexStruct), + out: complexStruct{ + Field1: "value1", + Field2: 123, + Field3: true, + Struct: testStructTag{ + Field1: "value1", + Field2: 2, + Field3: true, + }, + Map: map[string]interface{}{ + "key1": "value1", + "key2": float64(2), + "key3": true, + }, + Slice: []any{"value1", float64(2), true}, + }, + luaValidationScript: ` + assert(m.Field1 == "value1") + assert(m.Field2 == 123) + assert(m.Field3 == true) + assert(m.Struct.field1 == "value1") + assert(m.Struct.field2 == 2) + assert(m.Struct.field3 == true) + assert(m.Map.key1 == "value1") + assert(m.Map.key2 == 2) + assert(m.Map.key3 == true) + assert(m.Slice[1] == "value1") + assert(m.Slice[2] == 2) + assert(m.Slice[3] == true) + print("lua MapSliceStructUnified done") + `, + }, + } + + for _, tt := range unmarshalTests { + t.Run(tt.CaseName, func(t *testing.T) { + L := lua.NewState() + defer L.Close() + + table, err := Marshal(L, tt.in) + if err != nil { + t.Fatalf("marshal map failed: %v", err) + } + + if tt.luaValidationScript != "" { + L.SetGlobal("m", table) + + if err := L.DoString(tt.luaValidationScript); err != nil { + t.Errorf("validate %s error: %v", tt.CaseName, err) + } + } + + if tt.ptr == nil { + return + } + + typ := reflect.TypeOf(tt.ptr) + if typ.Kind() != reflect.Pointer { + t.Fatalf("%s: unmarshalTest.ptr %T is not a pointer type", tt.CaseName, tt.ptr) + } + + typ = typ.Elem() + + // equals to: v = new(right-type) + v := reflect.New(typ) + + if !reflect.DeepEqual(tt.ptr, v.Interface()) { + // There's no reason for ptr to point to non-zero data, + // as we decode into new(right-type), so the data is + // discarded. + // This can easily mean tests that silently don't test + // what they should. To test decoding into existing + // data, see TestPrefilled. + t.Fatalf("%s: unmarshalTest.ptr %#v is not a pointer to a zero value", tt.CaseName, tt.ptr) + } + + err = Unmarshal(table, v.Interface()) + + if err != tt.err { + t.Errorf("expected %+v, got %+v", tt.err, err) + } + + // get the value out of the pointer, equals to: v = *v + got := v.Elem().Interface() + + if !reflect.DeepEqual(tt.out, got) { + t.Errorf("expected %+v, got %+v", tt.out, got) + } + }) + } +}