...

Source file src/golang.org/x/tools/cmd/digraph/digraph.go

Documentation: golang.org/x/tools/cmd/digraph

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  package main // import "golang.org/x/tools/cmd/digraph"
     5  
     6  // TODO(adonovan):
     7  // - support input files other than stdin
     8  // - support alternative formats (AT&T GraphViz, CSV, etc),
     9  //   a comment syntax, etc.
    10  // - allow queries to nest, like Blaze query language.
    11  
    12  import (
    13  	"bufio"
    14  	"bytes"
    15  	_ "embed"
    16  	"errors"
    17  	"flag"
    18  	"fmt"
    19  	"io"
    20  	"os"
    21  	"sort"
    22  	"strconv"
    23  	"strings"
    24  	"unicode"
    25  	"unicode/utf8"
    26  )
    27  
    28  func usage() {
    29  	// Extract the content of the /* ... */ comment in doc.go.
    30  	_, after, _ := strings.Cut(doc, "/*")
    31  	doc, _, _ := strings.Cut(after, "*/")
    32  	io.WriteString(flag.CommandLine.Output(), doc)
    33  	flag.PrintDefaults()
    34  
    35  	os.Exit(2)
    36  }
    37  
    38  //go:embed doc.go
    39  var doc string
    40  
    41  func main() {
    42  	flag.Usage = usage
    43  	flag.Parse()
    44  
    45  	args := flag.Args()
    46  	if len(args) == 0 {
    47  		usage()
    48  	}
    49  
    50  	if err := digraph(args[0], args[1:]); err != nil {
    51  		fmt.Fprintf(os.Stderr, "digraph: %s\n", err)
    52  		os.Exit(1)
    53  	}
    54  }
    55  
    56  type nodelist []string
    57  
    58  func (l nodelist) println(sep string) {
    59  	for i, node := range l {
    60  		if i > 0 {
    61  			fmt.Fprint(stdout, sep)
    62  		}
    63  		fmt.Fprint(stdout, node)
    64  	}
    65  	fmt.Fprintln(stdout)
    66  }
    67  
    68  type nodeset map[string]bool
    69  
    70  func (s nodeset) sort() nodelist {
    71  	nodes := make(nodelist, len(s))
    72  	var i int
    73  	for node := range s {
    74  		nodes[i] = node
    75  		i++
    76  	}
    77  	sort.Strings(nodes)
    78  	return nodes
    79  }
    80  
    81  func (s nodeset) addAll(x nodeset) {
    82  	for node := range x {
    83  		s[node] = true
    84  	}
    85  }
    86  
    87  // A graph maps nodes to the non-nil set of their immediate successors.
    88  type graph map[string]nodeset
    89  
    90  func (g graph) addNode(node string) nodeset {
    91  	edges := g[node]
    92  	if edges == nil {
    93  		edges = make(nodeset)
    94  		g[node] = edges
    95  	}
    96  	return edges
    97  }
    98  
    99  func (g graph) addEdges(from string, to ...string) {
   100  	edges := g.addNode(from)
   101  	for _, to := range to {
   102  		g.addNode(to)
   103  		edges[to] = true
   104  	}
   105  }
   106  
   107  func (g graph) nodelist() nodelist {
   108  	nodes := make(nodeset)
   109  	for node := range g {
   110  		nodes[node] = true
   111  	}
   112  	return nodes.sort()
   113  }
   114  
   115  func (g graph) reachableFrom(roots nodeset) nodeset {
   116  	seen := make(nodeset)
   117  	var visit func(node string)
   118  	visit = func(node string) {
   119  		if !seen[node] {
   120  			seen[node] = true
   121  			for e := range g[node] {
   122  				visit(e)
   123  			}
   124  		}
   125  	}
   126  	for root := range roots {
   127  		visit(root)
   128  	}
   129  	return seen
   130  }
   131  
   132  func (g graph) transpose() graph {
   133  	rev := make(graph)
   134  	for node, edges := range g {
   135  		rev.addNode(node)
   136  		for succ := range edges {
   137  			rev.addEdges(succ, node)
   138  		}
   139  	}
   140  	return rev
   141  }
   142  
   143  func (g graph) sccs() []nodeset {
   144  	// Kosaraju's algorithm---Tarjan is overkill here.
   145  
   146  	// Forward pass.
   147  	S := make(nodelist, 0, len(g)) // postorder stack
   148  	seen := make(nodeset)
   149  	var visit func(node string)
   150  	visit = func(node string) {
   151  		if !seen[node] {
   152  			seen[node] = true
   153  			for e := range g[node] {
   154  				visit(e)
   155  			}
   156  			S = append(S, node)
   157  		}
   158  	}
   159  	for node := range g {
   160  		visit(node)
   161  	}
   162  
   163  	// Reverse pass.
   164  	rev := g.transpose()
   165  	var scc nodeset
   166  	seen = make(nodeset)
   167  	var rvisit func(node string)
   168  	rvisit = func(node string) {
   169  		if !seen[node] {
   170  			seen[node] = true
   171  			scc[node] = true
   172  			for e := range rev[node] {
   173  				rvisit(e)
   174  			}
   175  		}
   176  	}
   177  	var sccs []nodeset
   178  	for len(S) > 0 {
   179  		top := S[len(S)-1]
   180  		S = S[:len(S)-1] // pop
   181  		if !seen[top] {
   182  			scc = make(nodeset)
   183  			rvisit(top)
   184  			if len(scc) == 1 && !g[top][top] {
   185  				continue
   186  			}
   187  			sccs = append(sccs, scc)
   188  		}
   189  	}
   190  	return sccs
   191  }
   192  
   193  func (g graph) allpaths(from, to string) error {
   194  	// Mark all nodes to "to".
   195  	seen := make(nodeset) // value of seen[x] indicates whether x is on some path to "to"
   196  	var visit func(node string) bool
   197  	visit = func(node string) bool {
   198  		reachesTo, ok := seen[node]
   199  		if !ok {
   200  			reachesTo = node == to
   201  			seen[node] = reachesTo
   202  			for e := range g[node] {
   203  				if visit(e) {
   204  					reachesTo = true
   205  				}
   206  			}
   207  			if reachesTo && node != to {
   208  				seen[node] = true
   209  			}
   210  		}
   211  		return reachesTo
   212  	}
   213  	visit(from)
   214  
   215  	// For each marked node, collect its marked successors.
   216  	var edges []string
   217  	for n := range seen {
   218  		for succ := range g[n] {
   219  			if seen[succ] {
   220  				edges = append(edges, n+" "+succ)
   221  			}
   222  		}
   223  	}
   224  
   225  	// Sort (so that this method is deterministic) and print edges.
   226  	sort.Strings(edges)
   227  	for _, e := range edges {
   228  		fmt.Fprintln(stdout, e)
   229  	}
   230  
   231  	return nil
   232  }
   233  
   234  func (g graph) somepath(from, to string) error {
   235  	// Search breadth-first so that we return a minimal path.
   236  
   237  	// A path is a linked list whose head is a candidate "to" node
   238  	// and whose tail is the path ending in the "from" node.
   239  	type path struct {
   240  		node string
   241  		tail *path
   242  	}
   243  
   244  	seen := nodeset{from: true}
   245  
   246  	var queue []*path
   247  	queue = append(queue, &path{node: from, tail: nil})
   248  	for len(queue) > 0 {
   249  		p := queue[0]
   250  		queue = queue[1:]
   251  
   252  		if p.node == to {
   253  			// Found a path. Print, tail first.
   254  			var print func(p *path)
   255  			print = func(p *path) {
   256  				if p.tail != nil {
   257  					print(p.tail)
   258  					fmt.Fprintln(stdout, p.tail.node+" "+p.node)
   259  				}
   260  			}
   261  			print(p)
   262  			return nil
   263  		}
   264  
   265  		for succ := range g[p.node] {
   266  			if !seen[succ] {
   267  				seen[succ] = true
   268  				queue = append(queue, &path{node: succ, tail: p})
   269  			}
   270  		}
   271  	}
   272  	return fmt.Errorf("no path from %q to %q", from, to)
   273  }
   274  
   275  func (g graph) toDot(w *bytes.Buffer) {
   276  	fmt.Fprintln(w, "digraph {")
   277  	for _, src := range g.nodelist() {
   278  		for _, dst := range g[src].sort() {
   279  			// Dot's quoting rules appear to align with Go's for escString,
   280  			// which is the syntax of node IDs. Labels require significantly
   281  			// more quoting, but that appears not to be necessary if the node ID
   282  			// is implicitly used as the label.
   283  			fmt.Fprintf(w, "\t%q -> %q;\n", src, dst)
   284  		}
   285  	}
   286  	fmt.Fprintln(w, "}")
   287  }
   288  
   289  func parse(rd io.Reader) (graph, error) {
   290  	g := make(graph)
   291  
   292  	var linenum int
   293  	// We avoid bufio.Scanner as it imposes a (configurable) limit
   294  	// on line length, whereas Reader.ReadString does not.
   295  	in := bufio.NewReader(rd)
   296  	for {
   297  		linenum++
   298  		line, err := in.ReadString('\n')
   299  		eof := false
   300  		if err == io.EOF {
   301  			eof = true
   302  		} else if err != nil {
   303  			return nil, err
   304  		}
   305  		// Split into words, honoring double-quotes per Go spec.
   306  		words, err := split(line)
   307  		if err != nil {
   308  			return nil, fmt.Errorf("at line %d: %v", linenum, err)
   309  		}
   310  		if len(words) > 0 {
   311  			g.addEdges(words[0], words[1:]...)
   312  		}
   313  		if eof {
   314  			break
   315  		}
   316  	}
   317  	return g, nil
   318  }
   319  
   320  // Overridable for redirection.
   321  var stdin io.Reader = os.Stdin
   322  var stdout io.Writer = os.Stdout
   323  
   324  func digraph(cmd string, args []string) error {
   325  	// Parse the input graph.
   326  	g, err := parse(stdin)
   327  	if err != nil {
   328  		return err
   329  	}
   330  
   331  	// Parse the command line.
   332  	switch cmd {
   333  	case "nodes":
   334  		if len(args) != 0 {
   335  			return fmt.Errorf("usage: digraph nodes")
   336  		}
   337  		g.nodelist().println("\n")
   338  
   339  	case "degree":
   340  		if len(args) != 0 {
   341  			return fmt.Errorf("usage: digraph degree")
   342  		}
   343  		nodes := make(nodeset)
   344  		for node := range g {
   345  			nodes[node] = true
   346  		}
   347  		rev := g.transpose()
   348  		for _, node := range nodes.sort() {
   349  			fmt.Fprintf(stdout, "%d\t%d\t%s\n", len(rev[node]), len(g[node]), node)
   350  		}
   351  
   352  	case "transpose":
   353  		if len(args) != 0 {
   354  			return fmt.Errorf("usage: digraph transpose")
   355  		}
   356  		var revEdges []string
   357  		for node, succs := range g.transpose() {
   358  			for succ := range succs {
   359  				revEdges = append(revEdges, fmt.Sprintf("%s %s", node, succ))
   360  			}
   361  		}
   362  		sort.Strings(revEdges) // make output deterministic
   363  		for _, e := range revEdges {
   364  			fmt.Fprintln(stdout, e)
   365  		}
   366  
   367  	case "succs", "preds":
   368  		if len(args) == 0 {
   369  			return fmt.Errorf("usage: digraph %s <node> ... ", cmd)
   370  		}
   371  		g := g
   372  		if cmd == "preds" {
   373  			g = g.transpose()
   374  		}
   375  		result := make(nodeset)
   376  		for _, root := range args {
   377  			edges := g[root]
   378  			if edges == nil {
   379  				return fmt.Errorf("no such node %q", root)
   380  			}
   381  			result.addAll(edges)
   382  		}
   383  		result.sort().println("\n")
   384  
   385  	case "forward", "reverse":
   386  		if len(args) == 0 {
   387  			return fmt.Errorf("usage: digraph %s <node> ... ", cmd)
   388  		}
   389  		roots := make(nodeset)
   390  		for _, root := range args {
   391  			if g[root] == nil {
   392  				return fmt.Errorf("no such node %q", root)
   393  			}
   394  			roots[root] = true
   395  		}
   396  		g := g
   397  		if cmd == "reverse" {
   398  			g = g.transpose()
   399  		}
   400  		g.reachableFrom(roots).sort().println("\n")
   401  
   402  	case "somepath":
   403  		if len(args) != 2 {
   404  			return fmt.Errorf("usage: digraph somepath <from> <to>")
   405  		}
   406  		from, to := args[0], args[1]
   407  		if g[from] == nil {
   408  			return fmt.Errorf("no such 'from' node %q", from)
   409  		}
   410  		if g[to] == nil {
   411  			return fmt.Errorf("no such 'to' node %q", to)
   412  		}
   413  		if err := g.somepath(from, to); err != nil {
   414  			return err
   415  		}
   416  
   417  	case "allpaths":
   418  		if len(args) != 2 {
   419  			return fmt.Errorf("usage: digraph allpaths <from> <to>")
   420  		}
   421  		from, to := args[0], args[1]
   422  		if g[from] == nil {
   423  			return fmt.Errorf("no such 'from' node %q", from)
   424  		}
   425  		if g[to] == nil {
   426  			return fmt.Errorf("no such 'to' node %q", to)
   427  		}
   428  		if err := g.allpaths(from, to); err != nil {
   429  			return err
   430  		}
   431  
   432  	case "sccs":
   433  		if len(args) != 0 {
   434  			return fmt.Errorf("usage: digraph sccs")
   435  		}
   436  		buf := new(bytes.Buffer)
   437  		oldStdout := stdout
   438  		stdout = buf
   439  		for _, scc := range g.sccs() {
   440  			scc.sort().println(" ")
   441  		}
   442  		lines := strings.SplitAfter(buf.String(), "\n")
   443  		sort.Strings(lines)
   444  		stdout = oldStdout
   445  		io.WriteString(stdout, strings.Join(lines, ""))
   446  
   447  	case "scc":
   448  		if len(args) != 1 {
   449  			return fmt.Errorf("usage: digraph scc <node>")
   450  		}
   451  		node := args[0]
   452  		if g[node] == nil {
   453  			return fmt.Errorf("no such node %q", node)
   454  		}
   455  		for _, scc := range g.sccs() {
   456  			if scc[node] {
   457  				scc.sort().println("\n")
   458  				break
   459  			}
   460  		}
   461  
   462  	case "focus":
   463  		if len(args) != 1 {
   464  			return fmt.Errorf("usage: digraph focus <node>")
   465  		}
   466  		node := args[0]
   467  		if g[node] == nil {
   468  			return fmt.Errorf("no such node %q", node)
   469  		}
   470  
   471  		edges := make(map[string]struct{})
   472  		for from := range g.reachableFrom(nodeset{node: true}) {
   473  			for to := range g[from] {
   474  				edges[fmt.Sprintf("%s %s", from, to)] = struct{}{}
   475  			}
   476  		}
   477  
   478  		gtrans := g.transpose()
   479  		for from := range gtrans.reachableFrom(nodeset{node: true}) {
   480  			for to := range gtrans[from] {
   481  				edges[fmt.Sprintf("%s %s", to, from)] = struct{}{}
   482  			}
   483  		}
   484  
   485  		edgesSorted := make([]string, 0, len(edges))
   486  		for e := range edges {
   487  			edgesSorted = append(edgesSorted, e)
   488  		}
   489  		sort.Strings(edgesSorted)
   490  		fmt.Fprintln(stdout, strings.Join(edgesSorted, "\n"))
   491  
   492  	case "to":
   493  		if len(args) != 1 || args[0] != "dot" {
   494  			return fmt.Errorf("usage: digraph to dot")
   495  		}
   496  		var b bytes.Buffer
   497  		g.toDot(&b)
   498  		stdout.Write(b.Bytes())
   499  
   500  	default:
   501  		return fmt.Errorf("no such command %q", cmd)
   502  	}
   503  
   504  	return nil
   505  }
   506  
   507  // -- Utilities --------------------------------------------------------
   508  
   509  // split splits a line into words, which are generally separated by
   510  // spaces, but Go-style double-quoted string literals are also supported.
   511  // (This approximates the behaviour of the Bourne shell.)
   512  //
   513  //	`one "two three"` -> ["one" "two three"]
   514  //	`a"\n"b` -> ["a\nb"]
   515  func split(line string) ([]string, error) {
   516  	var (
   517  		words   []string
   518  		inWord  bool
   519  		current bytes.Buffer
   520  	)
   521  
   522  	for len(line) > 0 {
   523  		r, size := utf8.DecodeRuneInString(line)
   524  		if unicode.IsSpace(r) {
   525  			if inWord {
   526  				words = append(words, current.String())
   527  				current.Reset()
   528  				inWord = false
   529  			}
   530  		} else if r == '"' {
   531  			var ok bool
   532  			size, ok = quotedLength(line)
   533  			if !ok {
   534  				return nil, errors.New("invalid quotation")
   535  			}
   536  			s, err := strconv.Unquote(line[:size])
   537  			if err != nil {
   538  				return nil, err
   539  			}
   540  			current.WriteString(s)
   541  			inWord = true
   542  		} else {
   543  			current.WriteRune(r)
   544  			inWord = true
   545  		}
   546  		line = line[size:]
   547  	}
   548  	if inWord {
   549  		words = append(words, current.String())
   550  	}
   551  	return words, nil
   552  }
   553  
   554  // quotedLength returns the length in bytes of the prefix of input that
   555  // contain a possibly-valid double-quoted Go string literal.
   556  //
   557  // On success, n is at least two (""); input[:n] may be passed to
   558  // strconv.Unquote to interpret its value, and input[n:] contains the
   559  // rest of the input.
   560  //
   561  // On failure, quotedLength returns false, and the entire input can be
   562  // passed to strconv.Unquote if an informative error message is desired.
   563  //
   564  // quotedLength does not and need not detect all errors, such as
   565  // invalid hex or octal escape sequences, since it assumes
   566  // strconv.Unquote will be applied to the prefix.  It guarantees only
   567  // that if there is a prefix of input containing a valid string literal,
   568  // its length is returned.
   569  //
   570  // TODO(adonovan): move this into a strconv-like utility package.
   571  func quotedLength(input string) (n int, ok bool) {
   572  	var offset int
   573  
   574  	// next returns the rune at offset, or -1 on EOF.
   575  	// offset advances to just after that rune.
   576  	next := func() rune {
   577  		if offset < len(input) {
   578  			r, size := utf8.DecodeRuneInString(input[offset:])
   579  			offset += size
   580  			return r
   581  		}
   582  		return -1
   583  	}
   584  
   585  	if next() != '"' {
   586  		return // error: not a quotation
   587  	}
   588  
   589  	for {
   590  		r := next()
   591  		if r == '\n' || r < 0 {
   592  			return // error: string literal not terminated
   593  		}
   594  		if r == '"' {
   595  			return offset, true // success
   596  		}
   597  		if r == '\\' {
   598  			var skip int
   599  			switch next() {
   600  			case 'a', 'b', 'f', 'n', 'r', 't', 'v', '\\', '"':
   601  				skip = 0
   602  			case '0', '1', '2', '3', '4', '5', '6', '7':
   603  				skip = 2
   604  			case 'x':
   605  				skip = 2
   606  			case 'u':
   607  				skip = 4
   608  			case 'U':
   609  				skip = 8
   610  			default:
   611  				return // error: invalid escape
   612  			}
   613  
   614  			for i := 0; i < skip; i++ {
   615  				next()
   616  			}
   617  		}
   618  	}
   619  }
   620  

View as plain text