...

Source file src/oss.terrastruct.com/d2/d2graph/serde.go

Documentation: oss.terrastruct.com/d2/d2graph

     1  package d2graph
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"strings"
     7  
     8  	"oss.terrastruct.com/d2/d2target"
     9  	"oss.terrastruct.com/util-go/go2"
    10  )
    11  
    12  type SerializedGraph struct {
    13  	Root      SerializedObject   `json:"root"`
    14  	Edges     []SerializedEdge   `json:"edges"`
    15  	Objects   []SerializedObject `json:"objects"`
    16  	RootLevel int                `json:"rootLevel"`
    17  }
    18  
    19  type SerializedObject map[string]interface{}
    20  
    21  type SerializedEdge map[string]interface{}
    22  
    23  func DeserializeGraph(bytes []byte, g *Graph) error {
    24  	var sg *SerializedGraph
    25  	err := json.Unmarshal(bytes, &sg)
    26  	if err != nil {
    27  		return err
    28  	}
    29  
    30  	var root Object
    31  	convert(sg.Root, &root)
    32  	g.Root = &root
    33  	root.Graph = g
    34  	g.RootLevel = sg.RootLevel
    35  
    36  	idToObj := make(map[string]*Object)
    37  	idToObj[""] = g.Root
    38  	var objects []*Object
    39  	for _, so := range sg.Objects {
    40  		var o Object
    41  		if err := convert(so, &o); err != nil {
    42  			return err
    43  		}
    44  		o.Graph = g
    45  		objects = append(objects, &o)
    46  		idToObj[so["AbsID"].(string)] = &o
    47  	}
    48  
    49  	for _, so := range append(sg.Objects, sg.Root) {
    50  		if so["ChildrenArray"] != nil {
    51  			children := make(map[string]*Object)
    52  			var childrenArray []*Object
    53  
    54  			for _, id := range so["ChildrenArray"].([]interface{}) {
    55  				o := idToObj[id.(string)]
    56  				childrenArray = append(childrenArray, o)
    57  				children[strings.ToLower(o.ID)] = o
    58  
    59  				o.Parent = idToObj[so["AbsID"].(string)]
    60  			}
    61  
    62  			idToObj[so["AbsID"].(string)].Children = children
    63  			idToObj[so["AbsID"].(string)].ChildrenArray = childrenArray
    64  		}
    65  	}
    66  
    67  	var edges []*Edge
    68  	for _, se := range sg.Edges {
    69  		var e Edge
    70  		if err := convert(se, &e); err != nil {
    71  			return err
    72  		}
    73  
    74  		if se["Src"] != nil {
    75  			e.Src = idToObj[se["Src"].(string)]
    76  		}
    77  		if se["Dst"] != nil {
    78  			e.Dst = idToObj[se["Dst"].(string)]
    79  		}
    80  		edges = append(edges, &e)
    81  	}
    82  
    83  	g.Objects = objects
    84  	g.Edges = edges
    85  
    86  	return nil
    87  }
    88  
    89  func SerializeGraph(g *Graph) ([]byte, error) {
    90  	sg := SerializedGraph{}
    91  
    92  	root, err := toSerializedObject(g.Root)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  	sg.Root = root
    97  	sg.RootLevel = g.RootLevel
    98  
    99  	var sobjects []SerializedObject
   100  	for _, o := range g.Objects {
   101  		so, err := toSerializedObject(o)
   102  		if err != nil {
   103  			return nil, err
   104  		}
   105  		sobjects = append(sobjects, so)
   106  	}
   107  	sg.Objects = sobjects
   108  
   109  	var sedges []SerializedEdge
   110  	for _, e := range g.Edges {
   111  		se, err := toSerializedEdge(e)
   112  		if err != nil {
   113  			return nil, err
   114  		}
   115  		sedges = append(sedges, se)
   116  	}
   117  	sg.Edges = sedges
   118  
   119  	return json.Marshal(sg)
   120  }
   121  
   122  func toSerializedObject(o *Object) (SerializedObject, error) {
   123  	var so SerializedObject
   124  	if err := convert(o, &so); err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	so["AbsID"] = o.AbsID()
   129  
   130  	if len(o.ChildrenArray) > 0 {
   131  		var children []string
   132  		for _, c := range o.ChildrenArray {
   133  			children = append(children, c.AbsID())
   134  		}
   135  		so["ChildrenArray"] = children
   136  	}
   137  
   138  	return so, nil
   139  }
   140  
   141  func toSerializedEdge(e *Edge) (SerializedEdge, error) {
   142  	var se SerializedEdge
   143  	if err := convert(e, &se); err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	if e.Src != nil {
   148  		se["Src"] = go2.Pointer(e.Src.AbsID())
   149  	}
   150  	if e.Dst != nil {
   151  		se["Dst"] = go2.Pointer(e.Dst.AbsID())
   152  	}
   153  
   154  	return se, nil
   155  }
   156  
   157  func convert[T, Q any](from T, to *Q) error {
   158  	b, err := json.Marshal(from)
   159  	if err != nil {
   160  		return err
   161  	}
   162  	if err := json.Unmarshal(b, to); err != nil {
   163  		return err
   164  	}
   165  	return nil
   166  }
   167  
   168  func CompareSerializedGraph(g, other *Graph) error {
   169  	if len(g.Objects) != len(other.Objects) {
   170  		return fmt.Errorf("object count differs: g=%d, other=%d", len(g.Objects), len(other.Objects))
   171  	}
   172  
   173  	if len(g.Edges) != len(other.Edges) {
   174  		return fmt.Errorf("edge count differs: g=%d, other=%d", len(g.Edges), len(other.Edges))
   175  	}
   176  
   177  	if err := CompareSerializedObject(g.Root, other.Root); err != nil {
   178  		return fmt.Errorf("root differs: %v", err)
   179  	}
   180  
   181  	for i := 0; i < len(g.Objects); i++ {
   182  		if err := CompareSerializedObject(g.Objects[i], other.Objects[i]); err != nil {
   183  			return fmt.Errorf(
   184  				"objects differ at %d [g=%s, other=%s]: %v",
   185  				i,
   186  				g.Objects[i].ID,
   187  				other.Objects[i].ID,
   188  				err,
   189  			)
   190  		}
   191  	}
   192  
   193  	for i := 0; i < len(g.Edges); i++ {
   194  		if err := CompareSerializedEdge(g.Edges[i], other.Edges[i]); err != nil {
   195  			return fmt.Errorf(
   196  				"edges differ at %d [g=%s, other=%s]: %v",
   197  				i,
   198  				g.Edges[i].AbsID(),
   199  				other.Edges[i].AbsID(),
   200  				err,
   201  			)
   202  		}
   203  	}
   204  
   205  	return nil
   206  }
   207  
   208  func CompareSerializedObject(obj, other *Object) error {
   209  	if obj != nil && other == nil {
   210  		return fmt.Errorf("other is nil")
   211  	} else if obj == nil && other != nil {
   212  		return fmt.Errorf("obj is nil")
   213  	} else if obj == nil {
   214  		// both are nil
   215  		return nil
   216  	}
   217  
   218  	if obj.ID != other.ID {
   219  		return fmt.Errorf("ids differ: obj=%s, other=%s", obj.ID, other.ID)
   220  	}
   221  
   222  	if obj.AbsID() != other.AbsID() {
   223  		return fmt.Errorf("absolute ids differ: obj=%s, other=%s", obj.AbsID(), other.AbsID())
   224  	}
   225  
   226  	if obj.Box != nil && other.Box == nil {
   227  		return fmt.Errorf("other should have a box")
   228  	} else if obj.Box == nil && other.Box != nil {
   229  		return fmt.Errorf("other should not have a box")
   230  	} else if obj.Box != nil {
   231  		if obj.Width != other.Width {
   232  			return fmt.Errorf("widths differ: obj=%f, other=%f", obj.Width, other.Width)
   233  		}
   234  
   235  		if obj.Height != other.Height {
   236  			return fmt.Errorf("heights differ: obj=%f, other=%f", obj.Height, other.Height)
   237  		}
   238  	}
   239  
   240  	if obj.Parent != nil && other.Parent == nil {
   241  		return fmt.Errorf("other should have a parent")
   242  	} else if obj.Parent == nil && other.Parent != nil {
   243  		return fmt.Errorf("other should not have a parent")
   244  	} else if obj.Parent != nil && obj.Parent.ID != other.Parent.ID {
   245  		return fmt.Errorf("parent differs: obj=%s, other=%s", obj.Parent.ID, other.Parent.ID)
   246  	}
   247  
   248  	if len(obj.Children) != len(other.Children) {
   249  		return fmt.Errorf("children count differs: obj=%d, other=%d", len(obj.Children), len(other.Children))
   250  	}
   251  
   252  	for childID, objChild := range obj.Children {
   253  		if otherChild, exists := other.Children[childID]; exists {
   254  			if err := CompareSerializedObject(objChild, otherChild); err != nil {
   255  				return fmt.Errorf("children differ at key %s: %v", childID, err)
   256  			}
   257  		} else {
   258  			return fmt.Errorf("child %s does not exist in other", childID)
   259  		}
   260  	}
   261  
   262  	if len(obj.ChildrenArray) != len(other.ChildrenArray) {
   263  		return fmt.Errorf("childrenArray count differs: obj=%d, other=%d", len(obj.ChildrenArray), len(other.ChildrenArray))
   264  	}
   265  
   266  	for i := 0; i < len(obj.ChildrenArray); i++ {
   267  		if err := CompareSerializedObject(obj.ChildrenArray[i], other.ChildrenArray[i]); err != nil {
   268  			return fmt.Errorf("childrenArray differs at %d: %v", i, err)
   269  		}
   270  	}
   271  
   272  	if d2target.IsShape(obj.Shape.Value) != d2target.IsShape(other.Shape.Value) {
   273  		return fmt.Errorf(
   274  			"shapes differ: obj=%s, other=%s",
   275  			obj.Shape.Value,
   276  			other.Shape.Value,
   277  		)
   278  	}
   279  
   280  	if obj.Icon == nil && other.Icon != nil {
   281  		return fmt.Errorf("other does not have an icon")
   282  	} else if obj.Icon != nil && other.Icon == nil {
   283  		return fmt.Errorf("obj does not have an icon")
   284  	}
   285  
   286  	if obj.Direction.Value != other.Direction.Value {
   287  		return fmt.Errorf(
   288  			"directions differ: obj=%s, other=%s",
   289  			obj.Direction.Value,
   290  			other.Direction.Value,
   291  		)
   292  	}
   293  
   294  	if obj.Label.Value != other.Label.Value {
   295  		return fmt.Errorf(
   296  			"labels differ: obj=%s, other=%s",
   297  			obj.Label.Value,
   298  			other.Label.Value,
   299  		)
   300  	}
   301  
   302  	if obj.NearKey != nil {
   303  		if other.NearKey == nil {
   304  			return fmt.Errorf("other does not have near")
   305  		}
   306  		objKey := strings.Join(Key(obj.NearKey), ".")
   307  		deserKey := strings.Join(Key(other.NearKey), ".")
   308  		if objKey != deserKey {
   309  			return fmt.Errorf(
   310  				"near differs: obj=%s, other=%s",
   311  				objKey,
   312  				deserKey,
   313  			)
   314  		}
   315  	} else if other.NearKey != nil {
   316  		return fmt.Errorf("other should not have near")
   317  	}
   318  
   319  	if obj.LabelDimensions.Width != other.LabelDimensions.Width {
   320  		return fmt.Errorf(
   321  			"label width differs: obj=%d, other=%d",
   322  			obj.LabelDimensions.Width,
   323  			other.LabelDimensions.Width,
   324  		)
   325  	}
   326  
   327  	if obj.LabelDimensions.Height != other.LabelDimensions.Height {
   328  		return fmt.Errorf(
   329  			"label height differs: obj=%d, other=%d",
   330  			obj.LabelDimensions.Height,
   331  			other.LabelDimensions.Height,
   332  		)
   333  	}
   334  
   335  	if obj.SQLTable == nil && other.SQLTable != nil {
   336  		return fmt.Errorf("other is not a sql table")
   337  	} else if obj.SQLTable != nil && other.SQLTable == nil {
   338  		return fmt.Errorf("obj is not a sql table")
   339  	}
   340  
   341  	if obj.SQLTable != nil {
   342  		if len(obj.SQLTable.Columns) != len(other.SQLTable.Columns) {
   343  			return fmt.Errorf(
   344  				"table columns count differ: obj=%d, other=%d",
   345  				len(obj.SQLTable.Columns),
   346  				len(other.SQLTable.Columns),
   347  			)
   348  		}
   349  	}
   350  
   351  	return nil
   352  }
   353  
   354  func CompareSerializedEdge(edge, other *Edge) error {
   355  	if edge.AbsID() != other.AbsID() {
   356  		return fmt.Errorf(
   357  			"absolute ids differ: edge=%s, other=%s",
   358  			edge.AbsID(),
   359  			other.AbsID(),
   360  		)
   361  	}
   362  
   363  	if edge.Src.AbsID() != other.Src.AbsID() {
   364  		return fmt.Errorf(
   365  			"sources differ: edge=%s, other=%s",
   366  			edge.Src.AbsID(),
   367  			other.Src.AbsID(),
   368  		)
   369  	}
   370  
   371  	if edge.Dst.AbsID() != other.Dst.AbsID() {
   372  		return fmt.Errorf(
   373  			"targets differ: edge=%s, other=%s",
   374  			edge.Dst.AbsID(),
   375  			other.Dst.AbsID(),
   376  		)
   377  	}
   378  
   379  	if edge.SrcArrow != other.SrcArrow {
   380  		return fmt.Errorf(
   381  			"source arrows differ: edge=%t, other=%t",
   382  			edge.SrcArrow,
   383  			other.SrcArrow,
   384  		)
   385  	}
   386  
   387  	if edge.DstArrow != other.DstArrow {
   388  		return fmt.Errorf(
   389  			"target arrows differ: edge=%t, other=%t",
   390  			edge.DstArrow,
   391  			other.DstArrow,
   392  		)
   393  	}
   394  
   395  	if edge.Label.Value != other.Label.Value {
   396  		return fmt.Errorf(
   397  			"labels differ: edge=%s, other=%s",
   398  			edge.Label.Value,
   399  			other.Label.Value,
   400  		)
   401  	}
   402  
   403  	if edge.LabelDimensions.Width != other.LabelDimensions.Width {
   404  		return fmt.Errorf(
   405  			"label width differs: edge=%d, other=%d",
   406  			edge.LabelDimensions.Width,
   407  			other.LabelDimensions.Width,
   408  		)
   409  	}
   410  
   411  	if edge.LabelDimensions.Height != other.LabelDimensions.Height {
   412  		return fmt.Errorf(
   413  			"label height differs: edge=%d, other=%d",
   414  			edge.LabelDimensions.Height,
   415  			other.LabelDimensions.Height,
   416  		)
   417  	}
   418  
   419  	if edge.SrcTableColumnIndex != nil && other.SrcTableColumnIndex == nil {
   420  		return fmt.Errorf("other should have src column index")
   421  	} else if other.SrcTableColumnIndex != nil && edge.SrcTableColumnIndex == nil {
   422  		return fmt.Errorf("other should not have src column index")
   423  	} else if other.SrcTableColumnIndex != nil {
   424  		edgeColumn := *edge.SrcTableColumnIndex
   425  		otherColumn := *other.SrcTableColumnIndex
   426  		if edgeColumn != otherColumn {
   427  			return fmt.Errorf("src column differs: edge=%d, other=%d", edgeColumn, otherColumn)
   428  		}
   429  	}
   430  
   431  	if edge.DstTableColumnIndex != nil && other.DstTableColumnIndex == nil {
   432  		return fmt.Errorf("other should have dst column index")
   433  	} else if other.DstTableColumnIndex != nil && edge.DstTableColumnIndex == nil {
   434  		return fmt.Errorf("other should not have dst column index")
   435  	} else if other.DstTableColumnIndex != nil {
   436  		edgeColumn := *edge.DstTableColumnIndex
   437  		otherColumn := *other.DstTableColumnIndex
   438  		if edgeColumn != otherColumn {
   439  			return fmt.Errorf("dst column differs: edge=%d, other=%d", edgeColumn, otherColumn)
   440  		}
   441  	}
   442  	return nil
   443  }
   444  

View as plain text