Fix multiple dial on add consumer

This commit is contained in:
Alex X
2024-05-03 14:30:05 +03:00
parent b3c5ef8c86
commit 2ea66deb08
3 changed files with 38 additions and 26 deletions
+34 -23
View File
@@ -3,18 +3,17 @@ package streams
import (
"errors"
"strings"
"sync/atomic"
"github.com/AlexxIT/go2rtc/pkg/core"
)
func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
// support for multiple simultaneous requests from different consumers
consN := atomic.AddInt32(&s.requests, 1) - 1
// support for multiple simultaneous pending from different consumers
consN := s.pending.Add(1) - 1
var prodErrors []error
var prodErrors = make([]error, len(s.producers))
var prodMedias []*core.Media
var prods []*Producer // matched producers for consumer
var prodStarts []*Producer
// Step 1. Get consumer medias
consMedias := cons.GetMedias()
@@ -23,15 +22,20 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
producers:
for prodN, prod := range s.producers {
if prodErrors[prodN] != nil {
log.Trace().Msgf("[streams] skip cons=%d prod=%d", consN, prodN)
continue
}
if err = prod.Dial(); err != nil {
log.Trace().Err(err).Msgf("[streams] skip prod=%s", prod.url)
prodErrors = append(prodErrors, err)
log.Trace().Err(err).Msgf("[streams] dial cons=%d prod=%d", consN, prodN)
prodErrors[prodN] = err
continue
}
// Step 2. Get producer medias (not tracks yet)
for _, prodMedia := range prod.GetMedias() {
log.Trace().Msgf("[streams] check prod=%d media=%s", prodN, prodMedia)
log.Trace().Msgf("[streams] check cons=%d prod=%d media=%s", consN, prodN, prodMedia)
prodMedias = append(prodMedias, prodMedia)
// Step 3. Match consumer/producer codecs list
@@ -44,11 +48,12 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
switch prodMedia.Direction {
case core.DirectionRecvonly:
log.Trace().Msgf("[streams] match prod=%d => cons=%d", prodN, consN)
log.Trace().Msgf("[streams] match cons=%d <= prod=%d", consN, prodN)
// Step 4. Get recvonly track from producer
if track, err = prod.GetTrack(prodMedia, prodCodec); err != nil {
log.Info().Err(err).Msg("[streams] can't get track")
prodErrors[prodN] = err
continue
}
// Step 5. Add track to consumer
@@ -68,11 +73,12 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
// Step 5. Add track to producer
if err = prod.AddTrack(prodMedia, prodCodec, track); err != nil {
log.Info().Err(err).Msg("[streams] can't add track")
prodErrors[prodN] = err
continue
}
}
prods = append(prods, prod)
prodStarts = append(prodStarts, prod)
if !consMedia.MatchAll() {
break producers
@@ -82,11 +88,11 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
}
// stop producers if they don't have readers
if atomic.AddInt32(&s.requests, -1) == 0 {
if s.pending.Add(-1) == 0 {
s.stopProducers()
}
if len(prods) == 0 {
if len(prodStarts) == 0 {
return formatError(consMedias, prodMedias, prodErrors)
}
@@ -95,7 +101,7 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
s.mu.Unlock()
// there may be duplicates, but that's not a problem
for _, prod := range prods {
for _, prod := range prodStarts {
prod.start()
}
@@ -103,6 +109,20 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
}
func formatError(consMedias, prodMedias []*core.Media, prodErrors []error) error {
// 1. Return errors if any not nil
var text string
for _, err := range prodErrors {
if err != nil {
text = appendString(text, err.Error())
}
}
if len(text) != 0 {
return errors.New("streams: " + text)
}
// 2. Return "codecs not matched"
if prodMedias != nil {
var prod, cons string
@@ -125,16 +145,7 @@ func formatError(consMedias, prodMedias []*core.Media, prodErrors []error) error
return errors.New("streams: codecs not matched: " + prod + " => " + cons)
}
if prodErrors != nil {
var text string
for _, err := range prodErrors {
text = appendString(text, err.Error())
}
return errors.New("streams: " + text)
}
// 3. Return unknown error
return errors.New("streams: unknown error")
}
+2 -2
View File
@@ -245,10 +245,10 @@ func (p *Producer) stop() {
switch p.state {
case stateExternal:
log.Debug().Msgf("[streams] can't stop external producer")
log.Trace().Msgf("[streams] skip stop external producer")
return
case stateNone:
log.Debug().Msgf("[streams] can't stop none producer")
log.Trace().Msgf("[streams] skip stop none producer")
return
case stateStart:
p.workerID++
+2 -1
View File
@@ -3,6 +3,7 @@ package streams
import (
"encoding/json"
"sync"
"sync/atomic"
"github.com/AlexxIT/go2rtc/pkg/core"
)
@@ -11,7 +12,7 @@ type Stream struct {
producers []*Producer
consumers []core.Consumer
mu sync.Mutex
requests int32
pending atomic.Int32
}
func NewStream(source any) *Stream {