1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package security
16
17 import (
18 "context"
19 "net/http"
20 "strings"
21
22 "github.com/go-openapi/errors"
23
24 "github.com/go-openapi/runtime"
25 )
26
27 const (
28 query = "query"
29 header = "header"
30 accessTokenParam = "access_token"
31 )
32
33
34 func HttpAuthenticator(handler func(*http.Request) (bool, interface{}, error)) runtime.Authenticator {
35 return runtime.AuthenticatorFunc(func(params interface{}) (bool, interface{}, error) {
36 if request, ok := params.(*http.Request); ok {
37 return handler(request)
38 }
39 if scoped, ok := params.(*ScopedAuthRequest); ok {
40 return handler(scoped.Request)
41 }
42 return false, nil, nil
43 })
44 }
45
46
47 func ScopedAuthenticator(handler func(*ScopedAuthRequest) (bool, interface{}, error)) runtime.Authenticator {
48 return runtime.AuthenticatorFunc(func(params interface{}) (bool, interface{}, error) {
49 if request, ok := params.(*ScopedAuthRequest); ok {
50 return handler(request)
51 }
52 return false, nil, nil
53 })
54 }
55
56
57 type UserPassAuthentication func(string, string) (interface{}, error)
58
59
60 type UserPassAuthenticationCtx func(context.Context, string, string) (context.Context, interface{}, error)
61
62
63 type TokenAuthentication func(string) (interface{}, error)
64
65
66 type TokenAuthenticationCtx func(context.Context, string) (context.Context, interface{}, error)
67
68
69 type ScopedTokenAuthentication func(string, []string) (interface{}, error)
70
71
72 type ScopedTokenAuthenticationCtx func(context.Context, string, []string) (context.Context, interface{}, error)
73
74 var DefaultRealmName = "API"
75
76 type secCtxKey uint8
77
78 const (
79 failedBasicAuth secCtxKey = iota
80 oauth2SchemeName
81 )
82
83 func FailedBasicAuth(r *http.Request) string {
84 return FailedBasicAuthCtx(r.Context())
85 }
86
87 func FailedBasicAuthCtx(ctx context.Context) string {
88 v, ok := ctx.Value(failedBasicAuth).(string)
89 if !ok {
90 return ""
91 }
92 return v
93 }
94
95 func OAuth2SchemeName(r *http.Request) string {
96 return OAuth2SchemeNameCtx(r.Context())
97 }
98
99 func OAuth2SchemeNameCtx(ctx context.Context) string {
100 v, ok := ctx.Value(oauth2SchemeName).(string)
101 if !ok {
102 return ""
103 }
104 return v
105 }
106
107
108 func BasicAuth(authenticate UserPassAuthentication) runtime.Authenticator {
109 return BasicAuthRealm(DefaultRealmName, authenticate)
110 }
111
112
113 func BasicAuthRealm(realm string, authenticate UserPassAuthentication) runtime.Authenticator {
114 if realm == "" {
115 realm = DefaultRealmName
116 }
117
118 return HttpAuthenticator(func(r *http.Request) (bool, interface{}, error) {
119 if usr, pass, ok := r.BasicAuth(); ok {
120 p, err := authenticate(usr, pass)
121 if err != nil {
122 *r = *r.WithContext(context.WithValue(r.Context(), failedBasicAuth, realm))
123 }
124 return true, p, err
125 }
126 *r = *r.WithContext(context.WithValue(r.Context(), failedBasicAuth, realm))
127 return false, nil, nil
128 })
129 }
130
131
132 func BasicAuthCtx(authenticate UserPassAuthenticationCtx) runtime.Authenticator {
133 return BasicAuthRealmCtx(DefaultRealmName, authenticate)
134 }
135
136
137 func BasicAuthRealmCtx(realm string, authenticate UserPassAuthenticationCtx) runtime.Authenticator {
138 if realm == "" {
139 realm = DefaultRealmName
140 }
141
142 return HttpAuthenticator(func(r *http.Request) (bool, interface{}, error) {
143 if usr, pass, ok := r.BasicAuth(); ok {
144 ctx, p, err := authenticate(r.Context(), usr, pass)
145 if err != nil {
146 ctx = context.WithValue(ctx, failedBasicAuth, realm)
147 }
148 *r = *r.WithContext(ctx)
149 return true, p, err
150 }
151 *r = *r.WithContext(context.WithValue(r.Context(), failedBasicAuth, realm))
152 return false, nil, nil
153 })
154 }
155
156
157
158 func APIKeyAuth(name, in string, authenticate TokenAuthentication) runtime.Authenticator {
159 inl := strings.ToLower(in)
160 if inl != query && inl != header {
161
162 panic(errors.New(500, "api key auth: in value needs to be either \"query\" or \"header\""))
163 }
164
165 var getToken func(*http.Request) string
166 switch inl {
167 case header:
168 getToken = func(r *http.Request) string { return r.Header.Get(name) }
169 case query:
170 getToken = func(r *http.Request) string { return r.URL.Query().Get(name) }
171 }
172
173 return HttpAuthenticator(func(r *http.Request) (bool, interface{}, error) {
174 token := getToken(r)
175 if token == "" {
176 return false, nil, nil
177 }
178
179 p, err := authenticate(token)
180 return true, p, err
181 })
182 }
183
184
185
186 func APIKeyAuthCtx(name, in string, authenticate TokenAuthenticationCtx) runtime.Authenticator {
187 inl := strings.ToLower(in)
188 if inl != query && inl != header {
189
190 panic(errors.New(500, "api key auth: in value needs to be either \"query\" or \"header\""))
191 }
192
193 var getToken func(*http.Request) string
194 switch inl {
195 case header:
196 getToken = func(r *http.Request) string { return r.Header.Get(name) }
197 case query:
198 getToken = func(r *http.Request) string { return r.URL.Query().Get(name) }
199 }
200
201 return HttpAuthenticator(func(r *http.Request) (bool, interface{}, error) {
202 token := getToken(r)
203 if token == "" {
204 return false, nil, nil
205 }
206
207 ctx, p, err := authenticate(r.Context(), token)
208 *r = *r.WithContext(ctx)
209 return true, p, err
210 })
211 }
212
213
214 type ScopedAuthRequest struct {
215 Request *http.Request
216 RequiredScopes []string
217 }
218
219
220 func BearerAuth(name string, authenticate ScopedTokenAuthentication) runtime.Authenticator {
221 const prefix = "Bearer "
222 return ScopedAuthenticator(func(r *ScopedAuthRequest) (bool, interface{}, error) {
223 var token string
224 hdr := r.Request.Header.Get(runtime.HeaderAuthorization)
225 if strings.HasPrefix(hdr, prefix) {
226 token = strings.TrimPrefix(hdr, prefix)
227 }
228 if token == "" {
229 qs := r.Request.URL.Query()
230 token = qs.Get(accessTokenParam)
231 }
232
233 ct, _, _ := runtime.ContentType(r.Request.Header)
234 if token == "" && (ct == "application/x-www-form-urlencoded" || ct == "multipart/form-data") {
235 token = r.Request.FormValue(accessTokenParam)
236 }
237
238 if token == "" {
239 return false, nil, nil
240 }
241
242 rctx := context.WithValue(r.Request.Context(), oauth2SchemeName, name)
243 *r.Request = *r.Request.WithContext(rctx)
244 p, err := authenticate(token, r.RequiredScopes)
245 return true, p, err
246 })
247 }
248
249
250 func BearerAuthCtx(name string, authenticate ScopedTokenAuthenticationCtx) runtime.Authenticator {
251 const prefix = "Bearer "
252 return ScopedAuthenticator(func(r *ScopedAuthRequest) (bool, interface{}, error) {
253 var token string
254 hdr := r.Request.Header.Get(runtime.HeaderAuthorization)
255 if strings.HasPrefix(hdr, prefix) {
256 token = strings.TrimPrefix(hdr, prefix)
257 }
258 if token == "" {
259 qs := r.Request.URL.Query()
260 token = qs.Get(accessTokenParam)
261 }
262
263 ct, _, _ := runtime.ContentType(r.Request.Header)
264 if token == "" && (ct == "application/x-www-form-urlencoded" || ct == "multipart/form-data") {
265 token = r.Request.FormValue(accessTokenParam)
266 }
267
268 if token == "" {
269 return false, nil, nil
270 }
271
272 rctx := context.WithValue(r.Request.Context(), oauth2SchemeName, name)
273 ctx, p, err := authenticate(rctx, token, r.RequiredScopes)
274 *r.Request = *r.Request.WithContext(ctx)
275 return true, p, err
276 })
277 }
278
View as plain text