From 3a51fa2397b4f924e954c0703fe4a478f22da265 Mon Sep 17 00:00:00 2001 From: Alexey Khit Date: Thu, 29 Jun 2023 10:55:43 +0300 Subject: [PATCH] Fix panic with only audio for MP4/MSE #404 --- pkg/mp4/consumer.go | 30 ++++++++-------- pkg/mp4/consumer_test.go | 77 ++++++++++++++++++++++++++++++++++++++++ pkg/mp4/helpers.go | 6 ++-- 3 files changed, 96 insertions(+), 17 deletions(-) create mode 100644 pkg/mp4/consumer_test.go diff --git a/pkg/mp4/consumer.go b/pkg/mp4/consumer.go index 069016cb..f9306b10 100644 --- a/pkg/mp4/consumer.go +++ b/pkg/mp4/consumer.go @@ -22,7 +22,7 @@ type Consumer struct { muxer *Muxer mu sync.Mutex - wait byte + state byte send int } @@ -60,18 +60,16 @@ func (c *Consumer) AddTrack(media *core.Media, _ *core.Codec, track *core.Receiv switch track.Codec.Name { case core.CodecH264: - c.wait = waitInit - handler.Handler = func(packet *rtp.Packet) { if packet.Version != h264.RTPPacketVersionAVC { return } - if c.wait != waitNone { - if c.wait == waitInit || !h264.IsKeyframe(packet.Payload) { + if c.state != stateStart { + if c.state != stateInit || !h264.IsKeyframe(packet.Payload) { return } - c.wait = waitNone + c.state = stateStart } // important to use Mutex because right fragment order @@ -89,18 +87,16 @@ func (c *Consumer) AddTrack(media *core.Media, _ *core.Codec, track *core.Receiv } case core.CodecH265: - c.wait = waitInit - handler.Handler = func(packet *rtp.Packet) { if packet.Version != h264.RTPPacketVersionAVC { return } - if c.wait != waitNone { - if c.wait == waitInit || !h265.IsKeyframe(packet.Payload) { + if c.state != stateStart { + if c.state != stateInit || !h265.IsKeyframe(packet.Payload) { return } - c.wait = waitNone + c.state = stateStart } c.mu.Lock() @@ -116,7 +112,7 @@ func (c *Consumer) AddTrack(media *core.Media, _ *core.Codec, track *core.Receiv default: handler.Handler = func(packet *rtp.Packet) { - if c.wait != waitNone { + if c.state != stateStart { return } @@ -182,9 +178,15 @@ func (c *Consumer) Init() ([]byte, error) { } func (c *Consumer) Start() { - if c.wait == waitInit { - c.wait = waitKeyframe + for _, sender := range c.senders { + switch sender.Codec.Name { + case core.CodecH264, core.CodecH265: + c.state = stateInit + return + } } + + c.state = stateStart } func (c *Consumer) MarshalJSON() ([]byte, error) { diff --git a/pkg/mp4/consumer_test.go b/pkg/mp4/consumer_test.go new file mode 100644 index 00000000..4991c044 --- /dev/null +++ b/pkg/mp4/consumer_test.go @@ -0,0 +1,77 @@ +package mp4 + +import ( + "github.com/AlexxIT/go2rtc/pkg/core" + "github.com/AlexxIT/go2rtc/pkg/h264" + "github.com/pion/rtp" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func TestStartH264(t *testing.T) { + codec := &core.Codec{Name: core.CodecH264} + track := core.NewReceiver(nil, codec) + + packetKey := &rtp.Packet{ + Header: rtp.Header{Marker: true}, + Payload: []byte{h264.NALUTypeIFrame, 0, 0}, + } + + packetNotKey := &rtp.Packet{ + Header: rtp.Header{Marker: true}, + Payload: []byte{h264.NALUTypePFrame, 0, 0}, + } + + cons := &Consumer{} + err := cons.AddTrack(nil, nil, track) + require.Nil(t, err) + + track.WriteRTP(packetKey) + time.Sleep(time.Millisecond) + + _, err = cons.Init() + require.Nil(t, err) + + cons.Start() + + track.WriteRTP(packetNotKey) + time.Sleep(time.Millisecond) + + require.Zero(t, cons.send) + + track.WriteRTP(packetKey) + time.Sleep(time.Millisecond) + + require.NotZero(t, cons.send) +} + +func TestStartOPUS(t *testing.T) { + // Test for fix this issue + // https://github.com/AlexxIT/go2rtc/issues/404 + codec := &core.Codec{Name: core.CodecOpus} + track := core.NewReceiver(nil, codec) + + cons := &Consumer{} + err := cons.AddTrack(nil, nil, track) + require.Nil(t, err) + + track.WriteRTP(&rtp.Packet{ + Payload: []byte{0}, + }) + time.Sleep(time.Millisecond) + + require.Zero(t, cons.send) + + _, err = cons.Init() + require.Nil(t, err) + + cons.Start() + + track.WriteRTP(&rtp.Packet{ + Payload: []byte{0}, + }) + time.Sleep(time.Millisecond) + + require.NotZero(t, cons.send) +} diff --git a/pkg/mp4/helpers.go b/pkg/mp4/helpers.go index c22f1220..174dc48d 100644 --- a/pkg/mp4/helpers.go +++ b/pkg/mp4/helpers.go @@ -49,7 +49,7 @@ func ParseQuery(query map[string][]string) []*core.Media { } const ( - waitNone byte = iota - waitKeyframe - waitInit + stateNone byte = iota + stateInit + stateStart )