...

Source file src/github.com/MicahParks/keyfunc/v2/checksum_test.go

Documentation: github.com/MicahParks/keyfunc/v2

     1  package keyfunc_test
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"os"
     8  	"path/filepath"
     9  	"reflect"
    10  	"testing"
    11  
    12  	"github.com/golang-jwt/jwt/v5"
    13  
    14  	"github.com/MicahParks/keyfunc/v2"
    15  )
    16  
    17  // TestChecksum confirms that the JWKS will only perform a refresh if a new JWKS is read from the remote resource.
    18  func TestChecksum(t *testing.T) {
    19  	tempDir, err := os.MkdirTemp("", "*")
    20  	if err != nil {
    21  		t.Fatalf(logFmt, "Failed to create a temporary directory.", err)
    22  	}
    23  	defer func() {
    24  		err = os.RemoveAll(tempDir)
    25  		if err != nil {
    26  			t.Fatalf(logFmt, "Failed to remove temporary directory.", err)
    27  		}
    28  	}()
    29  
    30  	jwksFile := filepath.Join(tempDir, jwksFilePath)
    31  
    32  	err = os.WriteFile(jwksFile, []byte(jwksJSON), 0600)
    33  	if err != nil {
    34  		t.Fatalf(logFmt, "Failed to write JWKS file to temporary directory.", err)
    35  	}
    36  
    37  	server := httptest.NewServer(http.FileServer(http.Dir(tempDir)))
    38  	defer server.Close()
    39  
    40  	testingRefreshErrorHandler := func(err error) {
    41  		panic(fmt.Sprintf(logFmt, "Unhandled JWKS error.", err))
    42  	}
    43  	opts := keyfunc.Options{
    44  		RefreshErrorHandler: testingRefreshErrorHandler,
    45  		RefreshUnknownKID:   true,
    46  	}
    47  
    48  	jwksURL := server.URL + jwksFilePath
    49  
    50  	jwks, err := keyfunc.Get(jwksURL, opts)
    51  	if err != nil {
    52  		t.Fatalf(logFmt, "Failed to get JWKS from testing URL.", err)
    53  	}
    54  	defer jwks.EndBackground()
    55  
    56  	cryptoKeyPointers := make(map[string]interface{})
    57  	for kid, cryptoKey := range jwks.ReadOnlyKeys() {
    58  		cryptoKeyPointers[kid] = cryptoKey
    59  	}
    60  
    61  	// Create a JWT that will not be in the JWKS.
    62  	token := jwt.New(jwt.SigningMethodHS256)
    63  	token.Header["kid"] = "unknown"
    64  	signed, err := token.SignedString([]byte("test"))
    65  	if err != nil {
    66  		t.Fatalf(logFmt, "Failed to sign test JWT.", err)
    67  	}
    68  
    69  	// Force the JWKS to refresh.
    70  	_, _ = jwt.Parse(signed, jwks.Keyfunc)
    71  
    72  	// Confirm the keys in the JWKS have not been refreshed.
    73  	newKeys := jwks.ReadOnlyKeys()
    74  	if len(newKeys) != len(cryptoKeyPointers) {
    75  		t.Fatalf("The number of keys should not be different.")
    76  	}
    77  	for kid, cryptoKey := range newKeys {
    78  		if !reflect.DeepEqual(cryptoKeyPointers[kid], cryptoKey) {
    79  			t.Fatalf("The JWKS should not have refreshed without a checksum change.")
    80  		}
    81  	}
    82  
    83  	// Write a different JWKS.
    84  	_, _, jwksBytes, _, err := keysAndJWKS()
    85  	if err != nil {
    86  		t.Fatalf(logFmt, "Failed to create a test JWKS.", err)
    87  	}
    88  	err = os.WriteFile(jwksFile, jwksBytes, 0600)
    89  	if err != nil {
    90  		t.Fatalf(logFmt, "Failed to write JWKS file to temporary directory.", err)
    91  	}
    92  
    93  	// Force the JWKS to refresh.
    94  	_, _ = jwt.Parse(signed, jwks.Keyfunc)
    95  
    96  	// Confirm the keys in the JWKS have been refreshed.
    97  	newKeys = jwks.ReadOnlyKeys()
    98  	different := false
    99  	for kid, cryptoKey := range newKeys {
   100  		if !reflect.DeepEqual(cryptoKeyPointers[kid], cryptoKey) {
   101  			different = true
   102  			break
   103  		}
   104  	}
   105  	if !different {
   106  		t.Fatalf("A different JWKS checksum should have triggered a JWKS refresh.")
   107  	}
   108  }
   109  

View as plain text