...

Source file src/oss.terrastruct.com/d2/lib/textmeasure/markdown.go

Documentation: oss.terrastruct.com/d2/lib/textmeasure

     1  package textmeasure
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"math"
     7  	"strings"
     8  
     9  	"github.com/PuerkitoBio/goquery"
    10  	"github.com/yuin/goldmark"
    11  	"github.com/yuin/goldmark/extension"
    12  	goldmarkHtml "github.com/yuin/goldmark/renderer/html"
    13  	"golang.org/x/net/html"
    14  
    15  	"oss.terrastruct.com/util-go/go2"
    16  
    17  	"oss.terrastruct.com/d2/d2renderers/d2fonts"
    18  )
    19  
    20  var markdownRenderer goldmark.Markdown
    21  
    22  // these are css values from github-markdown.css so we can accurately compute the rendered dimensions
    23  const (
    24  	MarkdownFontSize   = d2fonts.FONT_SIZE_M
    25  	MarkdownLineHeight = 1.5
    26  
    27  	PaddingLeft_ul_ol_em = 2.
    28  	MarginBottom_ul      = 16.
    29  
    30  	MarginTop_li_p  = 16.
    31  	MarginTop_li_em = 0.25
    32  	MarginBottom_p  = 16.
    33  
    34  	LineHeight_h           = 1.25
    35  	MarginTop_h            = 24
    36  	MarginBottom_h         = 16
    37  	PaddingBottom_h1_h2_em = 0.3
    38  	BorderBottom_h1_h2     = 1
    39  
    40  	Height_hr_em       = 0.25
    41  	MarginTopBottom_hr = 24
    42  
    43  	Padding_pre          = 16
    44  	MarginBottom_pre     = 16
    45  	LineHeight_pre       = 1.45
    46  	FontSize_pre_code_em = 0.85
    47  
    48  	PaddingTopBottom_code_em = 0.2
    49  	PaddingLeftRight_code_em = 0.4
    50  
    51  	PaddingLR_blockquote_em  = 1.
    52  	MarginBottom_blockquote  = 16
    53  	BorderLeft_blockquote_em = 0.25
    54  
    55  	h1_em = 2.
    56  	h2_em = 1.5
    57  	h3_em = 1.25
    58  	h4_em = 1.
    59  	h5_em = 0.875
    60  	h6_em = 0.85
    61  )
    62  
    63  func HeaderToFontSize(baseFontSize int, header string) int {
    64  	switch header {
    65  	case "h1":
    66  		return int(h1_em * float64(baseFontSize))
    67  	case "h2":
    68  		return int(h2_em * float64(baseFontSize))
    69  	case "h3":
    70  		return int(h3_em * float64(baseFontSize))
    71  	case "h4":
    72  		return int(h4_em * float64(baseFontSize))
    73  	case "h5":
    74  		return int(h5_em * float64(baseFontSize))
    75  	case "h6":
    76  		return int(h6_em * float64(baseFontSize))
    77  	}
    78  	return 0
    79  }
    80  
    81  func RenderMarkdown(m string) (string, error) {
    82  	var output bytes.Buffer
    83  	if err := markdownRenderer.Convert([]byte(m), &output); err != nil {
    84  		return "", err
    85  	}
    86  	return output.String(), nil
    87  }
    88  
    89  func init() {
    90  	markdownRenderer = goldmark.New(
    91  		goldmark.WithRendererOptions(
    92  			goldmarkHtml.WithUnsafe(),
    93  			goldmarkHtml.WithXHTML(),
    94  		),
    95  		goldmark.WithExtensions(
    96  			extension.Strikethrough,
    97  		),
    98  	)
    99  }
   100  
   101  func MeasureMarkdown(mdText string, ruler *Ruler, fontFamily *d2fonts.FontFamily, fontSize int) (width, height int, err error) {
   102  	render, err := RenderMarkdown(mdText)
   103  	if err != nil {
   104  		return width, height, err
   105  	}
   106  
   107  	doc, err := goquery.NewDocumentFromReader(strings.NewReader(render))
   108  	if err != nil {
   109  		return width, height, err
   110  	}
   111  
   112  	{
   113  		originalLineHeight := ruler.LineHeightFactor
   114  		ruler.boundsWithDot = true
   115  		ruler.LineHeightFactor = MarkdownLineHeight
   116  		defer func() {
   117  			ruler.LineHeightFactor = originalLineHeight
   118  			ruler.boundsWithDot = false
   119  		}()
   120  	}
   121  
   122  	// TODO consider setting a max width + (manual) text wrapping
   123  	bodyNode := doc.Find("body").First().Nodes[0]
   124  	bodyAttrs := ruler.measureNode(0, bodyNode, fontFamily, fontSize, d2fonts.FONT_STYLE_REGULAR)
   125  
   126  	return int(math.Ceil(bodyAttrs.width)), int(math.Ceil(bodyAttrs.height)), nil
   127  }
   128  
   129  func hasPrev(n *html.Node) bool {
   130  	if n.PrevSibling == nil {
   131  		return false
   132  	}
   133  	if strings.TrimSpace(n.PrevSibling.Data) == "" {
   134  		return hasPrev(n.PrevSibling)
   135  	}
   136  	return true
   137  }
   138  
   139  func hasNext(n *html.Node) bool {
   140  	if n.NextSibling == nil {
   141  		return false
   142  	}
   143  	// skip over empty text nodes
   144  	if strings.TrimSpace(n.NextSibling.Data) == "" {
   145  		return hasNext(n.NextSibling)
   146  	}
   147  	return true
   148  }
   149  
   150  func getPrev(n *html.Node) *html.Node {
   151  	if n == nil {
   152  		return nil
   153  	}
   154  	if strings.TrimSpace(n.Data) == "" {
   155  		if next := getNext(n.PrevSibling); next != nil {
   156  			return next
   157  		}
   158  	}
   159  	return n
   160  }
   161  
   162  func getNext(n *html.Node) *html.Node {
   163  	if n == nil {
   164  		return nil
   165  	}
   166  	if strings.TrimSpace(n.Data) == "" {
   167  		if next := getNext(n.NextSibling); next != nil {
   168  			return next
   169  		}
   170  	}
   171  	return n
   172  }
   173  
   174  func isBlockElement(elType string) bool {
   175  	switch elType {
   176  	case "blockquote",
   177  		"div",
   178  		"h1", "h2", "h3", "h4", "h5", "h6",
   179  		"hr",
   180  		"li",
   181  		"ol",
   182  		"p",
   183  		"pre",
   184  		"ul":
   185  		return true
   186  	default:
   187  		return false
   188  	}
   189  }
   190  
   191  func hasAncestorElement(n *html.Node, elType string) bool {
   192  	if n.Parent == nil {
   193  		return false
   194  	}
   195  	if n.Parent.Type == html.ElementNode && n.Parent.Data == elType {
   196  		return true
   197  	}
   198  	return hasAncestorElement(n.Parent, elType)
   199  }
   200  
   201  type blockAttrs struct {
   202  	width, height, marginTop, marginBottom float64
   203  }
   204  
   205  func (b *blockAttrs) isNotEmpty() bool {
   206  	return b != nil && *b != blockAttrs{}
   207  }
   208  
   209  // measures node dimensions to match rendering with styles in github-markdown.css
   210  func (ruler *Ruler) measureNode(depth int, n *html.Node, fontFamily *d2fonts.FontFamily, fontSize int, fontStyle d2fonts.FontStyle) blockAttrs {
   211  	if fontFamily == nil {
   212  		fontFamily = go2.Pointer(d2fonts.SourceSansPro)
   213  	}
   214  	font := fontFamily.Font(fontSize, fontStyle)
   215  
   216  	var parentElementType string
   217  	if n.Parent != nil && n.Parent.Type == html.ElementNode {
   218  		parentElementType = n.Parent.Data
   219  	}
   220  
   221  	debugMeasure := false
   222  	var depthStr string
   223  	if debugMeasure {
   224  		if depth == 0 {
   225  			fmt.Println()
   226  		}
   227  		depthStr = "┌"
   228  		for i := 0; i < depth; i++ {
   229  			depthStr += "-"
   230  		}
   231  	}
   232  
   233  	switch n.Type {
   234  	case html.TextNode:
   235  		if strings.Trim(n.Data, "\n\t\b") == "" {
   236  			return blockAttrs{}
   237  		}
   238  		str := n.Data
   239  		isCode := parentElementType == "pre" || parentElementType == "code"
   240  		spaceWidths := 0.
   241  
   242  		if !isCode {
   243  			spaceWidth := ruler.spaceWidth(font)
   244  			// MeasurePrecise will not include leading or trailing whitespace, so we account for it here
   245  			str = strings.ReplaceAll(str, "\n", " ")
   246  			str = strings.ReplaceAll(str, "\t", " ")
   247  			if strings.HasPrefix(str, " ") {
   248  				// consecutive leading/trailing spaces end up rendered as a single space
   249  				str = strings.TrimPrefix(str, " ")
   250  				if hasPrev(n) {
   251  					spaceWidths += spaceWidth
   252  				}
   253  			}
   254  			if strings.HasSuffix(str, " ") {
   255  				str = strings.TrimSuffix(str, " ")
   256  				if hasNext(n) {
   257  					spaceWidths += spaceWidth
   258  				}
   259  			}
   260  		}
   261  
   262  		if parentElementType == "pre" {
   263  			originalLineHeight := ruler.LineHeightFactor
   264  			ruler.LineHeightFactor = LineHeight_pre
   265  			defer func() {
   266  				ruler.LineHeightFactor = originalLineHeight
   267  			}()
   268  		}
   269  		w, h := ruler.MeasurePrecise(font, str)
   270  		if isCode {
   271  			w *= FontSize_pre_code_em
   272  			h *= FontSize_pre_code_em
   273  		} else {
   274  			w = ruler.scaleUnicode(w, font, str)
   275  		}
   276  		if debugMeasure {
   277  			fmt.Printf("%stext(%v,%v)\n", depthStr, w, h)
   278  		}
   279  		return blockAttrs{w + spaceWidths, h, 0, 0}
   280  	case html.ElementNode:
   281  		isCode := false
   282  		switch n.Data {
   283  		case "h1", "h2", "h3", "h4", "h5", "h6":
   284  			fontSize = HeaderToFontSize(fontSize, n.Data)
   285  			fontStyle = d2fonts.FONT_STYLE_SEMIBOLD
   286  			originalLineHeight := ruler.LineHeightFactor
   287  			ruler.LineHeightFactor = LineHeight_h
   288  			defer func() {
   289  				ruler.LineHeightFactor = originalLineHeight
   290  			}()
   291  		case "em":
   292  			fontStyle = d2fonts.FONT_STYLE_ITALIC
   293  		case "b", "strong":
   294  			fontStyle = d2fonts.FONT_STYLE_BOLD
   295  		case "pre", "code":
   296  			fontFamily = go2.Pointer(d2fonts.SourceCodePro)
   297  			fontStyle = d2fonts.FONT_STYLE_REGULAR
   298  			isCode = true
   299  		}
   300  
   301  		block := blockAttrs{}
   302  		lineHeightPx := float64(fontSize) * ruler.LineHeightFactor
   303  
   304  		if n.FirstChild != nil {
   305  			first := getNext(n.FirstChild)
   306  			last := getPrev(n.LastChild)
   307  
   308  			var blocks []blockAttrs
   309  			var inlineBlock *blockAttrs
   310  			// first create blocks from combined inline elements, then combine all blocks
   311  			// inlineBlock will be non-nil while inline elements are being combined into a block
   312  			endInlineBlock := func() {
   313  				if !isCode && inlineBlock.height > 0 && inlineBlock.height < lineHeightPx {
   314  					inlineBlock.height = lineHeightPx
   315  				}
   316  				blocks = append(blocks, *inlineBlock)
   317  				inlineBlock = nil
   318  			}
   319  			for child := n.FirstChild; child != nil; child = child.NextSibling {
   320  				childBlock := ruler.measureNode(depth+1, child, fontFamily, fontSize, fontStyle)
   321  
   322  				if child.Type == html.ElementNode && isBlockElement(child.Data) {
   323  					if inlineBlock != nil {
   324  						endInlineBlock()
   325  					}
   326  					newBlock := &blockAttrs{}
   327  					newBlock.width = childBlock.width
   328  					newBlock.height = childBlock.height
   329  					if child == first && n.Data == "blockquote" {
   330  						newBlock.marginTop = 0.
   331  					} else {
   332  						newBlock.marginTop = childBlock.marginTop
   333  					}
   334  					if child == last && n.Data == "blockquote" {
   335  						newBlock.marginBottom = 0.
   336  					} else {
   337  						newBlock.marginBottom = childBlock.marginBottom
   338  					}
   339  
   340  					blocks = append(blocks, *newBlock)
   341  				} else if child.Type == html.ElementNode && child.Data == "br" {
   342  					if inlineBlock != nil {
   343  						endInlineBlock()
   344  					} else {
   345  						block.height += lineHeightPx
   346  					}
   347  				} else if childBlock.isNotEmpty() {
   348  					if inlineBlock == nil {
   349  						// start inline block with child
   350  						inlineBlock = &childBlock
   351  					} else {
   352  						// stack inline element dimensions horizontally
   353  						inlineBlock.width += childBlock.width
   354  						inlineBlock.height = go2.Max(inlineBlock.height, childBlock.height)
   355  
   356  						inlineBlock.marginTop = go2.Max(inlineBlock.marginTop, childBlock.marginTop)
   357  						inlineBlock.marginBottom = go2.Max(inlineBlock.marginBottom, childBlock.marginBottom)
   358  					}
   359  				}
   360  			}
   361  			if inlineBlock != nil {
   362  				endInlineBlock()
   363  			}
   364  
   365  			var prevMarginBottom float64
   366  			for i, b := range blocks {
   367  				if i == 0 {
   368  					block.marginTop = go2.Max(block.marginTop, b.marginTop)
   369  				} else {
   370  					marginDiff := b.marginTop - prevMarginBottom
   371  					if marginDiff > 0 {
   372  						block.height += marginDiff
   373  					}
   374  				}
   375  				if i == len(blocks)-1 {
   376  					block.marginBottom = go2.Max(block.marginBottom, b.marginBottom)
   377  				} else {
   378  					block.height += b.marginBottom
   379  					prevMarginBottom = b.marginBottom
   380  				}
   381  
   382  				block.height += b.height
   383  				block.width = go2.Max(block.width, b.width)
   384  			}
   385  		}
   386  
   387  		switch n.Data {
   388  		case "blockquote":
   389  			block.width += (2*PaddingLR_blockquote_em + BorderLeft_blockquote_em) * float64(fontSize)
   390  			block.marginBottom = go2.Max(block.marginBottom, MarginBottom_blockquote)
   391  		case "p":
   392  			if parentElementType == "li" {
   393  				block.marginTop = go2.Max(block.marginTop, MarginTop_li_p)
   394  			}
   395  			block.marginBottom = go2.Max(block.marginBottom, MarginBottom_p)
   396  		case "h1", "h2", "h3", "h4", "h5", "h6":
   397  			block.marginTop = go2.Max(block.marginTop, MarginTop_h)
   398  			block.marginBottom = go2.Max(block.marginBottom, MarginBottom_h)
   399  			switch n.Data {
   400  			case "h1", "h2":
   401  				block.height += PaddingBottom_h1_h2_em*float64(fontSize) + BorderBottom_h1_h2
   402  			}
   403  		case "li":
   404  			block.width += PaddingLeft_ul_ol_em * float64(fontSize)
   405  			if hasPrev(n) {
   406  				block.marginTop = go2.Max(block.marginTop, MarginTop_li_em*float64(fontSize))
   407  			}
   408  		case "ol", "ul":
   409  			if hasAncestorElement(n, "ul") || hasAncestorElement(n, "ol") {
   410  				block.marginTop = 0
   411  				block.marginBottom = 0
   412  			} else {
   413  				block.marginBottom = go2.Max(block.marginBottom, MarginBottom_ul)
   414  			}
   415  		case "pre":
   416  			block.width += 2 * Padding_pre
   417  			block.height += 2 * Padding_pre
   418  			block.marginBottom = go2.Max(block.marginBottom, MarginBottom_pre)
   419  		case "code":
   420  			if parentElementType != "pre" {
   421  				block.width += 2 * PaddingLeftRight_code_em * float64(fontSize)
   422  				block.height += 2 * PaddingTopBottom_code_em * float64(fontSize)
   423  			}
   424  		case "hr":
   425  			block.height += Height_hr_em * float64(fontSize)
   426  			block.marginTop = go2.Max(block.marginTop, MarginTopBottom_hr)
   427  			block.marginBottom = go2.Max(block.marginBottom, MarginTopBottom_hr)
   428  		}
   429  		if block.height > 0 && block.height < lineHeightPx {
   430  			block.height = lineHeightPx
   431  		}
   432  		if debugMeasure {
   433  			fmt.Printf("%s%s(%v,%v) mt:%v mb:%v\n", depthStr, n.Data, block.width, block.height, block.marginTop, block.marginBottom)
   434  		}
   435  		return block
   436  	}
   437  	return blockAttrs{}
   438  }
   439  

View as plain text