From 01d25e96dbe8b71b1ce5835f6f367d1de8a52300 Mon Sep 17 00:00:00 2001 From: Evan Phoenix Date: Thu, 11 Jan 2024 13:50:37 -0800 Subject: [PATCH] Handle from="" more properly. Fixes #192 --- v5/merge.go | 8 ++-- v5/merge_test.go | 1 + v5/patch.go | 105 +++++++++++++++++++++++++++++++---------------- v5/patch_test.go | 27 +++++++++--- 4 files changed, 95 insertions(+), 46 deletions(-) diff --git a/v5/merge.go b/v5/merge.go index a7c4573..cec19a6 100644 --- a/v5/merge.go +++ b/v5/merge.go @@ -88,14 +88,14 @@ func pruneDocNulls(doc *partialDoc) *partialDoc { func pruneAryNulls(ary *partialArray) *partialArray { newAry := []*lazyNode{} - for _, v := range *ary { + for _, v := range ary.nodes { if v != nil { pruneNulls(v) } newAry = append(newAry, v) } - *ary = newAry + ary.nodes = newAry return ary } @@ -151,7 +151,7 @@ func doMergePatch(docData, patchData []byte, mergeMerge bool) ([]byte, error) { } } else { patchAry := &partialArray{} - patchErr = json.Unmarshal(patchData, patchAry) + patchErr = json.Unmarshal(patchData, &patchAry.nodes) if patchErr != nil { return nil, errBadJSONPatch @@ -159,7 +159,7 @@ func doMergePatch(docData, patchData []byte, mergeMerge bool) ([]byte, error) { pruneAryNulls(patchAry) - out, patchErr := json.Marshal(patchAry) + out, patchErr := json.Marshal(patchAry.nodes) if patchErr != nil { return nil, errBadJSONPatch diff --git a/v5/merge_test.go b/v5/merge_test.go index 72e3ea8..e974b7b 100644 --- a/v5/merge_test.go +++ b/v5/merge_test.go @@ -70,6 +70,7 @@ func TestMergePatchNilArray(t *testing.T) { } for _, c := range cases { + t.Log(c.original) act := mergePatch(c.original, c.patch) if !compareJSON(c.res, act) { diff --git a/v5/patch.go b/v5/patch.go index 73ff2c5..3b113e7 100644 --- a/v5/patch.go +++ b/v5/patch.go @@ -45,7 +45,7 @@ var ( type lazyNode struct { raw *json.RawMessage doc *partialDoc - ary partialArray + ary *partialArray which int } @@ -56,11 +56,15 @@ type Operation map[string]*json.RawMessage type Patch []Operation type partialDoc struct { + self *lazyNode keys []string obj map[string]*lazyNode } -type partialArray []*lazyNode +type partialArray struct { + self *lazyNode + nodes []*lazyNode +} type container interface { get(key string, options *ApplyOptions) (*lazyNode, error) @@ -114,7 +118,7 @@ func (n *lazyNode) MarshalJSON() ([]byte, error) { case eDoc: return json.Marshal(n.doc) case eAry: - return json.Marshal(n.ary) + return json.Marshal(n.ary.nodes) default: return nil, ErrUnknownType } @@ -199,6 +203,14 @@ func (n *partialDoc) UnmarshalJSON(data []byte) error { return nil } +func (n *partialArray) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &n.nodes) +} + +func (n *partialArray) MarshalJSON() ([]byte, error) { + return json.Marshal(n.nodes) +} + func skipValue(d *json.Decoder) error { t, err := d.Token() if err != nil { @@ -264,7 +276,7 @@ func (n *lazyNode) intoDoc() (*partialDoc, error) { func (n *lazyNode) intoAry() (*partialArray, error) { if n.which == eAry { - return &n.ary, nil + return n.ary, nil } if n.raw == nil { @@ -278,7 +290,7 @@ func (n *lazyNode) intoAry() (*partialArray, error) { } n.which = eAry - return &n.ary, nil + return n.ary, nil } func (n *lazyNode) compact() []byte { @@ -380,12 +392,12 @@ func (n *lazyNode) equal(o *lazyNode) bool { return false } - if len(n.ary) != len(o.ary) { + if len(n.ary.nodes) != len(o.ary.nodes) { return false } - for idx, val := range n.ary { - if !val.equal(o.ary[idx]) { + for idx, val := range n.ary.nodes { + if !val.equal(o.ary.nodes[idx]) { return false } } @@ -497,6 +509,9 @@ func findObject(pd *container, path string, options *ApplyOptions) (container, s split := strings.Split(path, "/") if len(split) < 2 { + if path == "" { + return doc, "" + } return nil, "" } @@ -552,6 +567,9 @@ func (d *partialDoc) add(key string, val *lazyNode, options *ApplyOptions) error } func (d *partialDoc) get(key string, options *ApplyOptions) (*lazyNode, error) { + if key == "" { + return d.self, nil + } v, ok := d.obj[key] if !ok { return v, errors.Wrapf(ErrMissing, "unable to get nonexistent key: %s", key) @@ -591,19 +609,19 @@ func (d *partialArray) set(key string, val *lazyNode, options *ApplyOptions) err if !options.SupportNegativeIndices { return errors.Wrapf(ErrInvalidIndex, "Unable to access invalid index: %d", idx) } - if idx < -len(*d) { + if idx < -len(d.nodes) { return errors.Wrapf(ErrInvalidIndex, "Unable to access invalid index: %d", idx) } - idx += len(*d) + idx += len(d.nodes) } - (*d)[idx] = val + d.nodes[idx] = val return nil } func (d *partialArray) add(key string, val *lazyNode, options *ApplyOptions) error { if key == "-" { - *d = append(*d, val) + d.nodes = append(d.nodes, val) return nil } @@ -612,11 +630,11 @@ func (d *partialArray) add(key string, val *lazyNode, options *ApplyOptions) err return errors.Wrapf(err, "value was not a proper array index: '%s'", key) } - sz := len(*d) + 1 + sz := len(d.nodes) + 1 ary := make([]*lazyNode, sz) - cur := *d + cur := d if idx >= len(ary) { return errors.Wrapf(ErrInvalidIndex, "Unable to access invalid index: %d", idx) @@ -632,15 +650,19 @@ func (d *partialArray) add(key string, val *lazyNode, options *ApplyOptions) err idx += len(ary) } - copy(ary[0:idx], cur[0:idx]) + copy(ary[0:idx], cur.nodes[0:idx]) ary[idx] = val - copy(ary[idx+1:], cur[idx:]) + copy(ary[idx+1:], cur.nodes[idx:]) - *d = ary + d.nodes = ary return nil } func (d *partialArray) get(key string, options *ApplyOptions) (*lazyNode, error) { + if key == "" { + return d.self, nil + } + idx, err := strconv.Atoi(key) if err != nil { @@ -651,17 +673,17 @@ func (d *partialArray) get(key string, options *ApplyOptions) (*lazyNode, error) if !options.SupportNegativeIndices { return nil, errors.Wrapf(ErrInvalidIndex, "Unable to access invalid index: %d", idx) } - if idx < -len(*d) { + if idx < -len(d.nodes) { return nil, errors.Wrapf(ErrInvalidIndex, "Unable to access invalid index: %d", idx) } - idx += len(*d) + idx += len(d.nodes) } - if idx >= len(*d) { + if idx >= len(d.nodes) { return nil, errors.Wrapf(ErrInvalidIndex, "Unable to access invalid index: %d", idx) } - return (*d)[idx], nil + return d.nodes[idx], nil } func (d *partialArray) remove(key string, options *ApplyOptions) error { @@ -670,9 +692,9 @@ func (d *partialArray) remove(key string, options *ApplyOptions) error { return err } - cur := *d + cur := d - if idx >= len(cur) { + if idx >= len(cur.nodes) { if options.AllowMissingPathOnRemove { return nil } @@ -683,21 +705,21 @@ func (d *partialArray) remove(key string, options *ApplyOptions) error { if !options.SupportNegativeIndices { return errors.Wrapf(ErrInvalidIndex, "Unable to access invalid index: %d", idx) } - if idx < -len(cur) { + if idx < -len(cur.nodes) { if options.AllowMissingPathOnRemove { return nil } return errors.Wrapf(ErrInvalidIndex, "Unable to access invalid index: %d", idx) } - idx += len(cur) + idx += len(cur.nodes) } - ary := make([]*lazyNode, len(cur)-1) + ary := make([]*lazyNode, len(cur.nodes)-1) - copy(ary[0:idx], cur[0:idx]) - copy(ary[idx:], cur[idx+1:]) + copy(ary[0:idx], cur.nodes[0:idx]) + copy(ary[idx:], cur.nodes[idx+1:]) - *d = ary + d.nodes = ary return nil } @@ -762,9 +784,9 @@ func ensurePathExists(pd *container, path string, options *ApplyOptions) error { if arrIndex, err = strconv.Atoi(part); err == nil { pa, ok := doc.(*partialArray) - if ok && arrIndex >= len(*pa)+1 { + if ok && arrIndex >= len(pa.nodes)+1 { // Pad the array with null values up to the required index. - for i := len(*pa); i <= arrIndex-1; i++ { + for i := len(pa.nodes); i <= arrIndex-1; i++ { doc.add(strconv.Itoa(i), newLazyNode(newRawMessage(rawJSONNull)), options) } } @@ -899,7 +921,7 @@ func (p Patch) replace(doc *container, op Operation, options *ApplyOptions) erro switch val.which { case eAry: - *doc = &val.ary + *doc = val.ary case eDoc: *doc = val.doc case eRaw: @@ -934,6 +956,10 @@ func (p Patch) move(doc *container, op Operation, options *ApplyOptions) error { return errors.Wrapf(err, "move operation failed to decode from") } + if from == "" { + return errors.Wrapf(ErrInvalid, "unable to move entire document to another path") + } + con, key := findObject(doc, from, options) if con == nil { @@ -983,7 +1009,7 @@ func (p Patch) test(doc *container, op Operation, options *ApplyOptions) error { self.doc = sv self.which = eDoc case *partialArray: - self.ary = *sv + self.ary = sv self.which = eAry } @@ -1030,7 +1056,7 @@ func (p Patch) copy(doc *container, op Operation, accumulatedCopySize *int64, op con, key := findObject(doc, from, options) if con == nil { - return errors.Wrapf(ErrMissing, "copy operation does not apply: doc is missing from path: %s", from) + return errors.Wrapf(ErrMissing, "copy operation does not apply: doc is missing from path: \"%s\"", from) } val, err := con.get(key, options) @@ -1117,11 +1143,18 @@ func (p Patch) ApplyIndentWithOptions(doc []byte, indent string, options *ApplyO return doc, nil } + raw := json.RawMessage(doc) + self := newLazyNode(&raw) + var pd container if doc[0] == '[' { - pd = &partialArray{} + pd = &partialArray{ + self: self, + } } else { - pd = &partialDoc{} + pd = &partialDoc{ + self: self, + } } err := json.Unmarshal(doc, pd) diff --git a/v5/patch_test.go b/v5/patch_test.go index 58bb7f1..5591cbe 100644 --- a/v5/patch_test.go +++ b/v5/patch_test.go @@ -578,6 +578,20 @@ var Cases = []Case{ false, false, }, + { + `{"foo": 1}`, + `[ { "op": "copy", "from": "", "path": "/bar"}]`, + `{"foo": 1, "bar": {"foo": 1}}`, + false, + false, + }, + { + `[{"foo": 1}]`, + `[ { "op": "copy", "from": "", "path": "/1"}]`, + `[{"foo": 1}, [{"foo": 1}]]`, + false, + false, + }, } type BadCase struct { @@ -635,11 +649,6 @@ var BadCases = []BadCase{ `[ { "op": "add", "pathz": "/baz", "value": "qux" } ]`, true, }, - { - `{ "foo": "bar" }`, - `[ { "op": "add", "path": "", "value": "qux" } ]`, - false, - }, { `{ "foo": ["bar","baz"]}`, `[ { "op": "replace", "path": "/foo/2", "value": "bum"}]`, @@ -727,6 +736,11 @@ var BadCases = []BadCase{ `[{"op": "copy", "path": "/qux", "from": "/baz"}]`, false, }, + { + `{ "foo": "bar"}`, + `[{"op": "move", "path": "/qux", "from": ""}]`, + false, + }, } // This is not thread safe, so we cannot run patch tests in parallel. @@ -802,9 +816,10 @@ func TestAllCases(t *testing.T) { } if err == nil && !c.failOnDecode { - _, err = p.Apply([]byte(c.doc)) + out, err := p.Apply([]byte(c.doc)) if err == nil { + t.Log(string(out)) t.Errorf("Patch %q should have failed to apply but it did not", c.patch) }