1
16
17 package oidc
18
19 import (
20 "crypto"
21 "crypto/ecdsa"
22 "crypto/rsa"
23 "crypto/tls"
24 "encoding/hex"
25 "encoding/json"
26 "errors"
27 "fmt"
28 "net/http"
29 "net/http/httptest"
30 "net/url"
31 "os"
32 "testing"
33
34 "github.com/golang/mock/gomock"
35 "github.com/stretchr/testify/require"
36 "gopkg.in/square/go-jose.v2"
37 )
38
39 const (
40 openIDWellKnownWebPath = "/.well-known/openid-configuration"
41 authWebPath = "/auth"
42 tokenWebPath = "/token"
43 jwksWebPath = "/jwks"
44 )
45
46 var (
47 ErrRefreshTokenExpired = errors.New("refresh token is expired")
48 ErrBadClientID = errors.New("client ID is bad")
49 )
50
51 type TestServer struct {
52 httpServer *httptest.Server
53 tokenHandler *MockTokenHandler
54 jwksHandler *MockJWKsHandler
55 }
56
57
58 func (ts *TestServer) JwksHandler() *MockJWKsHandler {
59 return ts.jwksHandler
60 }
61
62
63 func (ts *TestServer) TokenHandler() *MockTokenHandler {
64 return ts.tokenHandler
65 }
66
67
68 func (ts *TestServer) URL() string {
69 return ts.httpServer.URL
70 }
71
72
73 func (ts *TestServer) TokenURL() (string, error) {
74 url, err := url.JoinPath(ts.httpServer.URL, tokenWebPath)
75 if err != nil {
76 return "", fmt.Errorf("error joining paths: %v", err)
77 }
78
79 return url, nil
80 }
81
82
83 func BuildAndRunTestServer(t *testing.T, caPath, caKeyPath, issuerOverride string) *TestServer {
84 t.Helper()
85
86 certContent, err := os.ReadFile(caPath)
87 require.NoError(t, err)
88 keyContent, err := os.ReadFile(caKeyPath)
89 require.NoError(t, err)
90
91 cert, err := tls.X509KeyPair(certContent, keyContent)
92 require.NoError(t, err)
93
94 mux := http.NewServeMux()
95 httpServer := httptest.NewUnstartedServer(mux)
96 httpServer.TLS = &tls.Config{
97 Certificates: []tls.Certificate{cert},
98 }
99 httpServer.StartTLS()
100
101 mockCtrl := gomock.NewController(t)
102
103 t.Cleanup(func() {
104 mockCtrl.Finish()
105 httpServer.Close()
106 })
107
108 oidcServer := &TestServer{
109 httpServer: httpServer,
110 tokenHandler: NewMockTokenHandler(mockCtrl),
111 jwksHandler: NewMockJWKsHandler(mockCtrl),
112 }
113
114 issuer := httpServer.URL
115
116
117 if len(issuerOverride) > 0 {
118 issuer = issuerOverride
119 }
120
121 mux.HandleFunc(openIDWellKnownWebPath, func(writer http.ResponseWriter, request *http.Request) {
122 discoveryDocHandler(t, writer, httpServer.URL, issuer)
123 })
124
125
126
127 mux.HandleFunc("/c/d/bar"+openIDWellKnownWebPath, func(writer http.ResponseWriter, request *http.Request) {
128 discoveryDocHandler(t, writer, httpServer.URL, issuer)
129 })
130
131 mux.HandleFunc(tokenWebPath, func(writer http.ResponseWriter, request *http.Request) {
132 token, err := oidcServer.tokenHandler.Token()
133 if err != nil {
134 http.Error(writer, err.Error(), http.StatusBadRequest)
135 return
136 }
137
138 writer.Header().Add("Content-Type", "application/json")
139 writer.WriteHeader(http.StatusOK)
140
141 err = json.NewEncoder(writer).Encode(token)
142 require.NoError(t, err)
143 })
144
145 mux.HandleFunc(authWebPath, func(writer http.ResponseWriter, request *http.Request) {
146 writer.WriteHeader(http.StatusOK)
147 })
148
149 mux.HandleFunc(jwksWebPath, func(writer http.ResponseWriter, request *http.Request) {
150 keySet := oidcServer.jwksHandler.KeySet()
151
152 writer.Header().Add("Content-Type", "application/json")
153 writer.WriteHeader(http.StatusOK)
154
155 err := json.NewEncoder(writer).Encode(keySet)
156 require.NoError(t, err)
157 })
158
159 return oidcServer
160 }
161
162 func discoveryDocHandler(t *testing.T, writer http.ResponseWriter, httpServerURL, issuer string) {
163 authURL, err := url.JoinPath(httpServerURL + authWebPath)
164 require.NoError(t, err)
165 tokenURL, err := url.JoinPath(httpServerURL + tokenWebPath)
166 require.NoError(t, err)
167 jwksURL, err := url.JoinPath(httpServerURL + jwksWebPath)
168 require.NoError(t, err)
169 userInfoURL, err := url.JoinPath(httpServerURL + authWebPath)
170 require.NoError(t, err)
171
172 writer.Header().Add("Content-Type", "application/json")
173
174 err = json.NewEncoder(writer).Encode(struct {
175 Issuer string `json:"issuer"`
176 AuthURL string `json:"authorization_endpoint"`
177 TokenURL string `json:"token_endpoint"`
178 JWKSURL string `json:"jwks_uri"`
179 UserInfoURL string `json:"userinfo_endpoint"`
180 }{
181 Issuer: issuer,
182 AuthURL: authURL,
183 TokenURL: tokenURL,
184 JWKSURL: jwksURL,
185 UserInfoURL: userInfoURL,
186 })
187 require.NoError(t, err)
188 }
189
190 type JosePrivateKey interface {
191 *rsa.PrivateKey | *ecdsa.PrivateKey
192 }
193
194
195
196 func TokenHandlerBehaviorReturningPredefinedJWT[K JosePrivateKey](
197 t *testing.T,
198 privateKey K,
199 claims map[string]interface{}, accessToken, refreshToken string,
200 ) func() (Token, error) {
201 t.Helper()
202
203 return func() (Token, error) {
204 signer, err := jose.NewSigner(jose.SigningKey{Algorithm: GetSignatureAlgorithm(privateKey), Key: privateKey}, nil)
205 require.NoError(t, err)
206
207 payloadJSON, err := json.Marshal(claims)
208 require.NoError(t, err)
209
210 idTokenSignature, err := signer.Sign(payloadJSON)
211 require.NoError(t, err)
212 idToken, err := idTokenSignature.CompactSerialize()
213 require.NoError(t, err)
214
215 return Token{
216 IDToken: idToken,
217 AccessToken: accessToken,
218 RefreshToken: refreshToken,
219 }, nil
220 }
221 }
222
223 type JosePublicKey interface {
224 *rsa.PublicKey | *ecdsa.PublicKey
225 }
226
227
228
229 func DefaultJwksHandlerBehavior[K JosePublicKey](t *testing.T, verificationPublicKey K) func() jose.JSONWebKeySet {
230 t.Helper()
231
232 return func() jose.JSONWebKeySet {
233 key := jose.JSONWebKey{Key: verificationPublicKey, Use: "sig", Algorithm: string(GetSignatureAlgorithm(verificationPublicKey))}
234
235 thumbprint, err := key.Thumbprint(crypto.SHA256)
236 require.NoError(t, err)
237
238 key.KeyID = hex.EncodeToString(thumbprint)
239 return jose.JSONWebKeySet{
240 Keys: []jose.JSONWebKey{key},
241 }
242 }
243 }
244
245 type JoseKey interface{ JosePrivateKey | JosePublicKey }
246
247 func GetSignatureAlgorithm[K JoseKey](key K) jose.SignatureAlgorithm {
248 switch any(key).(type) {
249 case *rsa.PrivateKey, *rsa.PublicKey:
250 return jose.RS256
251 case *ecdsa.PrivateKey, *ecdsa.PublicKey:
252 return jose.ES256
253 default:
254 panic("unknown key type")
255 }
256 }
257
View as plain text