1
21
22 package fosite_test
23
24 import (
25 "encoding/json"
26 "fmt"
27 "net/http"
28 "net/http/httptest"
29 "testing"
30
31 "github.com/golang/mock/gomock"
32 "github.com/stretchr/testify/assert"
33 "github.com/stretchr/testify/require"
34
35 . "github.com/ory/fosite"
36 . "github.com/ory/fosite/internal"
37 )
38
39 func TestWriteAccessError(t *testing.T) {
40 f := &Fosite{}
41 header := http.Header{}
42 ctrl := gomock.NewController(t)
43 rw := NewMockResponseWriter(ctrl)
44 defer ctrl.Finish()
45
46 rw.EXPECT().Header().AnyTimes().Return(header)
47 rw.EXPECT().WriteHeader(http.StatusBadRequest)
48 rw.EXPECT().Write(gomock.Any())
49
50 f.WriteAccessError(rw, nil, ErrInvalidRequest)
51 }
52
53 func TestWriteAccessError_RFC6749(t *testing.T) {
54
55
56 f := &Fosite{}
57
58 for k, c := range []struct {
59 err *RFC6749Error
60 code string
61 debug bool
62 expectDebugMessage string
63 includeExtraFields bool
64 }{
65 {ErrInvalidRequest.WithDebug("some-debug"), "invalid_request", true, "some-debug", true},
66 {ErrInvalidRequest.WithDebugf("some-debug-%d", 1234), "invalid_request", true, "some-debug-1234", true},
67 {ErrInvalidRequest.WithDebug("some-debug"), "invalid_request", false, "some-debug", true},
68 {ErrInvalidClient.WithDebug("some-debug"), "invalid_client", false, "some-debug", true},
69 {ErrInvalidGrant.WithDebug("some-debug"), "invalid_grant", false, "some-debug", true},
70 {ErrInvalidScope.WithDebug("some-debug"), "invalid_scope", false, "some-debug", true},
71 {ErrUnauthorizedClient.WithDebug("some-debug"), "unauthorized_client", false, "some-debug", true},
72 {ErrUnsupportedGrantType.WithDebug("some-debug"), "unsupported_grant_type", false, "some-debug", true},
73 {ErrUnsupportedGrantType.WithDebug("some-debug"), "unsupported_grant_type", false, "some-debug", false},
74 {ErrUnsupportedGrantType.WithDebug("some-debug"), "unsupported_grant_type", true, "some-debug", false},
75 } {
76 t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
77 f.SendDebugMessagesToClients = c.debug
78 f.UseLegacyErrorFormat = c.includeExtraFields
79
80 rw := httptest.NewRecorder()
81 f.WriteAccessError(rw, nil, c.err)
82
83 var params struct {
84 Error string `json:"error"`
85 Description string `json:"error_description"`
86 Debug string `json:"error_debug"`
87 Hint string `json:"error_hint"`
88 }
89
90 require.NotNil(t, rw.Body)
91 err := json.NewDecoder(rw.Body).Decode(¶ms)
92 require.NoError(t, err)
93
94 assert.Equal(t, c.code, params.Error)
95 if !c.includeExtraFields {
96 assert.Empty(t, params.Debug)
97 assert.Empty(t, params.Hint)
98 assert.Contains(t, params.Description, c.err.DescriptionField)
99 assert.Contains(t, params.Description, c.err.HintField)
100
101 if c.debug {
102 assert.Contains(t, params.Description, c.err.DebugField)
103 } else {
104 assert.NotContains(t, params.Description, c.err.DebugField)
105 }
106 } else {
107 assert.EqualValues(t, c.err.DescriptionField, params.Description)
108 assert.EqualValues(t, c.err.HintField, params.Hint)
109
110 if !c.debug {
111 assert.Empty(t, params.Debug)
112 } else {
113 assert.EqualValues(t, c.err.DebugField, params.Debug)
114 }
115 }
116 })
117 }
118 }
119
View as plain text