diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..23cab52 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,3 @@ +linters: + enable: + - nilerr diff --git a/Makefile b/Makefile index 5c58e1d..218f378 100644 --- a/Makefile +++ b/Makefile @@ -9,8 +9,8 @@ check: fmt lint vet .PHONY: download-ci-tools download-ci-tools: go install golang.org/x/tools/cmd/goimports@latest - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v1.50.1 - curl -sSfL https://raw.githubusercontent.com/reviewdog/reviewdog/master/install.sh | sh -s v0.14.1 + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s v1.53.1 + curl -sSfL https://raw.githubusercontent.com/reviewdog/reviewdog/master/install.sh | sh -s v0.14.2 .PHONY: fmt fmt: @@ -19,11 +19,11 @@ fmt: .PHONY: lint lint: - ./bin/golangci-lint run ./... + ./bin/golangci-lint run .PHONY: lint-ci lint-ci: - ./bin/golangci-lint run ./... | \ + ./bin/golangci-lint run | \ ./bin/reviewdog -f=golangci-lint -reporter=github-pr-review -filter-mode=nofilter .PHONY: vet diff --git a/chunk_header_test.go b/chunk_header_test.go index 86cec61..a674241 100644 --- a/chunk_header_test.go +++ b/chunk_header_test.go @@ -11,7 +11,7 @@ import ( "bytes" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestChunkBasicHeader(t *testing.T) { @@ -116,8 +116,8 @@ func TestChunkBasicHeader(t *testing.T) { buf := new(bytes.Buffer) err := encodeChunkBasicHeader(buf, tc.value) - assert.Nil(t, err) - assert.Equal(t, tc.binary, buf.Bytes()) + require.Nil(t, err) + require.Equal(t, tc.binary, buf.Bytes()) }) } }) @@ -132,8 +132,8 @@ func TestChunkBasicHeader(t *testing.T) { r := bytes.NewReader(tc.binary) var mh chunkBasicHeader err := decodeChunkBasicHeader(r, nil, &mh) - assert.Nil(t, err) - assert.Equal(t, tc.value, &mh) + require.Nil(t, err) + require.Equal(t, tc.value, &mh) }) } }) @@ -146,7 +146,7 @@ func TestChunkBasicHeaderError(t *testing.T) { fmt: 3, chunkStreamID: 65600, }) - assert.EqualError(t, err, "Chunk stream id is out of range: 65600 must be in range [2, 65599]") + require.EqualError(t, err, "Chunk stream id is out of range: 65600 must be in range [2, 65599]") }) t.Run("Out of range(under)", func(t *testing.T) { @@ -155,7 +155,7 @@ func TestChunkBasicHeaderError(t *testing.T) { fmt: 3, chunkStreamID: 1, }) - assert.EqualError(t, err, "Chunk stream id is out of range: 1 must be in range [2, 65599]") + require.EqualError(t, err, "Chunk stream id is out of range: 1 must be in range [2, 65599]") }) } @@ -367,8 +367,8 @@ func TestChunkMessageHeader(t *testing.T) { buf := new(bytes.Buffer) err := encodeChunkMessageHeader(buf, tc.fmt, tc.value) - assert.Nil(t, err) - assert.Equal(t, tc.binary, buf.Bytes()) + require.Nil(t, err) + require.Equal(t, tc.binary, buf.Bytes()) }) } }) @@ -383,8 +383,8 @@ func TestChunkMessageHeader(t *testing.T) { r := bytes.NewReader(tc.binary) var mh chunkMessageHeader err := decodeChunkMessageHeader(r, tc.fmt, nil, &mh) - assert.Nil(t, err) - assert.Equal(t, tc.value, &mh) + require.Nil(t, err) + require.Equal(t, tc.value, &mh) }) } }) diff --git a/chunk_streamer_test.go b/chunk_streamer_test.go index fd29683..add33c1 100644 --- a/chunk_streamer_test.go +++ b/chunk_streamer_test.go @@ -18,7 +18,7 @@ import ( "time" "github.com/fortytw2/leaktest" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/yutopp/go-rtmp/message" ) @@ -39,38 +39,38 @@ func TestStreamerSingleChunk(t *testing.T) { // write a message w, err := streamer.NewChunkWriter(context.Background(), chunkStreamID) - assert.Nil(t, err) - assert.NotNil(t, w) + require.Nil(t, err) + require.NotNil(t, w) enc := message.NewEncoder(w) err = enc.Encode(msg) - assert.Nil(t, err) + require.Nil(t, err) w.messageLength = uint32(w.buf.Len()) w.messageTypeID = byte(msg.TypeID()) w.timestamp = timestamp err = streamer.Sched(w) - assert.Nil(t, err) + require.Nil(t, err) _, err = streamer.NewChunkWriter(context.Background(), chunkStreamID) // wait for writing - assert.Nil(t, err) + require.Nil(t, err) // read a message r, err := streamer.readChunk() - assert.Nil(t, err) - assert.NotNil(t, r) - assert.True(t, r.completed) + require.Nil(t, err) + require.NotNil(t, r) + require.True(t, r.completed) dec := message.NewDecoder(r) var actualMsg message.Message err = dec.Decode(message.TypeID(r.messageTypeID), &actualMsg) - assert.Nil(t, err) - assert.Equal(t, uint32(timestamp), r.timestamp) + require.Nil(t, err) + require.Equal(t, uint32(timestamp), r.timestamp) // check message - assert.Equal(t, actualMsg.TypeID(), msg.TypeID()) + require.Equal(t, actualMsg.TypeID(), msg.TypeID()) actualMsgT := actualMsg.(*message.VideoMessage) actualContent, _ := ioutil.ReadAll(actualMsgT.Payload) - assert.Equal(t, actualContent, videoContent) + require.Equal(t, actualContent, videoContent) } func TestStreamerMultipleChunk(t *testing.T) { @@ -93,40 +93,40 @@ func TestStreamerMultipleChunk(t *testing.T) { // write a message w, err := streamer.NewChunkWriter(context.Background(), chunkStreamID) - assert.Nil(t, err) - assert.NotNil(t, w) + require.Nil(t, err) + require.NotNil(t, w) enc := message.NewEncoder(w) err = enc.Encode(msg) - assert.Nil(t, err) + require.Nil(t, err) w.messageLength = uint32(w.buf.Len()) w.messageTypeID = byte(msg.TypeID()) w.timestamp = timestamp err = streamer.Sched(w) - assert.Nil(t, err) + require.Nil(t, err) _, err = streamer.NewChunkWriter(context.Background(), chunkStreamID) // wait for writing - assert.Nil(t, err) + require.Nil(t, err) // read a message var r *ChunkStreamReader for i := 0; i < len(payloadUnit); i++ { r, err = streamer.readChunk() - assert.Nil(t, err) + require.Nil(t, err) } - assert.NotNil(t, r) + require.NotNil(t, r) dec := message.NewDecoder(r) var actualMsg message.Message err = dec.Decode(message.TypeID(r.messageTypeID), &actualMsg) - assert.Nil(t, err) - assert.Equal(t, uint32(timestamp), r.timestamp) + require.Nil(t, err) + require.Equal(t, uint32(timestamp), r.timestamp) // check message - assert.Equal(t, actualMsg.TypeID(), msg.TypeID()) + require.Equal(t, actualMsg.TypeID(), msg.TypeID()) actualMsgT := actualMsg.(*message.VideoMessage) actualContent, _ := ioutil.ReadAll(actualMsgT.Payload) - assert.Equal(t, actualContent, videoContent) + require.Equal(t, actualContent, videoContent) } func TestStreamerChunkExample1(t *testing.T) { @@ -215,8 +215,8 @@ func TestStreamerChunkExample1(t *testing.T) { for i, wc := range tc.writeCases { t.Run(fmt.Sprintf("Write: %d", i), func(t *testing.T) { w, err := streamer.NewChunkWriter(context.Background(), tc.chunkStreamID) - assert.Nil(t, err) - assert.NotNil(t, w) + require.Nil(t, err) + require.NotNil(t, w) bin := make([]byte, wc.length) @@ -227,22 +227,22 @@ func TestStreamerChunkExample1(t *testing.T) { w.buf.Write(bin) err = streamer.Sched(w) - assert.Nil(t, err) + require.Nil(t, err) }) } _, err := streamer.NewChunkWriter(context.Background(), tc.chunkStreamID) // wait for writing - assert.Nil(t, err) + require.Nil(t, err) for i, rc := range tc.readCases { t.Run(fmt.Sprintf("Read: %d", i), func(t *testing.T) { r, err := streamer.readChunk() - assert.Nil(t, err) - assert.NotNil(t, r) + require.Nil(t, err) + require.NotNil(t, r) - assert.Equal(t, rc.fmt, r.basicHeader.fmt) - assert.Equal(t, uint32(rc.timestamp), r.timestamp) - assert.Equal(t, rc.isComplete, r.completed) + require.Equal(t, rc.fmt, r.basicHeader.fmt) + require.Equal(t, uint32(rc.timestamp), r.timestamp) + require.Equal(t, rc.isComplete, r.completed) }) } }) @@ -310,8 +310,8 @@ func TestStreamerChunkExample2(t *testing.T) { for i, wc := range tc.writeCases { t.Run(fmt.Sprintf("Write: %d", i), func(t *testing.T) { w, err := streamer.NewChunkWriter(context.Background(), tc.chunkStreamID) - assert.Nil(t, err) - assert.NotNil(t, w) + require.Nil(t, err) + require.NotNil(t, w) bin := make([]byte, wc.length) @@ -321,23 +321,23 @@ func TestStreamerChunkExample2(t *testing.T) { w.timestamp = wc.timestamp w.buf.Write(bin) err = streamer.Sched(w) - assert.Nil(t, err) + require.Nil(t, err) }) } _, err := streamer.NewChunkWriter(context.Background(), tc.chunkStreamID) // wait for writing - assert.Nil(t, err) + require.Nil(t, err) for i, rc := range tc.readCases { t.Run(fmt.Sprintf("Read: %d", i), func(t *testing.T) { r, err := streamer.readChunk() _ = rc _ = err - assert.Nil(t, err) - assert.NotNil(t, r) - assert.Equal(t, rc.fmt, r.basicHeader.fmt) - assert.Equal(t, uint32(rc.delta), r.messageHeader.timestampDelta) - assert.Equal(t, rc.isComplete, r.completed) + require.Nil(t, err) + require.NotNil(t, r) + require.Equal(t, rc.fmt, r.basicHeader.fmt) + require.Equal(t, uint32(rc.delta), r.messageHeader.timestampDelta) + require.Equal(t, rc.isComplete, r.completed) }) } }) @@ -357,10 +357,10 @@ func TestWriteToInvalidWriter(t *testing.T) { StreamID: 0, Message: &message.Ack{}, }) - assert.Nil(t, err) + require.Nil(t, err) <-streamer.Done() - assert.EqualErrorf(t, streamer.Err(), "Always error!", "") + require.EqualErrorf(t, streamer.Err(), "Always error!", "") } type AlwaysErrorWriter struct{} @@ -379,7 +379,7 @@ func TestChunkStreamerHasNoLeaksOfGoroutines(t *testing.T) { streamer := NewChunkStreamer(inbuf, outbuf, nil) err := streamer.Close() - assert.Nil(t, err) + require.Nil(t, err) <-streamer.Done() } @@ -396,18 +396,18 @@ func TestChunkStreamerStreamsLimitation(t *testing.T) { { _, err := streamer.prepareChunkReader(0) - assert.Nil(t, err) + require.Nil(t, err) _, err = streamer.prepareChunkReader(1) - assert.EqualError(t, err, "Creating chunk streams limit exceeded(Reader): Limit = 1") + require.EqualError(t, err, "Creating chunk streams limit exceeded(Reader): Limit = 1") } { _, err := streamer.prepareChunkWriter(0) - assert.Nil(t, err) + require.Nil(t, err) _, err = streamer.prepareChunkWriter(1) - assert.EqualError(t, err, "Creating chunk streams limit exceeded(Writer): Limit = 1") + require.EqualError(t, err, "Creating chunk streams limit exceeded(Writer): Limit = 1") } } @@ -433,16 +433,16 @@ func TestChunkStreamerDualWriter(t *testing.T) { Payload: bytes.NewReader(largePayload), }, }) - assert.Nil(t, err) + require.Nil(t, err) } streamer.waitWriters() err := streamer.Close() - assert.Nil(t, err) + require.Nil(t, err) <-streamer.Done() - assert.Equal(t, nil, streamer.Err()) + require.Equal(t, nil, streamer.Err()) } func TestChunkStreamerDualWriterWithoutWaiting(t *testing.T) { @@ -467,14 +467,14 @@ func TestChunkStreamerDualWriterWithoutWaiting(t *testing.T) { Payload: bytes.NewReader(largePayload), }, }) - assert.Nil(t, err) + require.Nil(t, err) } err := streamer.Close() - assert.Nil(t, err) + require.Nil(t, err) <-streamer.Done() - assert.Equal(t, nil, streamer.Err()) + require.Equal(t, nil, streamer.Err()) } func TestChunkStreamerNewChunkWriterTwice(t *testing.T) { @@ -487,18 +487,18 @@ func TestChunkStreamerNewChunkWriterTwice(t *testing.T) { chunkStreamID := 10 _, err := streamer.NewChunkWriter(context.Background(), chunkStreamID) - assert.Nil(t, err) + require.Nil(t, err) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() _, err = streamer.NewChunkWriter(ctx, chunkStreamID) // Try to acquire same chunk writer - assert.EqualError(t, err, "Failed to wait chunk writer: context deadline exceeded") + require.EqualError(t, err, "Failed to wait chunk writer: context deadline exceeded") err = streamer.Close() - assert.Nil(t, err) + require.Nil(t, err) <-streamer.Done() - assert.Equal(t, nil, streamer.Err()) + require.Equal(t, nil, streamer.Err()) } func BenchmarkStreamerMultipleChunkRead(b *testing.B) { diff --git a/client_conn.go b/client_conn.go index 2bd631e..04726a0 100644 --- a/client_conn.go +++ b/client_conn.go @@ -119,13 +119,11 @@ func (cc *ClientConn) DeleteStream(body *message.NetStreamDeleteStream) error { } // Check if stream id exists - _, err = cc.conn.streams.At(body.StreamID) - if err != nil { + if _, err := cc.conn.streams.At(body.StreamID); err != nil { return err } - err = ctrlStream.DeleteStream(body) - if err != nil { + if err := ctrlStream.DeleteStream(body); err != nil { return err } diff --git a/client_test.go b/client_test.go index b23c466..93dacba 100644 --- a/client_test.go +++ b/client_test.go @@ -10,23 +10,23 @@ package rtmp import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestClientValidAddr(t *testing.T) { addr, err := makeValidAddr("host:123") - assert.Equal(t, nil, err) - assert.Equal(t, "host:123", addr) + require.Equal(t, nil, err) + require.Equal(t, "host:123", addr) addr, err = makeValidAddr("host") - assert.Equal(t, nil, err) - assert.Equal(t, "host:1935", addr) + require.Equal(t, nil, err) + require.Equal(t, "host:1935", addr) addr, err = makeValidAddr("host:") - assert.Equal(t, nil, err) - assert.Equal(t, "host:", addr) + require.Equal(t, nil, err) + require.Equal(t, "host:", addr) addr, err = makeValidAddr(":1111") - assert.Equal(t, nil, err) - assert.Equal(t, ":1111", addr) + require.Equal(t, nil, err) + require.Equal(t, ":1111", addr) } diff --git a/conn_test.go b/conn_test.go index 806a062..1c4935e 100644 --- a/conn_test.go +++ b/conn_test.go @@ -11,7 +11,7 @@ import ( "bytes" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/yutopp/go-rtmp/message" ) @@ -42,24 +42,24 @@ func TestConnConfig(t *testing.T) { }, }) - assert.Equal(t, true, conn.config.SkipHandshakeVerification) + require.Equal(t, true, conn.config.SkipHandshakeVerification) - assert.Equal(t, 1234, conn.config.ReaderBufferSize) - assert.Equal(t, 1234, conn.config.WriterBufferSize) + require.Equal(t, 1234, conn.config.ReaderBufferSize) + require.Equal(t, 1234, conn.config.WriterBufferSize) - assert.Equal(t, uint32(1234), conn.config.ControlState.DefaultChunkSize) - assert.Equal(t, uint32(1234), conn.config.ControlState.MaxChunkSize) - assert.Equal(t, 1234, conn.config.ControlState.MaxChunkStreams) + require.Equal(t, uint32(1234), conn.config.ControlState.DefaultChunkSize) + require.Equal(t, uint32(1234), conn.config.ControlState.MaxChunkSize) + require.Equal(t, 1234, conn.config.ControlState.MaxChunkStreams) - assert.Equal(t, int32(1234), conn.config.ControlState.DefaultAckWindowSize) - assert.Equal(t, int32(1234), conn.config.ControlState.MaxAckWindowSize) + require.Equal(t, int32(1234), conn.config.ControlState.DefaultAckWindowSize) + require.Equal(t, int32(1234), conn.config.ControlState.MaxAckWindowSize) - assert.Equal(t, int32(1234), conn.config.ControlState.DefaultBandwidthWindowSize) - assert.Equal(t, message.LimitTypeHard, conn.config.ControlState.DefaultBandwidthLimitType) - assert.Equal(t, int32(1234), conn.config.ControlState.MaxBandwidthWindowSize) + require.Equal(t, int32(1234), conn.config.ControlState.DefaultBandwidthWindowSize) + require.Equal(t, message.LimitTypeHard, conn.config.ControlState.DefaultBandwidthLimitType) + require.Equal(t, int32(1234), conn.config.ControlState.MaxBandwidthWindowSize) - assert.Equal(t, uint32(1234), conn.config.ControlState.MaxMessageSize) - assert.Equal(t, 1234, conn.config.ControlState.MaxMessageStreams) + require.Equal(t, uint32(1234), conn.config.ControlState.MaxMessageSize) + require.Equal(t, 1234, conn.config.ControlState.MaxMessageStreams) } type rwcMock struct { diff --git a/go.mod b/go.mod index 931d6c0..b13ac06 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,6 @@ require ( github.com/sirupsen/logrus v1.7.0 github.com/stretchr/testify v1.8.1 github.com/yutopp/go-amf0 v0.1.0 - github.com/yutopp/go-flv v0.3.0 + github.com/yutopp/go-flv v0.3.1 golang.org/x/sys v0.3.0 // indirect ) diff --git a/go.sum b/go.sum index 546b642..1c80417 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,8 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yutopp/go-amf0 v0.1.0 h1:a3UeBZG7nRF0zfvmPn2iAfNo1RGzUpHz1VyJD2oGrik= github.com/yutopp/go-amf0 v0.1.0/go.mod h1:QzDOBr9RV6sQh6E5GFEJROZbU0iQKijORBmprkb3FIk= -github.com/yutopp/go-flv v0.3.0 h1:zkjsXqxfkwnrtPNvicxCKkaJnvwzf7kOfL/OUERP6LQ= -github.com/yutopp/go-flv v0.3.0/go.mod h1:pAlHPSVRMv5aCUKmGOS/dZn/ooTgnc09qOPmiUNMubs= +github.com/yutopp/go-flv v0.3.1 h1:4ILK6OgCJgUNm2WOjaucWM5lUHE0+sLNPdjq3L0Xtjk= +github.com/yutopp/go-flv v0.3.1/go.mod h1:pAlHPSVRMv5aCUKmGOS/dZn/ooTgnc09qOPmiUNMubs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/handler_test.go b/handler_test.go index 12ccfee..fbffc0e 100644 --- a/handler_test.go +++ b/handler_test.go @@ -10,7 +10,7 @@ package rtmp import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/yutopp/go-rtmp/message" ) @@ -67,10 +67,10 @@ type testHandler struct { func (h *testHandler) OnServe(conn *Conn) { for _, s := range []*StreamControlState{conn.streamer.PeerState(), conn.streamer.SelfState()} { - assert.Equal(h.t, uint32(1234), s.ChunkSize()) - assert.Equal(h.t, uint32(1234), s.AckWindowSize()) - assert.Equal(h.t, int32(1234), s.BandwidthWindowSize()) - assert.Equal(h.t, message.LimitTypeHard, s.BandwidthLimitType()) + require.Equal(h.t, uint32(1234), s.ChunkSize()) + require.Equal(h.t, uint32(1234), s.AckWindowSize()) + require.Equal(h.t, int32(1234), s.BandwidthWindowSize()) + require.Equal(h.t, message.LimitTypeHard, s.BandwidthLimitType()) } close(h.closer) // Finish testing diff --git a/message/body_decoder_test.go b/message/body_decoder_test.go index f736780..2d966a4 100644 --- a/message/body_decoder_test.go +++ b/message/body_decoder_test.go @@ -11,7 +11,7 @@ import ( "bytes" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/yutopp/go-amf0" ) @@ -22,8 +22,8 @@ func TestDecodeDataMessageAtsetDataFrame(t *testing.T) { var v AMFConvertible err := DataBodyDecoderFor("@setDataFrame")(r, d, &v) - assert.Nil(t, err) - assert.Equal(t, &NetStreamSetDataFrame{ + require.Nil(t, err) + require.Equal(t, &NetStreamSetDataFrame{ Payload: bin, }, v) } @@ -38,11 +38,11 @@ func TestDecodeDataMessageUnknown(t *testing.T) { var v AMFConvertible err := DataBodyDecoderFor("hogehoge")(r, d, &v) - assert.Equal(t, &UnknownDataBodyDecodeError{ + require.Equal(t, &UnknownDataBodyDecodeError{ Name: "hogehoge", Objs: []interface{}{nil}, }, err) - assert.Nil(t, v) + require.Nil(t, v) } func TestDecodeCmdMessageConnect(t *testing.T) { @@ -55,8 +55,8 @@ func TestDecodeCmdMessageConnect(t *testing.T) { var v AMFConvertible err := CmdBodyDecoderFor("connect", 1)(r, d, &v) // Transaction is always 1 (7.2.1.1) - assert.Nil(t, err) - assert.Equal(t, &NetConnectionConnect{}, v) + require.Nil(t, err) + require.Equal(t, &NetConnectionConnect{}, v) } func TestDecodeCmdMessageCreateStream(t *testing.T) { @@ -69,8 +69,8 @@ func TestDecodeCmdMessageCreateStream(t *testing.T) { var v AMFConvertible err := CmdBodyDecoderFor("createStream", 42)(r, d, &v) - assert.Nil(t, err) - assert.Equal(t, &NetConnectionCreateStream{}, v) + require.Nil(t, err) + require.Equal(t, &NetConnectionCreateStream{}, v) } func TestDecodeCmdMessageDeleteStream(t *testing.T) { @@ -85,8 +85,8 @@ func TestDecodeCmdMessageDeleteStream(t *testing.T) { var v AMFConvertible err := CmdBodyDecoderFor("deleteStream", 42)(r, d, &v) - assert.Nil(t, err) - assert.Equal(t, &NetStreamDeleteStream{ + require.Nil(t, err) + require.Equal(t, &NetStreamDeleteStream{ StreamID: 42, }, v) } @@ -105,8 +105,8 @@ func TestDecodeCmdMessagePublish(t *testing.T) { var v AMFConvertible err := CmdBodyDecoderFor("publish", 42)(r, d, &v) - assert.Nil(t, err) - assert.Equal(t, &NetStreamPublish{ + require.Nil(t, err) + require.Equal(t, &NetStreamPublish{ PublishingName: "abc", PublishingType: "def", }, v) @@ -126,8 +126,8 @@ func TestDecodeCmdMessagePlay(t *testing.T) { var v AMFConvertible err := CmdBodyDecoderFor("play", 42)(r, d, &v) - assert.Nil(t, err) - assert.Equal(t, &NetStreamPlay{ + require.Nil(t, err) + require.Equal(t, &NetStreamPlay{ StreamName: "abc", Start: 42, }, v) @@ -145,8 +145,8 @@ func TestDecodeCmdMessageReleaseStream(t *testing.T) { var v AMFConvertible err := CmdBodyDecoderFor("releaseStream", 42)(r, d, &v) - assert.Nil(t, err) - assert.Equal(t, &NetConnectionReleaseStream{ + require.Nil(t, err) + require.Equal(t, &NetConnectionReleaseStream{ StreamName: "abc", }, v) } @@ -163,8 +163,8 @@ func TestDecodeCmdMessageFCPublish(t *testing.T) { var v AMFConvertible err := CmdBodyDecoderFor("FCPublish", 42)(r, d, &v) - assert.Nil(t, err) - assert.Equal(t, &NetStreamFCPublish{ + require.Nil(t, err) + require.Equal(t, &NetStreamFCPublish{ StreamName: "abc", }, v) } @@ -181,8 +181,8 @@ func TestDecodeCmdMessageFCUnpublish(t *testing.T) { var v AMFConvertible err := CmdBodyDecoderFor("FCUnpublish", 42)(r, d, &v) - assert.Nil(t, err) - assert.Equal(t, &NetStreamFCUnpublish{ + require.Nil(t, err) + require.Equal(t, &NetStreamFCUnpublish{ StreamName: "abc", }, v) } @@ -199,8 +199,8 @@ func TestDecodeCmdMessageGetStreamLength(t *testing.T) { var v AMFConvertible err := CmdBodyDecoderFor("getStreamLength", 42)(r, d, &v) - assert.Nil(t, err) - assert.Equal(t, &NetStreamGetStreamLength{ + require.Nil(t, err) + require.Equal(t, &NetStreamGetStreamLength{ StreamName: "abc", }, v) } @@ -215,8 +215,8 @@ func TestDecodeCmdMessagePing(t *testing.T) { var v AMFConvertible err := CmdBodyDecoderFor("ping", 42)(r, d, &v) - assert.Nil(t, err) - assert.Equal(t, &NetStreamPing{}, v) + require.Nil(t, err) + require.Equal(t, &NetStreamPing{}, v) } func TestDecodeCmdMessageCloseStream(t *testing.T) { @@ -229,8 +229,8 @@ func TestDecodeCmdMessageCloseStream(t *testing.T) { var v AMFConvertible err := CmdBodyDecoderFor("closeStream", 42)(r, d, &v) - assert.Nil(t, err) - assert.Equal(t, &NetStreamCloseStream{}, v) + require.Nil(t, err) + require.Equal(t, &NetStreamCloseStream{}, v) } func TestDecodeCmdMessageUnknown(t *testing.T) { @@ -243,10 +243,10 @@ func TestDecodeCmdMessageUnknown(t *testing.T) { var v AMFConvertible err := CmdBodyDecoderFor("hogehoge", 42)(r, d, &v) - assert.Equal(t, &UnknownCommandBodyDecodeError{ + require.Equal(t, &UnknownCommandBodyDecodeError{ Name: "hogehoge", TransactionID: 42, Objs: []interface{}{nil}, }, err) - assert.Nil(t, v) + require.Nil(t, v) } diff --git a/message/decoder_test.go b/message/decoder_test.go index 6b7ac5e..3c21c03 100644 --- a/message/decoder_test.go +++ b/message/decoder_test.go @@ -11,7 +11,7 @@ import ( "bytes" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDecodeCommon(t *testing.T) { @@ -26,7 +26,7 @@ func TestDecodeCommon(t *testing.T) { var msg Message err := dec.Decode(tc.TypeID, &msg) - assert.Nil(t, err) + require.Nil(t, err) assertEqualMessage(t, tc.Value, msg) }) } diff --git a/message/encoder_test.go b/message/encoder_test.go index 2c17200..81288b2 100644 --- a/message/encoder_test.go +++ b/message/encoder_test.go @@ -11,7 +11,7 @@ import ( "bytes" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestEncodeCommon(t *testing.T) { @@ -25,8 +25,8 @@ func TestEncodeCommon(t *testing.T) { enc := NewEncoder(buf) err := enc.Encode(tc.Value) - assert.Nil(t, err) - assert.Equal(t, tc.Binary, buf.Bytes()) + require.Nil(t, err) + require.Equal(t, tc.Binary, buf.Bytes()) }) } } diff --git a/message/message_test.go b/message/message_test.go index d49f60b..25e31ae 100644 --- a/message/message_test.go +++ b/message/message_test.go @@ -13,50 +13,50 @@ import ( "io/ioutil" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func assertEqualMessage(t *testing.T, expected, actual Message) { - assert.Equal(t, expected.TypeID(), actual.TypeID()) + require.Equal(t, expected.TypeID(), actual.TypeID()) switch expected := expected.(type) { case *AudioMessage: actual, ok := actual.(*AudioMessage) - assert.True(t, ok) + require.True(t, ok) assertEqualPayload(t, expected.Payload, actual.Payload) case *VideoMessage: actual, ok := actual.(*VideoMessage) - assert.True(t, ok) + require.True(t, ok) assertEqualPayload(t, expected.Payload, actual.Payload) case *DataMessage: actual, ok := actual.(*DataMessage) - assert.True(t, ok) + require.True(t, ok) - assert.Equal(t, expected.Name, actual.Name) - assert.Equal(t, expected.Encoding, actual.Encoding) + require.Equal(t, expected.Name, actual.Name) + require.Equal(t, expected.Encoding, actual.Encoding) assertEqualPayload(t, expected.Body, actual.Body) case *CommandMessage: actual, ok := actual.(*CommandMessage) - assert.True(t, ok) + require.True(t, ok) - assert.Equal(t, expected.CommandName, actual.CommandName) - assert.Equal(t, expected.TransactionID, actual.TransactionID) - assert.Equal(t, expected.Encoding, actual.Encoding) + require.Equal(t, expected.CommandName, actual.CommandName) + require.Equal(t, expected.TransactionID, actual.TransactionID) + require.Equal(t, expected.Encoding, actual.Encoding) assertEqualPayload(t, expected.Body, actual.Body) default: - assert.Equal(t, expected, actual) + require.Equal(t, expected, actual) } } func assertEqualPayload(t *testing.T, expected, actual io.Reader) { expectedBin, err := ioutil.ReadAll(expected) - assert.Nil(t, err) + require.Nil(t, err) switch p := expected.(type) { case *bytes.Reader: defer func() { @@ -65,11 +65,11 @@ func assertEqualPayload(t *testing.T, expected, actual io.Reader) { default: t.FailNow() } - assert.NotZero(t, len(expectedBin)) + require.NotZero(t, len(expectedBin)) actualBin, err := ioutil.ReadAll(actual) - assert.Nil(t, err) - assert.NotZero(t, len(actualBin)) + require.Nil(t, err) + require.NotZero(t, len(actualBin)) - assert.Equal(t, expectedBin, actualBin) + require.Equal(t, expectedBin, actualBin) } diff --git a/message/net_stream_test.go b/message/net_stream_test.go index 1574cad..62a3b99 100644 --- a/message/net_stream_test.go +++ b/message/net_stream_test.go @@ -10,7 +10,7 @@ package message import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type netStreamTestCase struct { @@ -62,21 +62,21 @@ func TestConvertNetStreamMessages(t *testing.T) { // Make a message from args err := tc.Box.FromArgs(tc.Args...) - assert.Equal(t, tc.FromErr, err) + require.Equal(t, tc.FromErr, err) if err != nil { return } - assert.Equal(t, tc.ExpectedMsg, tc.Box) // Message <- Args0 + require.Equal(t, tc.ExpectedMsg, tc.Box) // Message <- Args0 // Make args from message args, err := tc.Box.ToArgs(EncodingTypeAMF0) // TODO: fix interface... - assert.Equal(t, tc.ToErr, err) + require.Equal(t, tc.ToErr, err) if err != nil { return } - assert.Equal(t, tc.Args, args) // Args0 <- Message + require.Equal(t, tc.Args, args) // Args0 <- Message }) } } diff --git a/message/user_control_event_decoder_test.go b/message/user_control_event_decoder_test.go index 3d2174e..bbaca92 100644 --- a/message/user_control_event_decoder_test.go +++ b/message/user_control_event_decoder_test.go @@ -11,7 +11,7 @@ import ( "bytes" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUserControlEventDecodeCommon(t *testing.T) { @@ -26,8 +26,8 @@ func TestUserControlEventDecodeCommon(t *testing.T) { var msg UserCtrlEvent err := dec.Decode(&msg) - assert.Nil(t, err) - assert.Equal(t, tc.Value, msg) + require.Nil(t, err) + require.Equal(t, tc.Value, msg) }) } } diff --git a/message/user_control_event_encoder_test.go b/message/user_control_event_encoder_test.go index cbe51ea..71fcc00 100644 --- a/message/user_control_event_encoder_test.go +++ b/message/user_control_event_encoder_test.go @@ -11,7 +11,7 @@ import ( "bytes" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUserControlEventEncoderCommon(t *testing.T) { @@ -25,8 +25,8 @@ func TestUserControlEventEncoderCommon(t *testing.T) { enc := NewUserControlEventEncoder(buf) err := enc.Encode(tc.Value) - assert.Nil(t, err) - assert.Equal(t, tc.Binary, buf.Bytes()) + require.Nil(t, err) + require.Equal(t, tc.Binary, buf.Bytes()) }) } } diff --git a/server_client_test.go b/server_client_test.go index 09e8bb9..afb9f6f 100644 --- a/server_client_test.go +++ b/server_client_test.go @@ -14,7 +14,7 @@ import ( "testing" "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/yutopp/go-amf0" "github.com/yutopp/go-rtmp/message" @@ -24,31 +24,39 @@ const ( chunkSize = 128 ) +type serverCanAcceptConnectHandler struct { + DefaultHandler +} + func TestServerCanAcceptConnect(t *testing.T) { config := &ConnConfig{ - Handler: &ServerCanAcceptConnectHandler{}, + Handler: &serverCanAcceptConnectHandler{}, Logger: logrus.StandardLogger(), } prepareConnection(t, config, func(c *ClientConn) { err := c.Connect(nil) - assert.Nil(t, err) + require.Nil(t, err) }) } -type ServerCanAcceptConnectHandler struct { +type serverCanRejectConnectHandler struct { DefaultHandler } +func (h *serverCanRejectConnectHandler) OnConnect(_ uint32, _ *message.NetConnectionConnect) error { + return fmt.Errorf("Reject") +} + func TestServerCanRejectConnect(t *testing.T) { config := &ConnConfig{ - Handler: &ServerCanRejectConnectHandler{}, + Handler: &serverCanRejectConnectHandler{}, Logger: logrus.StandardLogger(), } prepareConnection(t, config, func(c *ClientConn) { err := c.Connect(nil) - assert.Equal(t, &ConnectRejectedError{ + require.Equal(t, &ConnectRejectedError{ TransactionID: 1, Result: &message.NetConnectionConnectResult{ Properties: message.NetConnectionConnectResultProperties{ @@ -67,17 +75,13 @@ func TestServerCanRejectConnect(t *testing.T) { }) } -type ServerCanRejectConnectHandler struct { +type serverCanAcceptCreateStreamHandler struct { DefaultHandler } -func (h *ServerCanRejectConnectHandler) OnConnect(_ uint32, _ *message.NetConnectionConnect) error { - return fmt.Errorf("Reject") -} - func TestServerCanAcceptCreateStream(t *testing.T) { config := &ConnConfig{ - Handler: &ServerCanAcceptCreateStreamHandler{}, + Handler: &serverCanAcceptCreateStreamHandler{}, Logger: logrus.StandardLogger(), ControlState: StreamControlStateConfig{ MaxMessageStreams: 2, // Control and another 1 stream @@ -86,15 +90,15 @@ func TestServerCanAcceptCreateStream(t *testing.T) { prepareConnection(t, config, func(c *ClientConn) { err := c.Connect(nil) - assert.Nil(t, err) + require.Nil(t, err) s0, err := c.CreateStream(nil, chunkSize) - assert.Nil(t, err) + require.Nil(t, err) defer s0.Close() // Rejected because a number of message streams is exceeded the limits s1, err := c.CreateStream(nil, chunkSize) - assert.Equal(t, &CreateStreamRejectedError{ + require.Equal(t, &CreateStreamRejectedError{ TransactionID: 2, Result: &message.NetConnectionCreateStreamResult{ StreamID: 0, @@ -104,13 +108,47 @@ func TestServerCanAcceptCreateStream(t *testing.T) { }) } -type ServerCanAcceptCreateStreamHandler struct { +type serverCanAcceptDeleteStreamHandler struct { DefaultHandler } +func TestServerCanAcceptDeleteStream(t *testing.T) { + config := &ConnConfig{ + Handler: &serverCanAcceptDeleteStreamHandler{}, + Logger: logrus.StandardLogger(), + ControlState: StreamControlStateConfig{ + MaxMessageStreams: 2, // Control and another 1 stream + }, + } + + prepareConnection(t, config, func(c *ClientConn) { + err := c.Connect(nil) + require.Nil(t, err) + + s0, err := c.CreateStream(nil, chunkSize) + require.NoError(t, err) + defer s0.Close() + + t.Run("Cannot delete a stream which does not exist", func(t *testing.T) { + err = c.DeleteStream(&message.NetStreamDeleteStream{ + StreamID: 42, + }) + require.Error(t, err) + }) + + t.Run("Can delete a stream", func(t *testing.T) { + err = c.DeleteStream(&message.NetStreamDeleteStream{ + StreamID: s0.streamID, + }) + require.NoError(t, err) + }) + }) +} + func prepareConnection(t *testing.T, config *ConnConfig, f func(c *ClientConn)) { + // prepare server l, err := net.Listen("tcp", "127.0.0.1:") - assert.Nil(t, err) + require.Nil(t, err) srv := NewServer(&ServerConfig{ OnConnect: func(conn net.Conn) (io.ReadWriteCloser, *ConnConfig) { @@ -119,21 +157,22 @@ func prepareConnection(t *testing.T, config *ConnConfig, f func(c *ClientConn)) }) defer func() { err := srv.Close() - assert.Nil(t, err) + require.Nil(t, err) }() go func() { err := srv.Serve(l) - assert.Equal(t, ErrClosed, err) + require.Equal(t, ErrClosed, err) }() + // prepare client c, err := Dial("rtmp", l.Addr().String(), &ConnConfig{ Logger: logrus.StandardLogger(), }) - assert.Nil(t, err) + require.Nil(t, err) defer func() { err := c.Close() - assert.Nil(t, err) + require.Nil(t, err) }() f(c) diff --git a/server_test.go b/server_test.go index c0de2f4..315ccb6 100644 --- a/server_test.go +++ b/server_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestServerCanClose(t *testing.T) { @@ -21,12 +21,12 @@ func TestServerCanClose(t *testing.T) { go func(ch <-chan time.Time) { <-ch err := srv.Close() - assert.Nil(t, err) + require.Nil(t, err) }(time.After(1 * time.Second)) l, err := net.Listen("tcp", "127.0.0.1:") - assert.Nil(t, err) + require.Nil(t, err) err = srv.Serve(l) - assert.Equal(t, ErrClosed, err) + require.Equal(t, ErrClosed, err) } diff --git a/stream.go b/stream.go index 01f1ec4..5a5c82f 100644 --- a/stream.go +++ b/stream.go @@ -257,9 +257,14 @@ func (s *Stream) NotifyStatus( } func (s *Stream) Close() error { + s.assumeClosed() return nil // TODO: implement } +func (s *Stream) assumeClosed() { + // TODO: implement +} + func (s *Stream) writeCommandMessage( chunkStreamID int, timestamp uint32, diff --git a/stream_handler_test.go b/stream_handler_test.go index 9a62cd8..891bfdb 100644 --- a/stream_handler_test.go +++ b/stream_handler_test.go @@ -10,7 +10,7 @@ package rtmp import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestStreamHandlerChangeState(t *testing.T) { @@ -19,41 +19,41 @@ func TestStreamHandlerChangeState(t *testing.T) { s := newStream(42, c) s.handler.ChangeState(streamStateUnknown) - assert.Equal(t, s.handler.state, streamStateUnknown) - assert.Equal(t, s.handler.handler, nil) + require.Equal(t, s.handler.state, streamStateUnknown) + require.Equal(t, s.handler.handler, nil) s.handler.ChangeState(streamStateServerNotConnected) - assert.Equal(t, s.handler.state, streamStateServerNotConnected) - assert.Equal(t, s.handler.handler, &serverControlNotConnectedHandler{sh: s.handler}) + require.Equal(t, s.handler.state, streamStateServerNotConnected) + require.Equal(t, s.handler.handler, &serverControlNotConnectedHandler{sh: s.handler}) s.handler.ChangeState(streamStateServerConnected) - assert.Equal(t, s.handler.state, streamStateServerConnected) - assert.Equal(t, s.handler.handler, &serverControlConnectedHandler{sh: s.handler}) + require.Equal(t, s.handler.state, streamStateServerConnected) + require.Equal(t, s.handler.handler, &serverControlConnectedHandler{sh: s.handler}) s.handler.ChangeState(streamStateServerInactive) - assert.Equal(t, s.handler.state, streamStateServerInactive) - assert.Equal(t, s.handler.handler, &serverDataInactiveHandler{sh: s.handler}) + require.Equal(t, s.handler.state, streamStateServerInactive) + require.Equal(t, s.handler.handler, &serverDataInactiveHandler{sh: s.handler}) s.handler.ChangeState(streamStateServerPublish) - assert.Equal(t, s.handler.state, streamStateServerPublish) - assert.Equal(t, s.handler.handler, &serverDataPublishHandler{sh: s.handler}) + require.Equal(t, s.handler.state, streamStateServerPublish) + require.Equal(t, s.handler.handler, &serverDataPublishHandler{sh: s.handler}) s.handler.ChangeState(streamStateServerPlay) - assert.Equal(t, s.handler.state, streamStateServerPlay) - assert.Equal(t, s.handler.handler, &serverDataPlayHandler{sh: s.handler}) + require.Equal(t, s.handler.state, streamStateServerPlay) + require.Equal(t, s.handler.handler, &serverDataPlayHandler{sh: s.handler}) s.handler.ChangeState(streamStateClientNotConnected) - assert.Equal(t, s.handler.state, streamStateClientNotConnected) - assert.Equal(t, s.handler.handler, &clientControlNotConnectedHandler{sh: s.handler}) + require.Equal(t, s.handler.state, streamStateClientNotConnected) + require.Equal(t, s.handler.handler, &clientControlNotConnectedHandler{sh: s.handler}) } func TestStreamStateString(t *testing.T) { - assert.Equal(t, "", streamStateUnknown.String()) - assert.Equal(t, "NotConnected(Server)", streamStateServerNotConnected.String()) - assert.Equal(t, "Connected(Server)", streamStateServerConnected.String()) - assert.Equal(t, "Inactive(Server)", streamStateServerInactive.String()) - assert.Equal(t, "Publish(Server)", streamStateServerPublish.String()) - assert.Equal(t, "Play(Server)", streamStateServerPlay.String()) - assert.Equal(t, "NotConnected(Client)", streamStateClientNotConnected.String()) - assert.Equal(t, "Connected(Client)", streamStateClientConnected.String()) + require.Equal(t, "", streamStateUnknown.String()) + require.Equal(t, "NotConnected(Server)", streamStateServerNotConnected.String()) + require.Equal(t, "Connected(Server)", streamStateServerConnected.String()) + require.Equal(t, "Inactive(Server)", streamStateServerInactive.String()) + require.Equal(t, "Publish(Server)", streamStateServerPublish.String()) + require.Equal(t, "Play(Server)", streamStateServerPlay.String()) + require.Equal(t, "NotConnected(Client)", streamStateClientNotConnected.String()) + require.Equal(t, "Connected(Client)", streamStateClientConnected.String()) } diff --git a/streams.go b/streams.go index 29fc33d..af20ee3 100644 --- a/streams.go +++ b/streams.go @@ -70,12 +70,14 @@ func (ss *streams) Delete(streamID uint32) error { ss.m.Lock() defer ss.m.Unlock() - _, ok := ss.streams[streamID] + s, ok := ss.streams[streamID] if !ok { return errors.Errorf("Stream not exists: StreamID = %d", streamID) } - delete(ss.streams, streamID) + delete(ss.streams, s.streamID) + + s.assumeClosed() return nil } diff --git a/streams_test.go b/streams_test.go index e221050..3d02851 100644 --- a/streams_test.go +++ b/streams_test.go @@ -10,7 +10,7 @@ package rtmp import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestStreams(t *testing.T) { @@ -24,17 +24,17 @@ func TestStreams(t *testing.T) { streams := newStreams(conn) s, err := streams.CreateIfAvailable() - assert.Nil(t, err) - assert.Equal(t, uint32(0), s.streamID) + require.Nil(t, err) + require.Equal(t, uint32(0), s.streamID) // Becomes error because number of max streams is 1 _, err = streams.CreateIfAvailable() - assert.NotNil(t, err) + require.NotNil(t, err) err = streams.Delete(s.streamID) - assert.Nil(t, err) + require.Nil(t, err) // Becomes error because the stream is already deleted err = streams.Delete(s.streamID) - assert.NotNil(t, err) + require.NotNil(t, err) }