1 package jsonpatch
2
3 import (
4 "bytes"
5 "encoding/json"
6 "fmt"
7 "reflect"
8 )
9
10 func merge(cur, patch *lazyNode, mergeMerge bool) *lazyNode {
11 curDoc, err := cur.intoDoc()
12
13 if err != nil {
14 pruneNulls(patch)
15 return patch
16 }
17
18 patchDoc, err := patch.intoDoc()
19
20 if err != nil {
21 return patch
22 }
23
24 mergeDocs(curDoc, patchDoc, mergeMerge)
25
26 return cur
27 }
28
29 func mergeDocs(doc, patch *partialDoc, mergeMerge bool) {
30 for k, v := range *patch {
31 if v == nil {
32 if mergeMerge {
33 (*doc)[k] = nil
34 } else {
35 delete(*doc, k)
36 }
37 } else {
38 cur, ok := (*doc)[k]
39
40 if !ok || cur == nil {
41 if !mergeMerge {
42 pruneNulls(v)
43 }
44
45 (*doc)[k] = v
46 } else {
47 (*doc)[k] = merge(cur, v, mergeMerge)
48 }
49 }
50 }
51 }
52
53 func pruneNulls(n *lazyNode) {
54 sub, err := n.intoDoc()
55
56 if err == nil {
57 pruneDocNulls(sub)
58 } else {
59 ary, err := n.intoAry()
60
61 if err == nil {
62 pruneAryNulls(ary)
63 }
64 }
65 }
66
67 func pruneDocNulls(doc *partialDoc) *partialDoc {
68 for k, v := range *doc {
69 if v == nil {
70 delete(*doc, k)
71 } else {
72 pruneNulls(v)
73 }
74 }
75
76 return doc
77 }
78
79 func pruneAryNulls(ary *partialArray) *partialArray {
80 newAry := []*lazyNode{}
81
82 for _, v := range *ary {
83 if v != nil {
84 pruneNulls(v)
85 }
86 newAry = append(newAry, v)
87 }
88
89 *ary = newAry
90
91 return ary
92 }
93
94 var ErrBadJSONDoc = fmt.Errorf("Invalid JSON Document")
95 var ErrBadJSONPatch = fmt.Errorf("Invalid JSON Patch")
96 var errBadMergeTypes = fmt.Errorf("Mismatched JSON Documents")
97
98
99
100
101 func MergeMergePatches(patch1Data, patch2Data []byte) ([]byte, error) {
102 return doMergePatch(patch1Data, patch2Data, true)
103 }
104
105
106 func MergePatch(docData, patchData []byte) ([]byte, error) {
107 return doMergePatch(docData, patchData, false)
108 }
109
110 func doMergePatch(docData, patchData []byte, mergeMerge bool) ([]byte, error) {
111 doc := &partialDoc{}
112
113 docErr := json.Unmarshal(docData, doc)
114
115 patch := &partialDoc{}
116
117 patchErr := json.Unmarshal(patchData, patch)
118
119 if _, ok := docErr.(*json.SyntaxError); ok {
120 return nil, ErrBadJSONDoc
121 }
122
123 if _, ok := patchErr.(*json.SyntaxError); ok {
124 return nil, ErrBadJSONPatch
125 }
126
127 if docErr == nil && *doc == nil {
128 return nil, ErrBadJSONDoc
129 }
130
131 if patchErr == nil && *patch == nil {
132 return nil, ErrBadJSONPatch
133 }
134
135 if docErr != nil || patchErr != nil {
136
137 if patchErr == nil {
138 if mergeMerge {
139 doc = patch
140 } else {
141 doc = pruneDocNulls(patch)
142 }
143 } else {
144 patchAry := &partialArray{}
145 patchErr = json.Unmarshal(patchData, patchAry)
146
147 if patchErr != nil {
148 return nil, ErrBadJSONPatch
149 }
150
151 pruneAryNulls(patchAry)
152
153 out, patchErr := json.Marshal(patchAry)
154
155 if patchErr != nil {
156 return nil, ErrBadJSONPatch
157 }
158
159 return out, nil
160 }
161 } else {
162 mergeDocs(doc, patch, mergeMerge)
163 }
164
165 return json.Marshal(doc)
166 }
167
168
169
170
171
172
173 func resemblesJSONArray(input []byte) bool {
174 input = bytes.TrimSpace(input)
175
176 hasPrefix := bytes.HasPrefix(input, []byte("["))
177 hasSuffix := bytes.HasSuffix(input, []byte("]"))
178
179 return hasPrefix && hasSuffix
180 }
181
182
183
184
185
186
187 func CreateMergePatch(originalJSON, modifiedJSON []byte) ([]byte, error) {
188 originalResemblesArray := resemblesJSONArray(originalJSON)
189 modifiedResemblesArray := resemblesJSONArray(modifiedJSON)
190
191
192 if originalResemblesArray && modifiedResemblesArray {
193 return createArrayMergePatch(originalJSON, modifiedJSON)
194 }
195
196
197 if !originalResemblesArray && !modifiedResemblesArray {
198 return createObjectMergePatch(originalJSON, modifiedJSON)
199 }
200
201
202 return nil, errBadMergeTypes
203 }
204
205
206
207 func createObjectMergePatch(originalJSON, modifiedJSON []byte) ([]byte, error) {
208 originalDoc := map[string]interface{}{}
209 modifiedDoc := map[string]interface{}{}
210
211 err := json.Unmarshal(originalJSON, &originalDoc)
212 if err != nil {
213 return nil, ErrBadJSONDoc
214 }
215
216 err = json.Unmarshal(modifiedJSON, &modifiedDoc)
217 if err != nil {
218 return nil, ErrBadJSONDoc
219 }
220
221 dest, err := getDiff(originalDoc, modifiedDoc)
222 if err != nil {
223 return nil, err
224 }
225
226 return json.Marshal(dest)
227 }
228
229
230
231
232
233 func createArrayMergePatch(originalJSON, modifiedJSON []byte) ([]byte, error) {
234 originalDocs := []json.RawMessage{}
235 modifiedDocs := []json.RawMessage{}
236
237 err := json.Unmarshal(originalJSON, &originalDocs)
238 if err != nil {
239 return nil, ErrBadJSONDoc
240 }
241
242 err = json.Unmarshal(modifiedJSON, &modifiedDocs)
243 if err != nil {
244 return nil, ErrBadJSONDoc
245 }
246
247 total := len(originalDocs)
248 if len(modifiedDocs) != total {
249 return nil, ErrBadJSONDoc
250 }
251
252 result := []json.RawMessage{}
253 for i := 0; i < len(originalDocs); i++ {
254 original := originalDocs[i]
255 modified := modifiedDocs[i]
256
257 patch, err := createObjectMergePatch(original, modified)
258 if err != nil {
259 return nil, err
260 }
261
262 result = append(result, json.RawMessage(patch))
263 }
264
265 return json.Marshal(result)
266 }
267
268
269
270 func matchesArray(a, b []interface{}) bool {
271 if len(a) != len(b) {
272 return false
273 }
274 if (a == nil && b != nil) || (a != nil && b == nil) {
275 return false
276 }
277 for i := range a {
278 if !matchesValue(a[i], b[i]) {
279 return false
280 }
281 }
282 return true
283 }
284
285
286
287
288 func matchesValue(av, bv interface{}) bool {
289 if reflect.TypeOf(av) != reflect.TypeOf(bv) {
290 return false
291 }
292 switch at := av.(type) {
293 case string:
294 bt := bv.(string)
295 if bt == at {
296 return true
297 }
298 case float64:
299 bt := bv.(float64)
300 if bt == at {
301 return true
302 }
303 case bool:
304 bt := bv.(bool)
305 if bt == at {
306 return true
307 }
308 case nil:
309
310 return true
311 case map[string]interface{}:
312 bt := bv.(map[string]interface{})
313 if len(bt) != len(at) {
314 return false
315 }
316 for key := range bt {
317 av, aOK := at[key]
318 bv, bOK := bt[key]
319 if aOK != bOK {
320 return false
321 }
322 if !matchesValue(av, bv) {
323 return false
324 }
325 }
326 return true
327 case []interface{}:
328 bt := bv.([]interface{})
329 return matchesArray(at, bt)
330 }
331 return false
332 }
333
334
335 func getDiff(a, b map[string]interface{}) (map[string]interface{}, error) {
336 into := map[string]interface{}{}
337 for key, bv := range b {
338 av, ok := a[key]
339
340 if !ok {
341 into[key] = bv
342 continue
343 }
344
345 if reflect.TypeOf(av) != reflect.TypeOf(bv) {
346 into[key] = bv
347 continue
348 }
349
350 switch at := av.(type) {
351 case map[string]interface{}:
352 bt := bv.(map[string]interface{})
353 dst := make(map[string]interface{}, len(bt))
354 dst, err := getDiff(at, bt)
355 if err != nil {
356 return nil, err
357 }
358 if len(dst) > 0 {
359 into[key] = dst
360 }
361 case string, float64, bool:
362 if !matchesValue(av, bv) {
363 into[key] = bv
364 }
365 case []interface{}:
366 bt := bv.([]interface{})
367 if !matchesArray(at, bt) {
368 into[key] = bv
369 }
370 case nil:
371 switch bv.(type) {
372 case nil:
373
374 default:
375 into[key] = bv
376 }
377 default:
378 panic(fmt.Sprintf("Unknown type:%T in key %s", av, key))
379 }
380 }
381
382 for key := range a {
383 _, found := b[key]
384 if !found {
385 into[key] = nil
386 }
387 }
388 return into, nil
389 }
390
View as plain text