...

Source file src/github.com/ory/fosite/access_request_handler_test.go

Documentation: github.com/ory/fosite

     1  /*
     2   * Copyright © 2015-2018 Aeneas Rekkas <aeneas+oss@aeneas.io>
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   * @author		Aeneas Rekkas <aeneas+oss@aeneas.io>
    17   * @copyright 	2015-2018 Aeneas Rekkas <aeneas+oss@aeneas.io>
    18   * @license 	Apache-2.0
    19   *
    20   */
    21  
    22  package fosite_test
    23  
    24  import (
    25  	"context"
    26  	"encoding/base64"
    27  	"fmt"
    28  	"net/http"
    29  	"net/url"
    30  	"testing"
    31  
    32  	"github.com/golang/mock/gomock"
    33  	"github.com/pkg/errors"
    34  	"github.com/stretchr/testify/assert"
    35  	"github.com/stretchr/testify/require"
    36  
    37  	. "github.com/ory/fosite"
    38  	"github.com/ory/fosite/internal"
    39  )
    40  
    41  func TestNewAccessRequest(t *testing.T) {
    42  	ctrl := gomock.NewController(t)
    43  	store := internal.NewMockStorage(ctrl)
    44  	handler := internal.NewMockTokenEndpointHandler(ctrl)
    45  	handler.EXPECT().CanHandleTokenEndpointRequest(gomock.Any()).Return(true).AnyTimes()
    46  	handler.EXPECT().CanSkipClientAuth(gomock.Any()).Return(false).AnyTimes()
    47  	hasher := internal.NewMockHasher(ctrl)
    48  	defer ctrl.Finish()
    49  
    50  	ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))
    51  
    52  	client := &DefaultClient{}
    53  	fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}
    54  	for k, c := range []struct {
    55  		header    http.Header
    56  		form      url.Values
    57  		mock      func()
    58  		method    string
    59  		expectErr error
    60  		expect    *AccessRequest
    61  		handlers  TokenEndpointHandlers
    62  	}{
    63  		{
    64  			header:    http.Header{},
    65  			expectErr: ErrInvalidRequest,
    66  			form:      url.Values{},
    67  			method:    "POST",
    68  			mock:      func() {},
    69  		},
    70  		{
    71  			header: http.Header{},
    72  			method: "POST",
    73  			form: url.Values{
    74  				"grant_type": {"foo"},
    75  			},
    76  			mock:      func() {},
    77  			expectErr: ErrInvalidRequest,
    78  		},
    79  		{
    80  			header: http.Header{},
    81  			method: "POST",
    82  			form: url.Values{
    83  				"grant_type": {"foo"},
    84  				"client_id":  {""},
    85  			},
    86  			expectErr: ErrInvalidRequest,
    87  			mock:      func() {},
    88  		},
    89  		{
    90  			header: http.Header{
    91  				"Authorization": {basicAuth("foo", "bar")},
    92  			},
    93  			method: "POST",
    94  			form: url.Values{
    95  				"grant_type": {"foo"},
    96  			},
    97  			expectErr: ErrInvalidClient,
    98  			mock: func() {
    99  				store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New(""))
   100  			},
   101  			handlers: TokenEndpointHandlers{handler},
   102  		},
   103  		{
   104  			header: http.Header{
   105  				"Authorization": {basicAuth("foo", "bar")},
   106  			},
   107  			method: "GET",
   108  			form: url.Values{
   109  				"grant_type": {"foo"},
   110  			},
   111  			expectErr: ErrInvalidRequest,
   112  			mock:      func() {},
   113  		},
   114  		{
   115  			header: http.Header{
   116  				"Authorization": {basicAuth("foo", "bar")},
   117  			},
   118  			method: "POST",
   119  			form: url.Values{
   120  				"grant_type": {"foo"},
   121  			},
   122  			expectErr: ErrInvalidClient,
   123  			mock: func() {
   124  				store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New(""))
   125  			},
   126  			handlers: TokenEndpointHandlers{handler},
   127  		},
   128  		{
   129  			header: http.Header{
   130  				"Authorization": {basicAuth("foo", "bar")},
   131  			},
   132  			method: "POST",
   133  			form: url.Values{
   134  				"grant_type": {"foo"},
   135  			},
   136  			expectErr: ErrInvalidClient,
   137  			mock: func() {
   138  				store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
   139  				client.Public = false
   140  				client.Secret = []byte("foo")
   141  				hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New(""))
   142  			},
   143  			handlers: TokenEndpointHandlers{handler},
   144  		},
   145  		{
   146  			header: http.Header{
   147  				"Authorization": {basicAuth("foo", "bar")},
   148  			},
   149  			method: "POST",
   150  			form: url.Values{
   151  				"grant_type": {"foo"},
   152  			},
   153  			expectErr: ErrServerError,
   154  			mock: func() {
   155  				store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
   156  				client.Public = false
   157  				client.Secret = []byte("foo")
   158  				hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
   159  				handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(ErrServerError)
   160  			},
   161  			handlers: TokenEndpointHandlers{handler},
   162  		},
   163  		{
   164  			header: http.Header{
   165  				"Authorization": {basicAuth("foo", "bar")},
   166  			},
   167  			method: "POST",
   168  			form: url.Values{
   169  				"grant_type": {"foo"},
   170  			},
   171  			mock: func() {
   172  				store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
   173  				client.Public = false
   174  				client.Secret = []byte("foo")
   175  				hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
   176  				handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
   177  			},
   178  			handlers: TokenEndpointHandlers{handler},
   179  			expect: &AccessRequest{
   180  				GrantTypes: Arguments{"foo"},
   181  				Request: Request{
   182  					Client: client,
   183  				},
   184  			},
   185  		},
   186  		{
   187  			header: http.Header{
   188  				"Authorization": {basicAuth("foo", "bar")},
   189  			},
   190  			method: "POST",
   191  			form: url.Values{
   192  				"grant_type": {"foo"},
   193  			},
   194  			mock: func() {
   195  				store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
   196  				client.Public = true
   197  				handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
   198  			},
   199  			handlers: TokenEndpointHandlers{handler},
   200  			expect: &AccessRequest{
   201  				GrantTypes: Arguments{"foo"},
   202  				Request: Request{
   203  					Client: client,
   204  				},
   205  			},
   206  		},
   207  	} {
   208  		t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
   209  			r := &http.Request{
   210  				Header:   c.header,
   211  				PostForm: c.form,
   212  				Form:     c.form,
   213  				Method:   c.method,
   214  			}
   215  			c.mock()
   216  			ctx := NewContext()
   217  			fosite.TokenEndpointHandlers = c.handlers
   218  			ar, err := fosite.NewAccessRequest(ctx, r, new(DefaultSession))
   219  
   220  			if c.expectErr != nil {
   221  				assert.EqualError(t, err, c.expectErr.Error())
   222  			} else {
   223  				require.NoError(t, err)
   224  				AssertObjectKeysEqual(t, c.expect, ar, "GrantTypes", "Client")
   225  				assert.NotNil(t, ar.GetRequestedAt())
   226  			}
   227  		})
   228  	}
   229  }
   230  
   231  func TestNewAccessRequestWithoutClientAuth(t *testing.T) {
   232  	ctrl := gomock.NewController(t)
   233  	store := internal.NewMockStorage(ctrl)
   234  	handler := internal.NewMockTokenEndpointHandler(ctrl)
   235  	handler.EXPECT().CanHandleTokenEndpointRequest(gomock.Any()).Return(true).AnyTimes()
   236  	handler.EXPECT().CanSkipClientAuth(gomock.Any()).Return(true).AnyTimes()
   237  	hasher := internal.NewMockHasher(ctrl)
   238  	defer ctrl.Finish()
   239  
   240  	client := &DefaultClient{}
   241  	anotherClient := &DefaultClient{ID: "another"}
   242  	fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}
   243  	for k, c := range []struct {
   244  		header    http.Header
   245  		form      url.Values
   246  		mock      func()
   247  		method    string
   248  		expectErr error
   249  		expect    *AccessRequest
   250  		handlers  TokenEndpointHandlers
   251  	}{
   252  		// No grant type -> error
   253  		{
   254  			form: url.Values{},
   255  			mock: func() {
   256  				store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Times(0)
   257  			},
   258  			method:    "POST",
   259  			expectErr: ErrInvalidRequest,
   260  		},
   261  		// No registered handlers -> error
   262  		{
   263  			form: url.Values{
   264  				"grant_type": {"foo"},
   265  			},
   266  			mock: func() {
   267  				store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Times(0)
   268  			},
   269  			method:    "POST",
   270  			expectErr: ErrInvalidRequest,
   271  			handlers:  TokenEndpointHandlers{},
   272  		},
   273  		// Handler can skip client auth and ignores missing client.
   274  		{
   275  			header: http.Header{
   276  				"Authorization": {basicAuth("foo", "bar")},
   277  			},
   278  			form: url.Values{
   279  				"grant_type": {"foo"},
   280  			},
   281  			mock: func() {
   282  				// despite error from storage, we should success, because client auth is not required
   283  				store.EXPECT().GetClient(gomock.Any(), "foo").Return(nil, errors.New("no client")).Times(1)
   284  				handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
   285  			},
   286  			method: "POST",
   287  			expect: &AccessRequest{
   288  				GrantTypes: Arguments{"foo"},
   289  				Request: Request{
   290  					Client: client,
   291  				},
   292  			},
   293  			handlers: TokenEndpointHandlers{handler},
   294  		},
   295  		// Should pass if no auth is set in the header and can skip!
   296  		{
   297  			form: url.Values{
   298  				"grant_type": {"foo"},
   299  			},
   300  			mock: func() {
   301  				handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
   302  			},
   303  			method: "POST",
   304  			expect: &AccessRequest{
   305  				GrantTypes: Arguments{"foo"},
   306  				Request: Request{
   307  					Client: client,
   308  				},
   309  			},
   310  			handlers: TokenEndpointHandlers{handler},
   311  		},
   312  		// Should also pass if client auth is set!
   313  		{
   314  			header: http.Header{
   315  				"Authorization": {basicAuth("foo", "bar")},
   316  			},
   317  			form: url.Values{
   318  				"grant_type": {"foo"},
   319  			},
   320  			mock: func() {
   321  				store.EXPECT().GetClient(gomock.Any(), "foo").Return(anotherClient, nil).Times(1)
   322  				hasher.EXPECT().Compare(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1)
   323  				handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
   324  			},
   325  			method: "POST",
   326  			expect: &AccessRequest{
   327  				GrantTypes: Arguments{"foo"},
   328  				Request: Request{
   329  					Client: anotherClient,
   330  				},
   331  			},
   332  			handlers: TokenEndpointHandlers{handler},
   333  		},
   334  	} {
   335  		t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
   336  			r := &http.Request{
   337  				Header:   c.header,
   338  				PostForm: c.form,
   339  				Form:     c.form,
   340  				Method:   c.method,
   341  			}
   342  			c.mock()
   343  			ctx := NewContext()
   344  			fosite.TokenEndpointHandlers = c.handlers
   345  			ar, err := fosite.NewAccessRequest(ctx, r, new(DefaultSession))
   346  
   347  			if c.expectErr != nil {
   348  				assert.EqualError(t, err, c.expectErr.Error())
   349  			} else {
   350  				require.NoError(t, err)
   351  				AssertObjectKeysEqual(t, c.expect, ar, "GrantTypes", "Client")
   352  				assert.NotNil(t, ar.GetRequestedAt())
   353  			}
   354  		})
   355  	}
   356  }
   357  
   358  // In this test case one handler requires client auth and another handler not.
   359  func TestNewAccessRequestWithMixedClientAuth(t *testing.T) {
   360  	ctrl := gomock.NewController(t)
   361  	store := internal.NewMockStorage(ctrl)
   362  
   363  	handlerWithClientAuth := internal.NewMockTokenEndpointHandler(ctrl)
   364  	handlerWithClientAuth.EXPECT().CanHandleTokenEndpointRequest(gomock.Any()).Return(true).AnyTimes()
   365  	handlerWithClientAuth.EXPECT().CanSkipClientAuth(gomock.Any()).Return(false).AnyTimes()
   366  
   367  	handlerWithoutClientAuth := internal.NewMockTokenEndpointHandler(ctrl)
   368  	handlerWithoutClientAuth.EXPECT().CanHandleTokenEndpointRequest(gomock.Any()).Return(true).AnyTimes()
   369  	handlerWithoutClientAuth.EXPECT().CanSkipClientAuth(gomock.Any()).Return(true).AnyTimes()
   370  
   371  	hasher := internal.NewMockHasher(ctrl)
   372  	defer ctrl.Finish()
   373  
   374  	ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil))
   375  
   376  	client := &DefaultClient{}
   377  	fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}
   378  	for k, c := range []struct {
   379  		header    http.Header
   380  		form      url.Values
   381  		mock      func()
   382  		method    string
   383  		expectErr error
   384  		expect    *AccessRequest
   385  		handlers  TokenEndpointHandlers
   386  	}{
   387  		{
   388  			header: http.Header{
   389  				"Authorization": {basicAuth("foo", "bar")},
   390  			},
   391  			form: url.Values{
   392  				"grant_type": {"foo"},
   393  			},
   394  			mock: func() {
   395  				store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
   396  				client.Public = false
   397  				client.Secret = []byte("foo")
   398  				hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("hash err"))
   399  				handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
   400  			},
   401  			method:    "POST",
   402  			expectErr: ErrInvalidClient,
   403  			handlers:  TokenEndpointHandlers{handlerWithoutClientAuth, handlerWithClientAuth},
   404  		},
   405  		{
   406  			header: http.Header{
   407  				"Authorization": {basicAuth("foo", "bar")},
   408  			},
   409  			form: url.Values{
   410  				"grant_type": {"foo"},
   411  			},
   412  			mock: func() {
   413  				store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil)
   414  				client.Public = false
   415  				client.Secret = []byte("foo")
   416  				hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil)
   417  				handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
   418  				handlerWithClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
   419  			},
   420  			method: "POST",
   421  			expect: &AccessRequest{
   422  				GrantTypes: Arguments{"foo"},
   423  				Request: Request{
   424  					Client: client,
   425  				},
   426  			},
   427  			handlers: TokenEndpointHandlers{handlerWithoutClientAuth, handlerWithClientAuth},
   428  		},
   429  		{
   430  			header: http.Header{},
   431  			form: url.Values{
   432  				"grant_type": {"foo"},
   433  			},
   434  			mock: func() {
   435  				store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Times(0)
   436  				handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil)
   437  			},
   438  			method:    "POST",
   439  			expectErr: ErrInvalidRequest,
   440  			handlers:  TokenEndpointHandlers{handlerWithoutClientAuth, handlerWithClientAuth},
   441  		},
   442  	} {
   443  		t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
   444  			r := &http.Request{
   445  				Header:   c.header,
   446  				PostForm: c.form,
   447  				Form:     c.form,
   448  				Method:   c.method,
   449  			}
   450  			c.mock()
   451  			ctx := NewContext()
   452  			fosite.TokenEndpointHandlers = c.handlers
   453  			ar, err := fosite.NewAccessRequest(ctx, r, new(DefaultSession))
   454  
   455  			if c.expectErr != nil {
   456  				assert.EqualError(t, err, c.expectErr.Error())
   457  			} else {
   458  				require.NoError(t, err)
   459  				AssertObjectKeysEqual(t, c.expect, ar, "GrantTypes", "Client")
   460  				assert.NotNil(t, ar.GetRequestedAt())
   461  			}
   462  		})
   463  	}
   464  }
   465  
   466  func basicAuth(username, password string) string {
   467  	return "Basic " + base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", username, password)))
   468  }
   469  

View as plain text