package imgbundler import ( "context" "crypto/rand" _ "embed" "fmt" "net/http" "net/http/httptest" "path/filepath" "strings" "sync" "testing" "cdr.dev/slog/sloggers/slogtest" tassert "github.com/stretchr/testify/assert" "oss.terrastruct.com/d2/lib/log" "oss.terrastruct.com/d2/lib/simplelog" ) //go:embed test_png.png var testPNGFile []byte type roundTripFunc func(req *http.Request) *http.Response func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil } func TestRegex(t *testing.T) { urls := []string{ "https://icons.terrastruct.com/essentials/004-picture.svg", "http://icons.terrastruct.com/essentials/004-picture.svg", } notURLs := []string{ "hi.png", "./cat.png", "/cat.png", } for _, href := range append(urls, notURLs...) { str := fmt.Sprintf(``, href) matches := imageRegex.FindAllStringSubmatch(str, -1) if len(matches) != 1 { t.Fatalf("uri regex didn't match %s", str) } } } func TestInlineRemote(t *testing.T) { imgCache = sync.Map{} // we don't want log.Error to cause this test to fail ctx := log.WithTB(context.Background(), t, &slogtest.Options{IgnoreErrors: true}) svgURL := "https://icons.terrastruct.com/essentials/004-picture.svg" pngURL := "https://cdn4.iconfinder.com/data/icons/smart-phones-technologies/512/android-phone.png" sampleSVG := fmt.Sprintf(` ab `, svgURL, pngURL) httpClient.Transport = roundTripFunc(func(req *http.Request) *http.Response { respRecorder := httptest.NewRecorder() switch req.URL.String() { case svgURL: respRecorder.WriteString(`\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n`) case pngURL: respRecorder.Write(testPNGFile) default: t.Fatal(req.URL) } respRecorder.WriteHeader(200) return respRecorder.Result() }) l := simplelog.FromLibLog(ctx) out, err := BundleRemote(ctx, l, []byte(sampleSVG), false) if err != nil { t.Fatal(err) } if strings.Contains(string(out), "https://") { t.Fatal("links still exist") } if !strings.Contains(string(out), "image/svg+xml") { t.Fatal("no svg image inserted") } if !strings.Contains(string(out), "image/png") { t.Fatal("no png image inserted") } imgCache = sync.Map{} // Test almost too large response httpClient.Transport = roundTripFunc(func(req *http.Request) *http.Response { respRecorder := httptest.NewRecorder() bytes := make([]byte, maxImageSize) rand.Read(bytes) respRecorder.Write(bytes) respRecorder.WriteHeader(200) return respRecorder.Result() }) _, err = BundleRemote(ctx, l, []byte(sampleSVG), false) if err != nil { t.Fatal(err) } imgCache = sync.Map{} // Test too large response httpClient.Transport = roundTripFunc(func(req *http.Request) *http.Response { respRecorder := httptest.NewRecorder() bytes := make([]byte, maxImageSize+1) rand.Read(bytes) respRecorder.Write(bytes) respRecorder.WriteHeader(200) return respRecorder.Result() }) _, err = BundleRemote(ctx, l, []byte(sampleSVG), false) if err == nil { t.Fatal("expected error") } imgCache = sync.Map{} // Test error response httpClient.Transport = roundTripFunc(func(req *http.Request) *http.Response { respRecorder := httptest.NewRecorder() respRecorder.WriteHeader(500) return respRecorder.Result() }) _, err = BundleRemote(ctx, l, []byte(sampleSVG), false) if err == nil { t.Fatal("expected error") } } func TestInlineLocal(t *testing.T) { imgCache = sync.Map{} ctx := log.WithTB(context.Background(), t, nil) svgURL, err := filepath.Abs("./test_svg.svg") if err != nil { t.Fatal(err) } pngURL, err := filepath.Abs("./test_png.png") if err != nil { t.Fatal(err) } sampleSVG := fmt.Sprintf(` ab `, svgURL, pngURL) l := simplelog.FromLibLog(ctx) out, err := BundleLocal(ctx, l, []byte(sampleSVG), false) if err != nil { t.Fatal(err) } if strings.Contains(string(out), svgURL) { t.Fatal("links still exist") } if !strings.Contains(string(out), "image/svg+xml") { t.Fatal("no svg image inserted") } if !strings.Contains(string(out), "image/png") { t.Fatal("no png image inserted") } } // TestDuplicateURL ensures that we don't fetch the same image twice func TestDuplicateURL(t *testing.T) { imgCache = sync.Map{} ctx := log.WithTB(context.Background(), t, nil) url1 := "https://icons.terrastruct.com/essentials/004-picture.svg" url2 := "https://icons.terrastruct.com/essentials/004-picture.svg" sampleSVG := fmt.Sprintf(` ab `, url1, url2) count := 0 httpClient.Transport = roundTripFunc(func(req *http.Request) *http.Response { count++ respRecorder := httptest.NewRecorder() respRecorder.WriteString(`\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n`) respRecorder.WriteHeader(200) return respRecorder.Result() }) l := simplelog.FromLibLog(ctx) out, err := BundleRemote(ctx, l, []byte(sampleSVG), false) if err != nil { t.Fatal(err) } tassert.Equal(t, 1, count) if strings.Contains(string(out), url1) { t.Fatal("links still exist") } tassert.Equal(t, 2, strings.Count(string(out), "image/svg+xml")) } func TestImgCache(t *testing.T) { imgCache = sync.Map{} ctx := log.WithTB(context.Background(), t, nil) url1 := "https://icons.terrastruct.com/essentials/004-picture.svg" url2 := "https://icons.terrastruct.com/essentials/004-picture.svg" sampleSVG := fmt.Sprintf(` ab `, url1, url2) count := 0 httpClient.Transport = roundTripFunc(func(req *http.Request) *http.Response { count++ respRecorder := httptest.NewRecorder() respRecorder.WriteString(`\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n`) respRecorder.WriteHeader(200) return respRecorder.Result() }) l := simplelog.FromLibLog(ctx) // Using a cache, imgs are not refetched on multiple runs _, err := BundleRemote(ctx, l, []byte(sampleSVG), true) if err != nil { t.Fatal(err) } _, err = BundleRemote(ctx, l, []byte(sampleSVG), true) if err != nil { t.Fatal(err) } tassert.Equal(t, 1, count) // With cache disabled, it refetches count = 0 _, err = BundleRemote(ctx, l, []byte(sampleSVG), false) if err != nil { t.Fatal(err) } _, err = BundleRemote(ctx, l, []byte(sampleSVG), false) if err != nil { t.Fatal(err) } tassert.Equal(t, 2, count) }