...

Source file src/github.com/MicahParks/keyfunc/override_test.go

Documentation: github.com/MicahParks/keyfunc

     1  package keyfunc_test
     2  
     3  import (
     4  	"crypto/rand"
     5  	"crypto/rsa"
     6  	"encoding/base64"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"math/big"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"os"
    14  	"path/filepath"
    15  	"testing"
    16  
    17  	"github.com/golang-jwt/jwt/v4"
    18  
    19  	"github.com/MicahParks/keyfunc"
    20  )
    21  
    22  const (
    23  	givenKID  = "givenKID"
    24  	remoteKID = "remoteKID"
    25  )
    26  
    27  // pseudoJWKS is a data structure used to JSON marshal a JWKS but is not fully featured.
    28  type pseudoJWKS struct {
    29  	Keys []pseudoJSONKey `json:"keys"`
    30  }
    31  
    32  // pseudoJSONKey is a data structure that is used to JSON marshal a JWK that is not fully featured.
    33  type pseudoJSONKey struct {
    34  	KID string `json:"kid"`
    35  	KTY string `json:"kty"`
    36  	E   string `json:"e"`
    37  	N   string `json:"n"`
    38  }
    39  
    40  // TestNewGiven tests that given keys will be added to a JWKS with a remote resource.
    41  func TestNewGiven(t *testing.T) {
    42  	tempDir, err := os.MkdirTemp("", "*")
    43  	if err != nil {
    44  		t.Fatalf(logFmt, "Failed to create a temporary directory.", err)
    45  	}
    46  	defer func() {
    47  		err = os.RemoveAll(tempDir)
    48  		if err != nil {
    49  			t.Fatalf(logFmt, "Failed to remove temporary directory.", err)
    50  		}
    51  	}()
    52  
    53  	jwksFile := filepath.Join(tempDir, jwksFilePath)
    54  
    55  	givenKeys, givenPrivateKeys, jwksBytes, remotePrivateKeys, err := keysAndJWKS()
    56  	if err != nil {
    57  		t.Fatalf(logFmt, "Failed to create cryptographic keys for the test.", err)
    58  	}
    59  
    60  	err = os.WriteFile(jwksFile, jwksBytes, 0600)
    61  	if err != nil {
    62  		t.Fatalf(logFmt, "Failed to write JWKS file to temporary directory.", err)
    63  	}
    64  
    65  	server := httptest.NewServer(http.FileServer(http.Dir(tempDir)))
    66  	defer server.Close()
    67  
    68  	testingRefreshErrorHandler := func(err error) {
    69  		panic(fmt.Sprintf(logFmt, "Unhandled JWKS error.", err))
    70  	}
    71  
    72  	jwksURL := server.URL + jwksFilePath
    73  
    74  	options := keyfunc.Options{
    75  		GivenKeys:           givenKeys,
    76  		RefreshErrorHandler: testingRefreshErrorHandler,
    77  	}
    78  
    79  	jwks, err := keyfunc.Get(jwksURL, options)
    80  	if err != nil {
    81  		t.Fatalf(logFmt, "Failed to get the JWKS the testing URL.", err)
    82  	}
    83  
    84  	// Test the given key with a unique key ID.
    85  	createSignParseValidate(t, givenPrivateKeys, jwks, givenKID, true)
    86  
    87  	// Test the given key with a non-unique key ID that should be overwritten.
    88  	createSignParseValidate(t, givenPrivateKeys, jwks, remoteKID, false)
    89  
    90  	// Test the remote key that should not have been overwritten.
    91  	createSignParseValidate(t, remotePrivateKeys, jwks, remoteKID, true)
    92  
    93  	// Change the JWKS options to overwrite remote keys.
    94  	options.GivenKIDOverride = true
    95  	jwks, err = keyfunc.Get(jwksURL, options)
    96  	if err != nil {
    97  		t.Fatalf(logFmt, "Failed to recreate JWKS.", err)
    98  	}
    99  
   100  	// Test the given key with a unique key ID.
   101  	createSignParseValidate(t, givenPrivateKeys, jwks, givenKID, true)
   102  
   103  	// Test the given key with a non-unique key ID that should overwrite the remote key.
   104  	createSignParseValidate(t, givenPrivateKeys, jwks, remoteKID, true)
   105  
   106  	// Test the remote key that should have been overwritten.
   107  	createSignParseValidate(t, remotePrivateKeys, jwks, remoteKID, false)
   108  }
   109  
   110  // createSignParseValidate creates, signs, parses, and validates a JWT.
   111  func createSignParseValidate(t *testing.T, keys map[string]*rsa.PrivateKey, jwks *keyfunc.JWKS, kid string, shouldValidate bool) {
   112  	unsignedToken := jwt.New(jwt.SigningMethodRS256)
   113  	unsignedToken.Header[kidAttribute] = kid
   114  
   115  	jwtB64, err := unsignedToken.SignedString(keys[kid])
   116  	if err != nil {
   117  		t.Fatalf(logFmt, "Failed to sign the JWT.", err)
   118  	}
   119  
   120  	token, err := jwt.Parse(jwtB64, jwks.Keyfunc)
   121  	if err != nil {
   122  		if !shouldValidate && errors.Is(err, rsa.ErrVerification) {
   123  			return
   124  		}
   125  		t.Fatalf(logFmt, "Failed to parse the JWT.", err)
   126  	}
   127  
   128  	if !shouldValidate {
   129  		t.Fatalf("The token should not have parsed properly.")
   130  	}
   131  
   132  	if !token.Valid {
   133  		t.Fatalf("The JWT is not valid.")
   134  	}
   135  }
   136  
   137  // keysAndJWKS creates a couple of cryptographic keys and the remote JWKS associated with them.
   138  func keysAndJWKS() (givenKeys map[string]keyfunc.GivenKey, givenPrivateKeys map[string]*rsa.PrivateKey, jwksBytes []byte, remotePrivateKeys map[string]*rsa.PrivateKey, err error) {
   139  	const rsaErrMessage = "failed to create RSA key: %w"
   140  	givenKeys = make(map[string]keyfunc.GivenKey)
   141  	givenPrivateKeys = make(map[string]*rsa.PrivateKey)
   142  	remotePrivateKeys = make(map[string]*rsa.PrivateKey)
   143  
   144  	// Create a key not in the remote JWKS.
   145  	key1, err := addRSA(givenKeys, givenKID)
   146  	if err != nil {
   147  		return nil, nil, nil, nil, fmt.Errorf(rsaErrMessage, err)
   148  	}
   149  	givenPrivateKeys[givenKID] = key1
   150  
   151  	// Create a key to be overwritten by or override the one with the same key ID in the remote JWKS.
   152  	key2, err := addRSA(givenKeys, remoteKID)
   153  	if err != nil {
   154  		return nil, nil, nil, nil, fmt.Errorf(rsaErrMessage, err)
   155  	}
   156  	givenPrivateKeys[remoteKID] = key2
   157  
   158  	// Create a key that exists in the remote JWKS.
   159  	key3, err := rsa.GenerateKey(rand.Reader, 2048)
   160  	if err != nil {
   161  		return nil, nil, nil, nil, fmt.Errorf(rsaErrMessage, err)
   162  	}
   163  	remotePrivateKeys[remoteKID] = key3
   164  
   165  	jwks := pseudoJWKS{Keys: []pseudoJSONKey{{
   166  		KID: remoteKID,
   167  		KTY: "RSA",
   168  		E:   base64.RawURLEncoding.EncodeToString(big.NewInt(int64(key3.PublicKey.E)).Bytes()),
   169  		N:   base64.RawURLEncoding.EncodeToString(key3.PublicKey.N.Bytes()),
   170  	}}}
   171  
   172  	jwksBytes, err = json.Marshal(jwks)
   173  	if err != nil {
   174  		return nil, nil, nil, nil, fmt.Errorf("failed to marshal the JWKS to JSON: %w", err)
   175  	}
   176  
   177  	return givenKeys, givenPrivateKeys, jwksBytes, remotePrivateKeys, nil
   178  }
   179  

View as plain text