diff --git a/pkg/tuya/api.go b/pkg/tuya/api.go index 29152951..2f7cd194 100644 --- a/pkg/tuya/api.go +++ b/pkg/tuya/api.go @@ -319,15 +319,8 @@ func (c *TuyaClient) InitDevice() (err error) { _ = json.Unmarshal([]byte(webRTCConfigResponse.Result.Skill), c.skill) } - var audioDirection string - if contains(webRTCConfigResponse.Result.AudioAttributes.CallMode, 2) && - contains(webRTCConfigResponse.Result.AudioAttributes.HardwareCapability, 1) { - audioDirection = core.DirectionSendRecv - c.hasBackchannel = true - } else { - audioDirection = core.DirectionRecvonly - c.hasBackchannel = false - } + c.hasBackchannel = contains(webRTCConfigResponse.Result.AudioAttributes.CallMode, 2) && + contains(webRTCConfigResponse.Result.AudioAttributes.HardwareCapability, 1) c.medias = make([]*core.Media, 0) @@ -335,9 +328,14 @@ func (c *TuyaClient) InitDevice() (err error) { // Use the first Audio-Codec audio := c.skill.Audios[0] + direction := core.DirectionRecvonly + if c.hasBackchannel { + direction = core.DirectionSendRecv + } + c.medias = append(c.medias, &core.Media{ Kind: core.KindAudio, - Direction: audioDirection, + Direction: direction, Codecs: []*core.Codec{ { Name: "PCMU", @@ -471,7 +469,6 @@ func (c *TuyaClient) LoadHubConfig() (config *OpenIoTHubConfig, err error) { return &openIoTHubConfigResponse.Result, nil } -// Search the streamType based on the selection "main" or "sub" func (c *TuyaClient) getStreamType(streamChoice string) uint32 { // Default streamType if nothing is found defaultStreamType := uint32(1) diff --git a/pkg/tuya/client.go b/pkg/tuya/client.go index ef9cbcc2..586a8bbd 100644 --- a/pkg/tuya/client.go +++ b/pkg/tuya/client.go @@ -1,7 +1,6 @@ package tuya import ( - "encoding/json" "errors" "fmt" "net/url" @@ -10,12 +9,13 @@ import ( "github.com/AlexxIT/go2rtc/internal/streams" "github.com/AlexxIT/go2rtc/pkg/core" "github.com/AlexxIT/go2rtc/pkg/webrtc" + "github.com/pion/rtp" pion "github.com/pion/webrtc/v4" ) type Client struct { api *TuyaClient - prod core.Producer + conn *webrtc.Conn done chan struct{} } @@ -95,7 +95,7 @@ func Dial(rawURL string) (core.Producer, error) { conf := pion.Configuration{ ICEServers: client.api.iceServers, ICETransportPolicy: pion.ICETransportPolicyAll, - BundlePolicy: pion.BundlePolicyMaxBundle, + BundlePolicy: pion.BundlePolicyBalanced, } api, err := webrtc.NewAPI() @@ -119,16 +119,16 @@ func Dial(rawURL string) (core.Producer, error) { // waiter will wait PC error or WS error or nil (connection OK) var connState core.Waiter - prod := webrtc.NewConn(pc) - prod.FormatName = "tuya/webrtc" - prod.Mode = core.ModeActiveProducer - prod.Protocol = "mqtt" - prod.URL = rawURL - - client.prod = prod + client.conn = webrtc.NewConn(pc) + client.conn.FormatName = "tuya/webrtc" + client.conn.Mode = core.ModeActiveProducer + client.conn.Protocol = "mqtt" + client.conn.URL = rawURL // Set up MQTT handlers client.api.mqtt.handleAnswer = func(answer AnswerFrame) { + // fmt.Printf("tuya: answer: %s\n", answer.Sdp) + desc := pion.SessionDescription{ Type: pion.SDPTypePranswer, SDP: answer.Sdp, @@ -139,17 +139,17 @@ func Dial(rawURL string) (core.Producer, error) { return } - if err = prod.SetAnswer(answer.Sdp); err != nil { + if err = client.conn.SetAnswer(answer.Sdp); err != nil { client.Stop() return } - prod.SDP = answer.Sdp + client.conn.SDP = answer.Sdp } client.api.mqtt.handleCandidate = func(candidate CandidateFrame) { if candidate.Candidate != "" { - prod.AddCandidate(candidate.Candidate) + client.conn.AddCandidate(candidate.Candidate) if err != nil { client.Stop() } @@ -165,7 +165,7 @@ func Dial(rawURL string) (core.Producer, error) { client.Stop() } - prod.Listen(func(msg any) { + client.conn.Listen(func(msg any) { switch msg := msg.(type) { case *pion.ICECandidate: _ = sendOffer.Wait() @@ -186,13 +186,14 @@ func Dial(rawURL string) (core.Producer, error) { }) // Create offer - offer, err := prod.CreateOffer(client.api.medias) + offer, err := client.conn.CreateOffer(client.api.medias) if err != nil { client.api.Close() return nil, err } // horter sdp, remove a=extmap... line, device ONLY allow 8KB json payload + // https://github.com/tuya/webrtc-demo-go/blob/04575054f18ccccb6bc9d82939dd46d449544e20/static/js/main.js#L224 re := regexp.MustCompile(`\r\na=extmap[^\r\n]*`) offer = re.ReplaceAllString(offer, "") @@ -209,23 +210,38 @@ func Dial(rawURL string) (core.Producer, error) { } func (c *Client) GetMedias() []*core.Media { - return c.prod.GetMedias() + return c.conn.GetMedias() } func (c *Client) GetTrack(media *core.Media, codec *core.Codec) (*core.Receiver, error) { - return c.prod.GetTrack(media, codec) + return c.conn.GetTrack(media, codec) } func (c *Client) AddTrack(media *core.Media, codec *core.Codec, track *core.Receiver) error { - if prod, ok := c.prod.(*webrtc.Conn); ok { - return prod.AddTrack(media, codec, track) + // RepackG711 will not work, so add default logic without repacking + + payloadType := codec.PayloadType + + localTrack := c.conn.GetSenderTrack(media.ID) + if localTrack == nil { + return errors.New("webrtc: can't get track") } + sender := core.NewSender(media, codec) + sender.Handler = func(packet *rtp.Packet) { + c.conn.Send += packet.MarshalSize() + //important to send with remote PayloadType + _ = localTrack.WriteRTP(payloadType, packet) + } + + sender.HandleRTP(track) + c.conn.Senders = append(c.conn.Senders, sender) + return nil } func (c *Client) Start() error { - return c.prod.Start() + return c.conn.Start() } func (c *Client) Stop() error { @@ -236,8 +252,8 @@ func (c *Client) Stop() error { close(c.done) } - if c.prod != nil { - _ = c.prod.Stop() + if c.conn != nil { + _ = c.conn.Stop() } if c.api != nil { @@ -248,9 +264,5 @@ func (c *Client) Stop() error { } func (c *Client) MarshalJSON() ([]byte, error) { - if webrtcProd, ok := c.prod.(*webrtc.Conn); ok { - return webrtcProd.MarshalJSON() - } - - return json.Marshal(c.prod) + return c.conn.MarshalJSON() } diff --git a/pkg/webrtc/conn.go b/pkg/webrtc/conn.go index 092b05c8..f853bf43 100644 --- a/pkg/webrtc/conn.go +++ b/pkg/webrtc/conn.go @@ -161,16 +161,7 @@ func (c *Conn) AddCandidate(candidate string) error { return c.pc.AddICECandidate(webrtc.ICECandidateInit{Candidate: candidate}) } -func (c *Conn) getTranseiver(mid string) *webrtc.RTPTransceiver { - for _, tr := range c.pc.GetTransceivers() { - if tr.Mid() == mid { - return tr - } - } - return nil -} - -func (c *Conn) getSenderTrack(mid string) *Track { +func (c *Conn) GetSenderTrack(mid string) *Track { if tr := c.getTranseiver(mid); tr != nil { if s := tr.Sender(); s != nil { if t := s.Track().(*Track); t != nil { @@ -181,6 +172,15 @@ func (c *Conn) getSenderTrack(mid string) *Track { return nil } +func (c *Conn) getTranseiver(mid string) *webrtc.RTPTransceiver { + for _, tr := range c.pc.GetTransceivers() { + if tr.Mid() == mid { + return tr + } + } + return nil +} + func (c *Conn) getMediaCodec(remote *webrtc.TrackRemote) (*core.Media, *core.Codec) { for _, tr := range c.pc.GetTransceivers() { // search Transeiver for this TrackRemote @@ -209,7 +209,7 @@ func (c *Conn) getMediaCodec(remote *webrtc.TrackRemote) (*core.Media, *core.Cod // check GetTrack panic(core.Caller()) - return nil, nil + // return nil, nil } func sanitizeIP6(host string) string { diff --git a/pkg/webrtc/consumer.go b/pkg/webrtc/consumer.go index ebc3a008..767394df 100644 --- a/pkg/webrtc/consumer.go +++ b/pkg/webrtc/consumer.go @@ -32,7 +32,7 @@ func (c *Conn) AddTrack(media *core.Media, codec *core.Codec, track *core.Receiv panic(core.Caller()) } - localTrack := c.getSenderTrack(media.ID) + localTrack := c.GetSenderTrack(media.ID) if localTrack == nil { return errors.New("webrtc: can't get track") }