diff --git a/pkg/core/codec.go b/pkg/core/codec.go index f38d7965..fe813de3 100644 --- a/pkg/core/codec.go +++ b/pkg/core/codec.go @@ -2,6 +2,7 @@ package core import ( "encoding/base64" + "encoding/json" "fmt" "strconv" "strings" @@ -18,6 +19,10 @@ type Codec struct { PayloadType uint8 } +func (c *Codec) MarshalJSON() ([]byte, error) { + return json.Marshal(c.String()) +} + func (c *Codec) String() string { s := fmt.Sprintf("%d %s", c.PayloadType, c.Name) if c.ClockRate != 0 && c.ClockRate != 90000 { diff --git a/pkg/core/core_test.go b/pkg/core/core_test.go new file mode 100644 index 00000000..4a05380a --- /dev/null +++ b/pkg/core/core_test.go @@ -0,0 +1,120 @@ +package core + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +type producer struct { + Medias []*Media + Receivers []*Receiver + + id byte +} + +func (p *producer) GetMedias() []*Media { + return p.Medias +} + +func (p *producer) GetTrack(_ *Media, codec *Codec) (*Receiver, error) { + for _, receiver := range p.Receivers { + if receiver.Codec == codec { + return receiver, nil + } + } + receiver := NewReceiver(nil, codec) + p.Receivers = append(p.Receivers, receiver) + return receiver, nil +} + +func (p *producer) Start() error { + pkt := &Packet{Payload: []byte{p.id}} + p.Receivers[0].Input(pkt) + return nil +} + +func (p *producer) Stop() error { + for _, receiver := range p.Receivers { + receiver.Close() + } + return nil +} + +type consumer struct { + Medias []*Media + Senders []*Sender + + cache chan byte +} + +func (c *consumer) GetMedias() []*Media { + return c.Medias +} + +func (c *consumer) AddTrack(_ *Media, _ *Codec, track *Receiver) error { + c.cache = make(chan byte, 1) + sender := NewSender(nil, track.Codec) + sender.Output = func(packet *Packet) { + c.cache <- packet.Payload[0] + } + sender.HandleRTP(track) + c.Senders = append(c.Senders, sender) + return nil +} + +func (c *consumer) Stop() error { + for _, sender := range c.Senders { + sender.Close() + } + return nil +} + +func (c *consumer) read() byte { + return <-c.cache +} + +func TestName(t *testing.T) { + GetProducer := func(b byte) Producer { + return &producer{ + Medias: []*Media{ + { + Kind: KindVideo, + Direction: DirectionRecvonly, + Codecs: []*Codec{ + {Name: CodecH264}, + }, + }, + }, + id: b, + } + } + + // stage1 + prod1 := GetProducer(1) + cons2 := &consumer{} + + media1 := prod1.GetMedias()[0] + track1, _ := prod1.GetTrack(media1, media1.Codecs[0]) + + _ = cons2.AddTrack(nil, nil, track1) + + _ = prod1.Start() + require.Equal(t, byte(1), cons2.read()) + + // stage2 + prod2 := GetProducer(2) + media2 := prod2.GetMedias()[0] + require.NotEqual(t, fmt.Sprintf("%p", media1), fmt.Sprintf("%p", media2)) + track2, _ := prod2.GetTrack(media2, media2.Codecs[0]) + track1.Replace(track2) + + _ = prod1.Stop() + + _ = prod2.Start() + require.Equal(t, byte(2), cons2.read()) + + // stage3 + _ = prod2.Stop() +} diff --git a/pkg/core/node.go b/pkg/core/node.go new file mode 100644 index 00000000..fd58f2d7 --- /dev/null +++ b/pkg/core/node.go @@ -0,0 +1,87 @@ +package core + +import ( + "sync" + + "github.com/pion/rtp" +) + +//type Packet struct { +// Payload []byte +// Timestamp uint32 // PTS if DTS == 0 else DTS +// Composition uint32 // CTS = PTS-DTS (for support B-frames) +// Sequence uint16 +//} + +type Packet = rtp.Packet + +// HandlerFunc - process input packets (just like http.HandlerFunc) +type HandlerFunc func(packet *Packet) + +// Filter - a decorator for any HandlerFunc +type Filter func(handler HandlerFunc) HandlerFunc + +// Node - Receiver or Sender or Filter (transform) +type Node struct { + Codec *Codec `json:"codec"` + Input HandlerFunc `json:"-"` + Output HandlerFunc `json:"-"` + + childs []*Node + parent *Node + + mu sync.Mutex +} + +func (n *Node) WithParent(parent *Node) *Node { + parent.AppendChild(n) + return n +} + +func (n *Node) AppendChild(child *Node) { + n.mu.Lock() + n.childs = append(n.childs, child) + n.mu.Unlock() + + child.parent = n +} + +func (n *Node) RemoveChild(child *Node) { + n.mu.Lock() + for i, ch := range n.childs { + if ch == child { + n.childs = append(n.childs[:i], n.childs[i+1:]...) + break + } + } + n.mu.Unlock() +} + +func (n *Node) Close() { + if parent := n.parent; parent != nil { + parent.RemoveChild(n) + + if len(parent.childs) == 0 { + parent.Close() + } + } else { + for _, childs := range n.childs { + childs.Close() + } + } +} + +func MoveNode(dst, src *Node) { + src.mu.Lock() + childs := src.childs + src.childs = nil + src.mu.Unlock() + + dst.mu.Lock() + dst.childs = childs + dst.mu.Unlock() + + for _, child := range childs { + child.parent = dst + } +} diff --git a/pkg/core/track.go b/pkg/core/track.go index 72e47074..83c39e01 100644 --- a/pkg/core/track.go +++ b/pkg/core/track.go @@ -1,225 +1,173 @@ package core import ( - "encoding/json" "errors" - "fmt" - "strconv" - "sync" "github.com/pion/rtp" ) -type Packet struct { - PayloadType uint8 - Sequence uint16 - Timestamp uint32 // PTS if DTS == 0 else DTS - Composition uint32 // CTS = PTS-DTS (for support B-frames) - Payload []byte -} - var ErrCantGetTrack = errors.New("can't get track") type Receiver struct { - Codec *Codec - Media *Media + Node - ID byte // Channel for RTSP, PayloadType for MPEG-TS + // Deprecated: should be removed + Media *Media `json:"-"` + // Deprecated: should be removed + ID byte `json:"-"` // Channel for RTSP, PayloadType for MPEG-TS - senders map[*Sender]chan *rtp.Packet - mu sync.RWMutex - bytes int + Bytes int `json:"bytes,omitempty"` + Packets int `json:"packets,omitempty"` } func NewReceiver(media *Media, codec *Codec) *Receiver { - Assert(codec != nil) - return &Receiver{Codec: codec, Media: media} -} - -// WriteRTP - fast and non blocking write to all readers buffers -func (t *Receiver) WriteRTP(packet *rtp.Packet) { - t.mu.Lock() - t.bytes += len(packet.Payload) - for sender, buffer := range t.senders { - select { - case buffer <- packet: - default: - sender.overflow++ + r := &Receiver{ + Node: Node{Codec: codec}, + Media: media, + } + r.Input = func(packet *Packet) { + r.Bytes += len(packet.Payload) + r.Packets++ + for _, child := range r.childs { + child.Input(packet) } } - t.mu.Unlock() + return r } -func (t *Receiver) Senders() (senders []*Sender) { - t.mu.RLock() - for sender := range t.senders { - senders = append(senders, sender) +// Deprecated: should be removed +func (r *Receiver) WriteRTP(packet *rtp.Packet) { + r.Input(packet) +} + +// Deprecated: should be removed +func (r *Receiver) Senders() []*Sender { + if len(r.childs) > 0 { + return []*Sender{{}} + } else { + return nil } - t.mu.RUnlock() - return } -func (t *Receiver) Close() { - t.mu.Lock() - // close all sender channel buffers and erase senders list - for _, buffer := range t.senders { - close(buffer) - } - t.senders = nil - t.mu.Unlock() +// Deprecated: should be removed +func (r *Receiver) Replace(target *Receiver) { + MoveNode(&target.Node, &r.Node) } -func (t *Receiver) Replace(target *Receiver) { - // move this receiver senders to new receiver - t.mu.Lock() - senders := t.senders - t.mu.Unlock() - - target.mu.Lock() - target.senders = senders - target.mu.Unlock() -} - -func (t *Receiver) String() string { - s := t.Codec.String() + ", bytes=" + strconv.Itoa(t.bytes) - t.mu.RLock() - s += fmt.Sprintf(", senders=%d", len(t.senders)) - t.mu.RUnlock() - return s -} - -func (t *Receiver) MarshalJSON() ([]byte, error) { - return json.Marshal(t.String()) +func (r *Receiver) Close() { + r.Node.Close() } type Sender struct { - Codec *Codec - Media *Media + Node - Handler HandlerFunc + // Deprecated: + Media *Media `json:"-"` + // Deprecated: + Handler HandlerFunc `json:"-"` - receivers []*Receiver - mu sync.RWMutex - bytes int + Bytes int `json:"bytes,omitempty"` + Packets int `json:"packets,omitempty"` + Drops int `json:"drops,omitempty"` - overflow int + buf chan *Packet + done chan struct{} } func NewSender(media *Media, codec *Codec) *Sender { - return &Sender{Codec: codec, Media: media} -} + var bufSize uint16 -// HandlerFunc like http.HandlerFunc -type HandlerFunc func(packet *rtp.Packet) - -func (s *Sender) HandleRTP(track *Receiver) { - s.Bind(track) - go s.worker(track) -} - -func (s *Sender) Bind(track *Receiver) { - var bufferSize uint16 - - if GetKind(track.Codec.Name) == KindVideo { - if track.Codec.IsRTP() { + if GetKind(codec.Name) == KindVideo { + if codec.IsRTP() { // in my tests 40Mbit/s 4K-video can generate up to 1500 items // for the h264.RTPDepay => RTPPay queue - bufferSize = 5000 + bufSize = 4096 } else { - bufferSize = 50 + bufSize = 64 } } else { - bufferSize = 100 + bufSize = 128 } - buffer := make(chan *rtp.Packet, bufferSize) - - track.mu.Lock() - if track.senders == nil { - track.senders = map[*Sender]chan *rtp.Packet{} + buf := make(chan *Packet, bufSize) + s := &Sender{ + Node: Node{Codec: codec}, + Media: media, + buf: buf, } - track.senders[s] = buffer - track.mu.Unlock() - - s.mu.Lock() - s.receivers = append(s.receivers, track) - s.mu.Unlock() + s.Input = func(packet *Packet) { + // writing to nil chan - OK, writing to closed chan - panic + s.mu.Lock() + select { + case s.buf <- packet: + s.Bytes += len(packet.Payload) + s.Packets++ + default: + s.Drops++ + } + s.mu.Unlock() + } + s.Output = func(packet *Packet) { + s.Handler(packet) + } + return s } -func (s *Sender) worker(track *Receiver) { - track.mu.Lock() - buffer := track.senders[s] - track.mu.Unlock() +// Deprecated: should be removed +func (s *Sender) HandleRTP(parent *Receiver) { + s.WithParent(parent) + s.Start() +} - // read packets from buffer channel until it will be closed - if buffer != nil { - for packet := range buffer { - s.bytes += len(packet.Payload) - s.Handler(packet) - } - } +// Deprecated: should be removed +func (s *Sender) Bind(parent *Receiver) { + s.WithParent(parent) +} - // remove current receiver from list - // it can only happen when receiver close buffer channel - s.mu.Lock() - for i, receiver := range s.receivers { - if receiver == track { - s.receivers = append(s.receivers[:i], s.receivers[i+1:]...) - break - } - } - s.mu.Unlock() +func (s *Sender) WithParent(parent *Receiver) *Sender { + s.Node.WithParent(&parent.Node) + return s } func (s *Sender) Start() { s.mu.Lock() - for _, track := range s.receivers { - go s.worker(track) + defer s.mu.Unlock() + + if s.buf == nil || s.done != nil { + return } - s.mu.Unlock() + s.done = make(chan struct{}) + + go func() { + for packet := range s.buf { + s.Output(packet) + } + close(s.done) + }() +} + +func (s *Sender) Wait() { + if done := s.done; s.done != nil { + <-done + } +} + +func (s *Sender) State() string { + if s.buf == nil { + return "closed" + } + if s.done == nil { + return "new" + } + return "connected" } func (s *Sender) Close() { - s.mu.Lock() - // remove this sender from all receivers list - for _, receiver := range s.receivers { - receiver.mu.Lock() - if buffer := receiver.senders[s]; buffer != nil { - // remove channel from list - delete(receiver.senders, s) - // close channel - close(buffer) - } - receiver.mu.Unlock() + // close buffer if exists + if buf := s.buf; buf != nil { + s.buf = nil + defer close(buf) } - s.receivers = nil - s.mu.Unlock() -} -func (s *Sender) String() string { - info := s.Codec.String() + ", bytes=" + strconv.Itoa(s.bytes) - s.mu.RLock() - info += ", receivers=" + strconv.Itoa(len(s.receivers)) - s.mu.RUnlock() - if s.overflow > 0 { - info += ", overflow=" + strconv.Itoa(s.overflow) - } - return info -} - -func (s *Sender) MarshalJSON() ([]byte, error) { - return json.Marshal(s.String()) -} - -// VA - helper, for extract video and audio receivers from list -func VA(receivers []*Receiver) (video, audio *Receiver) { - for _, receiver := range receivers { - switch GetKind(receiver.Codec.Name) { - case KindVideo: - video = receiver - case KindAudio: - audio = receiver - } - } - return + s.Node.Close() }