...

Source file src/github.com/ory/fosite/token/jwt/jwt_test.go

Documentation: github.com/ory/fosite/token/jwt

     1  /*
     2   * Copyright © 2015-2018 Aeneas Rekkas <aeneas+oss@aeneas.io>
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   * @author		Aeneas Rekkas <aeneas+oss@aeneas.io>
    17   * @copyright 	2015-2018 Aeneas Rekkas <aeneas+oss@aeneas.io>
    18   * @license 	Apache-2.0
    19   *
    20   */
    21  
    22  package jwt
    23  
    24  import (
    25  	"context"
    26  	"fmt"
    27  	"strings"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/stretchr/testify/assert"
    32  	"github.com/stretchr/testify/require"
    33  )
    34  
    35  var header = &Headers{
    36  	Extra: map[string]interface{}{
    37  		"foo": "bar",
    38  	},
    39  }
    40  
    41  func TestHash(t *testing.T) {
    42  	for k, tc := range []struct {
    43  		d        string
    44  		strategy JWTStrategy
    45  	}{
    46  		{
    47  			d: "RS256JWTStrategy",
    48  			strategy: &RS256JWTStrategy{
    49  				PrivateKey: MustRSAKey(),
    50  			},
    51  		},
    52  		{
    53  			d: "ES256JWTStrategy",
    54  			strategy: &ES256JWTStrategy{
    55  				PrivateKey: MustECDSAKey(),
    56  			},
    57  		},
    58  	} {
    59  		t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) {
    60  			in := []byte("foo")
    61  			out, err := tc.strategy.Hash(context.TODO(), in)
    62  			assert.NoError(t, err)
    63  			assert.NotEqual(t, in, out)
    64  		})
    65  	}
    66  }
    67  
    68  func TestAssign(t *testing.T) {
    69  	for k, c := range [][]map[string]interface{}{
    70  		{
    71  			{"foo": "bar"},
    72  			{"baz": "bar"},
    73  			{"foo": "bar", "baz": "bar"},
    74  		},
    75  		{
    76  			{"foo": "bar"},
    77  			{"foo": "baz"},
    78  			{"foo": "bar"},
    79  		},
    80  		{
    81  			{},
    82  			{"foo": "baz"},
    83  			{"foo": "baz"},
    84  		},
    85  		{
    86  			{"foo": "bar"},
    87  			{"foo": "baz", "bar": "baz"},
    88  			{"foo": "bar", "bar": "baz"},
    89  		},
    90  	} {
    91  		assert.EqualValues(t, c[2], assign(c[0], c[1]), "Case %d", k)
    92  	}
    93  }
    94  
    95  func TestGenerateJWT(t *testing.T) {
    96  	for k, tc := range []struct {
    97  		d        string
    98  		strategy JWTStrategy
    99  		resetKey func(strategy JWTStrategy)
   100  	}{
   101  		{
   102  			d: "RS256JWTStrategy",
   103  			strategy: &RS256JWTStrategy{
   104  				PrivateKey: MustRSAKey(),
   105  			},
   106  			resetKey: func(strategy JWTStrategy) {
   107  				strategy.(*RS256JWTStrategy).PrivateKey = MustRSAKey()
   108  			},
   109  		},
   110  		{
   111  			d: "ES256JWTStrategy",
   112  			strategy: &ES256JWTStrategy{
   113  				PrivateKey: MustECDSAKey(),
   114  			},
   115  			resetKey: func(strategy JWTStrategy) {
   116  				strategy.(*ES256JWTStrategy).PrivateKey = MustECDSAKey()
   117  			},
   118  		},
   119  	} {
   120  		t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) {
   121  			claims := &JWTClaims{
   122  				ExpiresAt: time.Now().UTC().Add(time.Hour),
   123  			}
   124  
   125  			token, sig, err := tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header)
   126  			require.NoError(t, err)
   127  			require.NotNil(t, token)
   128  
   129  			sig, err = tc.strategy.Validate(context.TODO(), token)
   130  			require.NoError(t, err)
   131  
   132  			sig, err = tc.strategy.Validate(context.TODO(), token+"."+"0123456789")
   133  			require.Error(t, err)
   134  
   135  			partToken := strings.Split(token, ".")[2]
   136  
   137  			sig, err = tc.strategy.Validate(context.TODO(), partToken)
   138  			require.Error(t, err)
   139  
   140  			// Reset private key
   141  			tc.resetKey(tc.strategy)
   142  
   143  			// Lets validate the exp claim
   144  			claims = &JWTClaims{
   145  				ExpiresAt: time.Now().UTC().Add(-time.Hour),
   146  			}
   147  			token, sig, err = tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header)
   148  			require.NoError(t, err)
   149  			require.NotNil(t, token)
   150  
   151  			sig, err = tc.strategy.Validate(context.TODO(), token)
   152  			require.Error(t, err)
   153  
   154  			// Lets validate the nbf claim
   155  			claims = &JWTClaims{
   156  				NotBefore: time.Now().UTC().Add(time.Hour),
   157  			}
   158  			token, sig, err = tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header)
   159  			require.NoError(t, err)
   160  			require.NotNil(t, token)
   161  			//t.Logf("%s.%s", token, sig)
   162  			sig, err = tc.strategy.Validate(context.TODO(), token)
   163  			require.Error(t, err)
   164  			require.Empty(t, sig, "%s", err)
   165  		})
   166  	}
   167  }
   168  
   169  func TestValidateSignatureRejectsJWT(t *testing.T) {
   170  	for k, tc := range []struct {
   171  		d        string
   172  		strategy JWTStrategy
   173  	}{
   174  		{
   175  			d: "RS256JWTStrategy",
   176  			strategy: &RS256JWTStrategy{
   177  				PrivateKey: MustRSAKey(),
   178  			},
   179  		},
   180  		{
   181  			d: "ES256JWTStrategy",
   182  			strategy: &ES256JWTStrategy{
   183  				PrivateKey: MustECDSAKey(),
   184  			},
   185  		},
   186  	} {
   187  		t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) {
   188  			for k, c := range []string{
   189  				"",
   190  				" ",
   191  				"foo.bar",
   192  				"foo.",
   193  				".foo",
   194  			} {
   195  				_, err := tc.strategy.Validate(context.TODO(), c)
   196  				assert.Error(t, err)
   197  				t.Logf("Passed test case %d", k)
   198  			}
   199  		})
   200  	}
   201  }
   202  

View as plain text