...

Source file src/github.com/go-kit/kit/auth/jwt/middleware_test.go

Documentation: github.com/go-kit/kit/auth/jwt

     1  package jwt
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"testing"
     7  	"time"
     8  
     9  	"crypto/subtle"
    10  
    11  	"github.com/go-kit/kit/endpoint"
    12  	"github.com/golang-jwt/jwt/v4"
    13  )
    14  
    15  type customClaims struct {
    16  	MyProperty string `json:"my_property"`
    17  	jwt.StandardClaims
    18  }
    19  
    20  func (c customClaims) VerifyMyProperty(p string) bool {
    21  	return subtle.ConstantTimeCompare([]byte(c.MyProperty), []byte(p)) != 0
    22  }
    23  
    24  var (
    25  	kid            = "kid"
    26  	key            = []byte("test_signing_key")
    27  	myProperty     = "some value"
    28  	method         = jwt.SigningMethodHS256
    29  	invalidMethod  = jwt.SigningMethodRS256
    30  	mapClaims      = jwt.MapClaims{"user": "go-kit"}
    31  	standardClaims = jwt.StandardClaims{Audience: "go-kit"}
    32  	myCustomClaims = customClaims{MyProperty: myProperty, StandardClaims: standardClaims}
    33  	// Signed tokens generated at https://jwt.io/
    34  	signedKey         = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
    35  	standardSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJnby1raXQifQ.L5ypIJjCOOv3jJ8G5SelaHvR04UJuxmcBN5QW3m_aoY"
    36  	customSignedKey   = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJteV9wcm9wZXJ0eSI6InNvbWUgdmFsdWUiLCJhdWQiOiJnby1raXQifQ.s8F-IDrV4WPJUsqr7qfDi-3GRlcKR0SRnkTeUT_U-i0"
    37  	invalidKey        = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA"
    38  	malformedKey      = "malformed.jwt.token"
    39  )
    40  
    41  func signingValidator(t *testing.T, signer endpoint.Endpoint, expectedKey string) {
    42  	ctx, err := signer(context.Background(), struct{}{})
    43  	if err != nil {
    44  		t.Fatalf("Signer returned error: %s", err)
    45  	}
    46  
    47  	token, ok := ctx.(context.Context).Value(JWTContextKey).(string)
    48  	if !ok {
    49  		t.Fatal("Token did not exist in context")
    50  	}
    51  
    52  	if token != expectedKey {
    53  		t.Fatalf("JWTs did not match: expecting %s got %s", expectedKey, token)
    54  	}
    55  }
    56  
    57  func TestNewSigner(t *testing.T) {
    58  	e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }
    59  
    60  	signer := NewSigner(kid, key, method, mapClaims)(e)
    61  	signingValidator(t, signer, signedKey)
    62  
    63  	signer = NewSigner(kid, key, method, standardClaims)(e)
    64  	signingValidator(t, signer, standardSignedKey)
    65  
    66  	signer = NewSigner(kid, key, method, myCustomClaims)(e)
    67  	signingValidator(t, signer, customSignedKey)
    68  }
    69  
    70  func TestJWTParser(t *testing.T) {
    71  	e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }
    72  
    73  	keys := func(token *jwt.Token) (interface{}, error) {
    74  		return key, nil
    75  	}
    76  
    77  	parser := NewParser(keys, method, MapClaimsFactory)(e)
    78  
    79  	// No Token is passed into the parser
    80  	_, err := parser(context.Background(), struct{}{})
    81  	if err == nil {
    82  		t.Error("Parser should have returned an error")
    83  	}
    84  
    85  	if err != ErrTokenContextMissing {
    86  		t.Errorf("unexpected error returned, expected: %s got: %s", ErrTokenContextMissing, err)
    87  	}
    88  
    89  	// Invalid Token is passed into the parser
    90  	ctx := context.WithValue(context.Background(), JWTContextKey, invalidKey)
    91  	_, err = parser(ctx, struct{}{})
    92  	if err == nil {
    93  		t.Error("Parser should have returned an error")
    94  	}
    95  
    96  	// Invalid Method is used in the parser
    97  	badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e)
    98  	ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
    99  	_, err = badParser(ctx, struct{}{})
   100  	if err == nil {
   101  		t.Error("Parser should have returned an error")
   102  	}
   103  
   104  	if err != ErrUnexpectedSigningMethod {
   105  		t.Errorf("unexpected error returned, expected: %s got: %s", ErrUnexpectedSigningMethod, err)
   106  	}
   107  
   108  	// Invalid key is used in the parser
   109  	invalidKeys := func(token *jwt.Token) (interface{}, error) {
   110  		return []byte("bad"), nil
   111  	}
   112  
   113  	badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e)
   114  	ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
   115  	_, err = badParser(ctx, struct{}{})
   116  	if err == nil {
   117  		t.Error("Parser should have returned an error")
   118  	}
   119  
   120  	// Correct token is passed into the parser
   121  	ctx = context.WithValue(context.Background(), JWTContextKey, signedKey)
   122  	ctx1, err := parser(ctx, struct{}{})
   123  	if err != nil {
   124  		t.Fatalf("Parser returned error: %s", err)
   125  	}
   126  
   127  	cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(jwt.MapClaims)
   128  	if !ok {
   129  		t.Fatal("Claims were not passed into context correctly")
   130  	}
   131  
   132  	if cl["user"] != mapClaims["user"] {
   133  		t.Fatalf("JWT Claims.user did not match: expecting %s got %s", mapClaims["user"], cl["user"])
   134  	}
   135  
   136  	// Test for malformed token error response
   137  	parser = NewParser(keys, method, StandardClaimsFactory)(e)
   138  	ctx = context.WithValue(context.Background(), JWTContextKey, malformedKey)
   139  	ctx1, err = parser(ctx, struct{}{})
   140  	if want, have := ErrTokenMalformed, err; want != have {
   141  		t.Fatalf("Expected %+v, got %+v", want, have)
   142  	}
   143  
   144  	// Test for expired token error response
   145  	parser = NewParser(keys, method, StandardClaimsFactory)(e)
   146  	expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100})
   147  	token, err := expired.SignedString(key)
   148  	if err != nil {
   149  		t.Fatalf("Unable to Sign Token: %+v", err)
   150  	}
   151  	ctx = context.WithValue(context.Background(), JWTContextKey, token)
   152  	ctx1, err = parser(ctx, struct{}{})
   153  	if want, have := ErrTokenExpired, err; want != have {
   154  		t.Fatalf("Expected %+v, got %+v", want, have)
   155  	}
   156  
   157  	// Test for not activated token error response
   158  	parser = NewParser(keys, method, StandardClaimsFactory)(e)
   159  	notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100})
   160  	token, err = notactive.SignedString(key)
   161  	if err != nil {
   162  		t.Fatalf("Unable to Sign Token: %+v", err)
   163  	}
   164  	ctx = context.WithValue(context.Background(), JWTContextKey, token)
   165  	ctx1, err = parser(ctx, struct{}{})
   166  	if want, have := ErrTokenNotActive, err; want != have {
   167  		t.Fatalf("Expected %+v, got %+v", want, have)
   168  	}
   169  
   170  	// test valid standard claims token
   171  	parser = NewParser(keys, method, StandardClaimsFactory)(e)
   172  	ctx = context.WithValue(context.Background(), JWTContextKey, standardSignedKey)
   173  	ctx1, err = parser(ctx, struct{}{})
   174  	if err != nil {
   175  		t.Fatalf("Parser returned error: %s", err)
   176  	}
   177  	stdCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*jwt.StandardClaims)
   178  	if !ok {
   179  		t.Fatal("Claims were not passed into context correctly")
   180  	}
   181  	if !stdCl.VerifyAudience("go-kit", true) {
   182  		t.Fatalf("JWT jwt.StandardClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, stdCl.Audience)
   183  	}
   184  
   185  	// test valid customized claims token
   186  	parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e)
   187  	ctx = context.WithValue(context.Background(), JWTContextKey, customSignedKey)
   188  	ctx1, err = parser(ctx, struct{}{})
   189  	if err != nil {
   190  		t.Fatalf("Parser returned error: %s", err)
   191  	}
   192  	custCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*customClaims)
   193  	if !ok {
   194  		t.Fatal("Claims were not passed into context correctly")
   195  	}
   196  	if !custCl.VerifyAudience("go-kit", true) {
   197  		t.Fatalf("JWT customClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, custCl.Audience)
   198  	}
   199  	if !custCl.VerifyMyProperty(myProperty) {
   200  		t.Fatalf("JWT customClaims.MyProperty did not match: expecting %s got %s", myProperty, custCl.MyProperty)
   201  	}
   202  }
   203  
   204  func TestIssue562(t *testing.T) {
   205  	var (
   206  		kf  = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil }
   207  		e   = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop)
   208  		key = JWTContextKey
   209  		val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
   210  		ctx = context.WithValue(context.Background(), key, val)
   211  	)
   212  	wg := sync.WaitGroup{}
   213  	for i := 0; i < 100; i++ {
   214  		wg.Add(1)
   215  		go func() {
   216  			defer wg.Done()
   217  			e(ctx, struct{}{}) // fatal error: concurrent map read and map write
   218  		}()
   219  	}
   220  	wg.Wait()
   221  }
   222  

View as plain text