Add reconnect logic to RTSP client

This commit is contained in:
Alexey Khit
2023-04-16 13:57:16 +03:00
parent 4b4deaaaf2
commit a5c4854aeb
5 changed files with 214 additions and 205 deletions
+24 -139
View File
@@ -31,10 +31,6 @@ func (c *Conn) Dial() (err error) {
c.URL.Host += ":554" c.URL.Host += ":554"
} }
// remove UserInfo from URL
c.auth = tcp.NewAuth(c.URL.User)
c.URL.User = nil
c.conn, err = net.DialTimeout("tcp", c.URL.Host, time.Second*5) c.conn, err = net.DialTimeout("tcp", c.URL.Host, time.Second*5)
if err != nil { if err != nil {
return return
@@ -56,55 +52,24 @@ func (c *Conn) Dial() (err error) {
c.conn = tlsConn c.conn = tlsConn
} }
// remove UserInfo from URL
c.auth = tcp.NewAuth(c.URL.User)
c.URL.User = nil
c.reader = bufio.NewReader(c.conn) c.reader = bufio.NewReader(c.conn)
c.session = ""
c.state = StateConn c.state = StateConn
return nil return nil
} }
// Request sends only Request // Do send WriteRequest and receive and process WriteResponse
func (c *Conn) Request(req *tcp.Request) error {
if req.Proto == "" {
req.Proto = ProtoRTSP
}
if req.Header == nil {
req.Header = make(map[string][]string)
}
c.sequence++
// important to send case sensitive CSeq
// https://github.com/AlexxIT/go2rtc/issues/7
req.Header["CSeq"] = []string{strconv.Itoa(c.sequence)}
c.auth.Write(req)
if c.Session != "" {
req.Header.Set("Session", c.Session)
}
if req.Body != nil {
val := strconv.Itoa(len(req.Body))
req.Header.Set("Content-Length", val)
}
c.Fire(req)
return req.Write(c.conn)
}
// Do send Request and receive and process Response
func (c *Conn) Do(req *tcp.Request) (*tcp.Response, error) { func (c *Conn) Do(req *tcp.Request) (*tcp.Response, error) {
// https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/ if err := c.WriteRequest(req); err != nil {
if err := c.conn.SetDeadline(time.Now().Add(Timeout)); err != nil {
return nil, err return nil, err
} }
if err := c.Request(req); err != nil { res, err := c.ReadResponse()
return nil, err
}
res, err := tcp.ReadResponse(c.reader)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -134,40 +99,6 @@ func (c *Conn) Do(req *tcp.Request) (*tcp.Response, error) {
return res, nil return res, nil
} }
func (c *Conn) Response(res *tcp.Response) error {
if res.Proto == "" {
res.Proto = ProtoRTSP
}
if res.Status == "" {
res.Status = "200 OK"
}
if res.Header == nil {
res.Header = make(map[string][]string)
}
if res.Request != nil && res.Request.Header != nil {
seq := res.Request.Header.Get("CSeq")
if seq != "" {
res.Header.Set("CSeq", seq)
}
}
if c.Session != "" {
res.Header.Set("Session", c.Session)
}
if res.Body != nil {
val := strconv.Itoa(len(res.Body))
res.Header.Set("Content-Length", val)
}
c.Fire(res)
return res.Write(c.conn)
}
func (c *Conn) Options() error { func (c *Conn) Options() error {
req := &tcp.Request{Method: MethodOptions, URL: c.URL} req := &tcp.Request{Method: MethodOptions, URL: c.URL}
@@ -219,11 +150,18 @@ func (c *Conn) Describe() error {
} }
} }
c.Medias, err = UnmarshalSDP(res.Body) medias, err := UnmarshalSDP(res.Body)
if err != nil { if err != nil {
return err return err
} }
// TODO: rewrite more smart
if c.Medias == nil {
c.Medias = medias
} else if len(c.Medias) > len(medias) {
c.Medias = c.Medias[:len(medias)]
}
c.mode = core.ModeActiveProducer c.mode = core.ModeActiveProducer
return nil return nil
@@ -250,28 +188,7 @@ func (c *Conn) Announce() (err error) {
return return
} }
func (c *Conn) Setup() error { func (c *Conn) SetupMedia(media *core.Media) (byte, error) {
for _, media := range c.Medias {
_, err := c.SetupMedia(media, true)
if err != nil {
return err
}
}
return nil
}
func (c *Conn) SetupMedia(media *core.Media, first bool) (byte, error) {
// TODO: rewrite recoonection and first flag
if first {
c.stateMu.Lock()
defer c.stateMu.Unlock()
}
if c.state != StateConn && c.state != StateSetup {
return 0, fmt.Errorf("RTSP SETUP from wrong state: %s", c.state)
}
var transport string var transport string
// try to use media position as channel number // try to use media position as channel number
@@ -311,39 +228,28 @@ func (c *Conn) SetupMedia(media *core.Media, first bool) (byte, error) {
}, },
} }
var res *tcp.Response res, err := c.Do(req)
res, err = c.Do(req)
if err != nil { if err != nil {
// some Dahua/Amcrest cameras fail here because two simultaneous // some Dahua/Amcrest cameras fail here because two simultaneous
// backchannel connections // backchannel connections
if c.Backchannel { if c.Backchannel {
_ = c.conn.Close()
c.Backchannel = false c.Backchannel = false
if err := c.Dial(); err != nil { if err = c.Reconnect(); err != nil {
return 0, err return 0, err
} }
if err := c.Describe(); err != nil { return c.SetupMedia(media)
return 0, err
}
for _, m := range c.Medias {
if m.Equal(media) {
return c.SetupMedia(m, false)
}
}
} }
return 0, err return 0, err
} }
if c.Session == "" { if c.session == "" {
// Session: 216525287999;timeout=60 // Session: 216525287999;timeout=60
if s := res.Header.Get("Session"); s != "" { if s := res.Header.Get("Session"); s != "" {
if j := strings.IndexByte(s, ';'); j > 0 { if j := strings.IndexByte(s, ';'); j > 0 {
s = s[:j] s = s[:j]
} }
c.Session = s c.session = s
} }
} }
@@ -361,8 +267,6 @@ func (c *Conn) SetupMedia(media *core.Media, first bool) (byte, error) {
} }
} }
c.state = StateSetup
channel := core.Between(transport, "interleaved=", "-") channel := core.Between(transport, "interleaved=", "-")
i, err := strconv.Atoi(channel) i, err := strconv.Atoi(channel)
if err != nil { if err != nil {
@@ -373,36 +277,17 @@ func (c *Conn) SetupMedia(media *core.Media, first bool) (byte, error) {
} }
func (c *Conn) Play() (err error) { func (c *Conn) Play() (err error) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
if c.state != StateSetup {
return fmt.Errorf("RTSP PLAY from wrong state: %s", c.state)
}
req := &tcp.Request{Method: MethodPlay, URL: c.URL} req := &tcp.Request{Method: MethodPlay, URL: c.URL}
if err = c.Request(req); err == nil { return c.WriteRequest(req)
c.state = StatePlay
}
return
} }
func (c *Conn) Teardown() (err error) { func (c *Conn) Teardown() (err error) {
// allow TEARDOWN from any state (ex. ANNOUNCE > SETUP) // allow TEARDOWN from any state (ex. ANNOUNCE > SETUP)
req := &tcp.Request{Method: MethodTeardown, URL: c.URL} req := &tcp.Request{Method: MethodTeardown, URL: c.URL}
return c.Request(req) return c.WriteRequest(req)
} }
func (c *Conn) Close() error { func (c *Conn) Close() error {
c.stateMu.Lock()
defer c.stateMu.Unlock()
if c.state == StateNone {
return nil
}
_ = c.Teardown() _ = c.Teardown()
c.state = StateNone
return c.conn.Close() return c.conn.Close()
} }
+96 -32
View File
@@ -25,7 +25,6 @@ type Conn struct {
SessionName string SessionName string
Medias []*core.Media Medias []*core.Media
Session string
UserAgent string UserAgent string
URL *url.URL URL *url.URL
@@ -34,12 +33,14 @@ type Conn struct {
auth *tcp.Auth auth *tcp.Auth
conn net.Conn conn net.Conn
mode core.Mode mode core.Mode
state State
stateMu sync.Mutex
reader *bufio.Reader reader *bufio.Reader
sequence int sequence int
session string
uri string uri string
state State
stateMu sync.Mutex
receivers []*core.Receiver receivers []*core.Receiver
senders []*core.Sender senders []*core.Sender
@@ -68,13 +69,12 @@ func (s State) String() string {
case StateNone: case StateNone:
return "NONE" return "NONE"
case StateConn: case StateConn:
return "CONN" return "CONN"
case StateSetup: case StateSetup:
return "SETUP" return "SETUP"
case StatePlay: case StatePlay:
return "PLAY" return "PLAY"
case StateHandle:
return "HANDLE"
} }
return strconv.Itoa(int(s)) return strconv.Itoa(int(s))
} }
@@ -84,31 +84,9 @@ const (
StateConn StateConn
StateSetup StateSetup
StatePlay StatePlay
StateHandle
) )
func (c *Conn) Handle() (err error) { func (c *Conn) Handle() (err error) {
c.stateMu.Lock()
switch c.state {
case StateNone: // Close after PLAY and before Handle is OK (because SETUP after PLAY)
case StatePlay:
c.state = StateHandle
default:
err = fmt.Errorf("RTSP HANDLE from wrong state: %s", c.state)
c.state = StateNone
_ = c.conn.Close()
}
ok := c.state == StateHandle
c.stateMu.Unlock()
if !ok {
return
}
var timeout time.Duration var timeout time.Duration
switch c.mode { switch c.mode {
@@ -158,7 +136,7 @@ func (c *Conn) Handle() (err error) {
switch string(buf4) { switch string(buf4) {
case "RTSP": case "RTSP":
var res *tcp.Response var res *tcp.Response
if res, err = tcp.ReadResponse(c.reader); err != nil { if res, err = c.ReadResponse(); err != nil {
return return
} }
c.Fire(res) c.Fire(res)
@@ -166,13 +144,15 @@ func (c *Conn) Handle() (err error) {
case "OPTI", "TEAR", "DESC", "SETU", "PLAY", "PAUS", "RECO", "ANNO", "GET_", "SET_": case "OPTI", "TEAR", "DESC", "SETU", "PLAY", "PAUS", "RECO", "ANNO", "GET_", "SET_":
var req *tcp.Request var req *tcp.Request
if req, err = tcp.ReadRequest(c.reader); err != nil { if req, err = c.ReadRequest(); err != nil {
return return
} }
c.Fire(req) c.Fire(req)
continue continue
default: default:
c.Fire("RTSP wrong input")
for i := 0; ; i++ { for i := 0; ; i++ {
// search next start symbol // search next start symbol
if _, err = c.reader.ReadBytes('$'); err != nil { if _, err = c.reader.ReadBytes('$'); err != nil {
@@ -204,8 +184,6 @@ func (c *Conn) Handle() (err error) {
return fmt.Errorf("RTSP wrong input") return fmt.Errorf("RTSP wrong input")
} }
} }
c.Fire("RTSP wrong input")
} }
} else { } else {
// hope that the odd channels are always RTCP // hope that the odd channels are always RTCP
@@ -259,6 +237,92 @@ func (c *Conn) Handle() (err error) {
return return
} }
func (c *Conn) WriteRequest(req *tcp.Request) error {
if req.Proto == "" {
req.Proto = ProtoRTSP
}
if req.Header == nil {
req.Header = make(map[string][]string)
}
c.sequence++
// important to send case sensitive CSeq
// https://github.com/AlexxIT/go2rtc/issues/7
req.Header["CSeq"] = []string{strconv.Itoa(c.sequence)}
c.auth.Write(req)
if c.session != "" {
req.Header.Set("Session", c.session)
}
if req.Body != nil {
val := strconv.Itoa(len(req.Body))
req.Header.Set("Content-Length", val)
}
c.Fire(req)
if err := c.conn.SetWriteDeadline(time.Now().Add(Timeout)); err != nil {
return err
}
return req.Write(c.conn)
}
func (c *Conn) ReadRequest() (*tcp.Request, error) {
if err := c.conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
return nil, err
}
return tcp.ReadRequest(c.reader)
}
func (c *Conn) WriteResponse(res *tcp.Response) error {
if res.Proto == "" {
res.Proto = ProtoRTSP
}
if res.Status == "" {
res.Status = "200 OK"
}
if res.Header == nil {
res.Header = make(map[string][]string)
}
if res.Request != nil && res.Request.Header != nil {
seq := res.Request.Header.Get("CSeq")
if seq != "" {
res.Header.Set("CSeq", seq)
}
}
if c.session != "" {
res.Header.Set("Session", c.session)
}
if res.Body != nil {
val := strconv.Itoa(len(res.Body))
res.Header.Set("Content-Length", val)
}
c.Fire(res)
if err := c.conn.SetWriteDeadline(time.Now().Add(Timeout)); err != nil {
return err
}
return res.Write(c.conn)
}
func (c *Conn) ReadResponse() (*tcp.Response, error) {
if err := c.conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
return nil, err
}
return tcp.ReadResponse(c.reader)
}
func (c *Conn) keepalive() { func (c *Conn) keepalive() {
// TODO: rewrite to RTCP // TODO: rewrite to RTCP
req := &tcp.Request{Method: MethodOptions, URL: c.URL} req := &tcp.Request{Method: MethodOptions, URL: c.URL}
@@ -267,7 +331,7 @@ func (c *Conn) keepalive() {
if c.state == StateNone { if c.state == StateNone {
return return
} }
if err := c.Request(req); err != nil { if err := c.WriteRequest(req); err != nil {
return return
} }
} }
+10 -1
View File
@@ -28,7 +28,16 @@ func (c *Conn) AddTrack(media *core.Media, codec *core.Codec, track *core.Receiv
switch c.mode { switch c.mode {
case core.ModeActiveProducer: // backchannel case core.ModeActiveProducer: // backchannel
if channel, err = c.SetupMedia(media, true); err != nil { c.stateMu.Lock()
defer c.stateMu.Unlock()
if c.state == StatePlay {
if err = c.Reconnect(); err != nil {
return
}
}
if channel, err = c.SetupMedia(media); err != nil {
return return
} }
+74 -23
View File
@@ -2,7 +2,7 @@ package rtsp
import ( import (
"encoding/json" "encoding/json"
"fmt" "errors"
"github.com/AlexxIT/go2rtc/pkg/core" "github.com/AlexxIT/go2rtc/pkg/core"
) )
@@ -15,51 +15,78 @@ func (c *Conn) GetTrack(media *core.Media, codec *core.Codec) (*core.Receiver, e
} }
} }
switch c.state { c.stateMu.Lock()
case StateConn, StateSetup: defer c.stateMu.Unlock()
default:
return nil, fmt.Errorf("RTSP GetTrack from wrong state: %s", c.state) if c.state == StatePlay {
if err := c.Reconnect(); err != nil {
return nil, err
}
} }
channel, err := c.SetupMedia(media, true) channel, err := c.SetupMedia(media)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.state = StateSetup
track := core.NewReceiver(media, codec) track := core.NewReceiver(media, codec)
track.ID = byte(channel) track.ID = channel
c.receivers = append(c.receivers, track) c.receivers = append(c.receivers, track)
return track, nil return track, nil
} }
func (c *Conn) Start() error { func (c *Conn) Start() (err error) {
switch c.mode { core.Assert(c.mode == core.ModeActiveProducer || c.mode == core.ModePassiveProducer)
case core.ModeActiveProducer:
if err := c.Play(); err != nil { for {
return err ok := false
c.stateMu.Lock()
switch c.state {
case StateNone:
err = nil
case StateConn:
err = errors.New("start from CONN state")
case StateSetup:
if err = c.Play(); err == nil {
c.state = StatePlay
ok = true
}
case StatePlay:
} }
case core.ModePassiveProducer: c.stateMu.Unlock()
default:
return fmt.Errorf("start wrong mode: %d", c.mode)
}
if err := c.Handle(); c.state != StateNone { if !ok {
_ = c.conn.Close() return
return err }
}
return nil // Handler can return different states:
// 1. None after PLAY should exit without error
// 2. Play after PLAY should exit from Start with error
// 3. Setup after PLAY should Play once again
err = c.Handle()
}
} }
func (c *Conn) Stop() error { func (c *Conn) Stop() (err error) {
for _, receiver := range c.receivers { for _, receiver := range c.receivers {
receiver.Close() receiver.Close()
} }
for _, sender := range c.senders { for _, sender := range c.senders {
sender.Close() sender.Close()
} }
return c.Close()
c.stateMu.Lock()
if c.state != StateNone {
c.state = StateNone
err = c.Close()
}
c.stateMu.Unlock()
return
} }
func (c *Conn) MarshalJSON() ([]byte, error) { func (c *Conn) MarshalJSON() ([]byte, error) {
@@ -82,3 +109,27 @@ func (c *Conn) MarshalJSON() ([]byte, error) {
return json.Marshal(info) return json.Marshal(info)
} }
func (c *Conn) Reconnect() error {
c.Fire("RTSP reconnect")
// close current session
_ = c.Close()
// start new session
if err := c.Dial(); err != nil {
return err
}
if err := c.Describe(); err != nil {
return err
}
// restore previous medias
for _, receiver := range c.receivers {
if _, err := c.SetupMedia(receiver.Media); err != nil {
return err
}
}
return nil
}
+10 -10
View File
@@ -25,7 +25,7 @@ func (c *Conn) Auth(username, password string) {
func (c *Conn) Accept() error { func (c *Conn) Accept() error {
for { for {
req, err := tcp.ReadRequest(c.reader) req, err := c.ReadRequest()
if err != nil { if err != nil {
return err return err
} }
@@ -42,7 +42,7 @@ func (c *Conn) Accept() error {
Status: "401 Unauthorized", Status: "401 Unauthorized",
Header: map[string][]string{"Www-Authenticate": {`Basic realm="go2rtc"`}}, Header: map[string][]string{"Www-Authenticate": {`Basic realm="go2rtc"`}},
} }
if err = c.Response(res); err != nil { if err = c.WriteResponse(res); err != nil {
return err return err
} }
continue continue
@@ -58,7 +58,7 @@ func (c *Conn) Accept() error {
}, },
Request: req, Request: req,
} }
if err = c.Response(res); err != nil { if err = c.WriteResponse(res); err != nil {
return err return err
} }
@@ -83,7 +83,7 @@ func (c *Conn) Accept() error {
c.Fire(MethodAnnounce) c.Fire(MethodAnnounce)
res := &tcp.Response{Request: req} res := &tcp.Response{Request: req}
if err = c.Response(res); err != nil { if err = c.WriteResponse(res); err != nil {
return err return err
} }
@@ -96,7 +96,7 @@ func (c *Conn) Accept() error {
Status: "404 Not Found", Status: "404 Not Found",
Request: req, Request: req,
} }
return c.Response(res) return c.WriteResponse(res)
} }
res := &tcp.Response{ res := &tcp.Response{
@@ -122,7 +122,7 @@ func (c *Conn) Accept() error {
return err return err
} }
if err = c.Response(res); err != nil { if err = c.WriteResponse(res); err != nil {
return err return err
} }
@@ -136,27 +136,27 @@ func (c *Conn) Accept() error {
const transport = "RTP/AVP/TCP;unicast;interleaved=" const transport = "RTP/AVP/TCP;unicast;interleaved="
if strings.HasPrefix(tr, transport) { if strings.HasPrefix(tr, transport) {
c.Session = core.RandString(8, 10) c.session = core.RandString(8, 10)
c.state = StateSetup c.state = StateSetup
res.Header.Set("Transport", tr[:len(transport)+3]) res.Header.Set("Transport", tr[:len(transport)+3])
} else { } else {
res.Status = "461 Unsupported transport" res.Status = "461 Unsupported transport"
} }
if err = c.Response(res); err != nil { if err = c.WriteResponse(res); err != nil {
return err return err
} }
case MethodRecord, MethodPlay: case MethodRecord, MethodPlay:
res := &tcp.Response{Request: req} res := &tcp.Response{Request: req}
if err = c.Response(res); err == nil { if err = c.WriteResponse(res); err == nil {
c.state = StatePlay c.state = StatePlay
} }
return err return err
case MethodTeardown: case MethodTeardown:
res := &tcp.Response{Request: req} res := &tcp.Response{Request: req}
_ = c.Response(res) _ = c.WriteResponse(res)
c.state = StateNone c.state = StateNone
return c.conn.Close() return c.conn.Close()