...

Source file src/oss.terrastruct.com/d2/d2cli/watch.go

Documentation: oss.terrastruct.com/d2/d2cli

     1  package d2cli
     2  
     3  import (
     4  	"context"
     5  	"embed"
     6  	_ "embed"
     7  	"errors"
     8  	"fmt"
     9  	"io/fs"
    10  	"net"
    11  	"net/http"
    12  	"os"
    13  	"path/filepath"
    14  	"runtime"
    15  	"sort"
    16  	"strings"
    17  	"sync"
    18  	"time"
    19  
    20  	"github.com/fsnotify/fsnotify"
    21  	"nhooyr.io/websocket"
    22  	"nhooyr.io/websocket/wsjson"
    23  
    24  	"oss.terrastruct.com/util-go/xbrowser"
    25  
    26  	"oss.terrastruct.com/util-go/xhttp"
    27  
    28  	"oss.terrastruct.com/util-go/xmain"
    29  
    30  	"oss.terrastruct.com/d2/d2plugin"
    31  	"oss.terrastruct.com/d2/d2renderers/d2fonts"
    32  	"oss.terrastruct.com/d2/d2renderers/d2svg"
    33  	"oss.terrastruct.com/d2/lib/png"
    34  )
    35  
    36  // Enabled with the build tag "dev".
    37  // See watch_dev.go
    38  // Controls whether the embedded staticFS is used or if files are served directly from the
    39  // file system. Useful for quick iteration in development.
    40  var devMode = false
    41  
    42  //go:embed static
    43  var staticFS embed.FS
    44  
    45  type watcherOpts struct {
    46  	layout          *string
    47  	plugins         []d2plugin.Plugin
    48  	renderOpts      d2svg.RenderOpts
    49  	animateInterval int64
    50  	host            string
    51  	port            string
    52  	inputPath       string
    53  	outputPath      string
    54  	boardPath       string
    55  	pwd             string
    56  	bundle          bool
    57  	forceAppendix   bool
    58  	pw              png.Playwright
    59  	fontFamily      *d2fonts.FontFamily
    60  }
    61  
    62  type watcher struct {
    63  	ctx     context.Context
    64  	cancel  context.CancelFunc
    65  	wg      sync.WaitGroup
    66  	devMode bool
    67  
    68  	ms *xmain.State
    69  	watcherOpts
    70  
    71  	compileCh chan struct{}
    72  
    73  	fw               *fsnotify.Watcher
    74  	l                net.Listener
    75  	staticFileServer http.Handler
    76  
    77  	boardpathMu sync.Mutex
    78  	wsclientsMu sync.Mutex
    79  	closing     bool
    80  	wsclientsWG sync.WaitGroup
    81  	wsclients   map[*wsclient]struct{}
    82  
    83  	errMu sync.Mutex
    84  	err   error
    85  
    86  	resMu sync.Mutex
    87  	res   *compileResult
    88  }
    89  
    90  type compileResult struct {
    91  	SVG   string   `json:"svg"`
    92  	Scale *float64 `json:"scale,omitEmpty"`
    93  	Err   string   `json:"err"`
    94  }
    95  
    96  func newWatcher(ctx context.Context, ms *xmain.State, opts watcherOpts) (*watcher, error) {
    97  	ctx, cancel := context.WithCancel(ctx)
    98  
    99  	w := &watcher{
   100  		ctx:     ctx,
   101  		cancel:  cancel,
   102  		devMode: devMode,
   103  
   104  		ms:          ms,
   105  		watcherOpts: opts,
   106  
   107  		compileCh: make(chan struct{}, 1),
   108  		wsclients: make(map[*wsclient]struct{}),
   109  	}
   110  	err := w.init()
   111  	if err != nil {
   112  		return nil, err
   113  	}
   114  	return w, nil
   115  }
   116  
   117  func (w *watcher) init() error {
   118  	fw, err := fsnotify.NewWatcher()
   119  	if err != nil {
   120  		return err
   121  	}
   122  	w.fw = fw
   123  	err = w.initStaticFileServer()
   124  	if err != nil {
   125  		return err
   126  	}
   127  	return w.listen()
   128  }
   129  
   130  func (w *watcher) initStaticFileServer() error {
   131  	// Serve files directly in dev mode for fast iteration.
   132  	if w.devMode {
   133  		_, file, _, ok := runtime.Caller(0)
   134  		if !ok {
   135  			return errors.New("d2: runtime failed to provide path of watch.go")
   136  		}
   137  
   138  		staticFilesDir := filepath.Join(filepath.Dir(file), "./static")
   139  		w.staticFileServer = http.FileServer(http.Dir(staticFilesDir))
   140  		return nil
   141  	}
   142  
   143  	sfs, err := fs.Sub(staticFS, "static")
   144  	if err != nil {
   145  		return err
   146  	}
   147  	w.staticFileServer = http.FileServer(http.FS(sfs))
   148  	return nil
   149  }
   150  
   151  func (w *watcher) run() error {
   152  	defer w.close()
   153  
   154  	w.goFunc(w.watchLoop)
   155  	w.goFunc(w.compileLoop)
   156  
   157  	err := w.goServe()
   158  	if err != nil {
   159  		return err
   160  	}
   161  
   162  	w.wg.Wait()
   163  	w.close()
   164  	return w.err
   165  }
   166  
   167  func (w *watcher) close() {
   168  	w.wsclientsMu.Lock()
   169  	if w.closing {
   170  		w.wsclientsMu.Unlock()
   171  		return
   172  	}
   173  	w.closing = true
   174  	w.wsclientsMu.Unlock()
   175  
   176  	w.cancel()
   177  	if w.fw != nil {
   178  		err := w.fw.Close()
   179  		w.setErr(err)
   180  	}
   181  	if w.l != nil {
   182  		err := w.l.Close()
   183  		w.setErr(err)
   184  	}
   185  
   186  	w.wsclientsWG.Wait()
   187  }
   188  
   189  func (w *watcher) setErr(err error) {
   190  	w.errMu.Lock()
   191  	if w.err == nil {
   192  		w.err = err
   193  	}
   194  	w.errMu.Unlock()
   195  }
   196  
   197  func (w *watcher) goFunc(fn func(context.Context) error) {
   198  	w.wg.Add(1)
   199  	go func() {
   200  		defer w.wg.Done()
   201  		defer w.cancel()
   202  
   203  		err := fn(w.ctx)
   204  		w.setErr(err)
   205  	}()
   206  }
   207  
   208  /*
   209   * IMPORTANT
   210   *
   211   * Do not touch watchLoop or ensureAddWatch without consulting @nhooyr
   212   * fsnotify and file system watching APIs in general are notoriously hard
   213   * to use correctly.
   214   *
   215   * This issue is a good summary though it too contains confusion and misunderstandings:
   216   *   https://github.com/fsnotify/fsnotify/issues/372
   217   *
   218   * The code was thoroughly considered and experimentally vetted.
   219   *
   220   * TODO: Abstract out file system and fsnotify to test this with 100% coverage. See comment in main_test.go
   221   */
   222  func (w *watcher) watchLoop(ctx context.Context) error {
   223  	lastModified := make(map[string]time.Time)
   224  
   225  	mt, err := w.ensureAddWatch(ctx, w.inputPath)
   226  	if err != nil {
   227  		return err
   228  	}
   229  	lastModified[w.inputPath] = mt
   230  	w.ms.Log.Info.Printf("compiling %v...", w.ms.HumanPath(w.inputPath))
   231  	w.requestCompile()
   232  
   233  	eatBurstTimer := time.NewTimer(0)
   234  	<-eatBurstTimer.C
   235  	pollTicker := time.NewTicker(time.Second * 10)
   236  	defer pollTicker.Stop()
   237  
   238  	changed := make(map[string]struct{})
   239  
   240  	for {
   241  		select {
   242  		case <-pollTicker.C:
   243  			// In case we missed an event indicating the path is unwatchable and we won't be
   244  			// getting any more events.
   245  			// File notification APIs are notoriously unreliable. I've personally experienced
   246  			// many quirks and so feel this check is justified even if excessive.
   247  			missedChanges := false
   248  			for _, watched := range w.fw.WatchList() {
   249  				mt, err := w.ensureAddWatch(ctx, watched)
   250  				if err != nil {
   251  					return err
   252  				}
   253  				if mt2, ok := lastModified[watched]; !ok || !mt.Equal(mt2) {
   254  					missedChanges = true
   255  					lastModified[watched] = mt
   256  				}
   257  			}
   258  			if missedChanges {
   259  				w.requestCompile()
   260  			}
   261  		case ev, ok := <-w.fw.Events:
   262  			if !ok {
   263  				return errors.New("fsnotify watcher closed")
   264  			}
   265  			w.ms.Log.Debug.Printf("received file system event %v", ev)
   266  			mt, err := w.ensureAddWatch(ctx, ev.Name)
   267  			if err != nil {
   268  				return err
   269  			}
   270  			if ev.Op == fsnotify.Chmod {
   271  				if mt.Equal(lastModified[ev.Name]) {
   272  					// Benign Chmod.
   273  					// See https://github.com/fsnotify/fsnotify/issues/15
   274  					continue
   275  				}
   276  				// We missed changes.
   277  				lastModified[ev.Name] = mt
   278  			}
   279  			changed[ev.Name] = struct{}{}
   280  			// The purpose of eatBurstTimer is to wait at least 16 milliseconds after a sequence of
   281  			// events to ensure that whomever is editing the file is now done.
   282  			//
   283  			// For example, On macOS editing with neovim, every write I see a chmod immediately
   284  			// followed by a write followed by another chmod. We don't want the three events to
   285  			// be treated as two or three compilations, we want them to be batched into one.
   286  			//
   287  			// Another example would be a very large file where one logical edit becomes write
   288  			// events. We wouldn't want to try to compile an incomplete file and then report a
   289  			// misleading error.
   290  			eatBurstTimer.Reset(time.Millisecond * 16)
   291  		case <-eatBurstTimer.C:
   292  			var changedList []string
   293  			for k := range changed {
   294  				changedList = append(changedList, k)
   295  				delete(changed, k)
   296  			}
   297  			sort.Strings(changedList)
   298  			changedStr := w.ms.HumanPath(changedList[0])
   299  			for i := 1; i < len(changedList); i++ {
   300  				changedStr += fmt.Sprintf(", %s", w.ms.HumanPath(changedList[i]))
   301  			}
   302  			w.ms.Log.Info.Printf("detected change in %s: recompiling...", changedStr)
   303  			w.requestCompile()
   304  		case err, ok := <-w.fw.Errors:
   305  			if !ok {
   306  				return errors.New("fsnotify watcher closed")
   307  			}
   308  			w.ms.Log.Error.Printf("fsnotify error: %v", err)
   309  		case <-ctx.Done():
   310  			return ctx.Err()
   311  		}
   312  	}
   313  }
   314  
   315  func (w *watcher) requestCompile() {
   316  	select {
   317  	case w.compileCh <- struct{}{}:
   318  	default:
   319  	}
   320  }
   321  
   322  func (w *watcher) ensureAddWatch(ctx context.Context, path string) (time.Time, error) {
   323  	interval := time.Millisecond * 16
   324  	tc := time.NewTimer(0)
   325  	<-tc.C
   326  	for {
   327  		mt, err := w.addWatch(ctx, path)
   328  		if err == nil {
   329  			return mt, nil
   330  		}
   331  		if interval >= time.Second {
   332  			w.ms.Log.Error.Printf("failed to watch %q: %v (retrying in %v)", w.ms.HumanPath(path), err, interval)
   333  		}
   334  
   335  		tc.Reset(interval)
   336  		select {
   337  		case <-tc.C:
   338  			if interval < time.Second {
   339  				interval = time.Second
   340  			}
   341  			if interval < time.Second*16 {
   342  				interval *= 2
   343  			}
   344  		case <-ctx.Done():
   345  			return time.Time{}, ctx.Err()
   346  		}
   347  	}
   348  }
   349  
   350  func (w *watcher) addWatch(ctx context.Context, path string) (time.Time, error) {
   351  	err := w.fw.Add(path)
   352  	if err != nil {
   353  		return time.Time{}, err
   354  	}
   355  	var d os.FileInfo
   356  	d, err = os.Stat(path)
   357  	if err != nil {
   358  		return time.Time{}, err
   359  	}
   360  	return d.ModTime(), nil
   361  }
   362  
   363  func (w *watcher) replaceWatchList(ctx context.Context, paths []string) error {
   364  	// First remove the files no longer being watched
   365  	for _, watched := range w.fw.WatchList() {
   366  		if watched == w.inputPath {
   367  			continue
   368  		}
   369  		found := false
   370  		for _, p := range paths {
   371  			if watched == p {
   372  				found = true
   373  				break
   374  			}
   375  		}
   376  		if !found {
   377  			// Don't mind errors here
   378  			w.fw.Remove(watched)
   379  		}
   380  	}
   381  	// Then add the files newly being watched
   382  	for _, p := range paths {
   383  		found := false
   384  		for _, watched := range w.fw.WatchList() {
   385  			if watched == p {
   386  				found = true
   387  				break
   388  			}
   389  		}
   390  		if !found {
   391  			_, err := w.ensureAddWatch(ctx, p)
   392  			if err != nil {
   393  				return err
   394  			}
   395  		}
   396  	}
   397  	return nil
   398  }
   399  
   400  func (w *watcher) compileLoop(ctx context.Context) error {
   401  	firstCompile := true
   402  	for {
   403  		select {
   404  		case <-w.compileCh:
   405  		case <-ctx.Done():
   406  			return ctx.Err()
   407  		}
   408  
   409  		recompiledPrefix := ""
   410  		if !firstCompile {
   411  			recompiledPrefix = "re"
   412  		}
   413  
   414  		if (filepath.Ext(w.outputPath) == ".png" || filepath.Ext(w.outputPath) == ".pdf") && !w.pw.Browser.IsConnected() {
   415  			newPW, err := w.pw.RestartBrowser()
   416  			if err != nil {
   417  				broadcastErr := fmt.Errorf("issue encountered with PNG exporter: %w", err)
   418  				w.ms.Log.Error.Print(broadcastErr)
   419  				w.broadcast(&compileResult{
   420  					Err: broadcastErr.Error(),
   421  				})
   422  				continue
   423  			}
   424  			w.pw = newPW
   425  		}
   426  
   427  		fs := trackedFS{}
   428  		w.boardpathMu.Lock()
   429  		var boardPath []string
   430  		if w.boardPath != "" {
   431  			boardPath = strings.Split(w.boardPath, string(os.PathSeparator))
   432  		}
   433  		svg, _, err := compile(ctx, w.ms, w.plugins, &fs, w.layout, w.renderOpts, w.fontFamily, w.animateInterval, w.inputPath, w.outputPath, boardPath, false, w.bundle, w.forceAppendix, w.pw.Page)
   434  		w.boardpathMu.Unlock()
   435  		errs := ""
   436  		if err != nil {
   437  			if len(svg) > 0 {
   438  				err = fmt.Errorf("failed to fully %scompile (rendering partial svg): %w", recompiledPrefix, err)
   439  			} else {
   440  				err = fmt.Errorf("failed to %scompile: %w", recompiledPrefix, err)
   441  			}
   442  			errs = err.Error()
   443  			w.ms.Log.Error.Print(errs)
   444  		}
   445  		err = w.replaceWatchList(ctx, fs.opened)
   446  		if err != nil {
   447  			return err
   448  		}
   449  
   450  		w.broadcast(&compileResult{
   451  			SVG:   string(svg),
   452  			Scale: w.renderOpts.Scale,
   453  			Err:   errs,
   454  		})
   455  
   456  		if firstCompile {
   457  			firstCompile = false
   458  			url := fmt.Sprintf("http://%s", w.l.Addr())
   459  			err = xbrowser.Open(ctx, w.ms.Env, url)
   460  			if err != nil {
   461  				w.ms.Log.Warn.Printf("failed to open browser to %v: %v", url, err)
   462  			}
   463  		}
   464  	}
   465  }
   466  
   467  func (w *watcher) listen() error {
   468  	l, err := net.Listen("tcp", net.JoinHostPort(w.host, w.port))
   469  	if err != nil {
   470  		return err
   471  	}
   472  	w.l = l
   473  	w.ms.Log.Success.Printf("listening on http://%v", w.l.Addr())
   474  	return nil
   475  }
   476  
   477  func (w *watcher) goServe() error {
   478  	m := http.NewServeMux()
   479  	// TODO: Add cmdlog logging and error reporting middleware
   480  	// TODO: Add standard debug/profiling routes
   481  	m.HandleFunc("/", w.handleRoot)
   482  	m.Handle("/static/", http.StripPrefix("/static", w.staticFileServer))
   483  	m.Handle("/watch", xhttp.HandlerFuncAdapter{Log: w.ms.Log, Func: w.handleWatch})
   484  
   485  	s := xhttp.NewServer(w.ms.Log.Warn, xhttp.Log(w.ms.Log, m))
   486  	w.goFunc(func(ctx context.Context) error {
   487  		return xhttp.Serve(ctx, time.Second*30, s, w.l)
   488  	})
   489  
   490  	return nil
   491  }
   492  
   493  func (w *watcher) getRes() *compileResult {
   494  	w.resMu.Lock()
   495  	defer w.resMu.Unlock()
   496  	return w.res
   497  }
   498  
   499  func (w *watcher) handleRoot(hw http.ResponseWriter, r *http.Request) {
   500  	hw.Header().Set("Content-Type", "text/html; charset=utf-8")
   501  	fmt.Fprintf(hw, `<!DOCTYPE html>
   502  <html lang="en">
   503  <head>
   504  	<meta charset="UTF-8">
   505  	<meta name="viewport" content="width=device-width, initial-scale=1.0">
   506  	<title>%s</title>
   507  	<script src="/static/watch.js"></script>
   508  	<link rel="stylesheet" href="/static/watch.css">
   509  	<link id="favicon" rel="icon" href="/static/favicon.ico">
   510  </head>
   511  <body data-d2-dev-mode=%t>
   512  	<div id="d2-err" style="display: none"></div>
   513  	<div id="d2-svg-container"></div>
   514  </body>
   515  </html>`, filepath.Base(w.outputPath), w.devMode)
   516  
   517  	w.boardpathMu.Lock()
   518  	// if path is "/x.svg", we just want "x"
   519  	boardPath := strings.TrimPrefix(r.URL.Path, "/")
   520  	if idx := strings.LastIndexByte(boardPath, '.'); idx != -1 {
   521  		boardPath = boardPath[:idx]
   522  	}
   523  	recompile := false
   524  	if boardPath != w.boardPath {
   525  		w.boardPath = boardPath
   526  		recompile = true
   527  	}
   528  	w.boardpathMu.Unlock()
   529  	if recompile {
   530  		w.requestCompile()
   531  	}
   532  }
   533  
   534  func (w *watcher) handleWatch(hw http.ResponseWriter, r *http.Request) error {
   535  	w.wsclientsMu.Lock()
   536  	if w.closing {
   537  		w.wsclientsMu.Unlock()
   538  		return xhttp.Errorf(http.StatusServiceUnavailable, "server shutting down...", "server shutting down...")
   539  	}
   540  	// We must register ourselves before we even upgrade the connection to ensure that
   541  	// w.close() will wait for us. If we instead registered afterwards, then there is a
   542  	// brief period between the hijack and the registration where close may return without
   543  	// waiting for us to finish.
   544  	w.wsclientsWG.Add(1)
   545  	w.wsclientsMu.Unlock()
   546  
   547  	c, err := websocket.Accept(hw, r, &websocket.AcceptOptions{
   548  		CompressionMode: websocket.CompressionDisabled,
   549  	})
   550  	if err != nil {
   551  		w.wsclientsWG.Done()
   552  		return err
   553  	}
   554  
   555  	go func() {
   556  		defer w.wsclientsWG.Done()
   557  		defer c.Close(websocket.StatusInternalError, "the sky is falling")
   558  
   559  		ctx, cancel := context.WithTimeout(w.ctx, time.Hour)
   560  		defer cancel()
   561  
   562  		cl := &wsclient{
   563  			w:         w,
   564  			resultsCh: make(chan struct{}, 1),
   565  			c:         c,
   566  		}
   567  
   568  		w.wsclientsMu.Lock()
   569  		w.wsclients[cl] = struct{}{}
   570  		w.wsclientsMu.Unlock()
   571  		defer func() {
   572  			w.wsclientsMu.Lock()
   573  			delete(w.wsclients, cl)
   574  			w.wsclientsMu.Unlock()
   575  		}()
   576  
   577  		ctx = cl.c.CloseRead(ctx)
   578  		go wsHeartbeat(ctx, cl.c)
   579  		_ = cl.writeLoop(ctx)
   580  	}()
   581  	return nil
   582  }
   583  
   584  type wsclient struct {
   585  	w         *watcher
   586  	resultsCh chan struct{}
   587  	c         *websocket.Conn
   588  }
   589  
   590  func (cl *wsclient) writeLoop(ctx context.Context) error {
   591  	for {
   592  		res := cl.w.getRes()
   593  		if res != nil {
   594  			err := cl.write(ctx, res)
   595  			if err != nil {
   596  				return err
   597  			}
   598  		}
   599  
   600  		select {
   601  		case <-cl.resultsCh:
   602  		case <-ctx.Done():
   603  			cl.c.Close(websocket.StatusGoingAway, "server shutting down...")
   604  			return ctx.Err()
   605  		}
   606  	}
   607  }
   608  
   609  func (cl *wsclient) write(ctx context.Context, res *compileResult) error {
   610  	ctx, cancel := context.WithTimeout(ctx, time.Second*30)
   611  	defer cancel()
   612  
   613  	return wsjson.Write(ctx, cl.c, res)
   614  }
   615  
   616  func (w *watcher) broadcast(res *compileResult) {
   617  	w.resMu.Lock()
   618  	w.res = res
   619  	w.resMu.Unlock()
   620  
   621  	w.wsclientsMu.Lock()
   622  	defer w.wsclientsMu.Unlock()
   623  	clientsSuffix := ""
   624  	if len(w.wsclients) != 1 {
   625  		clientsSuffix = "s"
   626  	}
   627  	w.ms.Log.Info.Printf("broadcasting update to %d client%s", len(w.wsclients), clientsSuffix)
   628  	for cl := range w.wsclients {
   629  		select {
   630  		case cl.resultsCh <- struct{}{}:
   631  		default:
   632  		}
   633  	}
   634  }
   635  
   636  func wsHeartbeat(ctx context.Context, c *websocket.Conn) {
   637  	defer c.Close(websocket.StatusInternalError, "the sky is falling")
   638  
   639  	t := time.NewTimer(0)
   640  	<-t.C
   641  	for {
   642  		err := c.Ping(ctx)
   643  		if err != nil {
   644  			return
   645  		}
   646  
   647  		t.Reset(time.Second * 30)
   648  		select {
   649  		case <-t.C:
   650  		case <-ctx.Done():
   651  			return
   652  		}
   653  	}
   654  }
   655  
   656  // trackedFS is OS's FS with the addition that it tracks which files are opened successfully
   657  type trackedFS struct {
   658  	opened []string
   659  }
   660  
   661  func (tfs *trackedFS) Open(name string) (fs.File, error) {
   662  	f, err := os.Open(name)
   663  	if err == nil {
   664  		tfs.opened = append(tfs.opened, name)
   665  	}
   666  	return f, err
   667  }
   668  

View as plain text