diff --git a/protocol.go b/protocol.go index cba19970..28608eb8 100644 --- a/protocol.go +++ b/protocol.go @@ -97,6 +97,11 @@ func (p *Protocol) Hash() string { return p.hash } +// Types returns the types of the protocol. +func (p *Protocol) Types() []NamedSchema { + return p.types +} + // String returns the canonical form of the protocol. func (p *Protocol) String() string { types := "" diff --git a/protocol_test.go b/protocol_test.go index c9cc0a34..4eb2b970 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -190,3 +190,17 @@ func TestParseProtocolFile_InvalidPath(t *testing.T) { assert.Error(t, err) } + +func TestParseProtocol_Types(t *testing.T) { + protocol, err := avro.ParseProtocolFile("testdata/echo.avpr") + + wantPing := `{"name":"org.hamba.avro.Ping","type":"record","fields":[{"name":"timestamp","type":"long"},{"name":"text","type":"string"}]}` + wantPong := `{"name":"org.hamba.avro.Pong","type":"record","fields":[{"name":"timestamp","type":"long"},{"name":"ping","type":"org.hamba.avro.Ping"}]}` + wantPongError := `{"name":"org.hamba.avro.PongError","type":"error","fields":[{"name":"timestamp","type":"long"},{"name":"reason","type":"string"}]}` + wantLen := 3 + require.NoError(t, err) + assert.Equal(t, wantLen, len(protocol.Types())) + assert.Equal(t, wantPing, protocol.Types()[0].String()) + assert.Equal(t, wantPong, protocol.Types()[1].String()) + assert.Equal(t, wantPongError, protocol.Types()[2].String()) +}