diff --git a/internal/hass/api.go b/internal/hass/api.go index 6aa49197..4628cc11 100644 --- a/internal/hass/api.go +++ b/internal/hass/api.go @@ -30,15 +30,12 @@ func apiStream(w http.ResponseWriter, r *http.Request) { // 1. link to go2rtc stream: rtsp://...:8554/{stream_name} // 2. static link to Hass camera // 3. dynamic link to Hass camera - stream := streams.Get(v.Name) - if stream == nil { - stream = streams.NewTemplate(v.Name, v.Channels.First.Url) + if streams.Patch(v.Name, v.Channels.First.Url) != nil { + apiOK(w, r) + } else { + http.Error(w, "", http.StatusBadRequest) } - stream.SetSource(v.Channels.First.Url) - - apiOK(w, r) - // /stream/{id}/channel/0/webrtc default: i := strings.IndexByte(r.RequestURI[8:], '/') diff --git a/internal/hls/ws.go b/internal/hls/ws.go index 96c6eb64..81a1dd72 100644 --- a/internal/hls/ws.go +++ b/internal/hls/ws.go @@ -2,18 +2,18 @@ package hls import ( "errors" + "time" + "github.com/AlexxIT/go2rtc/internal/api" "github.com/AlexxIT/go2rtc/internal/api/ws" "github.com/AlexxIT/go2rtc/internal/streams" "github.com/AlexxIT/go2rtc/pkg/core" "github.com/AlexxIT/go2rtc/pkg/mp4" "github.com/AlexxIT/go2rtc/pkg/tcp" - "time" ) func handlerWSHLS(tr *ws.Transport, msg *ws.Message) error { - src := tr.Request.URL.Query().Get("src") - stream := streams.Get(src) + stream := streams.GetOrPatch(tr.Request.URL.Query()) if stream == nil { return errors.New(api.StreamNotFound) } diff --git a/internal/mjpeg/init.go b/internal/mjpeg/init.go index 975e2a59..70ab8e00 100644 --- a/internal/mjpeg/init.go +++ b/internal/mjpeg/init.go @@ -2,6 +2,11 @@ package mjpeg import ( "errors" + "io" + "net/http" + "strconv" + "time" + "github.com/AlexxIT/go2rtc/internal/api" "github.com/AlexxIT/go2rtc/internal/api/ws" "github.com/AlexxIT/go2rtc/internal/ffmpeg" @@ -11,10 +16,6 @@ import ( "github.com/AlexxIT/go2rtc/pkg/mjpeg" "github.com/AlexxIT/go2rtc/pkg/tcp" "github.com/rs/zerolog/log" - "io" - "net/http" - "strconv" - "time" ) func Init() { @@ -158,8 +159,7 @@ func inputMjpeg(w http.ResponseWriter, r *http.Request) { } func handlerWS(tr *ws.Transport, _ *ws.Message) error { - src := tr.Request.URL.Query().Get("src") - stream := streams.Get(src) + stream := streams.GetOrPatch(tr.Request.URL.Query()) if stream == nil { return errors.New(api.StreamNotFound) } diff --git a/internal/mp4/ws.go b/internal/mp4/ws.go index eff94071..b9071f31 100644 --- a/internal/mp4/ws.go +++ b/internal/mp4/ws.go @@ -2,6 +2,7 @@ package mp4 import ( "errors" + "github.com/AlexxIT/go2rtc/internal/api" "github.com/AlexxIT/go2rtc/internal/api/ws" "github.com/AlexxIT/go2rtc/internal/streams" @@ -10,8 +11,7 @@ import ( ) func handlerWSMSE(tr *ws.Transport, msg *ws.Message) error { - src := tr.Request.URL.Query().Get("src") - stream := streams.Get(src) + stream := streams.GetOrPatch(tr.Request.URL.Query()) if stream == nil { return errors.New(api.StreamNotFound) } @@ -58,8 +58,7 @@ func handlerWSMSE(tr *ws.Transport, msg *ws.Message) error { } func handlerWSMP4(tr *ws.Transport, msg *ws.Message) error { - src := tr.Request.URL.Query().Get("src") - stream := streams.Get(src) + stream := streams.GetOrPatch(tr.Request.URL.Query()) if stream == nil { return errors.New(api.StreamNotFound) } diff --git a/internal/streams/producer.go b/internal/streams/producer.go index 6306e0af..6d3cf2b9 100644 --- a/internal/streams/producer.go +++ b/internal/streams/producer.go @@ -3,10 +3,11 @@ package streams import ( "encoding/json" "errors" - "github.com/AlexxIT/go2rtc/pkg/core" "strings" "sync" "time" + + "github.com/AlexxIT/go2rtc/pkg/core" ) type state byte @@ -35,6 +36,24 @@ type Producer struct { workerID int } +const SourceTemplate = "{input}" + +func NewProducer(source string) *Producer { + if strings.Contains(source, SourceTemplate) { + return &Producer{template: source} + } + + return &Producer{url: source} +} + +func (p *Producer) SetSource(s string) { + if p.template == "" { + p.url = s + } else { + p.url = strings.Replace(p.template, SourceTemplate, s, 1) + } +} + func (p *Producer) Dial() error { p.mu.Lock() defer p.mu.Unlock() @@ -112,13 +131,6 @@ func (p *Producer) AddTrack(media *core.Media, codec *core.Codec, track *core.Re return nil } -func (p *Producer) SetSource(s string) { - if p.template == "" { - p.template = p.url - } - p.url = strings.Replace(p.template, "{input}", s, 1) -} - func (p *Producer) MarshalJSON() ([]byte, error) { if p.conn != nil { return json.Marshal(p.conn) diff --git a/internal/streams/stream.go b/internal/streams/stream.go index 468f8889..58fe6483 100644 --- a/internal/streams/stream.go +++ b/internal/streams/stream.go @@ -3,10 +3,11 @@ package streams import ( "encoding/json" "errors" - "github.com/AlexxIT/go2rtc/pkg/core" "strings" "sync" "sync/atomic" + + "github.com/AlexxIT/go2rtc/pkg/core" ) type Stream struct { @@ -19,15 +20,13 @@ type Stream struct { func NewStream(source any) *Stream { switch source := source.(type) { case string: - s := new(Stream) - prod := &Producer{url: source} - s.producers = append(s.producers, prod) - return s + return &Stream{ + producers: []*Producer{NewProducer(source)}, + } case []any: s := new(Stream) for _, source := range source { - prod := &Producer{url: source.(string)} - s.producers = append(s.producers, prod) + s.producers = append(s.producers, NewProducer(source.(string))) } return s case map[string]any: diff --git a/internal/streams/stream_test.go b/internal/streams/stream_test.go index 86dc92c2..b2e88dc6 100644 --- a/internal/streams/stream_test.go +++ b/internal/streams/stream_test.go @@ -1,19 +1,38 @@ package streams import ( - "github.com/stretchr/testify/require" + "net/url" "testing" + + "github.com/AlexxIT/go2rtc/pkg/core" + "github.com/stretchr/testify/require" ) -func TestTemplate(t *testing.T) { - source1 := "does not matter" - - stream1 := New("from_yaml", source1) +func TestRecursion(t *testing.T) { + // create stream with some source + stream1 := New("from_yaml", "does not matter") require.Len(t, streams, 1) - stream2 := NewTemplate("camera.from_hass", "rtsp://localhost:8554/from_yaml?video") + // ask another unnamed stream that links go2rtc + query, err := url.ParseQuery("src=rtsp://localhost:8554/from_yaml?video") + require.Nil(t, err) + stream2 := GetOrPatch(query) + // check stream is same require.Equal(t, stream1, stream2) - require.Equal(t, stream2.producers[0].url, source1) + // check stream urls is same + require.Equal(t, stream1.producers[0].url, stream2.producers[0].url) require.Len(t, streams, 2) } + +func TestTempate(t *testing.T) { + HandleFunc("rtsp", func(url string) (core.Producer, error) { return nil, nil }) // bypass HasProducer + + // config from yaml + stream1 := New("camera.from_hass", "ffmpeg:{input}#video=copy") + // request from hass + stream2 := Patch("camera.from_hass", "rtsp://example.com") + + require.Equal(t, stream1, stream2) + require.Equal(t, "ffmpeg:rtsp://example.com#video=copy", stream1.producers[0].url) +} diff --git a/internal/streams/init.go b/internal/streams/streams.go similarity index 64% rename from internal/streams/init.go rename to internal/streams/streams.go index 8cf31ec8..7aa5f58f 100644 --- a/internal/streams/init.go +++ b/internal/streams/streams.go @@ -1,12 +1,14 @@ package streams import ( + "net/http" + "net/url" + "sync" + "github.com/AlexxIT/go2rtc/internal/api" "github.com/AlexxIT/go2rtc/internal/app" "github.com/AlexxIT/go2rtc/internal/app/store" "github.com/rs/zerolog" - "net/http" - "net/url" ) func Init() { @@ -39,18 +41,56 @@ func New(name string, source any) *Stream { return stream } -func NewTemplate(name string, source any) *Stream { +func Patch(name string, source string) *Stream { + streamsMu.Lock() + defer streamsMu.Unlock() + // check if source links to some stream name from go2rtc - if rawURL, ok := source.(string); ok { - if u, err := url.Parse(rawURL); err == nil && u.Scheme == "rtsp" && len(u.Path) > 1 { - if stream, ok := streams[u.Path[1:]]; ok { - streams[name] = stream - return stream - } + if u, err := url.Parse(source); err == nil && u.Scheme == "rtsp" && len(u.Path) > 1 { + rtspName := u.Path[1:] + if stream, ok := streams[rtspName]; ok { + // link (alias) stream[name] to stream[rtspName] + streams[name] = stream + return stream } } - return New(name, "{input}") + // check if src has supported scheme + if !HasProducer(source) { + return nil + } + + // check an existing stream with this name + if stream, ok := streams[name]; ok { + stream.SetSource(source) + return stream + } + + // create new stream with this name + return New(name, source) +} + +func GetOrPatch(query url.Values) *Stream { + // check if src param exists + source := query.Get("src") + if source == "" { + return nil + } + + // check if src is stream name + if stream, ok := streams[source]; ok { + return stream + } + + // check if name param provided + if name := query.Get("name"); name == "" { + log.Info().Msgf("[streams] create new stream url=%s", source) + + return Patch(name, source) + } + + // return new stream with src as name + return Patch(source, source) } func GetAll() (names []string) { @@ -91,11 +131,7 @@ func streamsHandler(w http.ResponseWriter, r *http.Request) { } // support {input} templates: https://github.com/AlexxIT/go2rtc#module-hass - stream := Get(name) - if stream == nil { - stream = NewTemplate(name, src) - } - stream.SetSource(src) + Patch(name, src) case "POST": // with dst - redirect source to dst @@ -120,3 +156,4 @@ func streamsHandler(w http.ResponseWriter, r *http.Request) { var log zerolog.Logger var streams = map[string]*Stream{} +var streamsMu sync.Mutex diff --git a/internal/webrtc/server.go b/internal/webrtc/server.go index a3c1ff8d..c50225ee 100644 --- a/internal/webrtc/server.go +++ b/internal/webrtc/server.go @@ -2,15 +2,17 @@ package webrtc import ( "encoding/json" - "github.com/AlexxIT/go2rtc/internal/streams" - "github.com/AlexxIT/go2rtc/pkg/core" - "github.com/AlexxIT/go2rtc/pkg/webrtc" - pion "github.com/pion/webrtc/v3" "io" "net/http" "strconv" "strings" "time" + + "github.com/AlexxIT/go2rtc/internal/api" + "github.com/AlexxIT/go2rtc/internal/streams" + "github.com/AlexxIT/go2rtc/pkg/core" + "github.com/AlexxIT/go2rtc/pkg/webrtc" + pion "github.com/pion/webrtc/v3" ) const MimeSDP = "application/sdp" @@ -140,7 +142,8 @@ func inputWebRTC(w http.ResponseWriter, r *http.Request) { dst := r.URL.Query().Get("dst") stream := streams.Get(dst) if stream == nil { - stream = streams.New(dst, nil) + http.Error(w, api.StreamNotFound, http.StatusNotFound) + return } // 1. Get offer diff --git a/internal/webrtc/init.go b/internal/webrtc/webrtc.go similarity index 99% rename from internal/webrtc/init.go rename to internal/webrtc/webrtc.go index 6720e984..df92f398 100644 --- a/internal/webrtc/init.go +++ b/internal/webrtc/webrtc.go @@ -2,6 +2,8 @@ package webrtc import ( "errors" + "net" + "github.com/AlexxIT/go2rtc/internal/api" "github.com/AlexxIT/go2rtc/internal/api/ws" "github.com/AlexxIT/go2rtc/internal/app" @@ -10,7 +12,6 @@ import ( "github.com/AlexxIT/go2rtc/pkg/webrtc" pion "github.com/pion/webrtc/v3" "github.com/rs/zerolog" - "net" ) func Init() { @@ -91,7 +92,7 @@ func asyncHandler(tr *ws.Transport, msg *ws.Message) error { query := tr.Request.URL.Query() if name := query.Get("src"); name != "" { - stream = streams.Get(name) + stream = streams.GetOrPatch(query) mode = core.ModePassiveConsumer log.Debug().Str("src", name).Msg("[webrtc] new consumer") } else if name = query.Get("dst"); name != "" {