1
2
3
4
5 package plotter
6
7 import (
8 "fmt"
9 "image/color"
10 "math"
11 "sort"
12
13 "gonum.org/v1/plot"
14 "gonum.org/v1/plot/font"
15 "gonum.org/v1/plot/text"
16 "gonum.org/v1/plot/tools/bezier"
17 "gonum.org/v1/plot/vg"
18 "gonum.org/v1/plot/vg/draw"
19 )
20
21
22
23
24 type Sankey struct {
25
26
27
28
29
30 Color color.Color
31
32
33
34
35 StockBarWidth vg.Length
36
37
38
39
40 LineStyle draw.LineStyle
41
42
43
44
45 TextStyle text.Style
46
47 flows []Flow
48
49
50
51
52
53
54 FlowStyle func(group string) (color.Color, draw.LineStyle)
55
56
57
58
59
60
61
62
63
64
65 StockStyle func(label string, category int) (lbl string, ts text.Style, xOff, yOff vg.Length, c color.Color, ls draw.LineStyle)
66
67
68
69
70 stocks map[int]map[string]*stock
71 }
72
73
74
75 func (s *Sankey) StockRange(label string, category int) (min, max float64, err error) {
76 stk, ok := s.stocks[category][label]
77 if !ok {
78 return 0, 0, fmt.Errorf("plotter: sankey diagram does not contain stock with label=%s and category=%d", label, category)
79 }
80 return stk.min, stk.max, nil
81 }
82
83
84 type stock struct {
85
86
87 receptorValue, sourceValue float64
88
89
90
91
92 label string
93 category int
94
95
96
97 order int
98
99
100
101 min float64
102
103
104 max float64
105 }
106
107
108 type Flow struct {
109
110
111
112 SourceLabel, ReceptorLabel string
113
114
115
116
117
118
119 SourceCategory, ReceptorCategory int
120
121
122
123 Value float64
124
125
126
127
128 Group string
129 }
130
131
132
133 func NewSankey(flows ...Flow) (*Sankey, error) {
134 var s Sankey
135
136 s.stocks = make(map[int]map[string]*stock)
137
138 s.flows = flows
139 for i, f := range flows {
140
141 if f.SourceCategory >= f.ReceptorCategory {
142 return nil, fmt.Errorf("plotter: Flow %d SourceCategory (%d) >= ReceptorCategory (%d)", i, f.SourceCategory, f.ReceptorCategory)
143 }
144 if f.Value < 0 {
145 return nil, fmt.Errorf("plotter: Flow %d value (%g) < 0", i, f.Value)
146 }
147
148
149 if _, ok := s.stocks[f.SourceCategory]; !ok {
150 s.stocks[f.SourceCategory] = make(map[string]*stock)
151 }
152 if _, ok := s.stocks[f.ReceptorCategory]; !ok {
153 s.stocks[f.ReceptorCategory] = make(map[string]*stock)
154 }
155
156
157 if _, ok := s.stocks[f.SourceCategory][f.SourceLabel]; !ok {
158 s.stocks[f.SourceCategory][f.SourceLabel] = &stock{
159 order: len(s.stocks[f.SourceCategory]),
160 label: f.SourceLabel,
161 category: f.SourceCategory,
162 }
163 }
164 if _, ok := s.stocks[f.ReceptorCategory][f.ReceptorLabel]; !ok {
165 s.stocks[f.ReceptorCategory][f.ReceptorLabel] = &stock{
166 order: len(s.stocks[f.ReceptorCategory]),
167 label: f.ReceptorLabel,
168 category: f.ReceptorCategory,
169 }
170 }
171
172
173 s.stocks[f.SourceCategory][f.SourceLabel].sourceValue += f.Value
174 s.stocks[f.ReceptorCategory][f.ReceptorLabel].receptorValue += f.Value
175 }
176
177 s.LineStyle = DefaultLineStyle
178
179 s.TextStyle = text.Style{
180 Font: font.From(DefaultFont, DefaultFontSize),
181 Rotation: math.Pi / 2,
182 XAlign: draw.XCenter,
183 YAlign: draw.YCenter,
184 Handler: plot.DefaultTextHandler,
185 }
186 s.StockBarWidth = s.TextStyle.FontExtents().Height * 1.15
187
188 s.FlowStyle = func(_ string) (color.Color, draw.LineStyle) {
189 return s.Color, s.LineStyle
190 }
191
192 s.StockStyle = func(label string, category int) (string, text.Style, vg.Length, vg.Length, color.Color, draw.LineStyle) {
193 return label, s.TextStyle, 0, 0, s.Color, s.LineStyle
194 }
195
196 stocks := s.stockList()
197 s.setStockRange(&stocks)
198
199 return &s, nil
200 }
201
202
203 func (s *Sankey) Plot(c draw.Canvas, plt *plot.Plot) {
204 trCat, trVal := plt.Transforms(&c)
205
206
207
208
209 sourceFlowPlaceholder := make(map[*stock]float64, len(s.flows))
210 receptorFlowPlaceholder := make(map[*stock]float64, len(s.flows))
211
212
213 for _, f := range s.flows {
214 startStock := s.stocks[f.SourceCategory][f.SourceLabel]
215 endStock := s.stocks[f.ReceptorCategory][f.ReceptorLabel]
216 catStart := trCat(float64(f.SourceCategory)) + s.StockBarWidth/2
217 catEnd := trCat(float64(f.ReceptorCategory)) - s.StockBarWidth/2
218 valStartLow := trVal(startStock.min + sourceFlowPlaceholder[startStock])
219 valEndLow := trVal(endStock.min + receptorFlowPlaceholder[endStock])
220 valStartHigh := trVal(startStock.min + sourceFlowPlaceholder[startStock] + f.Value)
221 valEndHigh := trVal(endStock.min + receptorFlowPlaceholder[endStock] + f.Value)
222 sourceFlowPlaceholder[startStock] += f.Value
223 receptorFlowPlaceholder[endStock] += f.Value
224
225 ptsLow := s.bezier(
226 vg.Point{X: catStart, Y: valStartLow},
227 vg.Point{X: catEnd, Y: valEndLow},
228 )
229 ptsHigh := s.bezier(
230 vg.Point{X: catEnd, Y: valEndHigh},
231 vg.Point{X: catStart, Y: valStartHigh},
232 )
233
234 color, lineStyle := s.FlowStyle(f.Group)
235
236
237 if color != nil {
238 poly := c.ClipPolygonX(append(ptsLow, ptsHigh...))
239 c.FillPolygon(color, poly)
240 }
241
242
243 outline := c.ClipLinesX(ptsLow)
244 c.StrokeLines(lineStyle, outline...)
245 outline = c.ClipLinesX(ptsHigh)
246 c.StrokeLines(lineStyle, outline...)
247 }
248
249
250 for _, stk := range s.stockList() {
251 catLoc := trCat(float64(stk.category))
252 if !c.ContainsX(catLoc) {
253 continue
254 }
255 catMin, catMax := catLoc-s.StockBarWidth/2, catLoc+s.StockBarWidth/2
256 valMin, valMax := trVal(stk.min), trVal(stk.max)
257
258 label, textStyle, xOff, yOff, color, lineStyle := s.StockStyle(stk.label, stk.category)
259
260
261 pts := []vg.Point{
262 {X: catMin, Y: valMin},
263 {X: catMin, Y: valMax},
264 {X: catMax, Y: valMax},
265 {X: catMax, Y: valMin},
266 }
267 if color != nil {
268 c.FillPolygon(color, pts)
269 }
270 txtPt := vg.Point{X: (catMin+catMax)/2 + xOff, Y: (valMin+valMax)/2 + yOff}
271 c.FillText(textStyle, txtPt, label)
272
273
274 pts = []vg.Point{
275 {X: catMin, Y: valMin},
276 {X: catMax, Y: valMin},
277 }
278 c.StrokeLines(lineStyle, pts)
279
280
281
282 pts = []vg.Point{
283 {X: catMin, Y: valMax},
284 {X: catMax, Y: valMax},
285 }
286 if stk.receptorValue < stk.sourceValue {
287 y := trVal(stk.max - (stk.sourceValue - stk.receptorValue))
288 pts = append([]vg.Point{{X: catMin, Y: y}}, pts...)
289 } else if stk.sourceValue < stk.receptorValue {
290 y := trVal(stk.max - (stk.receptorValue - stk.sourceValue))
291 pts = append(pts, vg.Point{X: catMax, Y: y})
292 }
293 c.StrokeLines(lineStyle, pts)
294 }
295 }
296
297
298 func (s *Sankey) stockList() []*stock {
299 var stocks []*stock
300 for _, ss := range s.stocks {
301 for _, sss := range ss {
302 stocks = append(stocks, sss)
303 }
304 }
305 sort.Sort(stockSorter(stocks))
306 return stocks
307 }
308
309
310
311 type stockSorter []*stock
312
313 func (s stockSorter) Len() int { return len(s) }
314 func (s stockSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
315 func (s stockSorter) Less(i, j int) bool {
316 if s[i].category != s[j].category {
317 return s[i].category < s[j].category
318 }
319 if s[i].order != s[j].order {
320 return s[i].order < s[j].order
321 }
322 panic(fmt.Errorf("plotter: can't sort stocks:\n%+v\n%+v", s[i], s[j]))
323 }
324
325
326 func (s *Sankey) setStockRange(stocks *[]*stock) {
327 var cat int
328 var min float64
329 for _, stk := range *stocks {
330 if stk.category != cat {
331 min = 0
332 }
333 cat = stk.category
334 stk.min = min
335 if stk.sourceValue > stk.receptorValue {
336 stk.max = stk.min + stk.sourceValue
337 } else {
338 stk.max = stk.min + stk.receptorValue
339 }
340 min = stk.max
341 }
342 }
343
344
345 func (s *Sankey) bezier(begin, end vg.Point) []vg.Point {
346
347
348 const directionOffsetFrac = 0.3
349 inPts := []vg.Point{
350 begin,
351 {X: begin.X + (end.X-begin.X)*directionOffsetFrac, Y: begin.Y},
352 {X: begin.X + (end.X-begin.X)*(1-directionOffsetFrac), Y: end.Y},
353 end,
354 }
355 curve := bezier.New(inPts...)
356
357
358 const nPoints = 20
359 outPts := make([]vg.Point, nPoints)
360 curve.Curve(outPts)
361 return outPts
362 }
363
364
365 func (s *Sankey) DataRange() (xmin, xmax, ymin, ymax float64) {
366 catMin := math.Inf(1)
367 catMax := math.Inf(-1)
368 for cat := range s.stocks {
369 c := float64(cat)
370 catMin = math.Min(catMin, c)
371 catMax = math.Max(catMax, c)
372 }
373
374 stocks := s.stockList()
375 valMin := math.Inf(1)
376 valMax := math.Inf(-1)
377 for _, stk := range stocks {
378 valMin = math.Min(valMin, stk.min)
379 valMax = math.Max(valMax, stk.max)
380 }
381 return catMin, catMax, valMin, valMax
382 }
383
384
385 func (s *Sankey) GlyphBoxes(plt *plot.Plot) []plot.GlyphBox {
386 stocks := s.stockList()
387 boxes := make([]plot.GlyphBox, 0, len(s.flows)+len(stocks))
388
389 for _, stk := range stocks {
390 b1 := plot.GlyphBox{
391 X: plt.X.Norm(float64(stk.category)),
392 Y: plt.Y.Norm((stk.min + stk.max) / 2),
393 Rectangle: vg.Rectangle{
394 Min: vg.Point{X: -s.StockBarWidth / 2},
395 Max: vg.Point{X: s.StockBarWidth / 2},
396 },
397 }
398 label, textStyle, xOff, yOff, _, _ := s.StockStyle(stk.label, stk.category)
399 rect := textStyle.Rectangle(label)
400 rect.Min.X += xOff
401 rect.Max.X += xOff
402 rect.Min.Y += yOff
403 rect.Max.Y += yOff
404 b2 := plot.GlyphBox{
405 X: plt.X.Norm(float64(stk.category)),
406 Y: plt.Y.Norm((stk.min + stk.max) / 2),
407 Rectangle: rect,
408 }
409 boxes = append(boxes, b1, b2)
410 }
411 return boxes
412 }
413
414
415
416
417 func (s *Sankey) Thumbnailers() (legendLabels []string, thumbnailers []plot.Thumbnailer) {
418 type empty struct{}
419 flowGroups := make(map[string]empty)
420 for _, f := range s.flows {
421 flowGroups[f.Group] = empty{}
422 }
423 legendLabels = make([]string, len(flowGroups))
424 thumbnailers = make([]plot.Thumbnailer, len(flowGroups))
425 i := 0
426 for g := range flowGroups {
427 legendLabels[i] = g
428 i++
429 }
430 sort.Strings(legendLabels)
431
432 for i, g := range legendLabels {
433 var thmb sankeyFlowThumbnailer
434 thmb.Color, thmb.LineStyle = s.FlowStyle(g)
435 thumbnailers[i] = plot.Thumbnailer(thmb)
436 }
437 return
438 }
439
440
441
442 type sankeyFlowThumbnailer struct {
443 draw.LineStyle
444 color.Color
445 }
446
447
448 func (t sankeyFlowThumbnailer) Thumbnail(c *draw.Canvas) {
449
450 pts := []vg.Point{
451 {X: c.Min.X, Y: c.Min.Y},
452 {X: c.Min.X, Y: c.Max.Y},
453 {X: c.Max.X, Y: c.Max.Y},
454 {X: c.Max.X, Y: c.Min.Y},
455 }
456 poly := c.ClipPolygonY(pts)
457 c.FillPolygon(t.Color, poly)
458
459
460 pts = []vg.Point{
461 {X: c.Min.X, Y: c.Max.Y},
462 {X: c.Max.X, Y: c.Max.Y},
463 }
464 outline := c.ClipLinesY(pts)
465 c.StrokeLines(t.LineStyle, outline...)
466
467
468 pts = []vg.Point{
469 {X: c.Min.X, Y: c.Min.Y},
470 {X: c.Max.X, Y: c.Min.Y},
471 }
472 outline = c.ClipLinesY(pts)
473 c.StrokeLines(t.LineStyle, outline...)
474 }
475
View as plain text