1 package d2graph
2
3 import (
4 "encoding/json"
5 "fmt"
6 "strings"
7
8 "oss.terrastruct.com/d2/d2target"
9 "oss.terrastruct.com/util-go/go2"
10 )
11
12 type SerializedGraph struct {
13 Root SerializedObject `json:"root"`
14 Edges []SerializedEdge `json:"edges"`
15 Objects []SerializedObject `json:"objects"`
16 RootLevel int `json:"rootLevel"`
17 }
18
19 type SerializedObject map[string]interface{}
20
21 type SerializedEdge map[string]interface{}
22
23 func DeserializeGraph(bytes []byte, g *Graph) error {
24 var sg *SerializedGraph
25 err := json.Unmarshal(bytes, &sg)
26 if err != nil {
27 return err
28 }
29
30 var root Object
31 convert(sg.Root, &root)
32 g.Root = &root
33 root.Graph = g
34 g.RootLevel = sg.RootLevel
35
36 idToObj := make(map[string]*Object)
37 idToObj[""] = g.Root
38 var objects []*Object
39 for _, so := range sg.Objects {
40 var o Object
41 if err := convert(so, &o); err != nil {
42 return err
43 }
44 o.Graph = g
45 objects = append(objects, &o)
46 idToObj[so["AbsID"].(string)] = &o
47 }
48
49 for _, so := range append(sg.Objects, sg.Root) {
50 if so["ChildrenArray"] != nil {
51 children := make(map[string]*Object)
52 var childrenArray []*Object
53
54 for _, id := range so["ChildrenArray"].([]interface{}) {
55 o := idToObj[id.(string)]
56 childrenArray = append(childrenArray, o)
57 children[strings.ToLower(o.ID)] = o
58
59 o.Parent = idToObj[so["AbsID"].(string)]
60 }
61
62 idToObj[so["AbsID"].(string)].Children = children
63 idToObj[so["AbsID"].(string)].ChildrenArray = childrenArray
64 }
65 }
66
67 var edges []*Edge
68 for _, se := range sg.Edges {
69 var e Edge
70 if err := convert(se, &e); err != nil {
71 return err
72 }
73
74 if se["Src"] != nil {
75 e.Src = idToObj[se["Src"].(string)]
76 }
77 if se["Dst"] != nil {
78 e.Dst = idToObj[se["Dst"].(string)]
79 }
80 edges = append(edges, &e)
81 }
82
83 g.Objects = objects
84 g.Edges = edges
85
86 return nil
87 }
88
89 func SerializeGraph(g *Graph) ([]byte, error) {
90 sg := SerializedGraph{}
91
92 root, err := toSerializedObject(g.Root)
93 if err != nil {
94 return nil, err
95 }
96 sg.Root = root
97 sg.RootLevel = g.RootLevel
98
99 var sobjects []SerializedObject
100 for _, o := range g.Objects {
101 so, err := toSerializedObject(o)
102 if err != nil {
103 return nil, err
104 }
105 sobjects = append(sobjects, so)
106 }
107 sg.Objects = sobjects
108
109 var sedges []SerializedEdge
110 for _, e := range g.Edges {
111 se, err := toSerializedEdge(e)
112 if err != nil {
113 return nil, err
114 }
115 sedges = append(sedges, se)
116 }
117 sg.Edges = sedges
118
119 return json.Marshal(sg)
120 }
121
122 func toSerializedObject(o *Object) (SerializedObject, error) {
123 var so SerializedObject
124 if err := convert(o, &so); err != nil {
125 return nil, err
126 }
127
128 so["AbsID"] = o.AbsID()
129
130 if len(o.ChildrenArray) > 0 {
131 var children []string
132 for _, c := range o.ChildrenArray {
133 children = append(children, c.AbsID())
134 }
135 so["ChildrenArray"] = children
136 }
137
138 return so, nil
139 }
140
141 func toSerializedEdge(e *Edge) (SerializedEdge, error) {
142 var se SerializedEdge
143 if err := convert(e, &se); err != nil {
144 return nil, err
145 }
146
147 if e.Src != nil {
148 se["Src"] = go2.Pointer(e.Src.AbsID())
149 }
150 if e.Dst != nil {
151 se["Dst"] = go2.Pointer(e.Dst.AbsID())
152 }
153
154 return se, nil
155 }
156
157 func convert[T, Q any](from T, to *Q) error {
158 b, err := json.Marshal(from)
159 if err != nil {
160 return err
161 }
162 if err := json.Unmarshal(b, to); err != nil {
163 return err
164 }
165 return nil
166 }
167
168 func CompareSerializedGraph(g, other *Graph) error {
169 if len(g.Objects) != len(other.Objects) {
170 return fmt.Errorf("object count differs: g=%d, other=%d", len(g.Objects), len(other.Objects))
171 }
172
173 if len(g.Edges) != len(other.Edges) {
174 return fmt.Errorf("edge count differs: g=%d, other=%d", len(g.Edges), len(other.Edges))
175 }
176
177 if err := CompareSerializedObject(g.Root, other.Root); err != nil {
178 return fmt.Errorf("root differs: %v", err)
179 }
180
181 for i := 0; i < len(g.Objects); i++ {
182 if err := CompareSerializedObject(g.Objects[i], other.Objects[i]); err != nil {
183 return fmt.Errorf(
184 "objects differ at %d [g=%s, other=%s]: %v",
185 i,
186 g.Objects[i].ID,
187 other.Objects[i].ID,
188 err,
189 )
190 }
191 }
192
193 for i := 0; i < len(g.Edges); i++ {
194 if err := CompareSerializedEdge(g.Edges[i], other.Edges[i]); err != nil {
195 return fmt.Errorf(
196 "edges differ at %d [g=%s, other=%s]: %v",
197 i,
198 g.Edges[i].AbsID(),
199 other.Edges[i].AbsID(),
200 err,
201 )
202 }
203 }
204
205 return nil
206 }
207
208 func CompareSerializedObject(obj, other *Object) error {
209 if obj != nil && other == nil {
210 return fmt.Errorf("other is nil")
211 } else if obj == nil && other != nil {
212 return fmt.Errorf("obj is nil")
213 } else if obj == nil {
214
215 return nil
216 }
217
218 if obj.ID != other.ID {
219 return fmt.Errorf("ids differ: obj=%s, other=%s", obj.ID, other.ID)
220 }
221
222 if obj.AbsID() != other.AbsID() {
223 return fmt.Errorf("absolute ids differ: obj=%s, other=%s", obj.AbsID(), other.AbsID())
224 }
225
226 if obj.Box != nil && other.Box == nil {
227 return fmt.Errorf("other should have a box")
228 } else if obj.Box == nil && other.Box != nil {
229 return fmt.Errorf("other should not have a box")
230 } else if obj.Box != nil {
231 if obj.Width != other.Width {
232 return fmt.Errorf("widths differ: obj=%f, other=%f", obj.Width, other.Width)
233 }
234
235 if obj.Height != other.Height {
236 return fmt.Errorf("heights differ: obj=%f, other=%f", obj.Height, other.Height)
237 }
238 }
239
240 if obj.Parent != nil && other.Parent == nil {
241 return fmt.Errorf("other should have a parent")
242 } else if obj.Parent == nil && other.Parent != nil {
243 return fmt.Errorf("other should not have a parent")
244 } else if obj.Parent != nil && obj.Parent.ID != other.Parent.ID {
245 return fmt.Errorf("parent differs: obj=%s, other=%s", obj.Parent.ID, other.Parent.ID)
246 }
247
248 if len(obj.Children) != len(other.Children) {
249 return fmt.Errorf("children count differs: obj=%d, other=%d", len(obj.Children), len(other.Children))
250 }
251
252 for childID, objChild := range obj.Children {
253 if otherChild, exists := other.Children[childID]; exists {
254 if err := CompareSerializedObject(objChild, otherChild); err != nil {
255 return fmt.Errorf("children differ at key %s: %v", childID, err)
256 }
257 } else {
258 return fmt.Errorf("child %s does not exist in other", childID)
259 }
260 }
261
262 if len(obj.ChildrenArray) != len(other.ChildrenArray) {
263 return fmt.Errorf("childrenArray count differs: obj=%d, other=%d", len(obj.ChildrenArray), len(other.ChildrenArray))
264 }
265
266 for i := 0; i < len(obj.ChildrenArray); i++ {
267 if err := CompareSerializedObject(obj.ChildrenArray[i], other.ChildrenArray[i]); err != nil {
268 return fmt.Errorf("childrenArray differs at %d: %v", i, err)
269 }
270 }
271
272 if d2target.IsShape(obj.Shape.Value) != d2target.IsShape(other.Shape.Value) {
273 return fmt.Errorf(
274 "shapes differ: obj=%s, other=%s",
275 obj.Shape.Value,
276 other.Shape.Value,
277 )
278 }
279
280 if obj.Icon == nil && other.Icon != nil {
281 return fmt.Errorf("other does not have an icon")
282 } else if obj.Icon != nil && other.Icon == nil {
283 return fmt.Errorf("obj does not have an icon")
284 }
285
286 if obj.Direction.Value != other.Direction.Value {
287 return fmt.Errorf(
288 "directions differ: obj=%s, other=%s",
289 obj.Direction.Value,
290 other.Direction.Value,
291 )
292 }
293
294 if obj.Label.Value != other.Label.Value {
295 return fmt.Errorf(
296 "labels differ: obj=%s, other=%s",
297 obj.Label.Value,
298 other.Label.Value,
299 )
300 }
301
302 if obj.NearKey != nil {
303 if other.NearKey == nil {
304 return fmt.Errorf("other does not have near")
305 }
306 objKey := strings.Join(Key(obj.NearKey), ".")
307 deserKey := strings.Join(Key(other.NearKey), ".")
308 if objKey != deserKey {
309 return fmt.Errorf(
310 "near differs: obj=%s, other=%s",
311 objKey,
312 deserKey,
313 )
314 }
315 } else if other.NearKey != nil {
316 return fmt.Errorf("other should not have near")
317 }
318
319 if obj.LabelDimensions.Width != other.LabelDimensions.Width {
320 return fmt.Errorf(
321 "label width differs: obj=%d, other=%d",
322 obj.LabelDimensions.Width,
323 other.LabelDimensions.Width,
324 )
325 }
326
327 if obj.LabelDimensions.Height != other.LabelDimensions.Height {
328 return fmt.Errorf(
329 "label height differs: obj=%d, other=%d",
330 obj.LabelDimensions.Height,
331 other.LabelDimensions.Height,
332 )
333 }
334
335 if obj.SQLTable == nil && other.SQLTable != nil {
336 return fmt.Errorf("other is not a sql table")
337 } else if obj.SQLTable != nil && other.SQLTable == nil {
338 return fmt.Errorf("obj is not a sql table")
339 }
340
341 if obj.SQLTable != nil {
342 if len(obj.SQLTable.Columns) != len(other.SQLTable.Columns) {
343 return fmt.Errorf(
344 "table columns count differ: obj=%d, other=%d",
345 len(obj.SQLTable.Columns),
346 len(other.SQLTable.Columns),
347 )
348 }
349 }
350
351 return nil
352 }
353
354 func CompareSerializedEdge(edge, other *Edge) error {
355 if edge.AbsID() != other.AbsID() {
356 return fmt.Errorf(
357 "absolute ids differ: edge=%s, other=%s",
358 edge.AbsID(),
359 other.AbsID(),
360 )
361 }
362
363 if edge.Src.AbsID() != other.Src.AbsID() {
364 return fmt.Errorf(
365 "sources differ: edge=%s, other=%s",
366 edge.Src.AbsID(),
367 other.Src.AbsID(),
368 )
369 }
370
371 if edge.Dst.AbsID() != other.Dst.AbsID() {
372 return fmt.Errorf(
373 "targets differ: edge=%s, other=%s",
374 edge.Dst.AbsID(),
375 other.Dst.AbsID(),
376 )
377 }
378
379 if edge.SrcArrow != other.SrcArrow {
380 return fmt.Errorf(
381 "source arrows differ: edge=%t, other=%t",
382 edge.SrcArrow,
383 other.SrcArrow,
384 )
385 }
386
387 if edge.DstArrow != other.DstArrow {
388 return fmt.Errorf(
389 "target arrows differ: edge=%t, other=%t",
390 edge.DstArrow,
391 other.DstArrow,
392 )
393 }
394
395 if edge.Label.Value != other.Label.Value {
396 return fmt.Errorf(
397 "labels differ: edge=%s, other=%s",
398 edge.Label.Value,
399 other.Label.Value,
400 )
401 }
402
403 if edge.LabelDimensions.Width != other.LabelDimensions.Width {
404 return fmt.Errorf(
405 "label width differs: edge=%d, other=%d",
406 edge.LabelDimensions.Width,
407 other.LabelDimensions.Width,
408 )
409 }
410
411 if edge.LabelDimensions.Height != other.LabelDimensions.Height {
412 return fmt.Errorf(
413 "label height differs: edge=%d, other=%d",
414 edge.LabelDimensions.Height,
415 other.LabelDimensions.Height,
416 )
417 }
418
419 if edge.SrcTableColumnIndex != nil && other.SrcTableColumnIndex == nil {
420 return fmt.Errorf("other should have src column index")
421 } else if other.SrcTableColumnIndex != nil && edge.SrcTableColumnIndex == nil {
422 return fmt.Errorf("other should not have src column index")
423 } else if other.SrcTableColumnIndex != nil {
424 edgeColumn := *edge.SrcTableColumnIndex
425 otherColumn := *other.SrcTableColumnIndex
426 if edgeColumn != otherColumn {
427 return fmt.Errorf("src column differs: edge=%d, other=%d", edgeColumn, otherColumn)
428 }
429 }
430
431 if edge.DstTableColumnIndex != nil && other.DstTableColumnIndex == nil {
432 return fmt.Errorf("other should have dst column index")
433 } else if other.DstTableColumnIndex != nil && edge.DstTableColumnIndex == nil {
434 return fmt.Errorf("other should not have dst column index")
435 } else if other.DstTableColumnIndex != nil {
436 edgeColumn := *edge.DstTableColumnIndex
437 otherColumn := *other.DstTableColumnIndex
438 if edgeColumn != otherColumn {
439 return fmt.Errorf("dst column differs: edge=%d, other=%d", edgeColumn, otherColumn)
440 }
441 }
442 return nil
443 }
444
View as plain text