Fix multiple dial on add consumer
This commit is contained in:
@@ -3,18 +3,17 @@ package streams
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/AlexxIT/go2rtc/pkg/core"
|
||||
)
|
||||
|
||||
func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
||||
// support for multiple simultaneous requests from different consumers
|
||||
consN := atomic.AddInt32(&s.requests, 1) - 1
|
||||
// support for multiple simultaneous pending from different consumers
|
||||
consN := s.pending.Add(1) - 1
|
||||
|
||||
var prodErrors []error
|
||||
var prodErrors = make([]error, len(s.producers))
|
||||
var prodMedias []*core.Media
|
||||
var prods []*Producer // matched producers for consumer
|
||||
var prodStarts []*Producer
|
||||
|
||||
// Step 1. Get consumer medias
|
||||
consMedias := cons.GetMedias()
|
||||
@@ -23,15 +22,20 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
||||
|
||||
producers:
|
||||
for prodN, prod := range s.producers {
|
||||
if prodErrors[prodN] != nil {
|
||||
log.Trace().Msgf("[streams] skip cons=%d prod=%d", consN, prodN)
|
||||
continue
|
||||
}
|
||||
|
||||
if err = prod.Dial(); err != nil {
|
||||
log.Trace().Err(err).Msgf("[streams] skip prod=%s", prod.url)
|
||||
prodErrors = append(prodErrors, err)
|
||||
log.Trace().Err(err).Msgf("[streams] dial cons=%d prod=%d", consN, prodN)
|
||||
prodErrors[prodN] = err
|
||||
continue
|
||||
}
|
||||
|
||||
// Step 2. Get producer medias (not tracks yet)
|
||||
for _, prodMedia := range prod.GetMedias() {
|
||||
log.Trace().Msgf("[streams] check prod=%d media=%s", prodN, prodMedia)
|
||||
log.Trace().Msgf("[streams] check cons=%d prod=%d media=%s", consN, prodN, prodMedia)
|
||||
prodMedias = append(prodMedias, prodMedia)
|
||||
|
||||
// Step 3. Match consumer/producer codecs list
|
||||
@@ -44,11 +48,12 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
||||
|
||||
switch prodMedia.Direction {
|
||||
case core.DirectionRecvonly:
|
||||
log.Trace().Msgf("[streams] match prod=%d => cons=%d", prodN, consN)
|
||||
log.Trace().Msgf("[streams] match cons=%d <= prod=%d", consN, prodN)
|
||||
|
||||
// Step 4. Get recvonly track from producer
|
||||
if track, err = prod.GetTrack(prodMedia, prodCodec); err != nil {
|
||||
log.Info().Err(err).Msg("[streams] can't get track")
|
||||
prodErrors[prodN] = err
|
||||
continue
|
||||
}
|
||||
// Step 5. Add track to consumer
|
||||
@@ -68,11 +73,12 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
||||
// Step 5. Add track to producer
|
||||
if err = prod.AddTrack(prodMedia, prodCodec, track); err != nil {
|
||||
log.Info().Err(err).Msg("[streams] can't add track")
|
||||
prodErrors[prodN] = err
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
prods = append(prods, prod)
|
||||
prodStarts = append(prodStarts, prod)
|
||||
|
||||
if !consMedia.MatchAll() {
|
||||
break producers
|
||||
@@ -82,11 +88,11 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
||||
}
|
||||
|
||||
// stop producers if they don't have readers
|
||||
if atomic.AddInt32(&s.requests, -1) == 0 {
|
||||
if s.pending.Add(-1) == 0 {
|
||||
s.stopProducers()
|
||||
}
|
||||
|
||||
if len(prods) == 0 {
|
||||
if len(prodStarts) == 0 {
|
||||
return formatError(consMedias, prodMedias, prodErrors)
|
||||
}
|
||||
|
||||
@@ -95,7 +101,7 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
||||
s.mu.Unlock()
|
||||
|
||||
// there may be duplicates, but that's not a problem
|
||||
for _, prod := range prods {
|
||||
for _, prod := range prodStarts {
|
||||
prod.start()
|
||||
}
|
||||
|
||||
@@ -103,6 +109,20 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
||||
}
|
||||
|
||||
func formatError(consMedias, prodMedias []*core.Media, prodErrors []error) error {
|
||||
// 1. Return errors if any not nil
|
||||
var text string
|
||||
|
||||
for _, err := range prodErrors {
|
||||
if err != nil {
|
||||
text = appendString(text, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if len(text) != 0 {
|
||||
return errors.New("streams: " + text)
|
||||
}
|
||||
|
||||
// 2. Return "codecs not matched"
|
||||
if prodMedias != nil {
|
||||
var prod, cons string
|
||||
|
||||
@@ -125,16 +145,7 @@ func formatError(consMedias, prodMedias []*core.Media, prodErrors []error) error
|
||||
return errors.New("streams: codecs not matched: " + prod + " => " + cons)
|
||||
}
|
||||
|
||||
if prodErrors != nil {
|
||||
var text string
|
||||
|
||||
for _, err := range prodErrors {
|
||||
text = appendString(text, err.Error())
|
||||
}
|
||||
|
||||
return errors.New("streams: " + text)
|
||||
}
|
||||
|
||||
// 3. Return unknown error
|
||||
return errors.New("streams: unknown error")
|
||||
}
|
||||
|
||||
|
||||
@@ -245,10 +245,10 @@ func (p *Producer) stop() {
|
||||
|
||||
switch p.state {
|
||||
case stateExternal:
|
||||
log.Debug().Msgf("[streams] can't stop external producer")
|
||||
log.Trace().Msgf("[streams] skip stop external producer")
|
||||
return
|
||||
case stateNone:
|
||||
log.Debug().Msgf("[streams] can't stop none producer")
|
||||
log.Trace().Msgf("[streams] skip stop none producer")
|
||||
return
|
||||
case stateStart:
|
||||
p.workerID++
|
||||
|
||||
@@ -3,6 +3,7 @@ package streams
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/AlexxIT/go2rtc/pkg/core"
|
||||
)
|
||||
@@ -11,7 +12,7 @@ type Stream struct {
|
||||
producers []*Producer
|
||||
consumers []core.Consumer
|
||||
mu sync.Mutex
|
||||
requests int32
|
||||
pending atomic.Int32
|
||||
}
|
||||
|
||||
func NewStream(source any) *Stream {
|
||||
|
||||
Reference in New Issue
Block a user