...

Source file src/github.com/go-openapi/runtime/security/authenticator.go

Documentation: github.com/go-openapi/runtime/security

     1  // Copyright 2015 go-swagger maintainers
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package security
    16  
    17  import (
    18  	"context"
    19  	"net/http"
    20  	"strings"
    21  
    22  	"github.com/go-openapi/errors"
    23  
    24  	"github.com/go-openapi/runtime"
    25  )
    26  
    27  const (
    28  	query            = "query"
    29  	header           = "header"
    30  	accessTokenParam = "access_token"
    31  )
    32  
    33  // HttpAuthenticator is a function that authenticates a HTTP request
    34  func HttpAuthenticator(handler func(*http.Request) (bool, interface{}, error)) runtime.Authenticator { //nolint:revive,stylecheck
    35  	return runtime.AuthenticatorFunc(func(params interface{}) (bool, interface{}, error) {
    36  		if request, ok := params.(*http.Request); ok {
    37  			return handler(request)
    38  		}
    39  		if scoped, ok := params.(*ScopedAuthRequest); ok {
    40  			return handler(scoped.Request)
    41  		}
    42  		return false, nil, nil
    43  	})
    44  }
    45  
    46  // ScopedAuthenticator is a function that authenticates a HTTP request against a list of valid scopes
    47  func ScopedAuthenticator(handler func(*ScopedAuthRequest) (bool, interface{}, error)) runtime.Authenticator {
    48  	return runtime.AuthenticatorFunc(func(params interface{}) (bool, interface{}, error) {
    49  		if request, ok := params.(*ScopedAuthRequest); ok {
    50  			return handler(request)
    51  		}
    52  		return false, nil, nil
    53  	})
    54  }
    55  
    56  // UserPassAuthentication authentication function
    57  type UserPassAuthentication func(string, string) (interface{}, error)
    58  
    59  // UserPassAuthenticationCtx authentication function with context.Context
    60  type UserPassAuthenticationCtx func(context.Context, string, string) (context.Context, interface{}, error)
    61  
    62  // TokenAuthentication authentication function
    63  type TokenAuthentication func(string) (interface{}, error)
    64  
    65  // TokenAuthenticationCtx authentication function with context.Context
    66  type TokenAuthenticationCtx func(context.Context, string) (context.Context, interface{}, error)
    67  
    68  // ScopedTokenAuthentication authentication function
    69  type ScopedTokenAuthentication func(string, []string) (interface{}, error)
    70  
    71  // ScopedTokenAuthenticationCtx authentication function with context.Context
    72  type ScopedTokenAuthenticationCtx func(context.Context, string, []string) (context.Context, interface{}, error)
    73  
    74  var DefaultRealmName = "API"
    75  
    76  type secCtxKey uint8
    77  
    78  const (
    79  	failedBasicAuth secCtxKey = iota
    80  	oauth2SchemeName
    81  )
    82  
    83  func FailedBasicAuth(r *http.Request) string {
    84  	return FailedBasicAuthCtx(r.Context())
    85  }
    86  
    87  func FailedBasicAuthCtx(ctx context.Context) string {
    88  	v, ok := ctx.Value(failedBasicAuth).(string)
    89  	if !ok {
    90  		return ""
    91  	}
    92  	return v
    93  }
    94  
    95  func OAuth2SchemeName(r *http.Request) string {
    96  	return OAuth2SchemeNameCtx(r.Context())
    97  }
    98  
    99  func OAuth2SchemeNameCtx(ctx context.Context) string {
   100  	v, ok := ctx.Value(oauth2SchemeName).(string)
   101  	if !ok {
   102  		return ""
   103  	}
   104  	return v
   105  }
   106  
   107  // BasicAuth creates a basic auth authenticator with the provided authentication function
   108  func BasicAuth(authenticate UserPassAuthentication) runtime.Authenticator {
   109  	return BasicAuthRealm(DefaultRealmName, authenticate)
   110  }
   111  
   112  // BasicAuthRealm creates a basic auth authenticator with the provided authentication function and realm name
   113  func BasicAuthRealm(realm string, authenticate UserPassAuthentication) runtime.Authenticator {
   114  	if realm == "" {
   115  		realm = DefaultRealmName
   116  	}
   117  
   118  	return HttpAuthenticator(func(r *http.Request) (bool, interface{}, error) {
   119  		if usr, pass, ok := r.BasicAuth(); ok {
   120  			p, err := authenticate(usr, pass)
   121  			if err != nil {
   122  				*r = *r.WithContext(context.WithValue(r.Context(), failedBasicAuth, realm))
   123  			}
   124  			return true, p, err
   125  		}
   126  		*r = *r.WithContext(context.WithValue(r.Context(), failedBasicAuth, realm))
   127  		return false, nil, nil
   128  	})
   129  }
   130  
   131  // BasicAuthCtx creates a basic auth authenticator with the provided authentication function with support for context.Context
   132  func BasicAuthCtx(authenticate UserPassAuthenticationCtx) runtime.Authenticator {
   133  	return BasicAuthRealmCtx(DefaultRealmName, authenticate)
   134  }
   135  
   136  // BasicAuthRealmCtx creates a basic auth authenticator with the provided authentication function and realm name with support for context.Context
   137  func BasicAuthRealmCtx(realm string, authenticate UserPassAuthenticationCtx) runtime.Authenticator {
   138  	if realm == "" {
   139  		realm = DefaultRealmName
   140  	}
   141  
   142  	return HttpAuthenticator(func(r *http.Request) (bool, interface{}, error) {
   143  		if usr, pass, ok := r.BasicAuth(); ok {
   144  			ctx, p, err := authenticate(r.Context(), usr, pass)
   145  			if err != nil {
   146  				ctx = context.WithValue(ctx, failedBasicAuth, realm)
   147  			}
   148  			*r = *r.WithContext(ctx)
   149  			return true, p, err
   150  		}
   151  		*r = *r.WithContext(context.WithValue(r.Context(), failedBasicAuth, realm))
   152  		return false, nil, nil
   153  	})
   154  }
   155  
   156  // APIKeyAuth creates an authenticator that uses a token for authorization.
   157  // This token can be obtained from either a header or a query string
   158  func APIKeyAuth(name, in string, authenticate TokenAuthentication) runtime.Authenticator {
   159  	inl := strings.ToLower(in)
   160  	if inl != query && inl != header {
   161  		// panic because this is most likely a typo
   162  		panic(errors.New(500, "api key auth: in value needs to be either \"query\" or \"header\""))
   163  	}
   164  
   165  	var getToken func(*http.Request) string
   166  	switch inl {
   167  	case header:
   168  		getToken = func(r *http.Request) string { return r.Header.Get(name) }
   169  	case query:
   170  		getToken = func(r *http.Request) string { return r.URL.Query().Get(name) }
   171  	}
   172  
   173  	return HttpAuthenticator(func(r *http.Request) (bool, interface{}, error) {
   174  		token := getToken(r)
   175  		if token == "" {
   176  			return false, nil, nil
   177  		}
   178  
   179  		p, err := authenticate(token)
   180  		return true, p, err
   181  	})
   182  }
   183  
   184  // APIKeyAuthCtx creates an authenticator that uses a token for authorization with support for context.Context.
   185  // This token can be obtained from either a header or a query string
   186  func APIKeyAuthCtx(name, in string, authenticate TokenAuthenticationCtx) runtime.Authenticator {
   187  	inl := strings.ToLower(in)
   188  	if inl != query && inl != header {
   189  		// panic because this is most likely a typo
   190  		panic(errors.New(500, "api key auth: in value needs to be either \"query\" or \"header\""))
   191  	}
   192  
   193  	var getToken func(*http.Request) string
   194  	switch inl {
   195  	case header:
   196  		getToken = func(r *http.Request) string { return r.Header.Get(name) }
   197  	case query:
   198  		getToken = func(r *http.Request) string { return r.URL.Query().Get(name) }
   199  	}
   200  
   201  	return HttpAuthenticator(func(r *http.Request) (bool, interface{}, error) {
   202  		token := getToken(r)
   203  		if token == "" {
   204  			return false, nil, nil
   205  		}
   206  
   207  		ctx, p, err := authenticate(r.Context(), token)
   208  		*r = *r.WithContext(ctx)
   209  		return true, p, err
   210  	})
   211  }
   212  
   213  // ScopedAuthRequest contains both a http request and the required scopes for a particular operation
   214  type ScopedAuthRequest struct {
   215  	Request        *http.Request
   216  	RequiredScopes []string
   217  }
   218  
   219  // BearerAuth for use with oauth2 flows
   220  func BearerAuth(name string, authenticate ScopedTokenAuthentication) runtime.Authenticator {
   221  	const prefix = "Bearer "
   222  	return ScopedAuthenticator(func(r *ScopedAuthRequest) (bool, interface{}, error) {
   223  		var token string
   224  		hdr := r.Request.Header.Get(runtime.HeaderAuthorization)
   225  		if strings.HasPrefix(hdr, prefix) {
   226  			token = strings.TrimPrefix(hdr, prefix)
   227  		}
   228  		if token == "" {
   229  			qs := r.Request.URL.Query()
   230  			token = qs.Get(accessTokenParam)
   231  		}
   232  		//#nosec
   233  		ct, _, _ := runtime.ContentType(r.Request.Header)
   234  		if token == "" && (ct == "application/x-www-form-urlencoded" || ct == "multipart/form-data") {
   235  			token = r.Request.FormValue(accessTokenParam)
   236  		}
   237  
   238  		if token == "" {
   239  			return false, nil, nil
   240  		}
   241  
   242  		rctx := context.WithValue(r.Request.Context(), oauth2SchemeName, name)
   243  		*r.Request = *r.Request.WithContext(rctx)
   244  		p, err := authenticate(token, r.RequiredScopes)
   245  		return true, p, err
   246  	})
   247  }
   248  
   249  // BearerAuthCtx for use with oauth2 flows with support for context.Context.
   250  func BearerAuthCtx(name string, authenticate ScopedTokenAuthenticationCtx) runtime.Authenticator {
   251  	const prefix = "Bearer "
   252  	return ScopedAuthenticator(func(r *ScopedAuthRequest) (bool, interface{}, error) {
   253  		var token string
   254  		hdr := r.Request.Header.Get(runtime.HeaderAuthorization)
   255  		if strings.HasPrefix(hdr, prefix) {
   256  			token = strings.TrimPrefix(hdr, prefix)
   257  		}
   258  		if token == "" {
   259  			qs := r.Request.URL.Query()
   260  			token = qs.Get(accessTokenParam)
   261  		}
   262  		//#nosec
   263  		ct, _, _ := runtime.ContentType(r.Request.Header)
   264  		if token == "" && (ct == "application/x-www-form-urlencoded" || ct == "multipart/form-data") {
   265  			token = r.Request.FormValue(accessTokenParam)
   266  		}
   267  
   268  		if token == "" {
   269  			return false, nil, nil
   270  		}
   271  
   272  		rctx := context.WithValue(r.Request.Context(), oauth2SchemeName, name)
   273  		ctx, p, err := authenticate(rctx, token, r.RequiredScopes)
   274  		*r.Request = *r.Request.WithContext(ctx)
   275  		return true, p, err
   276  	})
   277  }
   278  

View as plain text