1 package imgbundler
2
3 import (
4 "bytes"
5 "context"
6 "encoding/base64"
7 "fmt"
8 "html"
9 "io/ioutil"
10 "mime"
11 "net/http"
12 "net/url"
13 "os"
14 "path"
15 "regexp"
16 "strings"
17 "sync"
18 "time"
19
20 "golang.org/x/xerrors"
21
22 "oss.terrastruct.com/d2/lib/simplelog"
23 "oss.terrastruct.com/util-go/xdefer"
24 )
25
26 var imgCache sync.Map
27
28 const maxImageSize int64 = 1 << 25
29
30 var imageRegex = regexp.MustCompile(`<image href="([^"]+)"`)
31
32 func BundleLocal(ctx context.Context, l simplelog.Logger, in []byte, cacheImages bool) ([]byte, error) {
33 return bundle(ctx, l, in, false, cacheImages)
34 }
35
36 func BundleRemote(ctx context.Context, l simplelog.Logger, in []byte, cacheImages bool) ([]byte, error) {
37 return bundle(ctx, l, in, true, cacheImages)
38 }
39
40 type repl struct {
41 from []byte
42 to []byte
43 }
44
45 func bundle(ctx context.Context, l simplelog.Logger, svg []byte, isRemote, cacheImages bool) (_ []byte, err error) {
46 if isRemote {
47 defer xdefer.Errorf(&err, "failed to bundle remote images")
48 } else {
49 defer xdefer.Errorf(&err, "failed to bundle local images")
50 }
51 imgs := imageRegex.FindAllSubmatch(svg, -1)
52 imgs = filterImageElements(imgs, isRemote)
53
54 ctx, cancel := context.WithTimeout(ctx, time.Minute*5)
55 defer cancel()
56
57 return runWorkers(ctx, l, svg, imgs, isRemote, cacheImages)
58 }
59
60
61
62 func filterImageElements(imgs [][][]byte, isRemote bool) [][][]byte {
63 unq := make(map[string]struct{})
64 imgs2 := imgs[:0]
65 for _, img := range imgs {
66 href := string(img[1])
67 if _, ok := unq[href]; ok {
68 continue
69 }
70 unq[href] = struct{}{}
71
72
73 if strings.HasPrefix(href, "data:") {
74 continue
75 }
76
77 u, err := url.Parse(html.UnescapeString(href))
78 isRemoteImg := err == nil && strings.HasPrefix(u.Scheme, "http")
79
80 if isRemoteImg == isRemote {
81 imgs2 = append(imgs2, img)
82 }
83 }
84 return imgs2
85 }
86
87 func runWorkers(ctx context.Context, l simplelog.Logger, svg []byte, imgs [][][]byte, isRemote, cacheImages bool) (_ []byte, err error) {
88 var wg sync.WaitGroup
89 replc := make(chan repl)
90
91 wg.Add(len(imgs))
92 go func() {
93 wg.Wait()
94 close(replc)
95 }()
96
97
98 sema := make(chan struct{}, 16)
99
100 var errhrefsMu sync.Mutex
101 var errhrefs []string
102
103
104 go func() {
105 for _, img := range imgs {
106 img := img
107 sema <- struct{}{}
108 go func() {
109 defer func() {
110 wg.Done()
111 <-sema
112 }()
113
114 bundledImage, err := worker(ctx, l, img[1], isRemote, cacheImages)
115 if err != nil {
116 l.Error(fmt.Sprintf("failed to bundle %s: %v", img[1], err))
117 errhrefsMu.Lock()
118 errhrefs = append(errhrefs, string(img[1]))
119 errhrefsMu.Unlock()
120 return
121 }
122 select {
123 case <-ctx.Done():
124 case replc <- repl{
125 from: img[0],
126 to: bundledImage,
127 }:
128 }
129 }()
130 }
131 }()
132
133 t := time.NewTicker(time.Second * 5)
134 defer t.Stop()
135 for {
136 select {
137 case <-ctx.Done():
138 return svg, xerrors.Errorf("failed to wait for workers: %w", ctx.Err())
139 case <-t.C:
140 l.Info("fetching images...")
141 case repl, ok := <-replc:
142 if !ok {
143 if len(errhrefs) > 0 {
144 return svg, xerrors.Errorf("%v", errhrefs)
145 }
146 return svg, nil
147 }
148 svg = bytes.Replace(svg, repl.from, repl.to, -1)
149 }
150 }
151 }
152
153 func worker(ctx context.Context, l simplelog.Logger, href []byte, isRemote, cacheImages bool) ([]byte, error) {
154 if cacheImages {
155 if hit, ok := imgCache.Load(string(href)); ok {
156 return hit.([]byte), nil
157 }
158 }
159 var buf []byte
160 var mimeType string
161 var err error
162 if isRemote {
163 l.Debug(fmt.Sprintf("fetching %s remotely", string(href)))
164 buf, mimeType, err = httpGet(ctx, html.UnescapeString(string(href)))
165 } else {
166 l.Debug(fmt.Sprintf("reading %s from disk", string(href)))
167 buf, err = os.ReadFile(html.UnescapeString(string(href)))
168 }
169 if err != nil {
170 return nil, err
171 }
172
173 if mimeType == "" {
174 mimeType = sniffMimeType(href, buf, isRemote)
175 }
176 mimeType = strings.Replace(mimeType, "text/xml", "image/svg+xml", 1)
177 b64 := base64.StdEncoding.EncodeToString(buf)
178
179 out := []byte(fmt.Sprintf(`<image href="data:%s;base64,%s"`, mimeType, b64))
180 if cacheImages {
181 imgCache.Store(string(href), out)
182 }
183 return out, nil
184 }
185
186 var httpClient = &http.Client{}
187
188 func httpGet(ctx context.Context, href string) ([]byte, string, error) {
189 ctx, cancel := context.WithTimeout(ctx, time.Minute)
190 defer cancel()
191
192 req, err := http.NewRequestWithContext(ctx, "GET", href, nil)
193 if err != nil {
194 return nil, "", err
195 }
196
197 resp, err := httpClient.Do(req)
198 if err != nil {
199 return nil, "", err
200 }
201 defer resp.Body.Close()
202 if resp.StatusCode != 200 {
203 return nil, "", fmt.Errorf("expected status 200 but got %d %s", resp.StatusCode, resp.Status)
204 }
205 r := http.MaxBytesReader(nil, resp.Body, maxImageSize)
206 buf, err := ioutil.ReadAll(r)
207 if err != nil {
208 return nil, "", err
209 }
210 return buf, resp.Header.Get("Content-Type"), nil
211 }
212
213
214 func sniffMimeType(href, buf []byte, isRemote bool) string {
215 p := string(href)
216 if isRemote {
217 u, err := url.Parse(html.UnescapeString(p))
218 if err != nil {
219 p = ""
220 } else {
221 p = u.Path
222 }
223 }
224 mimeType := mime.TypeByExtension(path.Ext(p))
225 if mimeType == "" {
226 mimeType = http.DetectContentType(buf)
227 }
228 return mimeType
229 }
230
View as plain text