1 package d2cli
2
3 import (
4 "context"
5 "embed"
6 _ "embed"
7 "errors"
8 "fmt"
9 "io/fs"
10 "net"
11 "net/http"
12 "os"
13 "path/filepath"
14 "runtime"
15 "sort"
16 "strings"
17 "sync"
18 "time"
19
20 "github.com/fsnotify/fsnotify"
21 "nhooyr.io/websocket"
22 "nhooyr.io/websocket/wsjson"
23
24 "oss.terrastruct.com/util-go/xbrowser"
25
26 "oss.terrastruct.com/util-go/xhttp"
27
28 "oss.terrastruct.com/util-go/xmain"
29
30 "oss.terrastruct.com/d2/d2plugin"
31 "oss.terrastruct.com/d2/d2renderers/d2fonts"
32 "oss.terrastruct.com/d2/d2renderers/d2svg"
33 "oss.terrastruct.com/d2/lib/png"
34 )
35
36
37
38
39
40 var devMode = false
41
42
43 var staticFS embed.FS
44
45 type watcherOpts struct {
46 layout *string
47 plugins []d2plugin.Plugin
48 renderOpts d2svg.RenderOpts
49 animateInterval int64
50 host string
51 port string
52 inputPath string
53 outputPath string
54 boardPath string
55 pwd string
56 bundle bool
57 forceAppendix bool
58 pw png.Playwright
59 fontFamily *d2fonts.FontFamily
60 }
61
62 type watcher struct {
63 ctx context.Context
64 cancel context.CancelFunc
65 wg sync.WaitGroup
66 devMode bool
67
68 ms *xmain.State
69 watcherOpts
70
71 compileCh chan struct{}
72
73 fw *fsnotify.Watcher
74 l net.Listener
75 staticFileServer http.Handler
76
77 boardpathMu sync.Mutex
78 wsclientsMu sync.Mutex
79 closing bool
80 wsclientsWG sync.WaitGroup
81 wsclients map[*wsclient]struct{}
82
83 errMu sync.Mutex
84 err error
85
86 resMu sync.Mutex
87 res *compileResult
88 }
89
90 type compileResult struct {
91 SVG string `json:"svg"`
92 Scale *float64 `json:"scale,omitEmpty"`
93 Err string `json:"err"`
94 }
95
96 func newWatcher(ctx context.Context, ms *xmain.State, opts watcherOpts) (*watcher, error) {
97 ctx, cancel := context.WithCancel(ctx)
98
99 w := &watcher{
100 ctx: ctx,
101 cancel: cancel,
102 devMode: devMode,
103
104 ms: ms,
105 watcherOpts: opts,
106
107 compileCh: make(chan struct{}, 1),
108 wsclients: make(map[*wsclient]struct{}),
109 }
110 err := w.init()
111 if err != nil {
112 return nil, err
113 }
114 return w, nil
115 }
116
117 func (w *watcher) init() error {
118 fw, err := fsnotify.NewWatcher()
119 if err != nil {
120 return err
121 }
122 w.fw = fw
123 err = w.initStaticFileServer()
124 if err != nil {
125 return err
126 }
127 return w.listen()
128 }
129
130 func (w *watcher) initStaticFileServer() error {
131
132 if w.devMode {
133 _, file, _, ok := runtime.Caller(0)
134 if !ok {
135 return errors.New("d2: runtime failed to provide path of watch.go")
136 }
137
138 staticFilesDir := filepath.Join(filepath.Dir(file), "./static")
139 w.staticFileServer = http.FileServer(http.Dir(staticFilesDir))
140 return nil
141 }
142
143 sfs, err := fs.Sub(staticFS, "static")
144 if err != nil {
145 return err
146 }
147 w.staticFileServer = http.FileServer(http.FS(sfs))
148 return nil
149 }
150
151 func (w *watcher) run() error {
152 defer w.close()
153
154 w.goFunc(w.watchLoop)
155 w.goFunc(w.compileLoop)
156
157 err := w.goServe()
158 if err != nil {
159 return err
160 }
161
162 w.wg.Wait()
163 w.close()
164 return w.err
165 }
166
167 func (w *watcher) close() {
168 w.wsclientsMu.Lock()
169 if w.closing {
170 w.wsclientsMu.Unlock()
171 return
172 }
173 w.closing = true
174 w.wsclientsMu.Unlock()
175
176 w.cancel()
177 if w.fw != nil {
178 err := w.fw.Close()
179 w.setErr(err)
180 }
181 if w.l != nil {
182 err := w.l.Close()
183 w.setErr(err)
184 }
185
186 w.wsclientsWG.Wait()
187 }
188
189 func (w *watcher) setErr(err error) {
190 w.errMu.Lock()
191 if w.err == nil {
192 w.err = err
193 }
194 w.errMu.Unlock()
195 }
196
197 func (w *watcher) goFunc(fn func(context.Context) error) {
198 w.wg.Add(1)
199 go func() {
200 defer w.wg.Done()
201 defer w.cancel()
202
203 err := fn(w.ctx)
204 w.setErr(err)
205 }()
206 }
207
208
222 func (w *watcher) watchLoop(ctx context.Context) error {
223 lastModified := make(map[string]time.Time)
224
225 mt, err := w.ensureAddWatch(ctx, w.inputPath)
226 if err != nil {
227 return err
228 }
229 lastModified[w.inputPath] = mt
230 w.ms.Log.Info.Printf("compiling %v...", w.ms.HumanPath(w.inputPath))
231 w.requestCompile()
232
233 eatBurstTimer := time.NewTimer(0)
234 <-eatBurstTimer.C
235 pollTicker := time.NewTicker(time.Second * 10)
236 defer pollTicker.Stop()
237
238 changed := make(map[string]struct{})
239
240 for {
241 select {
242 case <-pollTicker.C:
243
244
245
246
247 missedChanges := false
248 for _, watched := range w.fw.WatchList() {
249 mt, err := w.ensureAddWatch(ctx, watched)
250 if err != nil {
251 return err
252 }
253 if mt2, ok := lastModified[watched]; !ok || !mt.Equal(mt2) {
254 missedChanges = true
255 lastModified[watched] = mt
256 }
257 }
258 if missedChanges {
259 w.requestCompile()
260 }
261 case ev, ok := <-w.fw.Events:
262 if !ok {
263 return errors.New("fsnotify watcher closed")
264 }
265 w.ms.Log.Debug.Printf("received file system event %v", ev)
266 mt, err := w.ensureAddWatch(ctx, ev.Name)
267 if err != nil {
268 return err
269 }
270 if ev.Op == fsnotify.Chmod {
271 if mt.Equal(lastModified[ev.Name]) {
272
273
274 continue
275 }
276
277 lastModified[ev.Name] = mt
278 }
279 changed[ev.Name] = struct{}{}
280
281
282
283
284
285
286
287
288
289
290 eatBurstTimer.Reset(time.Millisecond * 16)
291 case <-eatBurstTimer.C:
292 var changedList []string
293 for k := range changed {
294 changedList = append(changedList, k)
295 delete(changed, k)
296 }
297 sort.Strings(changedList)
298 changedStr := w.ms.HumanPath(changedList[0])
299 for i := 1; i < len(changedList); i++ {
300 changedStr += fmt.Sprintf(", %s", w.ms.HumanPath(changedList[i]))
301 }
302 w.ms.Log.Info.Printf("detected change in %s: recompiling...", changedStr)
303 w.requestCompile()
304 case err, ok := <-w.fw.Errors:
305 if !ok {
306 return errors.New("fsnotify watcher closed")
307 }
308 w.ms.Log.Error.Printf("fsnotify error: %v", err)
309 case <-ctx.Done():
310 return ctx.Err()
311 }
312 }
313 }
314
315 func (w *watcher) requestCompile() {
316 select {
317 case w.compileCh <- struct{}{}:
318 default:
319 }
320 }
321
322 func (w *watcher) ensureAddWatch(ctx context.Context, path string) (time.Time, error) {
323 interval := time.Millisecond * 16
324 tc := time.NewTimer(0)
325 <-tc.C
326 for {
327 mt, err := w.addWatch(ctx, path)
328 if err == nil {
329 return mt, nil
330 }
331 if interval >= time.Second {
332 w.ms.Log.Error.Printf("failed to watch %q: %v (retrying in %v)", w.ms.HumanPath(path), err, interval)
333 }
334
335 tc.Reset(interval)
336 select {
337 case <-tc.C:
338 if interval < time.Second {
339 interval = time.Second
340 }
341 if interval < time.Second*16 {
342 interval *= 2
343 }
344 case <-ctx.Done():
345 return time.Time{}, ctx.Err()
346 }
347 }
348 }
349
350 func (w *watcher) addWatch(ctx context.Context, path string) (time.Time, error) {
351 err := w.fw.Add(path)
352 if err != nil {
353 return time.Time{}, err
354 }
355 var d os.FileInfo
356 d, err = os.Stat(path)
357 if err != nil {
358 return time.Time{}, err
359 }
360 return d.ModTime(), nil
361 }
362
363 func (w *watcher) replaceWatchList(ctx context.Context, paths []string) error {
364
365 for _, watched := range w.fw.WatchList() {
366 if watched == w.inputPath {
367 continue
368 }
369 found := false
370 for _, p := range paths {
371 if watched == p {
372 found = true
373 break
374 }
375 }
376 if !found {
377
378 w.fw.Remove(watched)
379 }
380 }
381
382 for _, p := range paths {
383 found := false
384 for _, watched := range w.fw.WatchList() {
385 if watched == p {
386 found = true
387 break
388 }
389 }
390 if !found {
391 _, err := w.ensureAddWatch(ctx, p)
392 if err != nil {
393 return err
394 }
395 }
396 }
397 return nil
398 }
399
400 func (w *watcher) compileLoop(ctx context.Context) error {
401 firstCompile := true
402 for {
403 select {
404 case <-w.compileCh:
405 case <-ctx.Done():
406 return ctx.Err()
407 }
408
409 recompiledPrefix := ""
410 if !firstCompile {
411 recompiledPrefix = "re"
412 }
413
414 if (filepath.Ext(w.outputPath) == ".png" || filepath.Ext(w.outputPath) == ".pdf") && !w.pw.Browser.IsConnected() {
415 newPW, err := w.pw.RestartBrowser()
416 if err != nil {
417 broadcastErr := fmt.Errorf("issue encountered with PNG exporter: %w", err)
418 w.ms.Log.Error.Print(broadcastErr)
419 w.broadcast(&compileResult{
420 Err: broadcastErr.Error(),
421 })
422 continue
423 }
424 w.pw = newPW
425 }
426
427 fs := trackedFS{}
428 w.boardpathMu.Lock()
429 var boardPath []string
430 if w.boardPath != "" {
431 boardPath = strings.Split(w.boardPath, string(os.PathSeparator))
432 }
433 svg, _, err := compile(ctx, w.ms, w.plugins, &fs, w.layout, w.renderOpts, w.fontFamily, w.animateInterval, w.inputPath, w.outputPath, boardPath, false, w.bundle, w.forceAppendix, w.pw.Page)
434 w.boardpathMu.Unlock()
435 errs := ""
436 if err != nil {
437 if len(svg) > 0 {
438 err = fmt.Errorf("failed to fully %scompile (rendering partial svg): %w", recompiledPrefix, err)
439 } else {
440 err = fmt.Errorf("failed to %scompile: %w", recompiledPrefix, err)
441 }
442 errs = err.Error()
443 w.ms.Log.Error.Print(errs)
444 }
445 err = w.replaceWatchList(ctx, fs.opened)
446 if err != nil {
447 return err
448 }
449
450 w.broadcast(&compileResult{
451 SVG: string(svg),
452 Scale: w.renderOpts.Scale,
453 Err: errs,
454 })
455
456 if firstCompile {
457 firstCompile = false
458 url := fmt.Sprintf("http://%s", w.l.Addr())
459 err = xbrowser.Open(ctx, w.ms.Env, url)
460 if err != nil {
461 w.ms.Log.Warn.Printf("failed to open browser to %v: %v", url, err)
462 }
463 }
464 }
465 }
466
467 func (w *watcher) listen() error {
468 l, err := net.Listen("tcp", net.JoinHostPort(w.host, w.port))
469 if err != nil {
470 return err
471 }
472 w.l = l
473 w.ms.Log.Success.Printf("listening on http://%v", w.l.Addr())
474 return nil
475 }
476
477 func (w *watcher) goServe() error {
478 m := http.NewServeMux()
479
480
481 m.HandleFunc("/", w.handleRoot)
482 m.Handle("/static/", http.StripPrefix("/static", w.staticFileServer))
483 m.Handle("/watch", xhttp.HandlerFuncAdapter{Log: w.ms.Log, Func: w.handleWatch})
484
485 s := xhttp.NewServer(w.ms.Log.Warn, xhttp.Log(w.ms.Log, m))
486 w.goFunc(func(ctx context.Context) error {
487 return xhttp.Serve(ctx, time.Second*30, s, w.l)
488 })
489
490 return nil
491 }
492
493 func (w *watcher) getRes() *compileResult {
494 w.resMu.Lock()
495 defer w.resMu.Unlock()
496 return w.res
497 }
498
499 func (w *watcher) handleRoot(hw http.ResponseWriter, r *http.Request) {
500 hw.Header().Set("Content-Type", "text/html; charset=utf-8")
501 fmt.Fprintf(hw, `<!DOCTYPE html>
502 <html lang="en">
503 <head>
504 <meta charset="UTF-8">
505 <meta name="viewport" content="width=device-width, initial-scale=1.0">
506 <title>%s</title>
507 <script src="/static/watch.js"></script>
508 <link rel="stylesheet" href="/static/watch.css">
509 <link id="favicon" rel="icon" href="/static/favicon.ico">
510 </head>
511 <body data-d2-dev-mode=%t>
512 <div id="d2-err" style="display: none"></div>
513 <div id="d2-svg-container"></div>
514 </body>
515 </html>`, filepath.Base(w.outputPath), w.devMode)
516
517 w.boardpathMu.Lock()
518
519 boardPath := strings.TrimPrefix(r.URL.Path, "/")
520 if idx := strings.LastIndexByte(boardPath, '.'); idx != -1 {
521 boardPath = boardPath[:idx]
522 }
523 recompile := false
524 if boardPath != w.boardPath {
525 w.boardPath = boardPath
526 recompile = true
527 }
528 w.boardpathMu.Unlock()
529 if recompile {
530 w.requestCompile()
531 }
532 }
533
534 func (w *watcher) handleWatch(hw http.ResponseWriter, r *http.Request) error {
535 w.wsclientsMu.Lock()
536 if w.closing {
537 w.wsclientsMu.Unlock()
538 return xhttp.Errorf(http.StatusServiceUnavailable, "server shutting down...", "server shutting down...")
539 }
540
541
542
543
544 w.wsclientsWG.Add(1)
545 w.wsclientsMu.Unlock()
546
547 c, err := websocket.Accept(hw, r, &websocket.AcceptOptions{
548 CompressionMode: websocket.CompressionDisabled,
549 })
550 if err != nil {
551 w.wsclientsWG.Done()
552 return err
553 }
554
555 go func() {
556 defer w.wsclientsWG.Done()
557 defer c.Close(websocket.StatusInternalError, "the sky is falling")
558
559 ctx, cancel := context.WithTimeout(w.ctx, time.Hour)
560 defer cancel()
561
562 cl := &wsclient{
563 w: w,
564 resultsCh: make(chan struct{}, 1),
565 c: c,
566 }
567
568 w.wsclientsMu.Lock()
569 w.wsclients[cl] = struct{}{}
570 w.wsclientsMu.Unlock()
571 defer func() {
572 w.wsclientsMu.Lock()
573 delete(w.wsclients, cl)
574 w.wsclientsMu.Unlock()
575 }()
576
577 ctx = cl.c.CloseRead(ctx)
578 go wsHeartbeat(ctx, cl.c)
579 _ = cl.writeLoop(ctx)
580 }()
581 return nil
582 }
583
584 type wsclient struct {
585 w *watcher
586 resultsCh chan struct{}
587 c *websocket.Conn
588 }
589
590 func (cl *wsclient) writeLoop(ctx context.Context) error {
591 for {
592 res := cl.w.getRes()
593 if res != nil {
594 err := cl.write(ctx, res)
595 if err != nil {
596 return err
597 }
598 }
599
600 select {
601 case <-cl.resultsCh:
602 case <-ctx.Done():
603 cl.c.Close(websocket.StatusGoingAway, "server shutting down...")
604 return ctx.Err()
605 }
606 }
607 }
608
609 func (cl *wsclient) write(ctx context.Context, res *compileResult) error {
610 ctx, cancel := context.WithTimeout(ctx, time.Second*30)
611 defer cancel()
612
613 return wsjson.Write(ctx, cl.c, res)
614 }
615
616 func (w *watcher) broadcast(res *compileResult) {
617 w.resMu.Lock()
618 w.res = res
619 w.resMu.Unlock()
620
621 w.wsclientsMu.Lock()
622 defer w.wsclientsMu.Unlock()
623 clientsSuffix := ""
624 if len(w.wsclients) != 1 {
625 clientsSuffix = "s"
626 }
627 w.ms.Log.Info.Printf("broadcasting update to %d client%s", len(w.wsclients), clientsSuffix)
628 for cl := range w.wsclients {
629 select {
630 case cl.resultsCh <- struct{}{}:
631 default:
632 }
633 }
634 }
635
636 func wsHeartbeat(ctx context.Context, c *websocket.Conn) {
637 defer c.Close(websocket.StatusInternalError, "the sky is falling")
638
639 t := time.NewTimer(0)
640 <-t.C
641 for {
642 err := c.Ping(ctx)
643 if err != nil {
644 return
645 }
646
647 t.Reset(time.Second * 30)
648 select {
649 case <-t.C:
650 case <-ctx.Done():
651 return
652 }
653 }
654 }
655
656
657 type trackedFS struct {
658 opened []string
659 }
660
661 func (tfs *trackedFS) Open(name string) (fs.File, error) {
662 f, err := os.Open(name)
663 if err == nil {
664 tfs.opened = append(tfs.opened, name)
665 }
666 return f, err
667 }
668
View as plain text