From a5c4854aeb8ee5ec9c3d7f36bedc562dbefd8f41 Mon Sep 17 00:00:00 2001 From: Alexey Khit Date: Sun, 16 Apr 2023 13:57:16 +0300 Subject: [PATCH] Add reconnect logic to RTSP client --- pkg/rtsp/client.go | 163 +++++++------------------------------------ pkg/rtsp/conn.go | 128 ++++++++++++++++++++++++--------- pkg/rtsp/consumer.go | 11 ++- pkg/rtsp/producer.go | 97 +++++++++++++++++++------ pkg/rtsp/server.go | 20 +++--- 5 files changed, 214 insertions(+), 205 deletions(-) diff --git a/pkg/rtsp/client.go b/pkg/rtsp/client.go index 304203ae..4ed228f0 100644 --- a/pkg/rtsp/client.go +++ b/pkg/rtsp/client.go @@ -31,10 +31,6 @@ func (c *Conn) Dial() (err error) { c.URL.Host += ":554" } - // remove UserInfo from URL - c.auth = tcp.NewAuth(c.URL.User) - c.URL.User = nil - c.conn, err = net.DialTimeout("tcp", c.URL.Host, time.Second*5) if err != nil { return @@ -56,55 +52,24 @@ func (c *Conn) Dial() (err error) { c.conn = tlsConn } + // remove UserInfo from URL + c.auth = tcp.NewAuth(c.URL.User) + c.URL.User = nil + c.reader = bufio.NewReader(c.conn) + c.session = "" c.state = StateConn return nil } -// Request sends only Request -func (c *Conn) Request(req *tcp.Request) error { - if req.Proto == "" { - req.Proto = ProtoRTSP - } - - if req.Header == nil { - req.Header = make(map[string][]string) - } - - c.sequence++ - // important to send case sensitive CSeq - // https://github.com/AlexxIT/go2rtc/issues/7 - req.Header["CSeq"] = []string{strconv.Itoa(c.sequence)} - - c.auth.Write(req) - - if c.Session != "" { - req.Header.Set("Session", c.Session) - } - - if req.Body != nil { - val := strconv.Itoa(len(req.Body)) - req.Header.Set("Content-Length", val) - } - - c.Fire(req) - - return req.Write(c.conn) -} - -// Do send Request and receive and process Response +// Do send WriteRequest and receive and process WriteResponse func (c *Conn) Do(req *tcp.Request) (*tcp.Response, error) { - // https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/ - if err := c.conn.SetDeadline(time.Now().Add(Timeout)); err != nil { + if err := c.WriteRequest(req); err != nil { return nil, err } - if err := c.Request(req); err != nil { - return nil, err - } - - res, err := tcp.ReadResponse(c.reader) + res, err := c.ReadResponse() if err != nil { return nil, err } @@ -134,40 +99,6 @@ func (c *Conn) Do(req *tcp.Request) (*tcp.Response, error) { return res, nil } -func (c *Conn) Response(res *tcp.Response) error { - if res.Proto == "" { - res.Proto = ProtoRTSP - } - - if res.Status == "" { - res.Status = "200 OK" - } - - if res.Header == nil { - res.Header = make(map[string][]string) - } - - if res.Request != nil && res.Request.Header != nil { - seq := res.Request.Header.Get("CSeq") - if seq != "" { - res.Header.Set("CSeq", seq) - } - } - - if c.Session != "" { - res.Header.Set("Session", c.Session) - } - - if res.Body != nil { - val := strconv.Itoa(len(res.Body)) - res.Header.Set("Content-Length", val) - } - - c.Fire(res) - - return res.Write(c.conn) -} - func (c *Conn) Options() error { req := &tcp.Request{Method: MethodOptions, URL: c.URL} @@ -219,11 +150,18 @@ func (c *Conn) Describe() error { } } - c.Medias, err = UnmarshalSDP(res.Body) + medias, err := UnmarshalSDP(res.Body) if err != nil { return err } + // TODO: rewrite more smart + if c.Medias == nil { + c.Medias = medias + } else if len(c.Medias) > len(medias) { + c.Medias = c.Medias[:len(medias)] + } + c.mode = core.ModeActiveProducer return nil @@ -250,28 +188,7 @@ func (c *Conn) Announce() (err error) { return } -func (c *Conn) Setup() error { - for _, media := range c.Medias { - _, err := c.SetupMedia(media, true) - if err != nil { - return err - } - } - - return nil -} - -func (c *Conn) SetupMedia(media *core.Media, first bool) (byte, error) { - // TODO: rewrite recoonection and first flag - if first { - c.stateMu.Lock() - defer c.stateMu.Unlock() - } - - if c.state != StateConn && c.state != StateSetup { - return 0, fmt.Errorf("RTSP SETUP from wrong state: %s", c.state) - } - +func (c *Conn) SetupMedia(media *core.Media) (byte, error) { var transport string // try to use media position as channel number @@ -311,39 +228,28 @@ func (c *Conn) SetupMedia(media *core.Media, first bool) (byte, error) { }, } - var res *tcp.Response - res, err = c.Do(req) + res, err := c.Do(req) if err != nil { // some Dahua/Amcrest cameras fail here because two simultaneous // backchannel connections if c.Backchannel { - _ = c.conn.Close() - c.Backchannel = false - if err := c.Dial(); err != nil { + if err = c.Reconnect(); err != nil { return 0, err } - if err := c.Describe(); err != nil { - return 0, err - } - - for _, m := range c.Medias { - if m.Equal(media) { - return c.SetupMedia(m, false) - } - } + return c.SetupMedia(media) } return 0, err } - if c.Session == "" { + if c.session == "" { // Session: 216525287999;timeout=60 if s := res.Header.Get("Session"); s != "" { if j := strings.IndexByte(s, ';'); j > 0 { s = s[:j] } - c.Session = s + c.session = s } } @@ -361,8 +267,6 @@ func (c *Conn) SetupMedia(media *core.Media, first bool) (byte, error) { } } - c.state = StateSetup - channel := core.Between(transport, "interleaved=", "-") i, err := strconv.Atoi(channel) if err != nil { @@ -373,36 +277,17 @@ func (c *Conn) SetupMedia(media *core.Media, first bool) (byte, error) { } func (c *Conn) Play() (err error) { - c.stateMu.Lock() - defer c.stateMu.Unlock() - - if c.state != StateSetup { - return fmt.Errorf("RTSP PLAY from wrong state: %s", c.state) - } - req := &tcp.Request{Method: MethodPlay, URL: c.URL} - if err = c.Request(req); err == nil { - c.state = StatePlay - } - - return + return c.WriteRequest(req) } func (c *Conn) Teardown() (err error) { // allow TEARDOWN from any state (ex. ANNOUNCE > SETUP) req := &tcp.Request{Method: MethodTeardown, URL: c.URL} - return c.Request(req) + return c.WriteRequest(req) } func (c *Conn) Close() error { - c.stateMu.Lock() - defer c.stateMu.Unlock() - - if c.state == StateNone { - return nil - } - _ = c.Teardown() - c.state = StateNone return c.conn.Close() } diff --git a/pkg/rtsp/conn.go b/pkg/rtsp/conn.go index 2a0add62..ebc127bc 100644 --- a/pkg/rtsp/conn.go +++ b/pkg/rtsp/conn.go @@ -25,7 +25,6 @@ type Conn struct { SessionName string Medias []*core.Media - Session string UserAgent string URL *url.URL @@ -34,12 +33,14 @@ type Conn struct { auth *tcp.Auth conn net.Conn mode core.Mode - state State - stateMu sync.Mutex reader *bufio.Reader sequence int + session string uri string + state State + stateMu sync.Mutex + receivers []*core.Receiver senders []*core.Sender @@ -68,13 +69,12 @@ func (s State) String() string { case StateNone: return "NONE" case StateConn: + return "CONN" case StateSetup: return "SETUP" case StatePlay: return "PLAY" - case StateHandle: - return "HANDLE" } return strconv.Itoa(int(s)) } @@ -84,31 +84,9 @@ const ( StateConn StateSetup StatePlay - StateHandle ) func (c *Conn) Handle() (err error) { - c.stateMu.Lock() - - switch c.state { - case StateNone: // Close after PLAY and before Handle is OK (because SETUP after PLAY) - case StatePlay: - c.state = StateHandle - default: - err = fmt.Errorf("RTSP HANDLE from wrong state: %s", c.state) - - c.state = StateNone - _ = c.conn.Close() - } - - ok := c.state == StateHandle - - c.stateMu.Unlock() - - if !ok { - return - } - var timeout time.Duration switch c.mode { @@ -158,7 +136,7 @@ func (c *Conn) Handle() (err error) { switch string(buf4) { case "RTSP": var res *tcp.Response - if res, err = tcp.ReadResponse(c.reader); err != nil { + if res, err = c.ReadResponse(); err != nil { return } c.Fire(res) @@ -166,13 +144,15 @@ func (c *Conn) Handle() (err error) { case "OPTI", "TEAR", "DESC", "SETU", "PLAY", "PAUS", "RECO", "ANNO", "GET_", "SET_": var req *tcp.Request - if req, err = tcp.ReadRequest(c.reader); err != nil { + if req, err = c.ReadRequest(); err != nil { return } c.Fire(req) continue default: + c.Fire("RTSP wrong input") + for i := 0; ; i++ { // search next start symbol if _, err = c.reader.ReadBytes('$'); err != nil { @@ -204,8 +184,6 @@ func (c *Conn) Handle() (err error) { return fmt.Errorf("RTSP wrong input") } } - - c.Fire("RTSP wrong input") } } else { // hope that the odd channels are always RTCP @@ -259,6 +237,92 @@ func (c *Conn) Handle() (err error) { return } +func (c *Conn) WriteRequest(req *tcp.Request) error { + if req.Proto == "" { + req.Proto = ProtoRTSP + } + + if req.Header == nil { + req.Header = make(map[string][]string) + } + + c.sequence++ + // important to send case sensitive CSeq + // https://github.com/AlexxIT/go2rtc/issues/7 + req.Header["CSeq"] = []string{strconv.Itoa(c.sequence)} + + c.auth.Write(req) + + if c.session != "" { + req.Header.Set("Session", c.session) + } + + if req.Body != nil { + val := strconv.Itoa(len(req.Body)) + req.Header.Set("Content-Length", val) + } + + c.Fire(req) + + if err := c.conn.SetWriteDeadline(time.Now().Add(Timeout)); err != nil { + return err + } + + return req.Write(c.conn) +} + +func (c *Conn) ReadRequest() (*tcp.Request, error) { + if err := c.conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil { + return nil, err + } + return tcp.ReadRequest(c.reader) +} + +func (c *Conn) WriteResponse(res *tcp.Response) error { + if res.Proto == "" { + res.Proto = ProtoRTSP + } + + if res.Status == "" { + res.Status = "200 OK" + } + + if res.Header == nil { + res.Header = make(map[string][]string) + } + + if res.Request != nil && res.Request.Header != nil { + seq := res.Request.Header.Get("CSeq") + if seq != "" { + res.Header.Set("CSeq", seq) + } + } + + if c.session != "" { + res.Header.Set("Session", c.session) + } + + if res.Body != nil { + val := strconv.Itoa(len(res.Body)) + res.Header.Set("Content-Length", val) + } + + c.Fire(res) + + if err := c.conn.SetWriteDeadline(time.Now().Add(Timeout)); err != nil { + return err + } + + return res.Write(c.conn) +} + +func (c *Conn) ReadResponse() (*tcp.Response, error) { + if err := c.conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil { + return nil, err + } + return tcp.ReadResponse(c.reader) +} + func (c *Conn) keepalive() { // TODO: rewrite to RTCP req := &tcp.Request{Method: MethodOptions, URL: c.URL} @@ -267,7 +331,7 @@ func (c *Conn) keepalive() { if c.state == StateNone { return } - if err := c.Request(req); err != nil { + if err := c.WriteRequest(req); err != nil { return } } diff --git a/pkg/rtsp/consumer.go b/pkg/rtsp/consumer.go index b0eaf7ce..7f7fece2 100644 --- a/pkg/rtsp/consumer.go +++ b/pkg/rtsp/consumer.go @@ -28,7 +28,16 @@ func (c *Conn) AddTrack(media *core.Media, codec *core.Codec, track *core.Receiv switch c.mode { case core.ModeActiveProducer: // backchannel - if channel, err = c.SetupMedia(media, true); err != nil { + c.stateMu.Lock() + defer c.stateMu.Unlock() + + if c.state == StatePlay { + if err = c.Reconnect(); err != nil { + return + } + } + + if channel, err = c.SetupMedia(media); err != nil { return } diff --git a/pkg/rtsp/producer.go b/pkg/rtsp/producer.go index ea7aa3ea..f7772c54 100644 --- a/pkg/rtsp/producer.go +++ b/pkg/rtsp/producer.go @@ -2,7 +2,7 @@ package rtsp import ( "encoding/json" - "fmt" + "errors" "github.com/AlexxIT/go2rtc/pkg/core" ) @@ -15,51 +15,78 @@ func (c *Conn) GetTrack(media *core.Media, codec *core.Codec) (*core.Receiver, e } } - switch c.state { - case StateConn, StateSetup: - default: - return nil, fmt.Errorf("RTSP GetTrack from wrong state: %s", c.state) + c.stateMu.Lock() + defer c.stateMu.Unlock() + + if c.state == StatePlay { + if err := c.Reconnect(); err != nil { + return nil, err + } } - channel, err := c.SetupMedia(media, true) + channel, err := c.SetupMedia(media) if err != nil { return nil, err } + c.state = StateSetup + track := core.NewReceiver(media, codec) - track.ID = byte(channel) + track.ID = channel c.receivers = append(c.receivers, track) return track, nil } -func (c *Conn) Start() error { - switch c.mode { - case core.ModeActiveProducer: - if err := c.Play(); err != nil { - return err +func (c *Conn) Start() (err error) { + core.Assert(c.mode == core.ModeActiveProducer || c.mode == core.ModePassiveProducer) + + for { + ok := false + + c.stateMu.Lock() + switch c.state { + case StateNone: + err = nil + case StateConn: + err = errors.New("start from CONN state") + case StateSetup: + if err = c.Play(); err == nil { + c.state = StatePlay + ok = true + } + case StatePlay: } - case core.ModePassiveProducer: - default: - return fmt.Errorf("start wrong mode: %d", c.mode) - } + c.stateMu.Unlock() - if err := c.Handle(); c.state != StateNone { - _ = c.conn.Close() - return err - } + if !ok { + return + } - return nil + // Handler can return different states: + // 1. None after PLAY should exit without error + // 2. Play after PLAY should exit from Start with error + // 3. Setup after PLAY should Play once again + err = c.Handle() + } } -func (c *Conn) Stop() error { +func (c *Conn) Stop() (err error) { for _, receiver := range c.receivers { receiver.Close() } for _, sender := range c.senders { sender.Close() } - return c.Close() + + c.stateMu.Lock() + if c.state != StateNone { + c.state = StateNone + err = c.Close() + } + c.stateMu.Unlock() + + return } func (c *Conn) MarshalJSON() ([]byte, error) { @@ -82,3 +109,27 @@ func (c *Conn) MarshalJSON() ([]byte, error) { return json.Marshal(info) } + +func (c *Conn) Reconnect() error { + c.Fire("RTSP reconnect") + + // close current session + _ = c.Close() + + // start new session + if err := c.Dial(); err != nil { + return err + } + if err := c.Describe(); err != nil { + return err + } + + // restore previous medias + for _, receiver := range c.receivers { + if _, err := c.SetupMedia(receiver.Media); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/rtsp/server.go b/pkg/rtsp/server.go index d2be609c..f707f728 100644 --- a/pkg/rtsp/server.go +++ b/pkg/rtsp/server.go @@ -25,7 +25,7 @@ func (c *Conn) Auth(username, password string) { func (c *Conn) Accept() error { for { - req, err := tcp.ReadRequest(c.reader) + req, err := c.ReadRequest() if err != nil { return err } @@ -42,7 +42,7 @@ func (c *Conn) Accept() error { Status: "401 Unauthorized", Header: map[string][]string{"Www-Authenticate": {`Basic realm="go2rtc"`}}, } - if err = c.Response(res); err != nil { + if err = c.WriteResponse(res); err != nil { return err } continue @@ -58,7 +58,7 @@ func (c *Conn) Accept() error { }, Request: req, } - if err = c.Response(res); err != nil { + if err = c.WriteResponse(res); err != nil { return err } @@ -83,7 +83,7 @@ func (c *Conn) Accept() error { c.Fire(MethodAnnounce) res := &tcp.Response{Request: req} - if err = c.Response(res); err != nil { + if err = c.WriteResponse(res); err != nil { return err } @@ -96,7 +96,7 @@ func (c *Conn) Accept() error { Status: "404 Not Found", Request: req, } - return c.Response(res) + return c.WriteResponse(res) } res := &tcp.Response{ @@ -122,7 +122,7 @@ func (c *Conn) Accept() error { return err } - if err = c.Response(res); err != nil { + if err = c.WriteResponse(res); err != nil { return err } @@ -136,27 +136,27 @@ func (c *Conn) Accept() error { const transport = "RTP/AVP/TCP;unicast;interleaved=" if strings.HasPrefix(tr, transport) { - c.Session = core.RandString(8, 10) + c.session = core.RandString(8, 10) c.state = StateSetup res.Header.Set("Transport", tr[:len(transport)+3]) } else { res.Status = "461 Unsupported transport" } - if err = c.Response(res); err != nil { + if err = c.WriteResponse(res); err != nil { return err } case MethodRecord, MethodPlay: res := &tcp.Response{Request: req} - if err = c.Response(res); err == nil { + if err = c.WriteResponse(res); err == nil { c.state = StatePlay } return err case MethodTeardown: res := &tcp.Response{Request: req} - _ = c.Response(res) + _ = c.WriteResponse(res) c.state = StateNone return c.conn.Close()