Rewrite Receiver/Sender classes

This commit is contained in:
Alex X
2024-06-05 20:01:47 +03:00
parent e0b1a50356
commit 31e4ba2722
4 changed files with 327 additions and 167 deletions
+5
View File
@@ -2,6 +2,7 @@ package core
import (
"encoding/base64"
"encoding/json"
"fmt"
"strconv"
"strings"
@@ -18,6 +19,10 @@ type Codec struct {
PayloadType uint8
}
func (c *Codec) MarshalJSON() ([]byte, error) {
return json.Marshal(c.String())
}
func (c *Codec) String() string {
s := fmt.Sprintf("%d %s", c.PayloadType, c.Name)
if c.ClockRate != 0 && c.ClockRate != 90000 {
+120
View File
@@ -0,0 +1,120 @@
package core
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
type producer struct {
Medias []*Media
Receivers []*Receiver
id byte
}
func (p *producer) GetMedias() []*Media {
return p.Medias
}
func (p *producer) GetTrack(_ *Media, codec *Codec) (*Receiver, error) {
for _, receiver := range p.Receivers {
if receiver.Codec == codec {
return receiver, nil
}
}
receiver := NewReceiver(nil, codec)
p.Receivers = append(p.Receivers, receiver)
return receiver, nil
}
func (p *producer) Start() error {
pkt := &Packet{Payload: []byte{p.id}}
p.Receivers[0].Input(pkt)
return nil
}
func (p *producer) Stop() error {
for _, receiver := range p.Receivers {
receiver.Close()
}
return nil
}
type consumer struct {
Medias []*Media
Senders []*Sender
cache chan byte
}
func (c *consumer) GetMedias() []*Media {
return c.Medias
}
func (c *consumer) AddTrack(_ *Media, _ *Codec, track *Receiver) error {
c.cache = make(chan byte, 1)
sender := NewSender(nil, track.Codec)
sender.Output = func(packet *Packet) {
c.cache <- packet.Payload[0]
}
sender.HandleRTP(track)
c.Senders = append(c.Senders, sender)
return nil
}
func (c *consumer) Stop() error {
for _, sender := range c.Senders {
sender.Close()
}
return nil
}
func (c *consumer) read() byte {
return <-c.cache
}
func TestName(t *testing.T) {
GetProducer := func(b byte) Producer {
return &producer{
Medias: []*Media{
{
Kind: KindVideo,
Direction: DirectionRecvonly,
Codecs: []*Codec{
{Name: CodecH264},
},
},
},
id: b,
}
}
// stage1
prod1 := GetProducer(1)
cons2 := &consumer{}
media1 := prod1.GetMedias()[0]
track1, _ := prod1.GetTrack(media1, media1.Codecs[0])
_ = cons2.AddTrack(nil, nil, track1)
_ = prod1.Start()
require.Equal(t, byte(1), cons2.read())
// stage2
prod2 := GetProducer(2)
media2 := prod2.GetMedias()[0]
require.NotEqual(t, fmt.Sprintf("%p", media1), fmt.Sprintf("%p", media2))
track2, _ := prod2.GetTrack(media2, media2.Codecs[0])
track1.Replace(track2)
_ = prod1.Stop()
_ = prod2.Start()
require.Equal(t, byte(2), cons2.read())
// stage3
_ = prod2.Stop()
}
+87
View File
@@ -0,0 +1,87 @@
package core
import (
"sync"
"github.com/pion/rtp"
)
//type Packet struct {
// Payload []byte
// Timestamp uint32 // PTS if DTS == 0 else DTS
// Composition uint32 // CTS = PTS-DTS (for support B-frames)
// Sequence uint16
//}
type Packet = rtp.Packet
// HandlerFunc - process input packets (just like http.HandlerFunc)
type HandlerFunc func(packet *Packet)
// Filter - a decorator for any HandlerFunc
type Filter func(handler HandlerFunc) HandlerFunc
// Node - Receiver or Sender or Filter (transform)
type Node struct {
Codec *Codec `json:"codec"`
Input HandlerFunc `json:"-"`
Output HandlerFunc `json:"-"`
childs []*Node
parent *Node
mu sync.Mutex
}
func (n *Node) WithParent(parent *Node) *Node {
parent.AppendChild(n)
return n
}
func (n *Node) AppendChild(child *Node) {
n.mu.Lock()
n.childs = append(n.childs, child)
n.mu.Unlock()
child.parent = n
}
func (n *Node) RemoveChild(child *Node) {
n.mu.Lock()
for i, ch := range n.childs {
if ch == child {
n.childs = append(n.childs[:i], n.childs[i+1:]...)
break
}
}
n.mu.Unlock()
}
func (n *Node) Close() {
if parent := n.parent; parent != nil {
parent.RemoveChild(n)
if len(parent.childs) == 0 {
parent.Close()
}
} else {
for _, childs := range n.childs {
childs.Close()
}
}
}
func MoveNode(dst, src *Node) {
src.mu.Lock()
childs := src.childs
src.childs = nil
src.mu.Unlock()
dst.mu.Lock()
dst.childs = childs
dst.mu.Unlock()
for _, child := range childs {
child.parent = dst
}
}
+115 -167
View File
@@ -1,225 +1,173 @@
package core
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"sync"
"github.com/pion/rtp"
)
type Packet struct {
PayloadType uint8
Sequence uint16
Timestamp uint32 // PTS if DTS == 0 else DTS
Composition uint32 // CTS = PTS-DTS (for support B-frames)
Payload []byte
}
var ErrCantGetTrack = errors.New("can't get track")
type Receiver struct {
Codec *Codec
Media *Media
Node
ID byte // Channel for RTSP, PayloadType for MPEG-TS
// Deprecated: should be removed
Media *Media `json:"-"`
// Deprecated: should be removed
ID byte `json:"-"` // Channel for RTSP, PayloadType for MPEG-TS
senders map[*Sender]chan *rtp.Packet
mu sync.RWMutex
bytes int
Bytes int `json:"bytes,omitempty"`
Packets int `json:"packets,omitempty"`
}
func NewReceiver(media *Media, codec *Codec) *Receiver {
Assert(codec != nil)
return &Receiver{Codec: codec, Media: media}
}
// WriteRTP - fast and non blocking write to all readers buffers
func (t *Receiver) WriteRTP(packet *rtp.Packet) {
t.mu.Lock()
t.bytes += len(packet.Payload)
for sender, buffer := range t.senders {
select {
case buffer <- packet:
default:
sender.overflow++
r := &Receiver{
Node: Node{Codec: codec},
Media: media,
}
r.Input = func(packet *Packet) {
r.Bytes += len(packet.Payload)
r.Packets++
for _, child := range r.childs {
child.Input(packet)
}
}
t.mu.Unlock()
return r
}
func (t *Receiver) Senders() (senders []*Sender) {
t.mu.RLock()
for sender := range t.senders {
senders = append(senders, sender)
// Deprecated: should be removed
func (r *Receiver) WriteRTP(packet *rtp.Packet) {
r.Input(packet)
}
// Deprecated: should be removed
func (r *Receiver) Senders() []*Sender {
if len(r.childs) > 0 {
return []*Sender{{}}
} else {
return nil
}
t.mu.RUnlock()
return
}
func (t *Receiver) Close() {
t.mu.Lock()
// close all sender channel buffers and erase senders list
for _, buffer := range t.senders {
close(buffer)
}
t.senders = nil
t.mu.Unlock()
// Deprecated: should be removed
func (r *Receiver) Replace(target *Receiver) {
MoveNode(&target.Node, &r.Node)
}
func (t *Receiver) Replace(target *Receiver) {
// move this receiver senders to new receiver
t.mu.Lock()
senders := t.senders
t.mu.Unlock()
target.mu.Lock()
target.senders = senders
target.mu.Unlock()
}
func (t *Receiver) String() string {
s := t.Codec.String() + ", bytes=" + strconv.Itoa(t.bytes)
t.mu.RLock()
s += fmt.Sprintf(", senders=%d", len(t.senders))
t.mu.RUnlock()
return s
}
func (t *Receiver) MarshalJSON() ([]byte, error) {
return json.Marshal(t.String())
func (r *Receiver) Close() {
r.Node.Close()
}
type Sender struct {
Codec *Codec
Media *Media
Node
Handler HandlerFunc
// Deprecated:
Media *Media `json:"-"`
// Deprecated:
Handler HandlerFunc `json:"-"`
receivers []*Receiver
mu sync.RWMutex
bytes int
Bytes int `json:"bytes,omitempty"`
Packets int `json:"packets,omitempty"`
Drops int `json:"drops,omitempty"`
overflow int
buf chan *Packet
done chan struct{}
}
func NewSender(media *Media, codec *Codec) *Sender {
return &Sender{Codec: codec, Media: media}
}
var bufSize uint16
// HandlerFunc like http.HandlerFunc
type HandlerFunc func(packet *rtp.Packet)
func (s *Sender) HandleRTP(track *Receiver) {
s.Bind(track)
go s.worker(track)
}
func (s *Sender) Bind(track *Receiver) {
var bufferSize uint16
if GetKind(track.Codec.Name) == KindVideo {
if track.Codec.IsRTP() {
if GetKind(codec.Name) == KindVideo {
if codec.IsRTP() {
// in my tests 40Mbit/s 4K-video can generate up to 1500 items
// for the h264.RTPDepay => RTPPay queue
bufferSize = 5000
bufSize = 4096
} else {
bufferSize = 50
bufSize = 64
}
} else {
bufferSize = 100
bufSize = 128
}
buffer := make(chan *rtp.Packet, bufferSize)
track.mu.Lock()
if track.senders == nil {
track.senders = map[*Sender]chan *rtp.Packet{}
buf := make(chan *Packet, bufSize)
s := &Sender{
Node: Node{Codec: codec},
Media: media,
buf: buf,
}
track.senders[s] = buffer
track.mu.Unlock()
s.mu.Lock()
s.receivers = append(s.receivers, track)
s.mu.Unlock()
s.Input = func(packet *Packet) {
// writing to nil chan - OK, writing to closed chan - panic
s.mu.Lock()
select {
case s.buf <- packet:
s.Bytes += len(packet.Payload)
s.Packets++
default:
s.Drops++
}
s.mu.Unlock()
}
s.Output = func(packet *Packet) {
s.Handler(packet)
}
return s
}
func (s *Sender) worker(track *Receiver) {
track.mu.Lock()
buffer := track.senders[s]
track.mu.Unlock()
// Deprecated: should be removed
func (s *Sender) HandleRTP(parent *Receiver) {
s.WithParent(parent)
s.Start()
}
// read packets from buffer channel until it will be closed
if buffer != nil {
for packet := range buffer {
s.bytes += len(packet.Payload)
s.Handler(packet)
}
}
// Deprecated: should be removed
func (s *Sender) Bind(parent *Receiver) {
s.WithParent(parent)
}
// remove current receiver from list
// it can only happen when receiver close buffer channel
s.mu.Lock()
for i, receiver := range s.receivers {
if receiver == track {
s.receivers = append(s.receivers[:i], s.receivers[i+1:]...)
break
}
}
s.mu.Unlock()
func (s *Sender) WithParent(parent *Receiver) *Sender {
s.Node.WithParent(&parent.Node)
return s
}
func (s *Sender) Start() {
s.mu.Lock()
for _, track := range s.receivers {
go s.worker(track)
defer s.mu.Unlock()
if s.buf == nil || s.done != nil {
return
}
s.mu.Unlock()
s.done = make(chan struct{})
go func() {
for packet := range s.buf {
s.Output(packet)
}
close(s.done)
}()
}
func (s *Sender) Wait() {
if done := s.done; s.done != nil {
<-done
}
}
func (s *Sender) State() string {
if s.buf == nil {
return "closed"
}
if s.done == nil {
return "new"
}
return "connected"
}
func (s *Sender) Close() {
s.mu.Lock()
// remove this sender from all receivers list
for _, receiver := range s.receivers {
receiver.mu.Lock()
if buffer := receiver.senders[s]; buffer != nil {
// remove channel from list
delete(receiver.senders, s)
// close channel
close(buffer)
}
receiver.mu.Unlock()
// close buffer if exists
if buf := s.buf; buf != nil {
s.buf = nil
defer close(buf)
}
s.receivers = nil
s.mu.Unlock()
}
func (s *Sender) String() string {
info := s.Codec.String() + ", bytes=" + strconv.Itoa(s.bytes)
s.mu.RLock()
info += ", receivers=" + strconv.Itoa(len(s.receivers))
s.mu.RUnlock()
if s.overflow > 0 {
info += ", overflow=" + strconv.Itoa(s.overflow)
}
return info
}
func (s *Sender) MarshalJSON() ([]byte, error) {
return json.Marshal(s.String())
}
// VA - helper, for extract video and audio receivers from list
func VA(receivers []*Receiver) (video, audio *Receiver) {
for _, receiver := range receivers {
switch GetKind(receiver.Codec.Name) {
case KindVideo:
video = receiver
case KindAudio:
audio = receiver
}
}
return
s.Node.Close()
}