...

Source file src/gonum.org/v1/plot/plotter/sankey.go

Documentation: gonum.org/v1/plot/plotter

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package plotter
     6  
     7  import (
     8  	"fmt"
     9  	"image/color"
    10  	"math"
    11  	"sort"
    12  
    13  	"gonum.org/v1/plot"
    14  	"gonum.org/v1/plot/font"
    15  	"gonum.org/v1/plot/text"
    16  	"gonum.org/v1/plot/tools/bezier"
    17  	"gonum.org/v1/plot/vg"
    18  	"gonum.org/v1/plot/vg/draw"
    19  )
    20  
    21  // A Sankey diagram presents stock and flow data as rectangles representing
    22  // the amount of each stock and lines between the stocks representing the
    23  // amount of each flow.
    24  type Sankey struct {
    25  	// Color specifies the default fill
    26  	// colors for the stocks and flows. If Color is not nil,
    27  	// each stock and flow is rendered filled with Color,
    28  	// otherwise no fill is performed. Colors can be
    29  	// modified for individual stocks and flows.
    30  	Color color.Color
    31  
    32  	// StockBarWidth is the widths of the bars representing
    33  	// the stocks. The default value is 15% larger than the
    34  	// height of the stock label text.
    35  	StockBarWidth vg.Length
    36  
    37  	// LineStyle specifies the default border
    38  	// line style for the stocks and flows. Styles can be
    39  	// modified for individual stocks and flows.
    40  	LineStyle draw.LineStyle
    41  
    42  	// TextStyle specifies the default stock label
    43  	// text style. Styles can be modified for
    44  	// individual stocks.
    45  	TextStyle text.Style
    46  
    47  	flows []Flow
    48  
    49  	// FlowStyle is a function that specifies the
    50  	// background color and border line style of the
    51  	// flow based on its group name. The default
    52  	// function uses the default Color and LineStyle
    53  	// specified above for all groups.
    54  	FlowStyle func(group string) (color.Color, draw.LineStyle)
    55  
    56  	// StockStyle is a function that specifies, for a stock
    57  	// identified by its label and category, the label text
    58  	// to be printed on the plot (lbl), the style of the text (ts),
    59  	// the horizontal and vertical offsets for printing the text (xOff and yOff),
    60  	// the color of the fill for the bar representing the stock (c),
    61  	// and the style of the outline of the bar representing the stock (ls).
    62  	// The default function uses the default TextStyle, color and LineStyle
    63  	// specified above for all stocks; zero horizontal and vertical offsets;
    64  	// and the stock label as the text to be printed on the plot.
    65  	StockStyle func(label string, category int) (lbl string, ts text.Style, xOff, yOff vg.Length, c color.Color, ls draw.LineStyle)
    66  
    67  	// stocks arranges the stocks by category.
    68  	// The first key is the category and the seond
    69  	// key is the label.
    70  	stocks map[int]map[string]*stock
    71  }
    72  
    73  // StockRange returns the minimum and maximum value on the value axis
    74  // for the stock with the specified label and category.
    75  func (s *Sankey) StockRange(label string, category int) (min, max float64, err error) {
    76  	stk, ok := s.stocks[category][label]
    77  	if !ok {
    78  		return 0, 0, fmt.Errorf("plotter: sankey diagram does not contain stock with label=%s and category=%d", label, category)
    79  	}
    80  	return stk.min, stk.max, nil
    81  }
    82  
    83  // stock represents the amount of a stock and its plotting order.
    84  type stock struct {
    85  	// receptorValue and sourceValue are the totals of the values
    86  	// of flows coming into and going out of this stock, respectively.
    87  	receptorValue, sourceValue float64
    88  
    89  	// label is the label of this stock, and category represents
    90  	// its placement on the category axis. Together they make up a
    91  	// unique identifier.
    92  	label    string
    93  	category int
    94  
    95  	// order is the plotting order of this stock compared
    96  	// to other stocks in the same category.
    97  	order int
    98  
    99  	// min represents the beginning of the plotting location
   100  	// on the value axis.
   101  	min float64
   102  
   103  	// max is min plus the larger of receptorValue and sourceValue.
   104  	max float64
   105  }
   106  
   107  // A Flow represents the amount of an entity flowing between two stocks.
   108  type Flow struct {
   109  	// SourceLabel and ReceptorLabel are the labels
   110  	// of the stocks that originate and receive the flow,
   111  	// respectively.
   112  	SourceLabel, ReceptorLabel string
   113  
   114  	// SourceCategory and ReceptorCategory define
   115  	// the locations on the category axis of the stocks that
   116  	// originate and receive the flow, respectively. The
   117  	// SourceCategory must be a lower number than
   118  	// the ReceptorCategory.
   119  	SourceCategory, ReceptorCategory int
   120  
   121  	// Value represents the magnitute of the flow.
   122  	// It must be greater than or equal to zero.
   123  	Value float64
   124  
   125  	// Group specifies the group that a flow belongs
   126  	// to. It is used in assigning styles to groups
   127  	// and creating legends.
   128  	Group string
   129  }
   130  
   131  // NewSankey creates a new Sankey diagram with the specified
   132  // flows and stocks.
   133  func NewSankey(flows ...Flow) (*Sankey, error) {
   134  	var s Sankey
   135  
   136  	s.stocks = make(map[int]map[string]*stock)
   137  
   138  	s.flows = flows
   139  	for i, f := range flows {
   140  		// Here we make sure the stock categories are in the proper order.
   141  		if f.SourceCategory >= f.ReceptorCategory {
   142  			return nil, fmt.Errorf("plotter: Flow %d SourceCategory (%d) >= ReceptorCategory (%d)", i, f.SourceCategory, f.ReceptorCategory)
   143  		}
   144  		if f.Value < 0 {
   145  			return nil, fmt.Errorf("plotter: Flow %d value (%g) < 0", i, f.Value)
   146  		}
   147  
   148  		// Here we initialize the stock holders.
   149  		if _, ok := s.stocks[f.SourceCategory]; !ok {
   150  			s.stocks[f.SourceCategory] = make(map[string]*stock)
   151  		}
   152  		if _, ok := s.stocks[f.ReceptorCategory]; !ok {
   153  			s.stocks[f.ReceptorCategory] = make(map[string]*stock)
   154  		}
   155  
   156  		// Here we figure out the plotting order of the stocks.
   157  		if _, ok := s.stocks[f.SourceCategory][f.SourceLabel]; !ok {
   158  			s.stocks[f.SourceCategory][f.SourceLabel] = &stock{
   159  				order:    len(s.stocks[f.SourceCategory]),
   160  				label:    f.SourceLabel,
   161  				category: f.SourceCategory,
   162  			}
   163  		}
   164  		if _, ok := s.stocks[f.ReceptorCategory][f.ReceptorLabel]; !ok {
   165  			s.stocks[f.ReceptorCategory][f.ReceptorLabel] = &stock{
   166  				order:    len(s.stocks[f.ReceptorCategory]),
   167  				label:    f.ReceptorLabel,
   168  				category: f.ReceptorCategory,
   169  			}
   170  		}
   171  
   172  		// Here we add the current value to the total value of the stocks
   173  		s.stocks[f.SourceCategory][f.SourceLabel].sourceValue += f.Value
   174  		s.stocks[f.ReceptorCategory][f.ReceptorLabel].receptorValue += f.Value
   175  	}
   176  
   177  	s.LineStyle = DefaultLineStyle
   178  
   179  	s.TextStyle = text.Style{
   180  		Font:     font.From(DefaultFont, DefaultFontSize),
   181  		Rotation: math.Pi / 2,
   182  		XAlign:   draw.XCenter,
   183  		YAlign:   draw.YCenter,
   184  		Handler:  plot.DefaultTextHandler,
   185  	}
   186  	s.StockBarWidth = s.TextStyle.FontExtents().Height * 1.15
   187  
   188  	s.FlowStyle = func(_ string) (color.Color, draw.LineStyle) {
   189  		return s.Color, s.LineStyle
   190  	}
   191  
   192  	s.StockStyle = func(label string, category int) (string, text.Style, vg.Length, vg.Length, color.Color, draw.LineStyle) {
   193  		return label, s.TextStyle, 0, 0, s.Color, s.LineStyle
   194  	}
   195  
   196  	stocks := s.stockList()
   197  	s.setStockRange(&stocks)
   198  
   199  	return &s, nil
   200  }
   201  
   202  // Plot implements the plot.Plotter interface.
   203  func (s *Sankey) Plot(c draw.Canvas, plt *plot.Plot) {
   204  	trCat, trVal := plt.Transforms(&c)
   205  
   206  	// sourceFlowPlaceholder and receptorFlowPlaceholder track
   207  	// the current plotting location during
   208  	// the plotting process.
   209  	sourceFlowPlaceholder := make(map[*stock]float64, len(s.flows))
   210  	receptorFlowPlaceholder := make(map[*stock]float64, len(s.flows))
   211  
   212  	// Here we draw the flows.
   213  	for _, f := range s.flows {
   214  		startStock := s.stocks[f.SourceCategory][f.SourceLabel]
   215  		endStock := s.stocks[f.ReceptorCategory][f.ReceptorLabel]
   216  		catStart := trCat(float64(f.SourceCategory)) + s.StockBarWidth/2
   217  		catEnd := trCat(float64(f.ReceptorCategory)) - s.StockBarWidth/2
   218  		valStartLow := trVal(startStock.min + sourceFlowPlaceholder[startStock])
   219  		valEndLow := trVal(endStock.min + receptorFlowPlaceholder[endStock])
   220  		valStartHigh := trVal(startStock.min + sourceFlowPlaceholder[startStock] + f.Value)
   221  		valEndHigh := trVal(endStock.min + receptorFlowPlaceholder[endStock] + f.Value)
   222  		sourceFlowPlaceholder[startStock] += f.Value
   223  		receptorFlowPlaceholder[endStock] += f.Value
   224  
   225  		ptsLow := s.bezier(
   226  			vg.Point{X: catStart, Y: valStartLow},
   227  			vg.Point{X: catEnd, Y: valEndLow},
   228  		)
   229  		ptsHigh := s.bezier(
   230  			vg.Point{X: catEnd, Y: valEndHigh},
   231  			vg.Point{X: catStart, Y: valStartHigh},
   232  		)
   233  
   234  		color, lineStyle := s.FlowStyle(f.Group)
   235  
   236  		// Here we fill the flow polygons.
   237  		if color != nil {
   238  			poly := c.ClipPolygonX(append(ptsLow, ptsHigh...))
   239  			c.FillPolygon(color, poly)
   240  		}
   241  
   242  		// Here we draw the flow edges.
   243  		outline := c.ClipLinesX(ptsLow)
   244  		c.StrokeLines(lineStyle, outline...)
   245  		outline = c.ClipLinesX(ptsHigh)
   246  		c.StrokeLines(lineStyle, outline...)
   247  	}
   248  
   249  	// Here we draw the stocks.
   250  	for _, stk := range s.stockList() {
   251  		catLoc := trCat(float64(stk.category))
   252  		if !c.ContainsX(catLoc) {
   253  			continue
   254  		}
   255  		catMin, catMax := catLoc-s.StockBarWidth/2, catLoc+s.StockBarWidth/2
   256  		valMin, valMax := trVal(stk.min), trVal(stk.max)
   257  
   258  		label, textStyle, xOff, yOff, color, lineStyle := s.StockStyle(stk.label, stk.category)
   259  
   260  		// Here we fill the stock bars.
   261  		pts := []vg.Point{
   262  			{X: catMin, Y: valMin},
   263  			{X: catMin, Y: valMax},
   264  			{X: catMax, Y: valMax},
   265  			{X: catMax, Y: valMin},
   266  		}
   267  		if color != nil {
   268  			c.FillPolygon(color, pts) // poly)
   269  		}
   270  		txtPt := vg.Point{X: (catMin+catMax)/2 + xOff, Y: (valMin+valMax)/2 + yOff}
   271  		c.FillText(textStyle, txtPt, label)
   272  
   273  		// Here we draw the bottom edge.
   274  		pts = []vg.Point{
   275  			{X: catMin, Y: valMin},
   276  			{X: catMax, Y: valMin},
   277  		}
   278  		c.StrokeLines(lineStyle, pts)
   279  
   280  		// Here we draw the top edge plus vertical edges where there are
   281  		// no flows connected.
   282  		pts = []vg.Point{
   283  			{X: catMin, Y: valMax},
   284  			{X: catMax, Y: valMax},
   285  		}
   286  		if stk.receptorValue < stk.sourceValue {
   287  			y := trVal(stk.max - (stk.sourceValue - stk.receptorValue))
   288  			pts = append([]vg.Point{{X: catMin, Y: y}}, pts...)
   289  		} else if stk.sourceValue < stk.receptorValue {
   290  			y := trVal(stk.max - (stk.receptorValue - stk.sourceValue))
   291  			pts = append(pts, vg.Point{X: catMax, Y: y})
   292  		}
   293  		c.StrokeLines(lineStyle, pts)
   294  	}
   295  }
   296  
   297  // stockList returns a sorted list of the stocks in the diagram.
   298  func (s *Sankey) stockList() []*stock {
   299  	var stocks []*stock
   300  	for _, ss := range s.stocks {
   301  		for _, sss := range ss {
   302  			stocks = append(stocks, sss)
   303  		}
   304  	}
   305  	sort.Sort(stockSorter(stocks))
   306  	return stocks
   307  }
   308  
   309  // stockSorter is a wrapper for a list of *stocks that implements
   310  // sort.Interface.
   311  type stockSorter []*stock
   312  
   313  func (s stockSorter) Len() int      { return len(s) }
   314  func (s stockSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
   315  func (s stockSorter) Less(i, j int) bool {
   316  	if s[i].category != s[j].category {
   317  		return s[i].category < s[j].category
   318  	}
   319  	if s[i].order != s[j].order {
   320  		return s[i].order < s[j].order
   321  	}
   322  	panic(fmt.Errorf("plotter: can't sort stocks:\n%+v\n%+v", s[i], s[j]))
   323  }
   324  
   325  // setStockRange sets the minimum and maximum values of the stock plotting locations.
   326  func (s *Sankey) setStockRange(stocks *[]*stock) {
   327  	var cat int
   328  	var min float64
   329  	for _, stk := range *stocks {
   330  		if stk.category != cat {
   331  			min = 0
   332  		}
   333  		cat = stk.category
   334  		stk.min = min
   335  		if stk.sourceValue > stk.receptorValue {
   336  			stk.max = stk.min + stk.sourceValue
   337  		} else {
   338  			stk.max = stk.min + stk.receptorValue
   339  		}
   340  		min = stk.max
   341  	}
   342  }
   343  
   344  // bezier creates a bezier curve between the begin and end points.
   345  func (s *Sankey) bezier(begin, end vg.Point) []vg.Point {
   346  	// directionOffsetFrac is the fraction of the distance between begin.X and
   347  	// end.X for the bezier control points.
   348  	const directionOffsetFrac = 0.3
   349  	inPts := []vg.Point{
   350  		begin,
   351  		{X: begin.X + (end.X-begin.X)*directionOffsetFrac, Y: begin.Y},
   352  		{X: begin.X + (end.X-begin.X)*(1-directionOffsetFrac), Y: end.Y},
   353  		end,
   354  	}
   355  	curve := bezier.New(inPts...)
   356  
   357  	// nPoints is the number of points for bezier interpolation.
   358  	const nPoints = 20
   359  	outPts := make([]vg.Point, nPoints)
   360  	curve.Curve(outPts)
   361  	return outPts
   362  }
   363  
   364  // DataRange implements the plot.DataRanger interface.
   365  func (s *Sankey) DataRange() (xmin, xmax, ymin, ymax float64) {
   366  	catMin := math.Inf(1)
   367  	catMax := math.Inf(-1)
   368  	for cat := range s.stocks {
   369  		c := float64(cat)
   370  		catMin = math.Min(catMin, c)
   371  		catMax = math.Max(catMax, c)
   372  	}
   373  
   374  	stocks := s.stockList()
   375  	valMin := math.Inf(1)
   376  	valMax := math.Inf(-1)
   377  	for _, stk := range stocks {
   378  		valMin = math.Min(valMin, stk.min)
   379  		valMax = math.Max(valMax, stk.max)
   380  	}
   381  	return catMin, catMax, valMin, valMax
   382  }
   383  
   384  // GlyphBoxes implements the GlyphBoxer interface.
   385  func (s *Sankey) GlyphBoxes(plt *plot.Plot) []plot.GlyphBox {
   386  	stocks := s.stockList()
   387  	boxes := make([]plot.GlyphBox, 0, len(s.flows)+len(stocks))
   388  
   389  	for _, stk := range stocks {
   390  		b1 := plot.GlyphBox{
   391  			X: plt.X.Norm(float64(stk.category)),
   392  			Y: plt.Y.Norm((stk.min + stk.max) / 2),
   393  			Rectangle: vg.Rectangle{
   394  				Min: vg.Point{X: -s.StockBarWidth / 2},
   395  				Max: vg.Point{X: s.StockBarWidth / 2},
   396  			},
   397  		}
   398  		label, textStyle, xOff, yOff, _, _ := s.StockStyle(stk.label, stk.category)
   399  		rect := textStyle.Rectangle(label)
   400  		rect.Min.X += xOff
   401  		rect.Max.X += xOff
   402  		rect.Min.Y += yOff
   403  		rect.Max.Y += yOff
   404  		b2 := plot.GlyphBox{
   405  			X:         plt.X.Norm(float64(stk.category)),
   406  			Y:         plt.Y.Norm((stk.min + stk.max) / 2),
   407  			Rectangle: rect,
   408  		}
   409  		boxes = append(boxes, b1, b2)
   410  	}
   411  	return boxes
   412  }
   413  
   414  // Thumbnailers creates a group of objects that can be used to
   415  // add legend entries for the different flow groups in this
   416  // diagram, as well as the flow group labels that correspond to them.
   417  func (s *Sankey) Thumbnailers() (legendLabels []string, thumbnailers []plot.Thumbnailer) {
   418  	type empty struct{}
   419  	flowGroups := make(map[string]empty)
   420  	for _, f := range s.flows {
   421  		flowGroups[f.Group] = empty{}
   422  	}
   423  	legendLabels = make([]string, len(flowGroups))
   424  	thumbnailers = make([]plot.Thumbnailer, len(flowGroups))
   425  	i := 0
   426  	for g := range flowGroups {
   427  		legendLabels[i] = g
   428  		i++
   429  	}
   430  	sort.Strings(legendLabels)
   431  
   432  	for i, g := range legendLabels {
   433  		var thmb sankeyFlowThumbnailer
   434  		thmb.Color, thmb.LineStyle = s.FlowStyle(g)
   435  		thumbnailers[i] = plot.Thumbnailer(thmb)
   436  	}
   437  	return
   438  }
   439  
   440  // sankeyFlowThumbnailer implements the Thumbnailer interface
   441  // for Sankey flow groups.
   442  type sankeyFlowThumbnailer struct {
   443  	draw.LineStyle
   444  	color.Color
   445  }
   446  
   447  // Thumbnail fulfills the plot.Thumbnailer interface.
   448  func (t sankeyFlowThumbnailer) Thumbnail(c *draw.Canvas) {
   449  	// Here we draw the fill.
   450  	pts := []vg.Point{
   451  		{X: c.Min.X, Y: c.Min.Y},
   452  		{X: c.Min.X, Y: c.Max.Y},
   453  		{X: c.Max.X, Y: c.Max.Y},
   454  		{X: c.Max.X, Y: c.Min.Y},
   455  	}
   456  	poly := c.ClipPolygonY(pts)
   457  	c.FillPolygon(t.Color, poly)
   458  
   459  	// Here we draw the upper border.
   460  	pts = []vg.Point{
   461  		{X: c.Min.X, Y: c.Max.Y},
   462  		{X: c.Max.X, Y: c.Max.Y},
   463  	}
   464  	outline := c.ClipLinesY(pts)
   465  	c.StrokeLines(t.LineStyle, outline...)
   466  
   467  	// Here we draw the lower border.
   468  	pts = []vg.Point{
   469  		{X: c.Min.X, Y: c.Min.Y},
   470  		{X: c.Max.X, Y: c.Min.Y},
   471  	}
   472  	outline = c.ClipLinesY(pts)
   473  	c.StrokeLines(t.LineStyle, outline...)
   474  }
   475  

View as plain text