diff --git a/proxy/env.go b/proxy/env.go index 7a77516e320..dfe582014cb 100644 --- a/proxy/env.go +++ b/proxy/env.go @@ -46,6 +46,8 @@ func buildDefaultEnvironmentVariables(ctx context.Context) { setEnvDefault("PROXY_RTMP_SERVER", "11935") // The WebRTC media server, via UDP protocol. setEnvDefault("PROXY_WEBRTC_SERVER", "18000") + // The SRT media server, via UDP protocol. + setEnvDefault("PROXY_SRT_SERVER", "20080") // The API server of proxy itself. setEnvDefault("PROXY_SYSTEM_API", "12025") @@ -70,30 +72,36 @@ func buildDefaultEnvironmentVariables(ctx context.Context) { setEnvDefault("PROXY_DEFAULT_BACKEND_API", "1985") // Default backend udp rtc port, for debugging. setEnvDefault("PROXY_DEFAULT_BACKEND_RTC", "8000") + // Default backend udp srt port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_SRT", "10080") logger.Df(ctx, "load .env as GO_PPROF=%v, "+ "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+ - "PROXY_WEBRTC_SERVER=%v, "+ + "PROXY_WEBRTC_SERVER=%v, PROXY_SRT_SERVER=%v, "+ "PROXY_SYSTEM_API=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+ "PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+ "PROXY_DEFAULT_BACKEND_HTTP=%v, PROXY_DEFAULT_BACKEND_API=%v, "+ - "PROXY_DEFAULT_BACKEND_RTC=%v, "+ + "PROXY_DEFAULT_BACKEND_RTC=%v, PROXY_DEFAULT_BACKEND_SRT=%v, "+ "PROXY_LOAD_BALANCER_TYPE=%v, PROXY_REDIS_HOST=%v, PROXY_REDIS_PORT=%v, "+ "PROXY_REDIS_PASSWORD=%v, PROXY_REDIS_DB=%v", envGoPprof(), envForceQuitTimeout(), envGraceQuitTimeout(), envHttpAPI(), envHttpServer(), envRtmpServer(), - envWebRTCServer(), + envWebRTCServer(), envSRTServer(), envSystemAPI(), envDefaultBackendEnabled(), envDefaultBackendIP(), envDefaultBackendRTMP(), envDefaultBackendHttp(), envDefaultBackendAPI(), - envDefaultBackendRTC(), + envDefaultBackendRTC(), envDefaultBackendSRT(), envLoadBalancerType(), envRedisHost(), envRedisPort(), envRedisPassword(), envRedisDB(), ) } +func envDefaultBackendSRT() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_SRT") +} + func envDefaultBackendRTC() string { return os.Getenv("PROXY_DEFAULT_BACKEND_RTC") } @@ -102,6 +110,10 @@ func envDefaultBackendAPI() string { return os.Getenv("PROXY_DEFAULT_BACKEND_API") } +func envSRTServer() string { + return os.Getenv("PROXY_SRT_SERVER") +} + func envWebRTCServer() string { return os.Getenv("PROXY_WEBRTC_SERVER") } diff --git a/proxy/http.go b/proxy/http.go index e46664f8ecd..7f66c8ee164 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -106,10 +106,9 @@ func (v *httpServer) Run(ctx context.Context) error { stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSPlayStream(func(s *HLSPlayStream) { s.SRSProxyBackendHLSID = logger.GenerateContextID() s.StreamURL, s.FullURL = streamURL, fullURL - s.Initialize(ctx) })) - stream.ServeHTTP(w, r) + stream.Initialize(ctx).ServeHTTP(w, r) return } @@ -262,7 +261,7 @@ func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.Respons } // HLSPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS -// clients will share this object, and they use the same ctx among proxy servers. +// clients will share this object, and they do not use the same ctx among proxy servers. // // Unlike the HTTP FLV or TS connection, HLS client may request the m3u8 or ts via different HTTP connections. // Especially for requesting ts, we need to identify the stream URl or backend server for it. So we create @@ -289,7 +288,9 @@ func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream { } func (v *HLSPlayStream) Initialize(ctx context.Context) *HLSPlayStream { - v.ctx = logger.WithContext(ctx) + if v.ctx == nil { + v.ctx = logger.WithContext(ctx) + } return v } diff --git a/proxy/main.go b/proxy/main.go index 307cbf1de9e..ea87484744c 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -90,6 +90,13 @@ func doMain(ctx context.Context) error { return errors.Wrapf(err, "http api server") } + // Start the SRT server. + srtServer := newSRTServer() + defer srtServer.Close() + if err := srtServer.Run(ctx); err != nil { + return errors.Wrapf(err, "srt server") + } + // Start the System API server. systemAPI := NewSystemAPI(func(server *systemAPI) { server.gracefulQuitTimeout = gracefulQuitTimeout diff --git a/proxy/rtc.go b/proxy/rtc.go index b8cb82df5dc..65bf033989a 100644 --- a/proxy/rtc.go +++ b/proxy/rtc.go @@ -228,17 +228,17 @@ func (v *rtcServer) Run(ctx context.Context) error { endpoint = fmt.Sprintf(":%v", endpoint) } - addr, err := net.ResolveUDPAddr("udp", endpoint) + saddr, err := net.ResolveUDPAddr("udp", endpoint) if err != nil { return errors.Wrapf(err, "resolve udp addr %v", endpoint) } - listener, err := net.ListenUDP("udp", addr) + listener, err := net.ListenUDP("udp", saddr) if err != nil { - return errors.Wrapf(err, "listen udp %v", addr) + return errors.Wrapf(err, "listen udp %v", saddr) } v.listener = listener - logger.Df(ctx, "WebRTC server listen at %v", addr) + logger.Df(ctx, "WebRTC server listen at %v", saddr) // Consume all messages from UDP media transport. v.wg.Add(1) @@ -247,15 +247,15 @@ func (v *rtcServer) Run(ctx context.Context) error { for ctx.Err() == nil { buf := make([]byte, 4096) - n, addr, err := listener.ReadFromUDP(buf) + n, caddr, err := listener.ReadFromUDP(buf) if err != nil { // TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit. logger.Wf(ctx, "read from udp failed, err=%+v", err) continue } - if err := v.handleClientUDP(ctx, addr, buf[:n]); err != nil { - logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, addr, err) + if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil { + logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err) } } }() @@ -268,7 +268,7 @@ func (v *rtcServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data // If STUN binding request, parse the ufrag and identify the connection. if err := func() error { - if rtc_is_rtp_or_rtcp(data) || !rtc_is_stun(data) { + if rtcIsRTPOrRTCP(data) || !rtcIsSTUN(data) { return nil } @@ -358,7 +358,9 @@ func NewRTCConnection(opts ...func(*RTCConnection)) *RTCConnection { } func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *RTCConnection { - v.ctx = logger.WithContext(ctx) + if v.ctx == nil { + v.ctx = logger.WithContext(ctx) + } if listener != nil { v.listenerUDP = listener } @@ -431,7 +433,7 @@ func (v *RTCConnection) connectBackend(ctx context.Context) error { } // Connect to backend SRS server via UDP client. - // TODO: Support close the connection when timeout or DTLS alert. + // TODO: FIXME: Support close the connection when timeout or DTLS alert. backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: udpPort} if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { return errors.Wrapf(err, "dial udp to %v", backendAddr) diff --git a/proxy/srs.go b/proxy/srs.go index 46cf513d8d2..e4c62008117 100644 --- a/proxy/srs.go +++ b/proxy/srs.go @@ -136,6 +136,9 @@ func NewDefaultSRSForDebugging() (*SRSServer, error) { if envDefaultBackendRTC() != "" { server.RTC = []string{envDefaultBackendRTC()} } + if envDefaultBackendSRT() != "" { + server.SRT = []string{envDefaultBackendSRT()} + } return server, nil } diff --git a/proxy/srt.go b/proxy/srt.go new file mode 100644 index 00000000000..3e2b6514872 --- /dev/null +++ b/proxy/srt.go @@ -0,0 +1,574 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "net" + "strconv" + "strings" + stdSync "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/sync" +) + +type srtServer struct { + // The UDP listener for SRT server. + listener *net.UDPConn + + // The SRT connections, identify by the socket ID. + sockets sync.Map[uint32, *SRTConnection] + // The system start time. + start time.Time + + // The wait group for server. + wg stdSync.WaitGroup +} + +func newSRTServer(opts ...func(*srtServer)) *srtServer { + v := &srtServer{ + start: time.Now(), + } + + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srtServer) Close() error { + if v.listener != nil { + v.listener.Close() + } + + v.wg.Wait() + return nil +} + +func (v *srtServer) Run(ctx context.Context) error { + // Parse address to listen. + endpoint := envSRTServer() + if !strings.Contains(endpoint, ":") { + endpoint = ":" + endpoint + } + + saddr, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + return errors.Wrapf(err, "resolve udp addr %v", endpoint) + } + + listener, err := net.ListenUDP("udp", saddr) + if err != nil { + return errors.Wrapf(err, "listen udp %v", saddr) + } + v.listener = listener + logger.Df(ctx, "SRT server listen at %v", saddr) + + // Consume all messages from UDP media transport. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, caddr, err := v.listener.ReadFromUDP(buf) + if err != nil { + // TODO: If SRT server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "read from udp failed, err=%+v", err) + continue + } + + if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil { + logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err) + } + } + }() + + return nil +} + +func (v *srtServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { + socketID := srtParseSocketID(data) + + var pkt *SRTHandshakePacket + if srtIsHandshake(data) { + pkt = &SRTHandshakePacket{} + if err := pkt.UnmarshalBinary(data); err != nil { + return err + } + + if socketID == 0 { + socketID = pkt.SRTSocketID + } + } + + conn, ok := v.sockets.LoadOrStore(socketID, NewSRTConnection(func(c *SRTConnection) { + c.ctx = logger.WithContext(ctx) + c.listenerUDP, c.socketID = v.listener, socketID + c.start = v.start + })) + + ctx = conn.ctx + if !ok { + logger.Df(ctx, "Create new SRT connection skt=%v", socketID) + } + + if newSocketID, err := conn.HandlePacket(pkt, addr, data); err != nil { + return errors.Wrapf(err, "handle packet") + } else if newSocketID != 0 && newSocketID != socketID { + // The connection may use a new socket ID. + // TODO: FIXME: Should cleanup the dead SRT connection. + v.sockets.Store(newSocketID, conn) + } + + return nil +} + +// SRTConnection is an SRT connection proxy, for both caller and listener. It represents an SRT +// connection, identify by the socket ID. +// +// It's similar to RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is in +// the client request. The SRTConnection is stateless, and no need to sync between proxy servers. +// +// Unlike the WebRTC connection, SRTConnection does not support address changes. This means the +// client should never switch to another network or port. If this occurs, the client may be served +// by a different proxy server and fail because the other proxy server cannot identify the client. +type SRTConnection struct { + // The stream context for SRT connection. + ctx context.Context + + // The current socket ID. + socketID uint32 + + // The UDP connection proxy to backend. + backendUDP *net.UDPConn + // The listener UDP connection, used to send messages to client. + listenerUDP *net.UDPConn + + // Listener start time. + start time.Time + + // Handshake packets with client. + handshake0 *SRTHandshakePacket + handshake1 *SRTHandshakePacket + handshake2 *SRTHandshakePacket + handshake3 *SRTHandshakePacket +} + +func NewSRTConnection(opts ...func(*SRTConnection)) *SRTConnection { + v := &SRTConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *SRTConnection) HandlePacket(pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) (uint32, error) { + ctx := v.ctx + + // If not handshake, try to proxy to backend directly. + if pkt == nil { + // Proxy client message to backend. + if v.backendUDP != nil { + if _, err := v.backendUDP.Write(data); err != nil { + return v.socketID, errors.Wrapf(err, "write to backend") + } + } + + return v.socketID, nil + } + + // Handle handshake messages. + if err := v.handleHandshake(ctx, pkt, addr, data); err != nil { + return v.socketID, errors.Wrapf(err, "handle handshake %v", pkt) + } + + return v.socketID, nil +} + +func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) error { + // Handle handshake 0 and 1 messages. + if pkt.SynCookie == 0 { + // Save handshake 0 packet. + v.handshake0 = pkt + logger.Df(ctx, "SRT Handshake 0: %v", v.handshake0) + + // Response handshake 1. + v.handshake1 = &SRTHandshakePacket{ + ControlFlag: pkt.ControlFlag, + ControlType: 0, + SubType: 0, + AdditionalInfo: 0, + Timestamp: uint32(time.Since(v.start).Microseconds()), + SocketID: pkt.SRTSocketID, + Version: 5, + EncryptionField: 0, + ExtensionField: 0x4A17, + InitSequence: pkt.InitSequence, + MTU: pkt.MTU, + FlowWindow: pkt.FlowWindow, + HandshakeType: 1, + SRTSocketID: pkt.SRTSocketID, + SynCookie: 0x418d5e4e, + PeerIP: net.ParseIP("127.0.0.1"), + } + logger.Df(ctx, "SRT Handshake 1: %v", v.handshake1) + + if b, err := v.handshake1.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 1") + } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { + return errors.Wrapf(err, "write handshake 1") + } + + return nil + } + + // Handle handshake 2 and 3 messages. + // Parse stream id from packet. + streamID, err := pkt.StreamID() + if err != nil { + return errors.Wrapf(err, "parse stream id") + } + + // Save handshake packet. + v.handshake2 = pkt + logger.Df(ctx, "SRT Handshake 2: %v, sid=%v", v.handshake2, streamID) + + // Start the UDP proxy to backend. + if err := v.connectBackend(ctx, streamID); err != nil { + return errors.Wrapf(err, "connect backend for %v", streamID) + } + + // Proxy client message to backend. + if v.backendUDP == nil { + return errors.Errorf("no backend for %v", streamID) + } + + // Proxy handshake 0 to backend server. + if b, err := v.handshake0.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 0") + } else if _, err = v.backendUDP.Write(b); err != nil { + return errors.Wrapf(err, "write handshake 0") + } + logger.Df(ctx, "Proxy send handshake 0: %v", v.handshake0) + + // Read handshake 1 from backend server. + b := make([]byte, 4096) + handshake1p := &SRTHandshakePacket{} + if nn, err := v.backendUDP.Read(b); err != nil { + return errors.Wrapf(err, "read handshake 1") + } else if err := handshake1p.UnmarshalBinary(b[:nn]); err != nil { + return errors.Wrapf(err, "unmarshal handshake 1") + } + logger.Df(ctx, "Proxy got handshake 1: %v", handshake1p) + + // Proxy handshake 2 to backend server. + handshake2p := *v.handshake2 + handshake2p.SynCookie = handshake1p.SynCookie + if b, err := handshake2p.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 2") + } else if _, err = v.backendUDP.Write(b); err != nil { + return errors.Wrapf(err, "write handshake 2") + } + logger.Df(ctx, "Proxy send handshake 2: %v", handshake2p) + + // Read handshake 3 from backend server. + handshake3p := &SRTHandshakePacket{} + if nn, err := v.backendUDP.Read(b); err != nil { + return errors.Wrapf(err, "read handshake 3") + } else if err := handshake3p.UnmarshalBinary(b[:nn]); err != nil { + return errors.Wrapf(err, "unmarshal handshake 3") + } + logger.Df(ctx, "Proxy got handshake 3: %v", handshake3p) + + // Response handshake 3 to client. + v.handshake3 = &*handshake3p + v.handshake3.SynCookie = v.handshake1.SynCookie + v.socketID = handshake3p.SRTSocketID + logger.Df(ctx, "Handshake 3: %v", v.handshake3) + + if b, err := v.handshake3.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 3") + } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { + return errors.Wrapf(err, "write handshake 3") + } + + // Start a goroutine to proxy message from backend to client. + // TODO: FIXME: Support close the connection when timeout or client disconnected. + go func() { + for ctx.Err() == nil { + nn, err := v.backendUDP.Read(b) + if err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "read from backend failed, err=%v", err) + return + } + if _, err = v.listenerUDP.WriteToUDP(b[:nn], addr); err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "write to client failed, err=%v", err) + return + } + } + }() + return nil +} + +func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) error { + if v.backendUDP != nil { + return nil + } + + // Parse stream id to host and resource. + host, resource, err := parseSRTStreamID(streamID) + if err != nil { + return errors.Wrapf(err, "parse stream id %v", streamID) + } + + if host == "" { + host = "localhost" + } + + streamURL, err := buildStreamURL(fmt.Sprintf("srt://%v/%v", host, resource)) + if err != nil { + return errors.Wrapf(err, "build stream url %v", streamID) + } + + // Pick a backend SRS server to proxy the SRT stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + // Parse UDP port from backend. + if len(backend.SRT) == 0 { + return errors.Errorf("no udp server %v for %v", backend, streamURL) + } + + var udpPort int + if iv, err := strconv.ParseInt(backend.SRT[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.SRT[0], backend, streamURL) + } else { + udpPort = int(iv) + } + + // Connect to backend SRS server via UDP client. + // TODO: FIXME: Support close the connection when timeout or DTLS alert. + backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: udpPort} + if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { + return errors.Wrapf(err, "dial udp to %v of %v for %v", backendAddr, backend, streamURL) + } else { + v.backendUDP = backendUDP + } + + return nil +} + +// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2 +// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2.1 +type SRTHandshakePacket struct { + // F: 1 bit. Packet Type Flag. The control packet has this flag set to + // "1". The data packet has this flag set to "0". + ControlFlag uint8 + // Control Type: 15 bits. Control Packet Type. The use of these bits + // is determined by the control packet type definition. + // Handshake control packets (Control Type = 0x0000) are used to + // exchange peer configurations, to agree on connection parameters, and + // to establish a connection. + ControlType uint16 + // Subtype: 16 bits. This field specifies an additional subtype for + // specific packets. + SubType uint16 + // Type-specific Information: 32 bits. The use of this field depends on + // the particular control packet type. Handshake packets do not use + // this field. + AdditionalInfo uint32 + // Timestamp: 32 bits. + Timestamp uint32 + // Destination Socket ID: 32 bits. + SocketID uint32 + + // Version: 32 bits. A base protocol version number. Currently used + // values are 4 and 5. Values greater than 5 are reserved for future + // use. + Version uint32 + // Encryption Field: 16 bits. Block cipher family and key size. The + // values of this field are described in Table 2. The default value + // is AES-128. + // 0 | No Encryption Advertised + // 2 | AES-128 + // 3 | AES-192 + // 4 | AES-256 + EncryptionField uint16 + // Extension Field: 16 bits. This field is message specific extension + // related to Handshake Type field. The value MUST be set to 0 + // except for the following cases. (1) If the handshake control + // packet is the INDUCTION message, this field is sent back by the + // Listener. (2) In the case of a CONCLUSION message, this field + // value should contain a combination of Extension Type values. + // 0x00000001 | HSREQ + // 0x00000002 | KMREQ + // 0x00000004 | CONFIG + // 0x4A17 if HandshakeType is INDUCTION, see https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-4.3.1.1 + ExtensionField uint16 + // Initial Packet Sequence Number: 32 bits. The sequence number of the + // very first data packet to be sent. + InitSequence uint32 + // Maximum Transmission Unit Size: 32 bits. This value is typically set + // to 1500, which is the default Maximum Transmission Unit (MTU) size + // for Ethernet, but can be less. + MTU uint32 + // Maximum Flow Window Size: 32 bits. The value of this field is the + // maximum number of data packets allowed to be "in flight" (i.e. the + // number of sent packets for which an ACK control packet has not yet + // been received). + FlowWindow uint32 + // Handshake Type: 32 bits. This field indicates the handshake packet + // type. + // 0xFFFFFFFD | DONE + // 0xFFFFFFFE | AGREEMENT + // 0xFFFFFFFF | CONCLUSION + // 0x00000000 | WAVEHAND + // 0x00000001 | INDUCTION + HandshakeType uint32 + // SRT Socket ID: 32 bits. This field holds the ID of the source SRT + // socket from which a handshake packet is issued. + SRTSocketID uint32 + // SYN Cookie: 32 bits. Randomized value for processing a handshake. + // The value of this field is specified by the handshake message + // type. + SynCookie uint32 + // Peer IP Address: 128 bits. IPv4 or IPv6 address of the packet's + // sender. The value consists of four 32-bit fields. + PeerIP net.IP + // Extensions. + // Extension Type: 16 bits. The value of this field is used to process + // an integrated handshake. Each extension can have a pair of + // request and response types. + // Extension Length: 16 bits. The length of the Extension Contents + // field in four-byte blocks. + // Extension Contents: variable length. The payload of the extension. + ExtraData []byte +} + +func (v *SRTHandshakePacket) IsData() bool { + return v.ControlFlag == 0x00 +} + +func (v *SRTHandshakePacket) IsControl() bool { + return v.ControlFlag == 0x80 +} + +func (v *SRTHandshakePacket) IsHandshake() bool { + return v.IsControl() && v.ControlType == 0x00 && v.SubType == 0x00 +} + +func (v *SRTHandshakePacket) StreamID() (string, error) { + p := v.ExtraData + for { + if len(p) < 2 { + return "", errors.Errorf("Require 2 bytes, actual=%v, extra=%v", len(p), len(v.ExtraData)) + } + + extType := binary.BigEndian.Uint16(p) + extSize := binary.BigEndian.Uint16(p[2:]) + p = p[4:] + + if len(p) < int(extSize*4) { + return "", errors.Errorf("Require %v bytes, actual=%v, extra=%v", extSize*4, len(p), len(v.ExtraData)) + } + + // Ignore other packets except stream id. + if extType != 0x05 { + p = p[extSize*4:] + continue + } + + // We must copy it, because we will decode the stream id. + data := append([]byte{}, p[:extSize*4]...) + + // Reverse the stream id encoded in little-endian to big-endian. + for i := 0; i < len(data); i += 4 { + value := binary.LittleEndian.Uint32(data[i:]) + binary.BigEndian.PutUint32(data[i:], value) + } + + // Trim the trailing zero bytes. + data = bytes.TrimRight(data, "\x00") + return string(data), nil + } +} + +func (v *SRTHandshakePacket) String() string { + return fmt.Sprintf("Control=%v, CType=%v, SType=%v, Timestamp=%v, SocketID=%v, Version=%v, Encrypt=%v, Extension=%v, InitSequence=%v, MTU=%v, FlowWnd=%v, HSType=%v, SRTSocketID=%v, Cookie=%v, Peer=%vB, Extra=%vB", + v.IsControl(), v.ControlType, v.SubType, v.Timestamp, v.SocketID, v.Version, v.EncryptionField, v.ExtensionField, v.InitSequence, v.MTU, v.FlowWindow, v.HandshakeType, v.SRTSocketID, v.SynCookie, len(v.PeerIP), len(v.ExtraData)) +} + +func (v *SRTHandshakePacket) UnmarshalBinary(b []byte) error { + if len(b) < 4 { + return errors.Errorf("Invalid packet length %v", len(b)) + } + v.ControlFlag = b[0] & 0x80 + v.ControlType = binary.BigEndian.Uint16(b[0:2]) & 0x7fff + v.SubType = binary.BigEndian.Uint16(b[2:4]) + + if len(b) < 64 { + return errors.Errorf("Invalid packet length %v", len(b)) + } + v.AdditionalInfo = binary.BigEndian.Uint32(b[4:]) + v.Timestamp = binary.BigEndian.Uint32(b[8:]) + v.SocketID = binary.BigEndian.Uint32(b[12:]) + v.Version = binary.BigEndian.Uint32(b[16:]) + v.EncryptionField = binary.BigEndian.Uint16(b[20:]) + v.ExtensionField = binary.BigEndian.Uint16(b[22:]) + v.InitSequence = binary.BigEndian.Uint32(b[24:]) + v.MTU = binary.BigEndian.Uint32(b[28:]) + v.FlowWindow = binary.BigEndian.Uint32(b[32:]) + v.HandshakeType = binary.BigEndian.Uint32(b[36:]) + v.SRTSocketID = binary.BigEndian.Uint32(b[40:]) + v.SynCookie = binary.BigEndian.Uint32(b[44:]) + + // Only support IPv4. + v.PeerIP = net.IPv4(b[51], b[50], b[49], b[48]) + + v.ExtraData = b[64:] + + return nil +} + +func (v *SRTHandshakePacket) MarshalBinary() ([]byte, error) { + b := make([]byte, 64+len(v.ExtraData)) + binary.BigEndian.PutUint16(b, uint16(v.ControlFlag)<<8|v.ControlType) + binary.BigEndian.PutUint16(b[2:], v.SubType) + binary.BigEndian.PutUint32(b[4:], v.AdditionalInfo) + binary.BigEndian.PutUint32(b[8:], v.Timestamp) + binary.BigEndian.PutUint32(b[12:], v.SocketID) + binary.BigEndian.PutUint32(b[16:], v.Version) + binary.BigEndian.PutUint16(b[20:], v.EncryptionField) + binary.BigEndian.PutUint16(b[22:], v.ExtensionField) + binary.BigEndian.PutUint32(b[24:], v.InitSequence) + binary.BigEndian.PutUint32(b[28:], v.MTU) + binary.BigEndian.PutUint32(b[32:], v.FlowWindow) + binary.BigEndian.PutUint32(b[36:], v.HandshakeType) + binary.BigEndian.PutUint32(b[40:], v.SRTSocketID) + binary.BigEndian.PutUint32(b[44:], v.SynCookie) + + // Only support IPv4. + ip := v.PeerIP.To4() + b[48] = ip[3] + b[49] = ip[2] + b[50] = ip[1] + b[51] = ip[0] + + if len(v.ExtraData) > 0 { + copy(b[64:], v.ExtraData) + } + + return b, nil +} diff --git a/proxy/utils.go b/proxy/utils.go index c2f41ed1fb2..9aa9cdbef76 100644 --- a/proxy/utils.go +++ b/proxy/utils.go @@ -5,6 +5,7 @@ package main import ( "context" + "encoding/binary" "encoding/json" stdErr "errors" "fmt" @@ -178,35 +179,71 @@ func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) { return } -// rtc_is_stun returns true if data of UDP payload is a STUN packet. -func rtc_is_stun(data []byte) bool { +// rtcIsSTUN returns true if data of UDP payload is a STUN packet. +func rtcIsSTUN(data []byte) bool { return len(data) > 0 && (data[0] == 0 || data[0] == 1) } -// rtc_is_rtp_or_rtcp returns true if data of UDP payload is a RTP or RTCP packet. -func rtc_is_rtp_or_rtcp(data []byte) bool { +// rtcIsRTPOrRTCP returns true if data of UDP payload is a RTP or RTCP packet. +func rtcIsRTPOrRTCP(data []byte) bool { return len(data) >= 12 && (data[0]&0xC0) == 0x80 } +// srtIsHandshake returns true if data of UDP payload is a SRT handshake packet. +func srtIsHandshake(data []byte) bool { + return len(data) >= 4 && binary.BigEndian.Uint32(data) == 0x80000000 +} + +// srtParseSocketID parse the socket id from the SRT packet. +func srtParseSocketID(data []byte) uint32 { + if len(data) >= 16 { + return binary.BigEndian.Uint32(data[12:]) + } + return 0 +} + // parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP. func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) { - var iceUfrag, icePwd string if true { ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`) ufragMatch := ufragRe.FindStringSubmatch(sdp) if len(ufragMatch) <= 1 { return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp) } - iceUfrag = ufragMatch[1] + ufrag = ufragMatch[1] } + if true { pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`) pwdMatch := pwdRe.FindStringSubmatch(sdp) if len(pwdMatch) <= 1 { return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp) } - icePwd = pwdMatch[1] + pwd = pwdMatch[1] + } + + return ufrag, pwd, nil +} + +// parseSRTStreamID parse the SRT stream id to host(optional) and resource(required). +// See https://ossrs.io/lts/en-us/docs/v7/doc/srt#srt-url +func parseSRTStreamID(sid string) (host, resource string, err error) { + if true { + hostRe := regexp.MustCompile(`h=([^,]+)`) + hostMatch := hostRe.FindStringSubmatch(sid) + if len(hostMatch) > 1 { + host = hostMatch[1] + } + } + + if true { + resourceRe := regexp.MustCompile(`r=([^,]+)`) + resourceMatch := resourceRe.FindStringSubmatch(sid) + if len(resourceMatch) <= 1 { + return "", "", errors.Errorf("no resource in sid %v", sid) + } + resource = resourceMatch[1] } - return iceUfrag, icePwd, nil + return host, resource, nil } diff --git a/trunk/conf/origin1-for-proxy.conf b/trunk/conf/origin1-for-proxy.conf index 51627a92d2c..baca5c9f407 100644 --- a/trunk/conf/origin1-for-proxy.conf +++ b/trunk/conf/origin1-for-proxy.conf @@ -19,6 +19,12 @@ rtc_server { # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate candidate $CANDIDATE; } +srt_server { + enabled on; + listen 10081; + tsbpdmode off; + tlpktdrop off; +} heartbeat { enabled on; interval 9; @@ -44,4 +50,8 @@ vhost __defaultVhost__ { # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp rtc_to_rtmp on; } + srt { + enabled on; + srt_to_rtmp on; + } } diff --git a/trunk/conf/origin2-for-proxy.conf b/trunk/conf/origin2-for-proxy.conf index ab418833c57..48f6398930f 100644 --- a/trunk/conf/origin2-for-proxy.conf +++ b/trunk/conf/origin2-for-proxy.conf @@ -19,6 +19,12 @@ rtc_server { # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate candidate $CANDIDATE; } +srt_server { + enabled on; + listen 10082; + tsbpdmode off; + tlpktdrop off; +} heartbeat { enabled on; interval 9; @@ -44,4 +50,8 @@ vhost __defaultVhost__ { # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp rtc_to_rtmp on; } + srt { + enabled on; + srt_to_rtmp on; + } } diff --git a/trunk/conf/origin3-for-proxy.conf b/trunk/conf/origin3-for-proxy.conf index 43dd214bd15..95624fb7736 100644 --- a/trunk/conf/origin3-for-proxy.conf +++ b/trunk/conf/origin3-for-proxy.conf @@ -19,6 +19,12 @@ rtc_server { # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate candidate $CANDIDATE; } +srt_server { + enabled on; + listen 10083; + tsbpdmode off; + tlpktdrop off; +} heartbeat { enabled on; interval 9; @@ -44,4 +50,8 @@ vhost __defaultVhost__ { # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp rtc_to_rtmp on; } + srt { + enabled on; + srt_to_rtmp on; + } }