...
1 package keyfunc_test
2
3 import (
4 "fmt"
5 "net/http"
6 "net/http/httptest"
7 "os"
8 "path/filepath"
9 "reflect"
10 "testing"
11
12 "github.com/golang-jwt/jwt/v5"
13
14 "github.com/MicahParks/keyfunc/v2"
15 )
16
17
18 func TestChecksum(t *testing.T) {
19 tempDir, err := os.MkdirTemp("", "*")
20 if err != nil {
21 t.Fatalf(logFmt, "Failed to create a temporary directory.", err)
22 }
23 defer func() {
24 err = os.RemoveAll(tempDir)
25 if err != nil {
26 t.Fatalf(logFmt, "Failed to remove temporary directory.", err)
27 }
28 }()
29
30 jwksFile := filepath.Join(tempDir, jwksFilePath)
31
32 err = os.WriteFile(jwksFile, []byte(jwksJSON), 0600)
33 if err != nil {
34 t.Fatalf(logFmt, "Failed to write JWKS file to temporary directory.", err)
35 }
36
37 server := httptest.NewServer(http.FileServer(http.Dir(tempDir)))
38 defer server.Close()
39
40 testingRefreshErrorHandler := func(err error) {
41 panic(fmt.Sprintf(logFmt, "Unhandled JWKS error.", err))
42 }
43 opts := keyfunc.Options{
44 RefreshErrorHandler: testingRefreshErrorHandler,
45 RefreshUnknownKID: true,
46 }
47
48 jwksURL := server.URL + jwksFilePath
49
50 jwks, err := keyfunc.Get(jwksURL, opts)
51 if err != nil {
52 t.Fatalf(logFmt, "Failed to get JWKS from testing URL.", err)
53 }
54 defer jwks.EndBackground()
55
56 cryptoKeyPointers := make(map[string]interface{})
57 for kid, cryptoKey := range jwks.ReadOnlyKeys() {
58 cryptoKeyPointers[kid] = cryptoKey
59 }
60
61
62 token := jwt.New(jwt.SigningMethodHS256)
63 token.Header["kid"] = "unknown"
64 signed, err := token.SignedString([]byte("test"))
65 if err != nil {
66 t.Fatalf(logFmt, "Failed to sign test JWT.", err)
67 }
68
69
70 _, _ = jwt.Parse(signed, jwks.Keyfunc)
71
72
73 newKeys := jwks.ReadOnlyKeys()
74 if len(newKeys) != len(cryptoKeyPointers) {
75 t.Fatalf("The number of keys should not be different.")
76 }
77 for kid, cryptoKey := range newKeys {
78 if !reflect.DeepEqual(cryptoKeyPointers[kid], cryptoKey) {
79 t.Fatalf("The JWKS should not have refreshed without a checksum change.")
80 }
81 }
82
83
84 _, _, jwksBytes, _, err := keysAndJWKS()
85 if err != nil {
86 t.Fatalf(logFmt, "Failed to create a test JWKS.", err)
87 }
88 err = os.WriteFile(jwksFile, jwksBytes, 0600)
89 if err != nil {
90 t.Fatalf(logFmt, "Failed to write JWKS file to temporary directory.", err)
91 }
92
93
94 _, _ = jwt.Parse(signed, jwks.Keyfunc)
95
96
97 newKeys = jwks.ReadOnlyKeys()
98 different := false
99 for kid, cryptoKey := range newKeys {
100 if !reflect.DeepEqual(cryptoKeyPointers[kid], cryptoKey) {
101 different = true
102 break
103 }
104 }
105 if !different {
106 t.Fatalf("A different JWKS checksum should have triggered a JWKS refresh.")
107 }
108 }
109
View as plain text