Fix multiple dial on add consumer
This commit is contained in:
@@ -3,18 +3,17 @@ package streams
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"github.com/AlexxIT/go2rtc/pkg/core"
|
"github.com/AlexxIT/go2rtc/pkg/core"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
||||||
// support for multiple simultaneous requests from different consumers
|
// support for multiple simultaneous pending from different consumers
|
||||||
consN := atomic.AddInt32(&s.requests, 1) - 1
|
consN := s.pending.Add(1) - 1
|
||||||
|
|
||||||
var prodErrors []error
|
var prodErrors = make([]error, len(s.producers))
|
||||||
var prodMedias []*core.Media
|
var prodMedias []*core.Media
|
||||||
var prods []*Producer // matched producers for consumer
|
var prodStarts []*Producer
|
||||||
|
|
||||||
// Step 1. Get consumer medias
|
// Step 1. Get consumer medias
|
||||||
consMedias := cons.GetMedias()
|
consMedias := cons.GetMedias()
|
||||||
@@ -23,15 +22,20 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
|||||||
|
|
||||||
producers:
|
producers:
|
||||||
for prodN, prod := range s.producers {
|
for prodN, prod := range s.producers {
|
||||||
|
if prodErrors[prodN] != nil {
|
||||||
|
log.Trace().Msgf("[streams] skip cons=%d prod=%d", consN, prodN)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if err = prod.Dial(); err != nil {
|
if err = prod.Dial(); err != nil {
|
||||||
log.Trace().Err(err).Msgf("[streams] skip prod=%s", prod.url)
|
log.Trace().Err(err).Msgf("[streams] dial cons=%d prod=%d", consN, prodN)
|
||||||
prodErrors = append(prodErrors, err)
|
prodErrors[prodN] = err
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 2. Get producer medias (not tracks yet)
|
// Step 2. Get producer medias (not tracks yet)
|
||||||
for _, prodMedia := range prod.GetMedias() {
|
for _, prodMedia := range prod.GetMedias() {
|
||||||
log.Trace().Msgf("[streams] check prod=%d media=%s", prodN, prodMedia)
|
log.Trace().Msgf("[streams] check cons=%d prod=%d media=%s", consN, prodN, prodMedia)
|
||||||
prodMedias = append(prodMedias, prodMedia)
|
prodMedias = append(prodMedias, prodMedia)
|
||||||
|
|
||||||
// Step 3. Match consumer/producer codecs list
|
// Step 3. Match consumer/producer codecs list
|
||||||
@@ -44,11 +48,12 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
|||||||
|
|
||||||
switch prodMedia.Direction {
|
switch prodMedia.Direction {
|
||||||
case core.DirectionRecvonly:
|
case core.DirectionRecvonly:
|
||||||
log.Trace().Msgf("[streams] match prod=%d => cons=%d", prodN, consN)
|
log.Trace().Msgf("[streams] match cons=%d <= prod=%d", consN, prodN)
|
||||||
|
|
||||||
// Step 4. Get recvonly track from producer
|
// Step 4. Get recvonly track from producer
|
||||||
if track, err = prod.GetTrack(prodMedia, prodCodec); err != nil {
|
if track, err = prod.GetTrack(prodMedia, prodCodec); err != nil {
|
||||||
log.Info().Err(err).Msg("[streams] can't get track")
|
log.Info().Err(err).Msg("[streams] can't get track")
|
||||||
|
prodErrors[prodN] = err
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// Step 5. Add track to consumer
|
// Step 5. Add track to consumer
|
||||||
@@ -68,11 +73,12 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
|||||||
// Step 5. Add track to producer
|
// Step 5. Add track to producer
|
||||||
if err = prod.AddTrack(prodMedia, prodCodec, track); err != nil {
|
if err = prod.AddTrack(prodMedia, prodCodec, track); err != nil {
|
||||||
log.Info().Err(err).Msg("[streams] can't add track")
|
log.Info().Err(err).Msg("[streams] can't add track")
|
||||||
|
prodErrors[prodN] = err
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
prods = append(prods, prod)
|
prodStarts = append(prodStarts, prod)
|
||||||
|
|
||||||
if !consMedia.MatchAll() {
|
if !consMedia.MatchAll() {
|
||||||
break producers
|
break producers
|
||||||
@@ -82,11 +88,11 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// stop producers if they don't have readers
|
// stop producers if they don't have readers
|
||||||
if atomic.AddInt32(&s.requests, -1) == 0 {
|
if s.pending.Add(-1) == 0 {
|
||||||
s.stopProducers()
|
s.stopProducers()
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(prods) == 0 {
|
if len(prodStarts) == 0 {
|
||||||
return formatError(consMedias, prodMedias, prodErrors)
|
return formatError(consMedias, prodMedias, prodErrors)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,7 +101,7 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
|||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
// there may be duplicates, but that's not a problem
|
// there may be duplicates, but that's not a problem
|
||||||
for _, prod := range prods {
|
for _, prod := range prodStarts {
|
||||||
prod.start()
|
prod.start()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,6 +109,20 @@ func (s *Stream) AddConsumer(cons core.Consumer) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func formatError(consMedias, prodMedias []*core.Media, prodErrors []error) error {
|
func formatError(consMedias, prodMedias []*core.Media, prodErrors []error) error {
|
||||||
|
// 1. Return errors if any not nil
|
||||||
|
var text string
|
||||||
|
|
||||||
|
for _, err := range prodErrors {
|
||||||
|
if err != nil {
|
||||||
|
text = appendString(text, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(text) != 0 {
|
||||||
|
return errors.New("streams: " + text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Return "codecs not matched"
|
||||||
if prodMedias != nil {
|
if prodMedias != nil {
|
||||||
var prod, cons string
|
var prod, cons string
|
||||||
|
|
||||||
@@ -125,16 +145,7 @@ func formatError(consMedias, prodMedias []*core.Media, prodErrors []error) error
|
|||||||
return errors.New("streams: codecs not matched: " + prod + " => " + cons)
|
return errors.New("streams: codecs not matched: " + prod + " => " + cons)
|
||||||
}
|
}
|
||||||
|
|
||||||
if prodErrors != nil {
|
// 3. Return unknown error
|
||||||
var text string
|
|
||||||
|
|
||||||
for _, err := range prodErrors {
|
|
||||||
text = appendString(text, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors.New("streams: " + text)
|
|
||||||
}
|
|
||||||
|
|
||||||
return errors.New("streams: unknown error")
|
return errors.New("streams: unknown error")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -245,10 +245,10 @@ func (p *Producer) stop() {
|
|||||||
|
|
||||||
switch p.state {
|
switch p.state {
|
||||||
case stateExternal:
|
case stateExternal:
|
||||||
log.Debug().Msgf("[streams] can't stop external producer")
|
log.Trace().Msgf("[streams] skip stop external producer")
|
||||||
return
|
return
|
||||||
case stateNone:
|
case stateNone:
|
||||||
log.Debug().Msgf("[streams] can't stop none producer")
|
log.Trace().Msgf("[streams] skip stop none producer")
|
||||||
return
|
return
|
||||||
case stateStart:
|
case stateStart:
|
||||||
p.workerID++
|
p.workerID++
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package streams
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/AlexxIT/go2rtc/pkg/core"
|
"github.com/AlexxIT/go2rtc/pkg/core"
|
||||||
)
|
)
|
||||||
@@ -11,7 +12,7 @@ type Stream struct {
|
|||||||
producers []*Producer
|
producers []*Producer
|
||||||
consumers []core.Consumer
|
consumers []core.Consumer
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
requests int32
|
pending atomic.Int32
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStream(source any) *Stream {
|
func NewStream(source any) *Stream {
|
||||||
|
|||||||
Reference in New Issue
Block a user