diff --git a/internal/streams/api.go b/internal/streams/api.go index 53879252..d6142eb0 100644 --- a/internal/streams/api.go +++ b/internal/streams/api.go @@ -130,21 +130,14 @@ func apiStreamsDOT(w http.ResponseWriter, r *http.Request) { } func apiPreload(w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - src := query.Get("src") - // GET - return all preloads if r.Method == "GET" { api.ResponseJSON(w, GetPreloads()) return } - // check if stream exists - stream := Get(src) - if stream == nil { - http.Error(w, "", http.StatusNotFound) - return - } + query := r.URL.Query() + src := query.Get("src") switch r.Method { case "PUT": @@ -159,7 +152,7 @@ func apiPreload(w http.ResponseWriter, r *http.Request) { rawQuery := query.Encode() - if err := AddPreload(stream, rawQuery); err != nil { + if err := AddPreload(src, rawQuery); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -169,7 +162,7 @@ func apiPreload(w http.ResponseWriter, r *http.Request) { } case "DELETE": - if err := DelPreload(stream); err != nil { + if err := DelPreload(src); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } diff --git a/internal/streams/preload.go b/internal/streams/preload.go index 313c0c73..447b5ac3 100644 --- a/internal/streams/preload.go +++ b/internal/streams/preload.go @@ -1,28 +1,24 @@ package streams import ( - "errors" + "fmt" + "maps" "net/url" "sync" "github.com/AlexxIT/go2rtc/pkg/probe" ) -type preload struct { - cons *probe.Probe - query string +type Preload struct { + stream *Stream // Don't output the stream to JSON to not worry about its secrets. + Cons *probe.Probe `json:"consumer"` + Query string `json:"query"` } -var preloads = map[*Stream]*preload{} +var preloads = map[string]*Preload{} var preloadsMu sync.Mutex -func Preload(stream *Stream, rawQuery string) { - if err := AddPreload(stream, rawQuery); err != nil { - log.Error().Err(err).Caller().Send() - } -} - -func AddPreload(stream *Stream, rawQuery string) error { +func AddPreload(name, rawQuery string) error { if rawQuery == "" { rawQuery = "video&audio" } @@ -35,51 +31,39 @@ func AddPreload(stream *Stream, rawQuery string) error { preloadsMu.Lock() defer preloadsMu.Unlock() - if p := preloads[stream]; p != nil { - stream.RemoveConsumer(p.cons) + if p := preloads[name]; p != nil { + p.stream.RemoveConsumer(p.Cons) } + stream := Get(name) + if stream == nil { + return fmt.Errorf("streams: stream not found: %s", name) + } cons := probe.Create("preload", query) if err = stream.AddConsumer(cons); err != nil { return err } - preloads[stream] = &preload{cons: cons, query: rawQuery} + preloads[name] = &Preload{stream: stream, Cons: cons, Query: rawQuery} return nil } -func DelPreload(stream *Stream) error { +func DelPreload(name string) error { preloadsMu.Lock() defer preloadsMu.Unlock() - if p := preloads[stream]; p != nil { - stream.RemoveConsumer(p.cons) - delete(preloads, stream) + if p := preloads[name]; p != nil { + p.stream.RemoveConsumer(p.Cons) + delete(preloads, name) return nil } - return errors.New("streams: preload not found") + return fmt.Errorf("streams: preload not found: %s", name) } -func GetPreloads() map[string]string { - streamsMu.Lock() - defer streamsMu.Unlock() - +func GetPreloads() map[string]*Preload { preloadsMu.Lock() defer preloadsMu.Unlock() - - // build reverse lookup: stream -> name - streamNames := make(map[*Stream]string, len(streams)) - for name, stream := range streams { - streamNames[stream] = name - } - - result := make(map[string]string, len(preloads)) - for stream, p := range preloads { - if name, ok := streamNames[stream]; ok { - result[name] = p.query - } - } - return result + return maps.Clone(preloads) } diff --git a/internal/streams/streams.go b/internal/streams/streams.go index 433f9d36..f3b8df03 100644 --- a/internal/streams/streams.go +++ b/internal/streams/streams.go @@ -43,8 +43,8 @@ func Init() { } } for name, rawQuery := range cfg.Preload { - if stream := Get(name); stream != nil { - Preload(stream, rawQuery) + if err := AddPreload(name, rawQuery); err != nil { + log.Error().Err(err).Caller().Send() } } })