From fa8d4e4807c044718b0d7e6436da2c4ea7c4f031 Mon Sep 17 00:00:00 2001 From: Alexey Khit Date: Thu, 29 Jun 2023 22:52:59 +0300 Subject: [PATCH] Remove on the fly stream creation for security reason --- internal/hls/hls.go | 2 +- internal/mjpeg/init.go | 6 +++--- internal/mp4/mp4.go | 4 ++-- internal/mp4/ws.go | 4 ++-- internal/streams/init.go | 14 -------------- internal/webrtc/init.go | 2 +- 6 files changed, 9 insertions(+), 23 deletions(-) diff --git a/internal/hls/hls.go b/internal/hls/hls.go index 06e841cd..a92b46a7 100644 --- a/internal/hls/hls.go +++ b/internal/hls/hls.go @@ -63,7 +63,7 @@ func handlerStream(w http.ResponseWriter, r *http.Request) { } src := r.URL.Query().Get("src") - stream := streams.GetOrNew(src) + stream := streams.Get(src) if stream == nil { http.Error(w, api.StreamNotFound, http.StatusNotFound) return diff --git a/internal/mjpeg/init.go b/internal/mjpeg/init.go index b804cffb..975e2a59 100644 --- a/internal/mjpeg/init.go +++ b/internal/mjpeg/init.go @@ -26,7 +26,7 @@ func Init() { func handlerKeyframe(w http.ResponseWriter, r *http.Request) { src := r.URL.Query().Get("src") - stream := streams.GetOrNew(src) + stream := streams.Get(src) if stream == nil { http.Error(w, api.StreamNotFound, http.StatusNotFound) return @@ -91,7 +91,7 @@ func handlerStream(w http.ResponseWriter, r *http.Request) { func outputMjpeg(w http.ResponseWriter, r *http.Request) { src := r.URL.Query().Get("src") - stream := streams.GetOrNew(src) + stream := streams.Get(src) if stream == nil { http.Error(w, api.StreamNotFound, http.StatusNotFound) return @@ -159,7 +159,7 @@ func inputMjpeg(w http.ResponseWriter, r *http.Request) { func handlerWS(tr *ws.Transport, _ *ws.Message) error { src := tr.Request.URL.Query().Get("src") - stream := streams.GetOrNew(src) + stream := streams.Get(src) if stream == nil { return errors.New(api.StreamNotFound) } diff --git a/internal/mp4/mp4.go b/internal/mp4/mp4.go index 70dccd42..8b006ad4 100644 --- a/internal/mp4/mp4.go +++ b/internal/mp4/mp4.go @@ -40,7 +40,7 @@ func handlerKeyframe(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() src := query.Get("src") - stream := streams.GetOrNew(src) + stream := streams.Get(src) if stream == nil { http.Error(w, api.StreamNotFound, http.StatusNotFound) return @@ -101,7 +101,7 @@ func handlerMP4(w http.ResponseWriter, r *http.Request) { } src := query.Get("src") - stream := streams.GetOrNew(src) + stream := streams.Get(src) if stream == nil { http.Error(w, api.StreamNotFound, http.StatusNotFound) return diff --git a/internal/mp4/ws.go b/internal/mp4/ws.go index 1ef3f02f..08d7da02 100644 --- a/internal/mp4/ws.go +++ b/internal/mp4/ws.go @@ -13,7 +13,7 @@ import ( func handlerWSMSE(tr *ws.Transport, msg *ws.Message) error { src := tr.Request.URL.Query().Get("src") - stream := streams.GetOrNew(src) + stream := streams.Get(src) if stream == nil { return errors.New(api.StreamNotFound) } @@ -60,7 +60,7 @@ func handlerWSMSE(tr *ws.Transport, msg *ws.Message) error { func handlerWSMP4(tr *ws.Transport, msg *ws.Message) error { src := tr.Request.URL.Query().Get("src") - stream := streams.GetOrNew(src) + stream := streams.Get(src) if stream == nil { return errors.New(api.StreamNotFound) } diff --git a/internal/streams/init.go b/internal/streams/init.go index 0cb8822b..8cf31ec8 100644 --- a/internal/streams/init.go +++ b/internal/streams/init.go @@ -53,20 +53,6 @@ func NewTemplate(name string, source any) *Stream { return New(name, "{input}") } -func GetOrNew(src string) *Stream { - if stream, ok := streams[src]; ok { - return stream - } - - if !HasProducer(src) { - return nil - } - - log.Info().Str("url", src).Msg("[streams] create new stream") - - return New(src, src) -} - func GetAll() (names []string) { for name := range streams { names = append(names, name) diff --git a/internal/webrtc/init.go b/internal/webrtc/init.go index 07f219b7..6720e984 100644 --- a/internal/webrtc/init.go +++ b/internal/webrtc/init.go @@ -91,7 +91,7 @@ func asyncHandler(tr *ws.Transport, msg *ws.Message) error { query := tr.Request.URL.Query() if name := query.Get("src"); name != "" { - stream = streams.GetOrNew(name) + stream = streams.Get(name) mode = core.ModePassiveConsumer log.Debug().Str("src", name).Msg("[webrtc] new consumer") } else if name = query.Get("dst"); name != "" {