diff --git a/internal/hass/api.go b/internal/hass/api.go index e3de23b3..9f110fc8 100644 --- a/internal/hass/api.go +++ b/internal/hass/api.go @@ -30,10 +30,10 @@ func apiStream(w http.ResponseWriter, r *http.Request) { // 1. link to go2rtc stream: rtsp://...:8554/{stream_name} // 2. static link to Hass camera // 3. dynamic link to Hass camera - if streams.Patch(v.Name, v.Channels.First.Url) != nil { + if _, err := streams.Patch(v.Name, v.Channels.First.Url); err == nil { apiOK(w, r) } else { - http.Error(w, "", http.StatusBadRequest) + http.Error(w, err.Error(), http.StatusBadRequest) } // /stream/{id}/channel/0/webrtc diff --git a/internal/hls/ws.go b/internal/hls/ws.go index 608f515f..00eedfe2 100644 --- a/internal/hls/ws.go +++ b/internal/hls/ws.go @@ -11,7 +11,7 @@ import ( ) func handlerWSHLS(tr *ws.Transport, msg *ws.Message) error { - stream := streams.GetOrPatch(tr.Request.URL.Query()) + stream, _ := streams.GetOrPatch(tr.Request.URL.Query()) if stream == nil { return errors.New(api.StreamNotFound) } diff --git a/internal/mjpeg/init.go b/internal/mjpeg/init.go index 27c557e4..2fa9fa32 100644 --- a/internal/mjpeg/init.go +++ b/internal/mjpeg/init.go @@ -36,7 +36,7 @@ func Init() { var log zerolog.Logger func handlerKeyframe(w http.ResponseWriter, r *http.Request) { - stream := streams.GetOrPatch(r.URL.Query()) + stream, _ := streams.GetOrPatch(r.URL.Query()) if stream == nil { http.Error(w, api.StreamNotFound, http.StatusNotFound) return @@ -145,7 +145,7 @@ func inputMjpeg(w http.ResponseWriter, r *http.Request) { } func handlerWS(tr *ws.Transport, _ *ws.Message) error { - stream := streams.GetOrPatch(tr.Request.URL.Query()) + stream, _ := streams.GetOrPatch(tr.Request.URL.Query()) if stream == nil { return errors.New(api.StreamNotFound) } diff --git a/internal/mp4/mp4.go b/internal/mp4/mp4.go index cca5220c..d0a6d971 100644 --- a/internal/mp4/mp4.go +++ b/internal/mp4/mp4.go @@ -91,7 +91,7 @@ func handlerMP4(w http.ResponseWriter, r *http.Request) { return } - stream := streams.GetOrPatch(query) + stream, _ := streams.GetOrPatch(query) if stream == nil { http.Error(w, api.StreamNotFound, http.StatusNotFound) return diff --git a/internal/mp4/ws.go b/internal/mp4/ws.go index c880fb58..c1afac24 100644 --- a/internal/mp4/ws.go +++ b/internal/mp4/ws.go @@ -11,7 +11,7 @@ import ( ) func handlerWSMSE(tr *ws.Transport, msg *ws.Message) error { - stream := streams.GetOrPatch(tr.Request.URL.Query()) + stream, _ := streams.GetOrPatch(tr.Request.URL.Query()) if stream == nil { return errors.New(api.StreamNotFound) } @@ -43,7 +43,7 @@ func handlerWSMSE(tr *ws.Transport, msg *ws.Message) error { } func handlerWSMP4(tr *ws.Transport, msg *ws.Message) error { - stream := streams.GetOrPatch(tr.Request.URL.Query()) + stream, _ := streams.GetOrPatch(tr.Request.URL.Query()) if stream == nil { return errors.New(api.StreamNotFound) } diff --git a/internal/streams/api.go b/internal/streams/api.go index 28f09708..c2b93b91 100644 --- a/internal/streams/api.go +++ b/internal/streams/api.go @@ -52,8 +52,8 @@ func apiStreams(w http.ResponseWriter, r *http.Request) { name = src } - if New(name, query["src"]...) == nil { - http.Error(w, "", http.StatusBadRequest) + if _, err := New(name, query["src"]...); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) return } @@ -69,8 +69,8 @@ func apiStreams(w http.ResponseWriter, r *http.Request) { } // support {input} templates: https://github.com/AlexxIT/go2rtc#module-hass - if Patch(name, src) == nil { - http.Error(w, "", http.StatusBadRequest) + if _, err := Patch(name, src); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) } case "POST": diff --git a/internal/streams/streams.go b/internal/streams/streams.go index 633ad2d1..2bc65486 100644 --- a/internal/streams/streams.go +++ b/internal/streams/streams.go @@ -1,6 +1,7 @@ package streams import ( + "errors" "net/url" "sync" "time" @@ -48,10 +49,14 @@ func Init() { }) } -func New(name string, sources ...string) *Stream { +func New(name string, sources ...string) (*Stream, error) { for _, source := range sources { - if Validate(source) != nil { - return nil + if !HasProducer(source) { + return nil, errors.New("streams: source not supported") + } + + if err := Validate(source); err != nil { + return nil, err } } @@ -61,10 +66,10 @@ func New(name string, sources ...string) *Stream { streams[name] = stream streamsMu.Unlock() - return stream + return stream, nil } -func Patch(name string, source string) *Stream { +func Patch(name string, source string) (*Stream, error) { streamsMu.Lock() defer streamsMu.Unlock() @@ -76,7 +81,7 @@ func Patch(name string, source string) *Stream { // link (alias) streams[name] to streams[rtspName] streams[name] = stream } - return stream + return stream, nil } } @@ -85,40 +90,40 @@ func Patch(name string, source string) *Stream { // link (alias) streams[name] to streams[source] streams[name] = stream } - return stream + return stream, nil } // check if src has supported scheme if !HasProducer(source) { - return nil + return nil, errors.New("streams: source not supported") } - if Validate(source) != nil { - return nil + if err := Validate(source); err != nil { + return nil, err } // check an existing stream with this name if stream, ok := streams[name]; ok { stream.SetSource(source) - return stream + return stream, nil } // create new stream with this name stream := NewStream(source) streams[name] = stream - return stream + return stream, nil } -func GetOrPatch(query url.Values) *Stream { +func GetOrPatch(query url.Values) (*Stream, error) { // check if src param exists source := query.Get("src") if source == "" { - return nil + return nil, errors.New("streams: source empty") } // check if src is stream name if stream := Get(source); stream != nil { - return stream + return stream, nil } // check if name param provided diff --git a/internal/webrtc/webrtc.go b/internal/webrtc/webrtc.go index 11e9db89..eca1e12b 100644 --- a/internal/webrtc/webrtc.go +++ b/internal/webrtc/webrtc.go @@ -95,7 +95,7 @@ func asyncHandler(tr *ws.Transport, msg *ws.Message) (err error) { query := tr.Request.URL.Query() if name := query.Get("src"); name != "" { - stream = streams.GetOrPatch(query) + stream, _ = streams.GetOrPatch(query) mode = core.ModePassiveConsumer log.Debug().Str("src", name).Msg("[webrtc] new consumer") } else if name = query.Get("dst"); name != "" {