...

Source file src/oss.terrastruct.com/d2/lib/imgbundler/imgbundler.go

Documentation: oss.terrastruct.com/d2/lib/imgbundler

     1  package imgbundler
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/base64"
     7  	"fmt"
     8  	"html"
     9  	"io/ioutil"
    10  	"mime"
    11  	"net/http"
    12  	"net/url"
    13  	"os"
    14  	"path"
    15  	"regexp"
    16  	"strings"
    17  	"sync"
    18  	"time"
    19  
    20  	"golang.org/x/xerrors"
    21  
    22  	"oss.terrastruct.com/d2/lib/simplelog"
    23  	"oss.terrastruct.com/util-go/xdefer"
    24  )
    25  
    26  var imgCache sync.Map
    27  
    28  const maxImageSize int64 = 1 << 25 // 33_554_432
    29  
    30  var imageRegex = regexp.MustCompile(`<image href="([^"]+)"`)
    31  
    32  func BundleLocal(ctx context.Context, l simplelog.Logger, in []byte, cacheImages bool) ([]byte, error) {
    33  	return bundle(ctx, l, in, false, cacheImages)
    34  }
    35  
    36  func BundleRemote(ctx context.Context, l simplelog.Logger, in []byte, cacheImages bool) ([]byte, error) {
    37  	return bundle(ctx, l, in, true, cacheImages)
    38  }
    39  
    40  type repl struct {
    41  	from []byte
    42  	to   []byte
    43  }
    44  
    45  func bundle(ctx context.Context, l simplelog.Logger, svg []byte, isRemote, cacheImages bool) (_ []byte, err error) {
    46  	if isRemote {
    47  		defer xdefer.Errorf(&err, "failed to bundle remote images")
    48  	} else {
    49  		defer xdefer.Errorf(&err, "failed to bundle local images")
    50  	}
    51  	imgs := imageRegex.FindAllSubmatch(svg, -1)
    52  	imgs = filterImageElements(imgs, isRemote)
    53  
    54  	ctx, cancel := context.WithTimeout(ctx, time.Minute*5)
    55  	defer cancel()
    56  
    57  	return runWorkers(ctx, l, svg, imgs, isRemote, cacheImages)
    58  }
    59  
    60  // filterImageElements finds all unique image elements in imgs that are
    61  // eligible for bundling in the current context.
    62  func filterImageElements(imgs [][][]byte, isRemote bool) [][][]byte {
    63  	unq := make(map[string]struct{})
    64  	imgs2 := imgs[:0]
    65  	for _, img := range imgs {
    66  		href := string(img[1])
    67  		if _, ok := unq[href]; ok {
    68  			continue
    69  		}
    70  		unq[href] = struct{}{}
    71  
    72  		// Skip already bundled images.
    73  		if strings.HasPrefix(href, "data:") {
    74  			continue
    75  		}
    76  
    77  		u, err := url.Parse(html.UnescapeString(href))
    78  		isRemoteImg := err == nil && strings.HasPrefix(u.Scheme, "http")
    79  
    80  		if isRemoteImg == isRemote {
    81  			imgs2 = append(imgs2, img)
    82  		}
    83  	}
    84  	return imgs2
    85  }
    86  
    87  func runWorkers(ctx context.Context, l simplelog.Logger, svg []byte, imgs [][][]byte, isRemote, cacheImages bool) (_ []byte, err error) {
    88  	var wg sync.WaitGroup
    89  	replc := make(chan repl)
    90  
    91  	wg.Add(len(imgs))
    92  	go func() {
    93  		wg.Wait()
    94  		close(replc)
    95  	}()
    96  
    97  	// Limits the number of workers to 16.
    98  	sema := make(chan struct{}, 16)
    99  
   100  	var errhrefsMu sync.Mutex
   101  	var errhrefs []string
   102  
   103  	// Start workers as the sema allows.
   104  	go func() {
   105  		for _, img := range imgs {
   106  			img := img
   107  			sema <- struct{}{}
   108  			go func() {
   109  				defer func() {
   110  					wg.Done()
   111  					<-sema
   112  				}()
   113  
   114  				bundledImage, err := worker(ctx, l, img[1], isRemote, cacheImages)
   115  				if err != nil {
   116  					l.Error(fmt.Sprintf("failed to bundle %s: %v", img[1], err))
   117  					errhrefsMu.Lock()
   118  					errhrefs = append(errhrefs, string(img[1]))
   119  					errhrefsMu.Unlock()
   120  					return
   121  				}
   122  				select {
   123  				case <-ctx.Done():
   124  				case replc <- repl{
   125  					from: img[0],
   126  					to:   bundledImage,
   127  				}:
   128  				}
   129  			}()
   130  		}
   131  	}()
   132  
   133  	t := time.NewTicker(time.Second * 5)
   134  	defer t.Stop()
   135  	for {
   136  		select {
   137  		case <-ctx.Done():
   138  			return svg, xerrors.Errorf("failed to wait for workers: %w", ctx.Err())
   139  		case <-t.C:
   140  			l.Info("fetching images...")
   141  		case repl, ok := <-replc:
   142  			if !ok {
   143  				if len(errhrefs) > 0 {
   144  					return svg, xerrors.Errorf("%v", errhrefs)
   145  				}
   146  				return svg, nil
   147  			}
   148  			svg = bytes.Replace(svg, repl.from, repl.to, -1)
   149  		}
   150  	}
   151  }
   152  
   153  func worker(ctx context.Context, l simplelog.Logger, href []byte, isRemote, cacheImages bool) ([]byte, error) {
   154  	if cacheImages {
   155  		if hit, ok := imgCache.Load(string(href)); ok {
   156  			return hit.([]byte), nil
   157  		}
   158  	}
   159  	var buf []byte
   160  	var mimeType string
   161  	var err error
   162  	if isRemote {
   163  		l.Debug(fmt.Sprintf("fetching %s remotely", string(href)))
   164  		buf, mimeType, err = httpGet(ctx, html.UnescapeString(string(href)))
   165  	} else {
   166  		l.Debug(fmt.Sprintf("reading %s from disk", string(href)))
   167  		buf, err = os.ReadFile(html.UnescapeString(string(href)))
   168  	}
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  
   173  	if mimeType == "" {
   174  		mimeType = sniffMimeType(href, buf, isRemote)
   175  	}
   176  	mimeType = strings.Replace(mimeType, "text/xml", "image/svg+xml", 1)
   177  	b64 := base64.StdEncoding.EncodeToString(buf)
   178  
   179  	out := []byte(fmt.Sprintf(`<image href="data:%s;base64,%s"`, mimeType, b64))
   180  	if cacheImages {
   181  		imgCache.Store(string(href), out)
   182  	}
   183  	return out, nil
   184  }
   185  
   186  var httpClient = &http.Client{}
   187  
   188  func httpGet(ctx context.Context, href string) ([]byte, string, error) {
   189  	ctx, cancel := context.WithTimeout(ctx, time.Minute)
   190  	defer cancel()
   191  
   192  	req, err := http.NewRequestWithContext(ctx, "GET", href, nil)
   193  	if err != nil {
   194  		return nil, "", err
   195  	}
   196  
   197  	resp, err := httpClient.Do(req)
   198  	if err != nil {
   199  		return nil, "", err
   200  	}
   201  	defer resp.Body.Close()
   202  	if resp.StatusCode != 200 {
   203  		return nil, "", fmt.Errorf("expected status 200 but got %d %s", resp.StatusCode, resp.Status)
   204  	}
   205  	r := http.MaxBytesReader(nil, resp.Body, maxImageSize)
   206  	buf, err := ioutil.ReadAll(r)
   207  	if err != nil {
   208  		return nil, "", err
   209  	}
   210  	return buf, resp.Header.Get("Content-Type"), nil
   211  }
   212  
   213  // sniffMimeType sniffs the mime type of href based on its file extension and contents.
   214  func sniffMimeType(href, buf []byte, isRemote bool) string {
   215  	p := string(href)
   216  	if isRemote {
   217  		u, err := url.Parse(html.UnescapeString(p))
   218  		if err != nil {
   219  			p = ""
   220  		} else {
   221  			p = u.Path
   222  		}
   223  	}
   224  	mimeType := mime.TypeByExtension(path.Ext(p))
   225  	if mimeType == "" {
   226  		mimeType = http.DetectContentType(buf)
   227  	}
   228  	return mimeType
   229  }
   230  

View as plain text