diff --git a/cmd/exec/exec.go b/cmd/exec/exec.go index 5f96c374..4a4550c0 100644 --- a/cmd/exec/exec.go +++ b/cmd/exec/exec.go @@ -14,6 +14,7 @@ import ( "os" "os/exec" "strings" + "sync" "time" ) @@ -23,22 +24,22 @@ func Init() { return } - rtsp.OnProducer = func(prod streamer.Producer) bool { - if conn := prod.(*pkg.Conn); conn != nil { - if waiter := waiters[conn.URL.Path]; waiter != nil { - waiter <- prod - return true - } + rtsp.HandleFunc(func(conn *pkg.Conn) bool { + waitersMu.Lock() + waiter := waiters[conn.URL.Path] + waitersMu.Unlock() + + if waiter == nil { + return false } - return false - } + + waiter <- conn + return true + }) streams.HandleFunc("exec", Handle) log = app.GetLogger("exec") - - // TODO: add sync.Mutex - waiters = map[string]chan streamer.Producer{} } func Handle(url string) (streamer.Producer, error) { @@ -60,8 +61,15 @@ func Handle(url string) (streamer.Producer, error) { ch := make(chan streamer.Producer) + waitersMu.Lock() waiters[path] = ch - defer delete(waiters, path) + waitersMu.Unlock() + + defer func() { + waitersMu.Lock() + delete(waiters, path) + waitersMu.Unlock() + }() log.Debug().Str("url", url).Msg("[exec] run") @@ -86,4 +94,5 @@ func Handle(url string) (streamer.Producer, error) { // internal var log zerolog.Logger -var waiters map[string]chan streamer.Producer +var waiters = map[string]chan streamer.Producer{} +var waitersMu sync.Mutex diff --git a/cmd/rtsp/rtsp.go b/cmd/rtsp/rtsp.go index c0409a04..0936201a 100644 --- a/cmd/rtsp/rtsp.go +++ b/cmd/rtsp/rtsp.go @@ -32,20 +32,43 @@ func Init() { // RTSP server support address := conf.Mod.Listen - if address != "" { - _, Port, _ = net.SplitHostPort(address) - - go worker(address) + if address == "" { + return } + + ln, err := net.Listen("tcp", address) + if err != nil { + log.Error().Err(err).Msg("[rtsp] listen") + return + } + + _, Port, _ = net.SplitHostPort(address) + + log.Info().Str("addr", address).Msg("[rtsp] listen") + + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + go tcpHandler(conn) + } + }() +} + +type Handler func(conn *rtsp.Conn) bool + +func HandleFunc(handler Handler) { + handlers = append(handlers, handler) } var Port string -var OnProducer func(conn streamer.Producer) bool // TODO: maybe rewrite... - // internal var log zerolog.Logger +var handlers []Handler func rtspHandler(url string) (streamer.Producer, error) { backchannel := true @@ -96,101 +119,89 @@ func rtspHandler(url string) (streamer.Producer, error) { return conn, nil } -func worker(address string) { - srv, err := tcp.NewServer(address) - if err != nil { - log.Error().Err(err).Msg("[rtsp] listen") - return - } +func tcpHandler(c net.Conn) { + var name string + var closer func() - log.Info().Str("addr", address).Msg("[rtsp] listen") + trace := log.Trace().Enabled() - srv.Listen(func(msg interface{}) { - switch msg.(type) { - case net.Conn: - var name string - var onDisconnect func() + conn := rtsp.NewServer(c) + conn.Listen(func(msg interface{}) { + if trace { + switch msg := msg.(type) { + case *tcp.Request: + log.Trace().Msgf("[rtsp] server request:\n%s", msg) + case *tcp.Response: + log.Trace().Msgf("[rtsp] server response:\n%s", msg) + } + } - trace := log.Trace().Enabled() + switch msg { + case rtsp.MethodDescribe: + name = conn.URL.Path[1:] - conn := rtsp.NewServer(msg.(net.Conn)) - conn.Listen(func(msg interface{}) { - if trace { - switch msg := msg.(type) { - case *tcp.Request: - log.Trace().Msgf("[rtsp] server request:\n%s", msg) - case *tcp.Response: - log.Trace().Msgf("[rtsp] server response:\n%s", msg) - } - } - - switch msg { - case rtsp.MethodDescribe: - name = conn.URL.Path[1:] - - log.Debug().Str("stream", name).Msg("[rtsp] new consumer") - - stream := streams.Get(name) // TODO: rewrite - if stream == nil { - return - } - - initMedias(conn) - - if err = stream.AddConsumer(conn); err != nil { - log.Warn().Err(err).Str("stream", name).Msg("[rtsp]") - return - } - - onDisconnect = func() { - stream.RemoveConsumer(conn) - } - - case rtsp.MethodAnnounce: - if OnProducer != nil { - if OnProducer(conn) { - return - } - } - - name = conn.URL.Path[1:] - - log.Debug().Str("stream", name).Msg("[rtsp] new producer") - - stream := streams.Get(name) - if stream == nil { - return - } - - stream.AddProducer(conn) - - onDisconnect = func() { - stream.RemoveProducer(conn) - } - - case streamer.StatePlaying: - log.Debug().Str("stream", name).Msg("[rtsp] start") - } - }) - - if err = conn.Accept(); err != nil { - log.Warn().Err(err).Msg("[rtsp] accept") + stream := streams.Get(name) + if stream == nil { return } - if err = conn.Handle(); err != nil { - //log.Warn().Err(err).Msg("[rtsp] handle server") + log.Debug().Str("stream", name).Msg("[rtsp] new consumer") + + initMedias(conn) + + if err := stream.AddConsumer(conn); err != nil { + log.Warn().Err(err).Str("stream", name).Msg("[rtsp]") + return } - if onDisconnect != nil { - onDisconnect() + closer = func() { + stream.RemoveConsumer(conn) } - log.Debug().Str("stream", name).Msg("[rtsp] disconnect") + case rtsp.MethodAnnounce: + name = conn.URL.Path[1:] + + stream := streams.Get(name) + if stream == nil { + return + } + + log.Debug().Str("stream", name).Msg("[rtsp] new producer") + + stream.AddProducer(conn) + + closer = func() { + stream.RemoveProducer(conn) + } + + case streamer.StatePlaying: + log.Debug().Str("stream", name).Msg("[rtsp] start") } }) - srv.Serve() + if err := conn.Accept(); err != nil { + log.Warn().Err(err).Caller().Send() + _ = conn.Close() + return + } + + for _, handler := range handlers { + if handler(conn) { + return + } + } + + if closer != nil { + if err := conn.Handle(); err != nil { + log.Debug().Err(err).Caller().Send() + } + + closer() + + log.Debug().Str("stream", name).Msg("[rtsp] disconnect") + } + + _ = conn.Close() } func initMedias(conn *rtsp.Conn) { diff --git a/cmd/streams/producer.go b/cmd/streams/producer.go index 8c50ca22..5e32e88f 100644 --- a/cmd/streams/producer.go +++ b/cmd/streams/producer.go @@ -4,6 +4,7 @@ import ( "github.com/AlexxIT/go2rtc/pkg/streamer" "strings" "sync" + "time" ) type state byte @@ -24,8 +25,9 @@ type Producer struct { element streamer.Producer tracks []*streamer.Track - state state - mx sync.Mutex + state state + mu sync.Mutex + restart *time.Timer } func (p *Producer) SetSource(s string) { @@ -36,16 +38,16 @@ func (p *Producer) SetSource(s string) { } func (p *Producer) GetMedias() []*streamer.Media { - p.mx.Lock() - defer p.mx.Unlock() + p.mu.Lock() + defer p.mu.Unlock() if p.state == stateNone { - log.Debug().Str("url", p.url).Msg("[streams] probe producer") + log.Debug().Msgf("[streams] probe producer url=%s", p.url) var err error p.element, err = GetProducer(p.url) if err != nil || p.element == nil { - log.Error().Err(err).Str("url", p.url).Msg("[streams] probe producer") + log.Error().Err(err).Caller().Send() return nil } @@ -56,8 +58,12 @@ func (p *Producer) GetMedias() []*streamer.Media { } func (p *Producer) GetTrack(media *streamer.Media, codec *streamer.Codec) *streamer.Track { - p.mx.Lock() - defer p.mx.Unlock() + p.mu.Lock() + defer p.mu.Unlock() + + if p.state == stateNone { + return nil + } track := p.element.GetTrack(media, codec) if track == nil { @@ -82,36 +88,89 @@ func (p *Producer) GetTrack(media *streamer.Media, codec *streamer.Codec) *strea // internals func (p *Producer) start() { + p.mu.Lock() + defer p.mu.Unlock() + if p.state != stateTracks { return } - p.mx.Lock() - defer p.mx.Unlock() - - log.Debug().Str("url", p.url).Msg("[streams] start producer") + log.Debug().Msgf("[streams] start producer url=%s", p.url) p.state = stateStart go func() { + // safe read element while mu locked if err := p.element.Start(); err != nil { - log.Warn().Err(err).Str("url", p.url).Msg("[streams] start") + log.Warn().Err(err).Caller().Send() } + p.reconnect() + }() +} + +func (p *Producer) reconnect() { + p.mu.Lock() + defer p.mu.Unlock() + + if p.state != stateStart { + log.Debug().Msgf("[streams] closed ...") + return + } + + log.Debug().Msgf("[streams] reconnect to url=%s", p.url) + + var err error + p.element, err = GetProducer(p.url) + if err != nil || p.element == nil { + log.Debug().Err(err).Caller().Send() + // TODO: dynamic timeout + p.restart = time.AfterFunc(30*time.Second, p.reconnect) + return + } + + medias := p.element.GetMedias() + + // convert all old producer tracks to new tracks + for i, oldTrack := range p.tracks { + // match new element medias with old track codec + for _, media := range medias { + codec := media.MatchCodec(oldTrack.Codec) + if codec == nil { + continue + } + + // move sink from old track to new track + newTrack := p.element.GetTrack(media, codec) + newTrack.Sink = oldTrack.Sink + p.tracks[i] = newTrack + + break + } + } + + go func() { + if err = p.element.Start(); err != nil { + log.Debug().Err(err).Caller().Send() + } + p.reconnect() }() } func (p *Producer) stop() { - p.mx.Lock() + p.mu.Lock() - log.Debug().Str("url", p.url).Msg("[streams] stop producer") + log.Debug().Msgf("[streams] stop producer url=%s", p.url) if p.element != nil { _ = p.element.Stop() p.element = nil - } else { - log.Warn().Str("url", p.url).Msg("[streams] stop empty producer") } - p.tracks = nil - p.state = stateNone + if p.restart != nil { + p.restart.Stop() + p.restart = nil + } - p.mx.Unlock() + p.state = stateNone + p.tracks = nil + + p.mu.Unlock() } diff --git a/pkg/rtsp/conn.go b/pkg/rtsp/conn.go index 9cde760d..97543a21 100644 --- a/pkg/rtsp/conn.go +++ b/pkg/rtsp/conn.go @@ -469,10 +469,6 @@ func (c *Conn) Close() error { const transport = "RTP/AVP/TCP;unicast;interleaved=" func (c *Conn) Accept() error { - //if c.state != StateServerInit { - // panic("wrong state") - //} - for { req, err := tcp.ReadRequest(c.reader) if err != nil { @@ -600,6 +596,10 @@ func (c *Conn) Handle() (err error) { defer func() { if c.closed { err = nil + } else { + // may have gotten here because of the deadline + // so close the connection to stop keepalive + _ = c.conn.Close() } }() @@ -712,13 +712,11 @@ func (c *Conn) Handle() (err error) { } } -const KeepAlive = time.Second * 25 - func (c *Conn) keepalive() { // TODO: rewrite to RTCP req := &tcp.Request{Method: MethodOptions, URL: c.URL} for { - time.Sleep(KeepAlive) + time.Sleep(time.Second * 25) if c.closed { return } diff --git a/pkg/rtsp/streamer.go b/pkg/rtsp/streamer.go index 74211403..13c206aa 100644 --- a/pkg/rtsp/streamer.go +++ b/pkg/rtsp/streamer.go @@ -2,6 +2,7 @@ package rtsp import ( "encoding/json" + "fmt" "github.com/AlexxIT/go2rtc/pkg/streamer" "strconv" ) @@ -27,13 +28,16 @@ func (c *Conn) GetTrack(media *streamer.Media, codec *streamer.Codec) *streamer. } func (c *Conn) Start() error { - if c.mode == ModeServerProducer { - return nil + switch c.mode { + case ModeClientProducer: + if err := c.Play(); err != nil { + return err + } + case ModeServerProducer: + default: + return fmt.Errorf("start wrong mode: %d", c.mode) } - if err := c.Play(); err != nil { - return err - } return c.Handle() } diff --git a/pkg/streamer/media.go b/pkg/streamer/media.go index c14e3945..b4f07a04 100644 --- a/pkg/streamer/media.go +++ b/pkg/streamer/media.go @@ -75,13 +75,13 @@ func (m *Media) AV() bool { return m.Kind == KindVideo || m.Kind == KindAudio } -func (m *Media) MatchCodec(codec *Codec) bool { +func (m *Media) MatchCodec(codec *Codec) *Codec { for _, c := range m.Codecs { if c.Match(codec) { - return true + return c } } - return false + return nil } func (m *Media) MatchMedia(media *Media) *Codec {