From 2ea66deb0817f471a401bf76436a714c23c23b7f Mon Sep 17 00:00:00 2001 From: Alex X Date: Fri, 3 May 2024 14:30:05 +0300 Subject: [PATCH] Fix multiple dial on add consumer --- internal/streams/add_consumer.go | 57 +++++++++++++++++++------------- internal/streams/producer.go | 4 +-- internal/streams/stream.go | 3 +- 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/internal/streams/add_consumer.go b/internal/streams/add_consumer.go index d97a4266..eb767691 100644 --- a/internal/streams/add_consumer.go +++ b/internal/streams/add_consumer.go @@ -3,18 +3,17 @@ package streams import ( "errors" "strings" - "sync/atomic" "github.com/AlexxIT/go2rtc/pkg/core" ) func (s *Stream) AddConsumer(cons core.Consumer) (err error) { - // support for multiple simultaneous requests from different consumers - consN := atomic.AddInt32(&s.requests, 1) - 1 + // support for multiple simultaneous pending from different consumers + consN := s.pending.Add(1) - 1 - var prodErrors []error + var prodErrors = make([]error, len(s.producers)) var prodMedias []*core.Media - var prods []*Producer // matched producers for consumer + var prodStarts []*Producer // Step 1. Get consumer medias consMedias := cons.GetMedias() @@ -23,15 +22,20 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) { producers: for prodN, prod := range s.producers { + if prodErrors[prodN] != nil { + log.Trace().Msgf("[streams] skip cons=%d prod=%d", consN, prodN) + continue + } + if err = prod.Dial(); err != nil { - log.Trace().Err(err).Msgf("[streams] skip prod=%s", prod.url) - prodErrors = append(prodErrors, err) + log.Trace().Err(err).Msgf("[streams] dial cons=%d prod=%d", consN, prodN) + prodErrors[prodN] = err continue } // Step 2. Get producer medias (not tracks yet) for _, prodMedia := range prod.GetMedias() { - log.Trace().Msgf("[streams] check prod=%d media=%s", prodN, prodMedia) + log.Trace().Msgf("[streams] check cons=%d prod=%d media=%s", consN, prodN, prodMedia) prodMedias = append(prodMedias, prodMedia) // Step 3. Match consumer/producer codecs list @@ -44,11 +48,12 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) { switch prodMedia.Direction { case core.DirectionRecvonly: - log.Trace().Msgf("[streams] match prod=%d => cons=%d", prodN, consN) + log.Trace().Msgf("[streams] match cons=%d <= prod=%d", consN, prodN) // Step 4. Get recvonly track from producer if track, err = prod.GetTrack(prodMedia, prodCodec); err != nil { log.Info().Err(err).Msg("[streams] can't get track") + prodErrors[prodN] = err continue } // Step 5. Add track to consumer @@ -68,11 +73,12 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) { // Step 5. Add track to producer if err = prod.AddTrack(prodMedia, prodCodec, track); err != nil { log.Info().Err(err).Msg("[streams] can't add track") + prodErrors[prodN] = err continue } } - prods = append(prods, prod) + prodStarts = append(prodStarts, prod) if !consMedia.MatchAll() { break producers @@ -82,11 +88,11 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) { } // stop producers if they don't have readers - if atomic.AddInt32(&s.requests, -1) == 0 { + if s.pending.Add(-1) == 0 { s.stopProducers() } - if len(prods) == 0 { + if len(prodStarts) == 0 { return formatError(consMedias, prodMedias, prodErrors) } @@ -95,7 +101,7 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) { s.mu.Unlock() // there may be duplicates, but that's not a problem - for _, prod := range prods { + for _, prod := range prodStarts { prod.start() } @@ -103,6 +109,20 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) { } func formatError(consMedias, prodMedias []*core.Media, prodErrors []error) error { + // 1. Return errors if any not nil + var text string + + for _, err := range prodErrors { + if err != nil { + text = appendString(text, err.Error()) + } + } + + if len(text) != 0 { + return errors.New("streams: " + text) + } + + // 2. Return "codecs not matched" if prodMedias != nil { var prod, cons string @@ -125,16 +145,7 @@ func formatError(consMedias, prodMedias []*core.Media, prodErrors []error) error return errors.New("streams: codecs not matched: " + prod + " => " + cons) } - if prodErrors != nil { - var text string - - for _, err := range prodErrors { - text = appendString(text, err.Error()) - } - - return errors.New("streams: " + text) - } - + // 3. Return unknown error return errors.New("streams: unknown error") } diff --git a/internal/streams/producer.go b/internal/streams/producer.go index 6d3cf2b9..5a25dba5 100644 --- a/internal/streams/producer.go +++ b/internal/streams/producer.go @@ -245,10 +245,10 @@ func (p *Producer) stop() { switch p.state { case stateExternal: - log.Debug().Msgf("[streams] can't stop external producer") + log.Trace().Msgf("[streams] skip stop external producer") return case stateNone: - log.Debug().Msgf("[streams] can't stop none producer") + log.Trace().Msgf("[streams] skip stop none producer") return case stateStart: p.workerID++ diff --git a/internal/streams/stream.go b/internal/streams/stream.go index 0a8108e2..49c58e77 100644 --- a/internal/streams/stream.go +++ b/internal/streams/stream.go @@ -3,6 +3,7 @@ package streams import ( "encoding/json" "sync" + "sync/atomic" "github.com/AlexxIT/go2rtc/pkg/core" ) @@ -11,7 +12,7 @@ type Stream struct { producers []*Producer consumers []core.Consumer mu sync.Mutex - requests int32 + pending atomic.Int32 } func NewStream(source any) *Stream {