/* * Copyright © 2015-2018 Aeneas Rekkas * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * @author Aeneas Rekkas * @copyright 2015-2018 Aeneas Rekkas * @license Apache-2.0 * */ package fosite_test import ( "context" "encoding/base64" "fmt" "net/http" "net/url" "testing" "github.com/golang/mock/gomock" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" . "github.com/ory/fosite" "github.com/ory/fosite/internal" ) func TestNewAccessRequest(t *testing.T) { ctrl := gomock.NewController(t) store := internal.NewMockStorage(ctrl) handler := internal.NewMockTokenEndpointHandler(ctrl) handler.EXPECT().CanHandleTokenEndpointRequest(gomock.Any()).Return(true).AnyTimes() handler.EXPECT().CanSkipClientAuth(gomock.Any()).Return(false).AnyTimes() hasher := internal.NewMockHasher(ctrl) defer ctrl.Finish() ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil)) client := &DefaultClient{} fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy} for k, c := range []struct { header http.Header form url.Values mock func() method string expectErr error expect *AccessRequest handlers TokenEndpointHandlers }{ { header: http.Header{}, expectErr: ErrInvalidRequest, form: url.Values{}, method: "POST", mock: func() {}, }, { header: http.Header{}, method: "POST", form: url.Values{ "grant_type": {"foo"}, }, mock: func() {}, expectErr: ErrInvalidRequest, }, { header: http.Header{}, method: "POST", form: url.Values{ "grant_type": {"foo"}, "client_id": {""}, }, expectErr: ErrInvalidRequest, mock: func() {}, }, { header: http.Header{ "Authorization": {basicAuth("foo", "bar")}, }, method: "POST", form: url.Values{ "grant_type": {"foo"}, }, expectErr: ErrInvalidClient, mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New("")) }, handlers: TokenEndpointHandlers{handler}, }, { header: http.Header{ "Authorization": {basicAuth("foo", "bar")}, }, method: "GET", form: url.Values{ "grant_type": {"foo"}, }, expectErr: ErrInvalidRequest, mock: func() {}, }, { header: http.Header{ "Authorization": {basicAuth("foo", "bar")}, }, method: "POST", form: url.Values{ "grant_type": {"foo"}, }, expectErr: ErrInvalidClient, mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(nil, errors.New("")) }, handlers: TokenEndpointHandlers{handler}, }, { header: http.Header{ "Authorization": {basicAuth("foo", "bar")}, }, method: "POST", form: url.Values{ "grant_type": {"foo"}, }, expectErr: ErrInvalidClient, mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("")) }, handlers: TokenEndpointHandlers{handler}, }, { header: http.Header{ "Authorization": {basicAuth("foo", "bar")}, }, method: "POST", form: url.Values{ "grant_type": {"foo"}, }, expectErr: ErrServerError, mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(ErrServerError) }, handlers: TokenEndpointHandlers{handler}, }, { header: http.Header{ "Authorization": {basicAuth("foo", "bar")}, }, method: "POST", form: url.Values{ "grant_type": {"foo"}, }, mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) }, handlers: TokenEndpointHandlers{handler}, expect: &AccessRequest{ GrantTypes: Arguments{"foo"}, Request: Request{ Client: client, }, }, }, { header: http.Header{ "Authorization": {basicAuth("foo", "bar")}, }, method: "POST", form: url.Values{ "grant_type": {"foo"}, }, mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = true handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) }, handlers: TokenEndpointHandlers{handler}, expect: &AccessRequest{ GrantTypes: Arguments{"foo"}, Request: Request{ Client: client, }, }, }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { r := &http.Request{ Header: c.header, PostForm: c.form, Form: c.form, Method: c.method, } c.mock() ctx := NewContext() fosite.TokenEndpointHandlers = c.handlers ar, err := fosite.NewAccessRequest(ctx, r, new(DefaultSession)) if c.expectErr != nil { assert.EqualError(t, err, c.expectErr.Error()) } else { require.NoError(t, err) AssertObjectKeysEqual(t, c.expect, ar, "GrantTypes", "Client") assert.NotNil(t, ar.GetRequestedAt()) } }) } } func TestNewAccessRequestWithoutClientAuth(t *testing.T) { ctrl := gomock.NewController(t) store := internal.NewMockStorage(ctrl) handler := internal.NewMockTokenEndpointHandler(ctrl) handler.EXPECT().CanHandleTokenEndpointRequest(gomock.Any()).Return(true).AnyTimes() handler.EXPECT().CanSkipClientAuth(gomock.Any()).Return(true).AnyTimes() hasher := internal.NewMockHasher(ctrl) defer ctrl.Finish() client := &DefaultClient{} anotherClient := &DefaultClient{ID: "another"} fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy} for k, c := range []struct { header http.Header form url.Values mock func() method string expectErr error expect *AccessRequest handlers TokenEndpointHandlers }{ // No grant type -> error { form: url.Values{}, mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Times(0) }, method: "POST", expectErr: ErrInvalidRequest, }, // No registered handlers -> error { form: url.Values{ "grant_type": {"foo"}, }, mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Times(0) }, method: "POST", expectErr: ErrInvalidRequest, handlers: TokenEndpointHandlers{}, }, // Handler can skip client auth and ignores missing client. { header: http.Header{ "Authorization": {basicAuth("foo", "bar")}, }, form: url.Values{ "grant_type": {"foo"}, }, mock: func() { // despite error from storage, we should success, because client auth is not required store.EXPECT().GetClient(gomock.Any(), "foo").Return(nil, errors.New("no client")).Times(1) handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) }, method: "POST", expect: &AccessRequest{ GrantTypes: Arguments{"foo"}, Request: Request{ Client: client, }, }, handlers: TokenEndpointHandlers{handler}, }, // Should pass if no auth is set in the header and can skip! { form: url.Values{ "grant_type": {"foo"}, }, mock: func() { handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) }, method: "POST", expect: &AccessRequest{ GrantTypes: Arguments{"foo"}, Request: Request{ Client: client, }, }, handlers: TokenEndpointHandlers{handler}, }, // Should also pass if client auth is set! { header: http.Header{ "Authorization": {basicAuth("foo", "bar")}, }, form: url.Values{ "grant_type": {"foo"}, }, mock: func() { store.EXPECT().GetClient(gomock.Any(), "foo").Return(anotherClient, nil).Times(1) hasher.EXPECT().Compare(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1) handler.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) }, method: "POST", expect: &AccessRequest{ GrantTypes: Arguments{"foo"}, Request: Request{ Client: anotherClient, }, }, handlers: TokenEndpointHandlers{handler}, }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { r := &http.Request{ Header: c.header, PostForm: c.form, Form: c.form, Method: c.method, } c.mock() ctx := NewContext() fosite.TokenEndpointHandlers = c.handlers ar, err := fosite.NewAccessRequest(ctx, r, new(DefaultSession)) if c.expectErr != nil { assert.EqualError(t, err, c.expectErr.Error()) } else { require.NoError(t, err) AssertObjectKeysEqual(t, c.expect, ar, "GrantTypes", "Client") assert.NotNil(t, ar.GetRequestedAt()) } }) } } // In this test case one handler requires client auth and another handler not. func TestNewAccessRequestWithMixedClientAuth(t *testing.T) { ctrl := gomock.NewController(t) store := internal.NewMockStorage(ctrl) handlerWithClientAuth := internal.NewMockTokenEndpointHandler(ctrl) handlerWithClientAuth.EXPECT().CanHandleTokenEndpointRequest(gomock.Any()).Return(true).AnyTimes() handlerWithClientAuth.EXPECT().CanSkipClientAuth(gomock.Any()).Return(false).AnyTimes() handlerWithoutClientAuth := internal.NewMockTokenEndpointHandler(ctrl) handlerWithoutClientAuth.EXPECT().CanHandleTokenEndpointRequest(gomock.Any()).Return(true).AnyTimes() handlerWithoutClientAuth.EXPECT().CanSkipClientAuth(gomock.Any()).Return(true).AnyTimes() hasher := internal.NewMockHasher(ctrl) defer ctrl.Finish() ctx := gomock.AssignableToTypeOf(context.WithValue(context.TODO(), ContextKey("test"), nil)) client := &DefaultClient{} fosite := &Fosite{Store: store, Hasher: hasher, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy} for k, c := range []struct { header http.Header form url.Values mock func() method string expectErr error expect *AccessRequest handlers TokenEndpointHandlers }{ { header: http.Header{ "Authorization": {basicAuth("foo", "bar")}, }, form: url.Values{ "grant_type": {"foo"}, }, mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(errors.New("hash err")) handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) }, method: "POST", expectErr: ErrInvalidClient, handlers: TokenEndpointHandlers{handlerWithoutClientAuth, handlerWithClientAuth}, }, { header: http.Header{ "Authorization": {basicAuth("foo", "bar")}, }, form: url.Values{ "grant_type": {"foo"}, }, mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("foo")).Return(client, nil) client.Public = false client.Secret = []byte("foo") hasher.EXPECT().Compare(ctx, gomock.Eq([]byte("foo")), gomock.Eq([]byte("bar"))).Return(nil) handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) handlerWithClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) }, method: "POST", expect: &AccessRequest{ GrantTypes: Arguments{"foo"}, Request: Request{ Client: client, }, }, handlers: TokenEndpointHandlers{handlerWithoutClientAuth, handlerWithClientAuth}, }, { header: http.Header{}, form: url.Values{ "grant_type": {"foo"}, }, mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Any()).Times(0) handlerWithoutClientAuth.EXPECT().HandleTokenEndpointRequest(gomock.Any(), gomock.Any()).Return(nil) }, method: "POST", expectErr: ErrInvalidRequest, handlers: TokenEndpointHandlers{handlerWithoutClientAuth, handlerWithClientAuth}, }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { r := &http.Request{ Header: c.header, PostForm: c.form, Form: c.form, Method: c.method, } c.mock() ctx := NewContext() fosite.TokenEndpointHandlers = c.handlers ar, err := fosite.NewAccessRequest(ctx, r, new(DefaultSession)) if c.expectErr != nil { assert.EqualError(t, err, c.expectErr.Error()) } else { require.NoError(t, err) AssertObjectKeysEqual(t, c.expect, ar, "GrantTypes", "Client") assert.NotNil(t, ar.GetRequestedAt()) } }) } } func basicAuth(username, password string) string { return "Basic " + base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", username, password))) }