From a0e04fb70ea6598859c2f3114cd7700e606f0201 Mon Sep 17 00:00:00 2001 From: Alexey Khit Date: Sat, 25 Feb 2023 19:22:40 +0300 Subject: [PATCH] Fix WebRTC client --- cmd/webrtc/client.go | 6 ++++-- cmd/webrtc/server.go | 2 +- cmd/webrtc/webrtc.go | 14 +++++++++----- pkg/webrtc/client.go | 32 ++++++++++++++++++++++++++++++-- pkg/webrtc/conn.go | 15 +++++++++------ pkg/webrtc/producer.go | 10 ++++++++-- 6 files changed, 61 insertions(+), 18 deletions(-) diff --git a/cmd/webrtc/client.go b/cmd/webrtc/client.go index 6f1b7e39..c72d827b 100644 --- a/cmd/webrtc/client.go +++ b/cmd/webrtc/client.go @@ -39,7 +39,7 @@ func asyncClient(url string) (streamer.Producer, error) { }() // 2. Create PeerConnection - pc, err := newPeerConnection() + pc, err := newPeerConnection(false) if err != nil { log.Error().Err(err).Caller().Send() return nil, err @@ -115,7 +115,7 @@ func asyncClient(url string) (streamer.Producer, error) { // syncClient - support WebRTC-HTTP Egress Protocol (WHEP) func syncClient(url string) (streamer.Producer, error) { // 2. Create PeerConnection - pc, err := newPeerConnection() + pc, err := newPeerConnection(false) if err != nil { log.Error().Err(err).Caller().Send() return nil, err @@ -136,6 +136,8 @@ func syncClient(url string) (streamer.Producer, error) { } client := http.Client{Timeout: time.Second * 5000} + defer client.CloseIdleConnections() + res, err := client.Do(req) if err != nil { return nil, err diff --git a/cmd/webrtc/server.go b/cmd/webrtc/server.go index 0f56ca4f..1fd1885c 100644 --- a/cmd/webrtc/server.go +++ b/cmd/webrtc/server.go @@ -139,7 +139,7 @@ func inputWebRTC(w http.ResponseWriter, r *http.Request) { log.Trace().Msgf("[webrtc] WHIP offer\n%s", offer) - pc, err := newPeerConnection() + pc, err := newPeerConnection(true) if err != nil { log.Error().Err(err).Caller().Send() http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/cmd/webrtc/webrtc.go b/cmd/webrtc/webrtc.go index dd41a2a9..0c6db9dd 100644 --- a/cmd/webrtc/webrtc.go +++ b/cmd/webrtc/webrtc.go @@ -50,8 +50,12 @@ func Init() { SDPSemantics: pion.SDPSemanticsUnifiedPlanWithFallback, } - newPeerConnection = func() (*pion.PeerConnection, error) { - return pionAPI.NewPeerConnection(pionConf) + newPeerConnection = func(isServer bool) (*pion.PeerConnection, error) { + if isServer { + return pionAPI.NewPeerConnection(pionConf) + } else { + return pion.NewPeerConnection(pionConf) + } } for _, candidate := range cfg.Mod.Candidates { @@ -73,7 +77,7 @@ func Init() { var Port string var log zerolog.Logger -var newPeerConnection func() (*pion.PeerConnection, error) +var newPeerConnection func(isServer bool) (*pion.PeerConnection, error) func asyncHandler(tr *api.Transport, msg *api.Message) error { src := tr.Request.URL.Query().Get("src") @@ -85,7 +89,7 @@ func asyncHandler(tr *api.Transport, msg *api.Message) error { log.Debug().Str("url", src).Msg("[webrtc] new consumer") // create new PeerConnection instance - pc, err := newPeerConnection() + pc, err := newPeerConnection(true) if err != nil { log.Error().Err(err).Caller().Send() return err @@ -155,7 +159,7 @@ func asyncHandler(tr *api.Transport, msg *api.Message) error { } func ExchangeSDP(stream *streams.Stream, offer string, userAgent string) (answer string, err error) { - pc, err := newPeerConnection() + pc, err := newPeerConnection(true) if err != nil { log.Error().Err(err).Caller().Send() return diff --git a/pkg/webrtc/client.go b/pkg/webrtc/client.go index 946ae3f6..715b8a11 100644 --- a/pkg/webrtc/client.go +++ b/pkg/webrtc/client.go @@ -1,6 +1,10 @@ package webrtc -import "github.com/pion/webrtc/v3" +import ( + "github.com/AlexxIT/go2rtc/pkg/streamer" + "github.com/pion/sdp/v3" + "github.com/pion/webrtc/v3" +) func (c *Conn) CreateOffer() (string, error) { init := webrtc.RTPTransceiverInit{Direction: webrtc.RTPTransceiverDirectionRecvonly} @@ -30,5 +34,29 @@ func (c *Conn) CreateCompleteOffer() (string, error) { func (c *Conn) SetAnswer(answer string) (err error) { desc := webrtc.SessionDescription{SDP: answer, Type: webrtc.SDPTypeAnswer} - return c.pc.SetRemoteDescription(desc) + if err = c.pc.SetRemoteDescription(desc); err != nil { + return + } + + sd := &sdp.SessionDescription{} + if err = sd.Unmarshal([]byte(answer)); err != nil { + return + } + + medias := streamer.UnmarshalMedias(sd.MediaDescriptions) + + // sort medias, so video will always be before audio + // and ignore application media from Hass default lovelace card + for _, media := range medias { + if media.Kind == streamer.KindVideo { + c.medias = append(c.medias, media) + } + } + for _, media := range medias { + if media.Kind == streamer.KindAudio { + c.medias = append(c.medias, media) + } + } + + return nil } diff --git a/pkg/webrtc/conn.go b/pkg/webrtc/conn.go index 4473f2f9..f11f4892 100644 --- a/pkg/webrtc/conn.go +++ b/pkg/webrtc/conn.go @@ -19,10 +19,11 @@ type Conn struct { send int offer string + start chan struct{} } func NewConn(pc *webrtc.PeerConnection) *Conn { - c := &Conn{pc: pc} + c := &Conn{pc: pc, start: make(chan struct{})} pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { c.Fire(candidate) @@ -64,14 +65,11 @@ func NewConn(pc *webrtc.PeerConnection) *Conn { pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { c.Fire(state) - // TODO: rewrite? switch state { - case webrtc.PeerConnectionStateDisconnected: + case webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed: // disconnect event comes earlier, than failed // but it comes only for success connections - _ = pc.Close() - case webrtc.PeerConnectionStateFailed: - _ = pc.Close() + _ = c.Close() } }) @@ -79,6 +77,11 @@ func NewConn(pc *webrtc.PeerConnection) *Conn { } func (c *Conn) Close() error { + // unblocked write to chan + select { + case c.start <- struct{}{}: + default: + } return c.pc.Close() } diff --git a/pkg/webrtc/producer.go b/pkg/webrtc/producer.go index c0d0574d..1055f803 100644 --- a/pkg/webrtc/producer.go +++ b/pkg/webrtc/producer.go @@ -1,6 +1,8 @@ package webrtc -import "github.com/AlexxIT/go2rtc/pkg/streamer" +import ( + "github.com/AlexxIT/go2rtc/pkg/streamer" +) func (c *Conn) GetTrack(media *streamer.Media, codec *streamer.Codec) *streamer.Track { for _, track := range c.tracks { @@ -8,10 +10,14 @@ func (c *Conn) GetTrack(media *streamer.Media, codec *streamer.Codec) *streamer. return track } } - return nil + + track := streamer.NewTrack(codec, media.Direction) + c.tracks = append(c.tracks, track) + return track } func (c *Conn) Start() error { + <-c.start return nil }