diff --git a/pkg/tuya/client.go b/pkg/tuya/client.go index fbf6c31b..277848e7 100644 --- a/pkg/tuya/client.go +++ b/pkg/tuya/client.go @@ -6,6 +6,7 @@ import ( "fmt" "net/url" "regexp" + "sync" "github.com/AlexxIT/go2rtc/pkg/core" "github.com/AlexxIT/go2rtc/pkg/webrtc" @@ -24,6 +25,7 @@ type Client struct { isHEVC bool connected core.Waiter closed bool + handlersMu sync.RWMutex handlers map[uint32]func(*rtp.Packet) } @@ -222,7 +224,7 @@ func Dial(rawURL string) (core.Producer, error) { return } - if handler, ok := client.handlers[packet.SSRC]; ok { + if handler, ok := client.getHandler(packet.SSRC); ok { handler(packet) } } @@ -368,16 +370,20 @@ func (c *Client) Start() error { } } - c.handlers[c.videoSSRC] = func(packet *rtp.Packet) { - if video != nil { - video.WriteRTP(packet) - } + if c.videoSSRC != 0 { + c.setHandler(c.videoSSRC, func(packet *rtp.Packet) { + if video != nil { + video.WriteRTP(packet) + } + }) } - c.handlers[c.audioSSRC] = func(packet *rtp.Packet) { - if audio != nil { - audio.WriteRTP(packet) - } + if c.audioSSRC != 0 { + c.setHandler(c.audioSSRC, func(packet *rtp.Packet) { + if audio != nil { + audio.WriteRTP(packet) + } + }) } return c.conn.Start() @@ -390,9 +396,7 @@ func (c *Client) Stop() error { c.closed = true - for ssrc := range c.handlers { - delete(c.handlers, ssrc) - } + c.clearHandlers() if c.conn != nil { _ = c.conn.Stop() @@ -414,6 +418,27 @@ func (c *Client) MarshalJSON() ([]byte, error) { return c.conn.MarshalJSON() } +func (c *Client) setHandler(ssrc uint32, handler func(*rtp.Packet)) { + c.handlersMu.Lock() + defer c.handlersMu.Unlock() + c.handlers[ssrc] = handler +} + +func (c *Client) getHandler(ssrc uint32) (func(*rtp.Packet), bool) { + c.handlersMu.RLock() + defer c.handlersMu.RUnlock() + handler, ok := c.handlers[ssrc] + return handler, ok +} + +func (c *Client) clearHandlers() { + c.handlersMu.Lock() + defer c.handlersMu.Unlock() + for ssrc := range c.handlers { + delete(c.handlers, ssrc) + } +} + func (c *Client) probe(msg pion.DataChannelMessage) (bool, error) { // fmt.Printf("[tuya] Received string message: %s\n", string(msg.Data)) diff --git a/pkg/tuya/mqtt.go b/pkg/tuya/mqtt.go index 52f928a5..e5565487 100644 --- a/pkg/tuya/mqtt.go +++ b/pkg/tuya/mqtt.go @@ -109,7 +109,11 @@ func (c *TuyaMqttClient) Start(hubConfig *MQTTConfig, webrtcConfig *WebRTCConfig SetUsername(hubConfig.Username). SetPassword(hubConfig.Password). SetOnConnectHandler(c.onConnect). - SetConnectTimeout(10 * time.Second) + SetAutoReconnect(true). + SetMaxReconnectInterval(30 * time.Second). + SetConnectTimeout(15 * time.Second). + SetKeepAlive(30 * time.Second). + SetPingTimeout(15 * time.Second) c.client = mqtt.NewClient(opts)