Add reconnection feature

This commit is contained in:
Alexey Khit
2022-11-04 17:18:12 +03:00
parent 0231fc3a90
commit e9ea7a0b1f
6 changed files with 216 additions and 135 deletions
+22 -13
View File
@@ -14,6 +14,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
"sync"
"time" "time"
) )
@@ -23,22 +24,22 @@ func Init() {
return return
} }
rtsp.OnProducer = func(prod streamer.Producer) bool { rtsp.HandleFunc(func(conn *pkg.Conn) bool {
if conn := prod.(*pkg.Conn); conn != nil { waitersMu.Lock()
if waiter := waiters[conn.URL.Path]; waiter != nil { waiter := waiters[conn.URL.Path]
waiter <- prod waitersMu.Unlock()
return true
} if waiter == nil {
return false
} }
return false
} waiter <- conn
return true
})
streams.HandleFunc("exec", Handle) streams.HandleFunc("exec", Handle)
log = app.GetLogger("exec") log = app.GetLogger("exec")
// TODO: add sync.Mutex
waiters = map[string]chan streamer.Producer{}
} }
func Handle(url string) (streamer.Producer, error) { func Handle(url string) (streamer.Producer, error) {
@@ -60,8 +61,15 @@ func Handle(url string) (streamer.Producer, error) {
ch := make(chan streamer.Producer) ch := make(chan streamer.Producer)
waitersMu.Lock()
waiters[path] = ch waiters[path] = ch
defer delete(waiters, path) waitersMu.Unlock()
defer func() {
waitersMu.Lock()
delete(waiters, path)
waitersMu.Unlock()
}()
log.Debug().Str("url", url).Msg("[exec] run") log.Debug().Str("url", url).Msg("[exec] run")
@@ -86,4 +94,5 @@ func Handle(url string) (streamer.Producer, error) {
// internal // internal
var log zerolog.Logger var log zerolog.Logger
var waiters map[string]chan streamer.Producer var waiters = map[string]chan streamer.Producer{}
var waitersMu sync.Mutex
+98 -87
View File
@@ -32,20 +32,43 @@ func Init() {
// RTSP server support // RTSP server support
address := conf.Mod.Listen address := conf.Mod.Listen
if address != "" { if address == "" {
_, Port, _ = net.SplitHostPort(address) return
go worker(address)
} }
ln, err := net.Listen("tcp", address)
if err != nil {
log.Error().Err(err).Msg("[rtsp] listen")
return
}
_, Port, _ = net.SplitHostPort(address)
log.Info().Str("addr", address).Msg("[rtsp] listen")
go func() {
for {
conn, err := ln.Accept()
if err != nil {
return
}
go tcpHandler(conn)
}
}()
}
type Handler func(conn *rtsp.Conn) bool
func HandleFunc(handler Handler) {
handlers = append(handlers, handler)
} }
var Port string var Port string
var OnProducer func(conn streamer.Producer) bool // TODO: maybe rewrite...
// internal // internal
var log zerolog.Logger var log zerolog.Logger
var handlers []Handler
func rtspHandler(url string) (streamer.Producer, error) { func rtspHandler(url string) (streamer.Producer, error) {
backchannel := true backchannel := true
@@ -96,101 +119,89 @@ func rtspHandler(url string) (streamer.Producer, error) {
return conn, nil return conn, nil
} }
func worker(address string) { func tcpHandler(c net.Conn) {
srv, err := tcp.NewServer(address) var name string
if err != nil { var closer func()
log.Error().Err(err).Msg("[rtsp] listen")
return
}
log.Info().Str("addr", address).Msg("[rtsp] listen") trace := log.Trace().Enabled()
srv.Listen(func(msg interface{}) { conn := rtsp.NewServer(c)
switch msg.(type) { conn.Listen(func(msg interface{}) {
case net.Conn: if trace {
var name string switch msg := msg.(type) {
var onDisconnect func() case *tcp.Request:
log.Trace().Msgf("[rtsp] server request:\n%s", msg)
case *tcp.Response:
log.Trace().Msgf("[rtsp] server response:\n%s", msg)
}
}
trace := log.Trace().Enabled() switch msg {
case rtsp.MethodDescribe:
name = conn.URL.Path[1:]
conn := rtsp.NewServer(msg.(net.Conn)) stream := streams.Get(name)
conn.Listen(func(msg interface{}) { if stream == nil {
if trace {
switch msg := msg.(type) {
case *tcp.Request:
log.Trace().Msgf("[rtsp] server request:\n%s", msg)
case *tcp.Response:
log.Trace().Msgf("[rtsp] server response:\n%s", msg)
}
}
switch msg {
case rtsp.MethodDescribe:
name = conn.URL.Path[1:]
log.Debug().Str("stream", name).Msg("[rtsp] new consumer")
stream := streams.Get(name) // TODO: rewrite
if stream == nil {
return
}
initMedias(conn)
if err = stream.AddConsumer(conn); err != nil {
log.Warn().Err(err).Str("stream", name).Msg("[rtsp]")
return
}
onDisconnect = func() {
stream.RemoveConsumer(conn)
}
case rtsp.MethodAnnounce:
if OnProducer != nil {
if OnProducer(conn) {
return
}
}
name = conn.URL.Path[1:]
log.Debug().Str("stream", name).Msg("[rtsp] new producer")
stream := streams.Get(name)
if stream == nil {
return
}
stream.AddProducer(conn)
onDisconnect = func() {
stream.RemoveProducer(conn)
}
case streamer.StatePlaying:
log.Debug().Str("stream", name).Msg("[rtsp] start")
}
})
if err = conn.Accept(); err != nil {
log.Warn().Err(err).Msg("[rtsp] accept")
return return
} }
if err = conn.Handle(); err != nil { log.Debug().Str("stream", name).Msg("[rtsp] new consumer")
//log.Warn().Err(err).Msg("[rtsp] handle server")
initMedias(conn)
if err := stream.AddConsumer(conn); err != nil {
log.Warn().Err(err).Str("stream", name).Msg("[rtsp]")
return
} }
if onDisconnect != nil { closer = func() {
onDisconnect() stream.RemoveConsumer(conn)
} }
log.Debug().Str("stream", name).Msg("[rtsp] disconnect") case rtsp.MethodAnnounce:
name = conn.URL.Path[1:]
stream := streams.Get(name)
if stream == nil {
return
}
log.Debug().Str("stream", name).Msg("[rtsp] new producer")
stream.AddProducer(conn)
closer = func() {
stream.RemoveProducer(conn)
}
case streamer.StatePlaying:
log.Debug().Str("stream", name).Msg("[rtsp] start")
} }
}) })
srv.Serve() if err := conn.Accept(); err != nil {
log.Warn().Err(err).Caller().Send()
_ = conn.Close()
return
}
for _, handler := range handlers {
if handler(conn) {
return
}
}
if closer != nil {
if err := conn.Handle(); err != nil {
log.Debug().Err(err).Caller().Send()
}
closer()
log.Debug().Str("stream", name).Msg("[rtsp] disconnect")
}
_ = conn.Close()
} }
func initMedias(conn *rtsp.Conn) { func initMedias(conn *rtsp.Conn) {
+79 -20
View File
@@ -4,6 +4,7 @@ import (
"github.com/AlexxIT/go2rtc/pkg/streamer" "github.com/AlexxIT/go2rtc/pkg/streamer"
"strings" "strings"
"sync" "sync"
"time"
) )
type state byte type state byte
@@ -24,8 +25,9 @@ type Producer struct {
element streamer.Producer element streamer.Producer
tracks []*streamer.Track tracks []*streamer.Track
state state state state
mx sync.Mutex mu sync.Mutex
restart *time.Timer
} }
func (p *Producer) SetSource(s string) { func (p *Producer) SetSource(s string) {
@@ -36,16 +38,16 @@ func (p *Producer) SetSource(s string) {
} }
func (p *Producer) GetMedias() []*streamer.Media { func (p *Producer) GetMedias() []*streamer.Media {
p.mx.Lock() p.mu.Lock()
defer p.mx.Unlock() defer p.mu.Unlock()
if p.state == stateNone { if p.state == stateNone {
log.Debug().Str("url", p.url).Msg("[streams] probe producer") log.Debug().Msgf("[streams] probe producer url=%s", p.url)
var err error var err error
p.element, err = GetProducer(p.url) p.element, err = GetProducer(p.url)
if err != nil || p.element == nil { if err != nil || p.element == nil {
log.Error().Err(err).Str("url", p.url).Msg("[streams] probe producer") log.Error().Err(err).Caller().Send()
return nil return nil
} }
@@ -56,8 +58,12 @@ func (p *Producer) GetMedias() []*streamer.Media {
} }
func (p *Producer) GetTrack(media *streamer.Media, codec *streamer.Codec) *streamer.Track { func (p *Producer) GetTrack(media *streamer.Media, codec *streamer.Codec) *streamer.Track {
p.mx.Lock() p.mu.Lock()
defer p.mx.Unlock() defer p.mu.Unlock()
if p.state == stateNone {
return nil
}
track := p.element.GetTrack(media, codec) track := p.element.GetTrack(media, codec)
if track == nil { if track == nil {
@@ -82,36 +88,89 @@ func (p *Producer) GetTrack(media *streamer.Media, codec *streamer.Codec) *strea
// internals // internals
func (p *Producer) start() { func (p *Producer) start() {
p.mu.Lock()
defer p.mu.Unlock()
if p.state != stateTracks { if p.state != stateTracks {
return return
} }
p.mx.Lock() log.Debug().Msgf("[streams] start producer url=%s", p.url)
defer p.mx.Unlock()
log.Debug().Str("url", p.url).Msg("[streams] start producer")
p.state = stateStart p.state = stateStart
go func() { go func() {
// safe read element while mu locked
if err := p.element.Start(); err != nil { if err := p.element.Start(); err != nil {
log.Warn().Err(err).Str("url", p.url).Msg("[streams] start") log.Warn().Err(err).Caller().Send()
} }
p.reconnect()
}()
}
func (p *Producer) reconnect() {
p.mu.Lock()
defer p.mu.Unlock()
if p.state != stateStart {
log.Debug().Msgf("[streams] closed ...")
return
}
log.Debug().Msgf("[streams] reconnect to url=%s", p.url)
var err error
p.element, err = GetProducer(p.url)
if err != nil || p.element == nil {
log.Debug().Err(err).Caller().Send()
// TODO: dynamic timeout
p.restart = time.AfterFunc(30*time.Second, p.reconnect)
return
}
medias := p.element.GetMedias()
// convert all old producer tracks to new tracks
for i, oldTrack := range p.tracks {
// match new element medias with old track codec
for _, media := range medias {
codec := media.MatchCodec(oldTrack.Codec)
if codec == nil {
continue
}
// move sink from old track to new track
newTrack := p.element.GetTrack(media, codec)
newTrack.Sink = oldTrack.Sink
p.tracks[i] = newTrack
break
}
}
go func() {
if err = p.element.Start(); err != nil {
log.Debug().Err(err).Caller().Send()
}
p.reconnect()
}() }()
} }
func (p *Producer) stop() { func (p *Producer) stop() {
p.mx.Lock() p.mu.Lock()
log.Debug().Str("url", p.url).Msg("[streams] stop producer") log.Debug().Msgf("[streams] stop producer url=%s", p.url)
if p.element != nil { if p.element != nil {
_ = p.element.Stop() _ = p.element.Stop()
p.element = nil p.element = nil
} else {
log.Warn().Str("url", p.url).Msg("[streams] stop empty producer")
} }
p.tracks = nil if p.restart != nil {
p.state = stateNone p.restart.Stop()
p.restart = nil
}
p.mx.Unlock() p.state = stateNone
p.tracks = nil
p.mu.Unlock()
} }
+5 -7
View File
@@ -469,10 +469,6 @@ func (c *Conn) Close() error {
const transport = "RTP/AVP/TCP;unicast;interleaved=" const transport = "RTP/AVP/TCP;unicast;interleaved="
func (c *Conn) Accept() error { func (c *Conn) Accept() error {
//if c.state != StateServerInit {
// panic("wrong state")
//}
for { for {
req, err := tcp.ReadRequest(c.reader) req, err := tcp.ReadRequest(c.reader)
if err != nil { if err != nil {
@@ -600,6 +596,10 @@ func (c *Conn) Handle() (err error) {
defer func() { defer func() {
if c.closed { if c.closed {
err = nil err = nil
} else {
// may have gotten here because of the deadline
// so close the connection to stop keepalive
_ = c.conn.Close()
} }
}() }()
@@ -712,13 +712,11 @@ func (c *Conn) Handle() (err error) {
} }
} }
const KeepAlive = time.Second * 25
func (c *Conn) keepalive() { func (c *Conn) keepalive() {
// TODO: rewrite to RTCP // TODO: rewrite to RTCP
req := &tcp.Request{Method: MethodOptions, URL: c.URL} req := &tcp.Request{Method: MethodOptions, URL: c.URL}
for { for {
time.Sleep(KeepAlive) time.Sleep(time.Second * 25)
if c.closed { if c.closed {
return return
} }
+9 -5
View File
@@ -2,6 +2,7 @@ package rtsp
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/AlexxIT/go2rtc/pkg/streamer" "github.com/AlexxIT/go2rtc/pkg/streamer"
"strconv" "strconv"
) )
@@ -27,13 +28,16 @@ func (c *Conn) GetTrack(media *streamer.Media, codec *streamer.Codec) *streamer.
} }
func (c *Conn) Start() error { func (c *Conn) Start() error {
if c.mode == ModeServerProducer { switch c.mode {
return nil case ModeClientProducer:
if err := c.Play(); err != nil {
return err
}
case ModeServerProducer:
default:
return fmt.Errorf("start wrong mode: %d", c.mode)
} }
if err := c.Play(); err != nil {
return err
}
return c.Handle() return c.Handle()
} }
+3 -3
View File
@@ -75,13 +75,13 @@ func (m *Media) AV() bool {
return m.Kind == KindVideo || m.Kind == KindAudio return m.Kind == KindVideo || m.Kind == KindAudio
} }
func (m *Media) MatchCodec(codec *Codec) bool { func (m *Media) MatchCodec(codec *Codec) *Codec {
for _, c := range m.Codecs { for _, c := range m.Codecs {
if c.Match(codec) { if c.Match(codec) {
return true return c
} }
} }
return false return nil
} }
func (m *Media) MatchMedia(media *Media) *Codec { func (m *Media) MatchMedia(media *Media) *Codec {