diff --git a/pkg/core/track.go b/pkg/core/track.go index 8bc65374..d3f1467d 100644 --- a/pkg/core/track.go +++ b/pkg/core/track.go @@ -140,6 +140,7 @@ func (s *Sender) Start() { s.done = make(chan struct{}) go func() { + // for range on nil chan is OK for packet := range s.buf { s.Output(packet) } @@ -148,7 +149,7 @@ func (s *Sender) Start() { } func (s *Sender) Wait() { - if done := s.done; s.done != nil { + if done := s.done; done != nil { <-done } } @@ -165,10 +166,12 @@ func (s *Sender) State() string { func (s *Sender) Close() { // close buffer if exists - if buf := s.buf; buf != nil { - s.buf = nil - defer close(buf) + s.mu.Lock() + if s.buf != nil { + close(s.buf) // exit from for range loop + s.buf = nil // prevent writing to closed chan } + s.mu.Unlock() s.Node.Close() } diff --git a/pkg/core/track_test.go b/pkg/core/track_test.go new file mode 100644 index 00000000..cf877d49 --- /dev/null +++ b/pkg/core/track_test.go @@ -0,0 +1,53 @@ +package core + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSenser(t *testing.T) { + recv := make(chan *Packet) // blocking receiver + + sender := NewSender(nil, &Codec{}) + sender.Output = func(packet *Packet) { + recv <- packet + } + require.Equal(t, "new", sender.State()) + + sender.Start() + require.Equal(t, "connected", sender.State()) + + sender.Input(&Packet{}) + sender.Input(&Packet{}) + + require.Equal(t, 2, sender.Packets) + require.Equal(t, 0, sender.Drops) + + // important to read one before close + // because goroutine in Start() can run with nil chan + // it's OK in real life, but bad for test + _, ok := <-recv + require.True(t, ok) + + sender.Close() + require.Equal(t, "closed", sender.State()) + + sender.Input(&Packet{}) + + require.Equal(t, 2, sender.Packets) + require.Equal(t, 1, sender.Drops) + + // read 2nd + _, ok = <-recv + require.True(t, ok) + + // read 3rd + select { + case <-recv: + ok = true + default: + ok = false + } + require.False(t, ok) +}