From a15deedf0deeb6355a5cf334aff5916afb907d5d Mon Sep 17 00:00:00 2001 From: Alex X Date: Fri, 7 Mar 2025 21:44:23 +0300 Subject: [PATCH] Fix YAML patch in some cases #1626 --- internal/app/config.go | 4 +- internal/homekit/api.go | 4 +- internal/homekit/server.go | 2 +- internal/streams/api.go | 4 +- pkg/yaml/yaml.go | 236 ++++++++++++++++++++----------------- pkg/yaml/yaml_test.go | 235 ++++++++++++++++-------------------- 6 files changed, 237 insertions(+), 248 deletions(-) diff --git a/internal/app/config.go b/internal/app/config.go index 8ae6d460..9d4480b7 100644 --- a/internal/app/config.go +++ b/internal/app/config.go @@ -18,7 +18,7 @@ func LoadConfig(v any) { } } -func PatchConfig(key string, value any, path ...string) error { +func PatchConfig(path []string, value any) error { if ConfigPath == "" { return errors.New("config file disabled") } @@ -26,7 +26,7 @@ func PatchConfig(key string, value any, path ...string) error { // empty config is OK b, _ := os.ReadFile(ConfigPath) - b, err := yaml.Patch(b, key, value, path...) + b, err := yaml.Patch(b, path, value) if err != nil { return err } diff --git a/internal/homekit/api.go b/internal/homekit/api.go index 0ee4d057..9f76c2d6 100644 --- a/internal/homekit/api.go +++ b/internal/homekit/api.go @@ -103,7 +103,7 @@ func apiPair(id, url string) error { streams.New(id, conn.URL()) - return app.PatchConfig(id, conn.URL(), "streams") + return app.PatchConfig([]string{"streams", id}, conn.URL()) } func apiUnpair(id string) error { @@ -123,7 +123,7 @@ func apiUnpair(id string) error { streams.Delete(id) - return app.PatchConfig(id, nil, "streams") + return app.PatchConfig([]string{"streams", id}, nil) } func findHomeKitURLs() map[string]*url.URL { diff --git a/internal/homekit/server.go b/internal/homekit/server.go index cb114fea..363a7047 100644 --- a/internal/homekit/server.go +++ b/internal/homekit/server.go @@ -222,7 +222,7 @@ func (s *server) DelPair(conn net.Conn, id string) { } func (s *server) PatchConfig() { - if err := app.PatchConfig("pairings", s.pairings, "homekit", s.stream); err != nil { + if err := app.PatchConfig([]string{"homekit", s.stream, "pairings"}, s.pairings); err != nil { log.Error().Err(err).Msgf( "[homekit] can't save %s pairings=%v", s.stream, s.pairings, ) diff --git a/internal/streams/api.go b/internal/streams/api.go index d6042974..061e61c2 100644 --- a/internal/streams/api.go +++ b/internal/streams/api.go @@ -53,7 +53,7 @@ func apiStreams(w http.ResponseWriter, r *http.Request) { return } - if err := app.PatchConfig(name, query["src"], "streams"); err != nil { + if err := app.PatchConfig([]string{"streams", name}, query["src"]); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) } @@ -96,7 +96,7 @@ func apiStreams(w http.ResponseWriter, r *http.Request) { case "DELETE": delete(streams, src) - if err := app.PatchConfig(src, nil, "streams"); err != nil { + if err := app.PatchConfig([]string{"streams", src}, nil); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) } } diff --git a/pkg/yaml/yaml.go b/pkg/yaml/yaml.go index 70b3baf0..4672cb4c 100644 --- a/pkg/yaml/yaml.go +++ b/pkg/yaml/yaml.go @@ -23,149 +23,157 @@ func Encode(v any, indent int) ([]byte, error) { return b.Bytes(), nil } -// Patch - change key/value pair in YAML file without break formatting -func Patch(src []byte, key string, value any, path ...string) ([]byte, error) { - nodeParent, err := FindParent(src, path...) +func Patch(in []byte, path []string, value any) ([]byte, error) { + out, err := patch(in, path, value) if err != nil { return nil, err } - var dst []byte - - if nodeParent != nil { - dst, err = AddOrReplace(src, key, value, nodeParent) - } else { - dst, err = AddToEnd(src, key, value, path...) - } - - if err = yaml.Unmarshal(dst, map[string]any{}); err != nil { + // validate + if err = yaml.Unmarshal(out, map[string]any{}); err != nil { return nil, err } - return dst, nil + return out, nil } -// FindParent - return YAML Node from path of keys (tree) -func FindParent(src []byte, path ...string) (*yaml.Node, error) { - if len(src) == 0 { - return nil, nil - } - +func patch(in []byte, path []string, value any) ([]byte, error) { var root yaml.Node - if err := yaml.Unmarshal(src, &root); err != nil { + if err := yaml.Unmarshal(in, &root); err != nil { + // invalid yaml return nil, err } - if root.Content == nil { - return nil, nil + // empty in + if len(root.Content) != 1 { + return addToEnd(in, path, value) } - parent := root.Content[0] // yaml.DocumentNode - for _, name := range path { - if parent == nil { - break - } - _, parent = FindChild(parent, name) + // yaml is not dict + if root.Content[0].Kind != yaml.MappingNode { + return nil, errors.New("yaml: can't patch") } - return parent, nil + + // dict items list + nodes := root.Content[0].Content + + n := len(path) - 1 + + // parent node key/value + pKey, pVal := findNode(nodes, path[:n]) + if pKey == nil { + // no parent node + return addToEnd(in, path, value) + } + + var paste []byte + + if value != nil { + // nil value means delete key + var err error + v := map[string]any{path[n]: value} + if paste, err = Encode(v, 2); err != nil { + return nil, err + } + } + + iKey, _ := findNode(pVal.Content, path[n:]) + if iKey != nil { + // key item not nil (replace value) + paste = addIndent(paste, iKey.Column-1) + + i0, i1 := nodeBounds(in, iKey) + return join(in[:i0], paste, in[i1:]), nil + } + + if pVal.Content != nil { + // parent value not nil (use first child indent) + paste = addIndent(paste, pVal.Column-1) + } else { + // parent value is nil (use parent indent + 2) + paste = addIndent(paste, pKey.Column+1) + } + + _, i1 := nodeBounds(in, pKey) + return join(in[:i1], paste, in[i1:]), nil } -// FindChild - search and return YAML key/value pair for current Node -func FindChild(node *yaml.Node, name string) (key, value *yaml.Node) { - for i, child := range node.Content { - if child.Value != name { - continue +func findNode(nodes []*yaml.Node, keys []string) (key, value *yaml.Node) { + for i, name := range keys { + for j := 0; j < len(nodes); j += 2 { + if nodes[j].Value == name { + if i < len(keys)-1 { + nodes = nodes[j+1].Content + break + } + return nodes[j], nodes[j+1] + } } - return child, node.Content[i+1] } - return nil, nil } -func FirstChild(node *yaml.Node) *yaml.Node { - if node.Content == nil { - return node - } - return node.Content[0] -} +func nodeBounds(in []byte, node *yaml.Node) (offset0, offset1 int) { + // start from next line after node + offset0 = lineOffset(in, node.Line) + offset1 = lineOffset(in, node.Line+1) -func LastChild(node *yaml.Node) *yaml.Node { - if node.Content == nil { - return node - } - return LastChild(node.Content[len(node.Content)-1]) -} - -func AddOrReplace(src []byte, key string, value any, nodeParent *yaml.Node) ([]byte, error) { - v := map[string]any{key: value} - put, err := Encode(v, 2) - if err != nil { - return nil, err + if offset1 < 0 { + return offset0, len(in) } - if nodeKey, nodeValue := FindChild(nodeParent, key); nodeKey != nil { - put = AddIndent(put, nodeKey.Column-1) - - i0 := LineOffset(src, nodeKey.Line) - i1 := LineOffset(src, LastChild(nodeValue).Line+1) - - if i1 < 0 { // no new line on the end of file - if value != nil { - return append(src[:i0], put...), nil + for i := offset1; i < len(in); { + indent, length := parseLine(in[i:]) + if indent+1 != length { + if node.Column < indent+1 { + offset1 = i + length + } else { + break } - return src[:i0], nil } - - dst := make([]byte, 0, len(src)+len(put)) - dst = append(dst, src[:i0]...) - if value != nil { - dst = append(dst, put...) - } - return append(dst, src[i1:]...), nil + i += length } - put = AddIndent(put, FirstChild(nodeParent).Column-1) - - i := LineOffset(src, LastChild(nodeParent).Line+1) - - if i < 0 { // no new line on the end of file - src = append(src, '\n') - if value != nil { - src = append(src, put...) - } - return src, nil - } - - dst := make([]byte, 0, len(src)+len(put)) - dst = append(dst, src[:i]...) - if value != nil { - dst = append(dst, put...) - } - return append(dst, src[i:]...), nil + return } -func AddToEnd(src []byte, key string, value any, path ...string) ([]byte, error) { - if len(path) > 1 || value == nil { - return nil, errors.New("config: path not exist") +func addToEnd(in []byte, path []string, value any) ([]byte, error) { + if len(path) != 2 || value == nil { + return nil, errors.New("yaml: path not exist") } v := map[string]map[string]any{ - path[0]: {key: value}, + path[0]: {path[1]: value}, } - put, err := Encode(v, 2) + paste, err := Encode(v, 2) if err != nil { return nil, err } - dst := make([]byte, 0, len(src)+len(put)+10) - dst = append(dst, src...) - if l := len(src); l > 0 && src[l-1] != '\n' { - dst = append(dst, '\n') - } - return append(dst, put...), nil + return join(in, paste), nil } -func AddPrefix(src, pre []byte) (dst []byte) { +func join(items ...[]byte) []byte { + n := len(items) - 1 + for _, b := range items { + n += len(b) + } + + buf := make([]byte, 0, n) + for _, b := range items { + if len(b) == 0 { + continue + } + if n = len(buf); n > 0 && buf[n-1] != '\n' { + buf = append(buf, '\n') + } + buf = append(buf, b...) + } + + return buf +} + +func addPrefix(src, pre []byte) (dst []byte) { for len(src) > 0 { dst = append(dst, pre...) i := bytes.IndexByte(src, '\n') + 1 @@ -180,21 +188,21 @@ func AddPrefix(src, pre []byte) (dst []byte) { return } -func AddIndent(src []byte, indent int) (dst []byte) { +func addIndent(in []byte, indent int) (dst []byte) { pre := make([]byte, indent) for i := 0; i < indent; i++ { pre[i] = ' ' } - return AddPrefix(src, pre) + return addPrefix(in, pre) } -func LineOffset(b []byte, line int) (offset int) { +func lineOffset(in []byte, line int) (offset int) { for l := 1; ; l++ { if l == line { return offset } - i := bytes.IndexByte(b[offset:], '\n') + 1 + i := bytes.IndexByte(in[offset:], '\n') + 1 if i == 0 { break } @@ -202,3 +210,21 @@ func LineOffset(b []byte, line int) (offset int) { } return -1 } + +func parseLine(b []byte) (indent int, length int) { + prefix := true + for ; length < len(b); length++ { + switch b[length] { + case ' ': + if prefix { + indent++ + } + case '\n': + length++ + return + default: + prefix = false + } + } + return +} diff --git a/pkg/yaml/yaml_test.go b/pkg/yaml/yaml_test.go index 3f4c45bb..264546af 100644 --- a/pkg/yaml/yaml_test.go +++ b/pkg/yaml/yaml_test.go @@ -7,140 +7,103 @@ import ( ) func TestPatch(t *testing.T) { - b := []byte(`# prefix`) - - // 1. Add first - b, err := Patch(b, "camera1", "url1", "streams") - require.Nil(t, err) - - require.Equal(t, `# prefix -streams: - camera1: url1 -`, string(b)) - - // 2. Add second - b, err = Patch(b, "camera2", []string{"url2", "url3"}, "streams") - require.Nil(t, err) - - require.Equal(t, `# prefix -streams: - camera1: url1 - camera2: - - url2 - - url3 -`, string(b)) - - // 3. Replace first - b, err = Patch(b, "camera1", "url4", "streams") - require.Nil(t, err) - - require.Equal(t, `# prefix -streams: - camera1: url4 - camera2: - - url2 - - url3 -`, string(b)) - - // 4. Replace second - b, err = Patch(b, "camera2", "url5", "streams") - require.Nil(t, err) - - require.Equal(t, `# prefix -streams: - camera1: url4 - camera2: url5 -`, string(b)) - - // 5. Delete first - b, err = Patch(b, "camera1", nil, "streams") - require.Nil(t, err) - - require.Equal(t, `# prefix -streams: - camera2: url5 -`, string(b)) -} - -func TestPatchParings(t *testing.T) { - b := []byte(`homekit: - camera1: - pin: 123-45-678 -streams: - camera1: url1 -`) - - // 1. Add new key - pairings := []string{"client1", "client2"} - - b, err := Patch(b, "pairings", pairings, "homekit", "camera1") - require.Nil(t, err) - - require.Equal(t, `homekit: - camera1: - pin: 123-45-678 - pairings: - - client1 - - client2 -streams: - camera1: url1 -`, string(b)) -} - -func TestPatch2(t *testing.T) { - b := []byte(`streams: - camera1: - - url1 - - url2 -`) - - b, err := Patch(b, "camera2", "url3", "streams") - require.Nil(t, err) - - require.Equal(t, `streams: - camera1: - - url1 - - url2 - camera2: url3 -`, string(b)) -} - -func TestNoNewLineEnd1(t *testing.T) { - b := []byte(`streams: - camera1: url4 - camera2: - - url2 - - url3`) - - b, err := Patch(b, "camera2", "url5", "streams") - require.Nil(t, err) - - require.Equal(t, `streams: - camera1: url4 - camera2: url5 -`, string(b)) -} - -func TestNoNewLineEnd2(t *testing.T) { - b := []byte(`streams: - camera1: url1 -homekit: - camera1: - pin: 123-45-678`) - - // 1. Add new key - pairings := []string{"client1", "client2"} - - b, err := Patch(b, "pairings", pairings, "homekit", "camera1") - require.Nil(t, err) - - require.Equal(t, `streams: - camera1: url1 -homekit: - camera1: - pin: 123-45-678 - pairings: - - client1 - - client2 -`, string(b)) + tests := []struct { + name string + src string + path []string + value any + expect string + }{ + { + name: "empty config", + src: "", + path: []string{"streams", "camera1"}, + value: "val1", + expect: "streams:\n camera1: val1\n", + }, + { + name: "empty main key", + src: "#dummy", + path: []string{"streams", "camera1"}, + value: "val1", + expect: "#dummy\nstreams:\n camera1: val1\n", + }, + { + name: "single line value", + src: "streams:\n camera1: url1\n camera2: url2", + path: []string{"streams", "camera1"}, + value: "val1", + expect: "streams:\n camera1: val1\n camera2: url2", + }, + { + name: "next line value", + src: "streams:\n camera1:\n url1\n camera2: url2", + path: []string{"streams", "camera1"}, + value: "val1", + expect: "streams:\n camera1: val1\n camera2: url2", + }, + { + name: "two lines value", + src: "streams:\n camera1: url1\n url2\n camera2: url2", + path: []string{"streams", "camera1"}, + value: "val1", + expect: "streams:\n camera1: val1\n camera2: url2", + }, + { + name: "next two lines value", + src: "streams:\n camera1:\n url1\n url2\n camera2: url2", + path: []string{"streams", "camera1"}, + value: "val1", + expect: "streams:\n camera1: val1\n camera2: url2", + }, + { + name: "add array", + src: "", + path: []string{"streams", "camera1"}, + value: []string{"val1", "val2"}, + expect: "streams:\n camera1:\n - val1\n - val2\n", + }, + { + name: "remove value", + src: "streams:\n camera1: url1\n camera2: url2", + path: []string{"streams", "camera1"}, + value: nil, + expect: "streams:\n camera2: url2", + }, + { + name: "add pairings", + src: "homekit:\n camera1:\nstreams:\n camera1: url1", + path: []string{"homekit", "camera1", "pairings"}, + value: []string{"val1"}, + expect: "homekit:\n camera1:\n pairings:\n - val1\nstreams:\n camera1: url1", + }, + { + name: "remove pairings", + src: "homekit:\n camera1:\n pairings:\n - val1\nstreams:\n camera1: url1", + path: []string{"homekit", "camera1", "pairings"}, + value: nil, + expect: "homekit:\n camera1:\nstreams:\n camera1: url1", + }, + { + name: "no new line", + src: "streams:\n camera1: url1", + path: []string{"streams", "camera1"}, + value: "val1", + expect: "streams:\n camera1: val1\n", + }, + { + name: "no new line", + src: "streams:\n camera1: url1\nhomekit:\n camera1:\n name: dummy", + path: []string{"homekit", "camera1", "pairings"}, + value: []string{"val1"}, + expect: "streams:\n camera1: url1\nhomekit:\n camera1:\n name: dummy\n pairings:\n - val1\n", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b, err := Patch([]byte(tt.src), tt.path, tt.value) + require.NoError(t, err) + require.Equal(t, tt.expect, string(b)) + }) + } }