diff --git a/types/address.go b/types/address.go index 3b8b2b5..ed2b7f8 100644 --- a/types/address.go +++ b/types/address.go @@ -58,3 +58,18 @@ func Hex2Bytes(str string) []byte { h, _ := hex.DecodeString(str) return h } + +// MarshalText implements TextMarshaler interface for Address +func (a Address) MarshalText() (text []byte, err error) { + return []byte(a.String()), nil +} + +// UnmarshalText decodes the form generated by MarshalText +func (a *Address) UnmarshalText(text []byte) error { + if len(text) == 0 { + *a = Address{} + return nil + } + *a = BytesToAddress(FromHex(string(text))) + return nil +} diff --git a/types/hash.go b/types/hash.go index 08310fe..7ac0d79 100644 --- a/types/hash.go +++ b/types/hash.go @@ -66,3 +66,18 @@ func (h *Hash) SetBytes(b []byte) { // BigToHash sets byte representation of b to hash. // If b is larger than len(h), b will be cropped from the left. func BigToHash(b *big.Int) Hash { return BytesToHash(b.Bytes()) } + +// MarshalText implements TextMarshaler interface for Hash +func (h Hash) MarshalText() (text []byte, err error) { + return []byte(h.String()), nil +} + +// UnmarshalText decodes the form generated by MarshalText +func (h *Hash) UnmarshalText(text []byte) error { + if len(text) == 0 { + *h = Hash{} + return nil + } + *h = BytesToHash(FromHex(string(text))) + return nil +} diff --git a/types/types_test.go b/types/types_test.go index d464ad1..441972c 100644 --- a/types/types_test.go +++ b/types/types_test.go @@ -1,6 +1,7 @@ package types import ( + "encoding/json" "math/big" "testing" ) @@ -29,3 +30,77 @@ func TestHash_Compare(t *testing.T) { t.Fatal("incorrect uint64 conversion") } } + +func TestHash_MarshalText(t *testing.T) { + m := map[Hash]Hash{ + {1}: {2}, + } + + b, err := json.Marshal(m) + if err != nil { + t.Fatal(err) + } + + if string(b) != "{\"0x0100000000000000000000000000000000000000000000000000000000000000\":\"0x0200000000000000000000000000000000000000000000000000000000000000\"}" { + t.Fatal("incorrect marshalling") + } +} + +func TestHash_UnmarshalText(t *testing.T) { + b := []byte("{\"0x0100000000000000000000000000000000000000000000000000000000000000\":\"0x0200000000000000000000000000000000000000000000000000000000000000\"}") + + exp := map[Hash]Hash{ + {1}: {2}, + } + + m := make(map[Hash]Hash) + + err := json.Unmarshal(b, &m) + if err != nil { + t.Fatal(err) + } + + for k, v := range m { + if exp[k] != v { + t.Fatal("incorrect marshalling") + } + } +} + +func TestAddress_MarshalText(t *testing.T) { + addr := HexToAddress("0x9c1a711a5e31a9461f6d1f662068e0a2f9edf552") + m := map[Address]Address{ + addr: addr, + } + + b, err := json.Marshal(m) + if err != nil { + t.Fatal(err) + } + + if string(b) != "{\"0x9c1a711a5e31a9461f6d1f662068e0a2f9edf552\":\"0x9c1a711a5e31a9461f6d1f662068e0a2f9edf552\"}" { + t.Fatal("incorrect marshalling") + } +} + +func TestAddress_UnmarshalText(t *testing.T) { + b := []byte("{\"0x9c1a711a5e31a9461f6d1f662068e0a2f9edf552\":\"0x9c1a711a5e31a9461f6d1f662068e0a2f9edf552\"}") + + addr := HexToAddress("0x9c1a711a5e31a9461f6d1f662068e0a2f9edf552") + exp := map[Address]Address{ + addr: addr, + } + + m := make(map[Address]Address) + + err := json.Unmarshal(b, &m) + if err != nil { + t.Fatal(err) + } + + for k, v := range m { + if exp[k] != v { + t.Fatal("incorrect marshalling") + } + } +}