From 8b126c0d377623e2f9785c927a38245a690f992e Mon Sep 17 00:00:00 2001 From: Alexey Khit Date: Sat, 6 May 2023 14:31:46 +0300 Subject: [PATCH] Add support RTSP over WebSocket --- internal/rtsp/rtsp.go | 25 ++++--- pkg/rtsp/client.go | 30 ++------- pkg/rtsp/conn.go | 1 + pkg/rtsp/dial.go | 44 ++++++++++++ pkg/tcp/websocket/client.go | 130 ++++++++++++++++++++++++++++++++++++ pkg/tcp/websocket/dial.go | 64 ++++++++++++++++++ 6 files changed, 258 insertions(+), 36 deletions(-) create mode 100644 pkg/rtsp/dial.go create mode 100644 pkg/tcp/websocket/client.go create mode 100644 pkg/tcp/websocket/dial.go diff --git a/internal/rtsp/rtsp.go b/internal/rtsp/rtsp.go index f50337f4..9d1234d5 100644 --- a/internal/rtsp/rtsp.go +++ b/internal/rtsp/rtsp.go @@ -91,19 +91,19 @@ var log zerolog.Logger var handlers []Handler var defaultMedias []*core.Media -func rtspHandler(url string) (core.Producer, error) { - backchannel := true +func rtspHandler(rawURL string) (core.Producer, error) { + rawURL, rawQuery, _ := strings.Cut(rawURL, "#") - if i := strings.IndexByte(url, '#'); i > 0 { - if url[i+1:] == "backchannel=0" { - backchannel = false - } - url = url[:i] - } - - conn := rtsp.NewClient(url) + conn := rtsp.NewClient(rawURL) + conn.Backchannel = true conn.UserAgent = app.UserAgent + if rawQuery != "" { + query := streams.ParseQuery(rawQuery) + conn.Backchannel = query.Get("backchannel") == "1" + conn.Transport = query.Get("transport") + } + if log.Trace().Enabled() { conn.Listen(func(msg any) { switch msg := msg.(type) { @@ -121,12 +121,11 @@ func rtspHandler(url string) (core.Producer, error) { return nil, err } - conn.Backchannel = backchannel if err := conn.Describe(); err != nil { - if !backchannel { + if !conn.Backchannel { return nil, err } - log.Trace().Msgf("[rtsp] describe (backchannel=%t) err: %v", backchannel, err) + log.Trace().Msgf("[rtsp] describe (backchannel=%t) err: %v", conn.Backchannel, err) // second try without backchannel, we need to reconnect conn.Backchannel = false diff --git a/pkg/rtsp/client.go b/pkg/rtsp/client.go index b5f3db6b..a4a3b656 100644 --- a/pkg/rtsp/client.go +++ b/pkg/rtsp/client.go @@ -2,10 +2,9 @@ package rtsp import ( "bufio" - "crypto/tls" "errors" "fmt" - "net" + "github.com/AlexxIT/go2rtc/pkg/tcp/websocket" "net/http" "net/url" "strconv" @@ -23,33 +22,18 @@ func NewClient(uri string) *Conn { } func (c *Conn) Dial() (err error) { - if c.URL, err = url.Parse(c.uri); err != nil { - return + if c.Transport == "" { + c.conn, err = Dial(c.uri) + } else { + c.conn, err = websocket.Dial(c.Transport) } - if strings.IndexByte(c.URL.Host, ':') < 0 { - c.URL.Host += ":554" - } - - c.conn, err = net.DialTimeout("tcp", c.URL.Host, time.Second*5) if err != nil { return } - var tlsConf *tls.Config - switch c.URL.Scheme { - case "rtsps": - tlsConf = &tls.Config{ServerName: c.URL.Hostname()} - case "rtspx": - c.URL.Scheme = "rtsps" - tlsConf = &tls.Config{InsecureSkipVerify: true} - } - if tlsConf != nil { - tlsConn := tls.Client(c.conn, tlsConf) - if err = tlsConn.Handshake(); err != nil { - return err - } - c.conn = tlsConn + if c.URL, err = url.Parse(c.uri); err != nil { + return } // remove UserInfo from URL diff --git a/pkg/rtsp/conn.go b/pkg/rtsp/conn.go index 2bdb91bb..9b23087a 100644 --- a/pkg/rtsp/conn.go +++ b/pkg/rtsp/conn.go @@ -24,6 +24,7 @@ type Conn struct { Backchannel bool PacketSize uint16 SessionName string + Transport string // custom transport support, ex. RTSP over WebSocket Medias []*core.Media UserAgent string diff --git a/pkg/rtsp/dial.go b/pkg/rtsp/dial.go new file mode 100644 index 00000000..58d5dd65 --- /dev/null +++ b/pkg/rtsp/dial.go @@ -0,0 +1,44 @@ +package rtsp + +import ( + "crypto/tls" + "errors" + "net" + "net/url" + "strings" + "time" +) + +func Dial(uri string) (net.Conn, error) { + u, err := url.Parse(uri) + if err != nil { + return nil, err + } + + switch u.Scheme { + case "rtsp": + return dialTCP(u.Host, nil) + case "rtsps": + tlsConf := &tls.Config{ServerName: u.Hostname()} + return dialTCP(u.Host, tlsConf) + case "rtspx": + tlsConf := &tls.Config{InsecureSkipVerify: true} + return dialTCP(u.Host, tlsConf) + } + + return nil, errors.New("unsupported scheme: " + u.Scheme) +} + +func dialTCP(address string, tlsConf *tls.Config) (net.Conn, error) { + if strings.IndexByte(address, ':') < 0 { + address += ":554" + } + + conn, err := net.DialTimeout("tcp", address, time.Second*5) + if tlsConf == nil || err != nil { + return conn, err + } + + tlsConn := tls.Client(conn, tlsConf) + return tlsConn, tlsConn.Handshake() +} diff --git a/pkg/tcp/websocket/client.go b/pkg/tcp/websocket/client.go new file mode 100644 index 00000000..e95ce1e4 --- /dev/null +++ b/pkg/tcp/websocket/client.go @@ -0,0 +1,130 @@ +package websocket + +import ( + cryptorand "crypto/rand" + "encoding/binary" + "fmt" + "io" + "net" + "time" +) + +const BinaryMessage = 2 + +type Client struct { + conn net.Conn + remain int +} + +func NewClient(conn net.Conn) *Client { + return &Client{conn: conn} +} + +const finalBit = 0x80 +const maskBit = 0x80 + +func (w *Client) Read(b []byte) (n int, err error) { + if w.remain == 0 { + b2 := make([]byte, 2) + if _, err = io.ReadFull(w.conn, b2); err != nil { + return 0, err + } + + frameType := b2[0] & 0xF + w.remain = int(b2[1] & 0x7F) + + switch frameType { + case BinaryMessage: + default: + return 0, fmt.Errorf("unsupported frame type: %d", frameType) + } + + switch w.remain { + case 126: + if _, err = io.ReadFull(w.conn, b2); err != nil { + return 0, err + } + w.remain = int(binary.BigEndian.Uint16(b2)) + case 127: + b8 := make([]byte, 8) + if _, err = io.ReadFull(w.conn, b8); err != nil { + return 0, err + } + w.remain = int(binary.BigEndian.Uint64(b8)) + } + } + + if w.remain > len(b) { + n, err = io.ReadFull(w.conn, b) + w.remain -= n + return + } + + n, err = io.ReadFull(w.conn, b[:w.remain]) + w.remain = 0 + + return +} + +func (w *Client) Write(b []byte) (n int, err error) { + var data []byte + var start byte + + size := len(b) + + switch { + case size > 65535: + start = 10 + data = make([]byte, size+14) + data[1] = maskBit | 127 + binary.BigEndian.PutUint64(data[2:], uint64(size)) + case size > 125: + start = 4 + data = make([]byte, size+8) + data[1] = maskBit | 126 + binary.BigEndian.PutUint16(data[2:], uint16(size)) + default: + start = 2 + data = make([]byte, size+6) + data[1] = maskBit | byte(size) + } + + data[0] = BinaryMessage | finalBit + + mask := data[start : start+4] + msg := data[start+4:] + + if _, err = cryptorand.Read(mask); err != nil { + return 0, err + } + + for i := 0; i < len(b); i++ { + msg[i] = b[i] ^ mask[i%4] + } + + return w.conn.Write(data) +} + +func (w *Client) Close() error { + return w.conn.Close() +} + +func (w *Client) LocalAddr() net.Addr { + return w.conn.LocalAddr() +} + +func (w *Client) RemoteAddr() net.Addr { + return w.conn.RemoteAddr() +} + +func (w *Client) SetDeadline(t time.Time) error { + return w.conn.SetDeadline(t) +} + +func (w *Client) SetReadDeadline(t time.Time) error { + return w.conn.SetReadDeadline(t) +} + +func (w *Client) SetWriteDeadline(t time.Time) error { + return w.conn.SetWriteDeadline(t) +} diff --git a/pkg/tcp/websocket/dial.go b/pkg/tcp/websocket/dial.go new file mode 100644 index 00000000..737a5cbc --- /dev/null +++ b/pkg/tcp/websocket/dial.go @@ -0,0 +1,64 @@ +package websocket + +import ( + cryptorand "crypto/rand" + "crypto/sha1" + "encoding/base64" + "errors" + "github.com/AlexxIT/go2rtc/pkg/tcp" + "net" + "net/http" + "strings" +) + +func Dial(address string) (net.Conn, error) { + if strings.HasPrefix(address, "ws") { + address = "http" + address[2:] // support http and https + } + + // using custom client for support Digest Auth + // https://github.com/AlexxIT/go2rtc/issues/415 + ctx, pconn := tcp.WithConn() + + req, err := http.NewRequestWithContext(ctx, "GET", address, nil) + if err != nil { + return nil, err + } + + key, accept := GetKeyAccept() + + // Version, Key, Protocol important for Axis cameras + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-WebSocket-Key", key) + req.Header.Set("Sec-WebSocket-Protocol", "binary") + + res, err := tcp.Do(req) + if err != nil { + return nil, err + } + + if res.StatusCode != http.StatusSwitchingProtocols { + return nil, errors.New("wrong status: " + res.Status) + } + + if res.Header.Get("Sec-Websocket-Accept") != accept { + return nil, errors.New("wrong websocket accept") + } + + return NewClient(*pconn), nil +} + +func GetKeyAccept() (key, accept string) { + b := make([]byte, 16) + _, _ = cryptorand.Read(b) + key = base64.StdEncoding.EncodeToString(b) + + h := sha1.New() + h.Write([]byte(key)) + h.Write([]byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) + accept = base64.StdEncoding.EncodeToString(h.Sum(nil)) + + return +}