...

Source file src/github.com/docker/distribution/contrib/token-server/main.go

Documentation: github.com/docker/distribution/contrib/token-server

     1  package main
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"encoding/json"
     7  	"flag"
     8  	"math/big"
     9  	"net/http"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	dcontext "github.com/docker/distribution/context"
    15  	"github.com/docker/distribution/registry/api/errcode"
    16  	"github.com/docker/distribution/registry/auth"
    17  	_ "github.com/docker/distribution/registry/auth/htpasswd"
    18  	"github.com/docker/libtrust"
    19  	"github.com/gorilla/mux"
    20  	"github.com/sirupsen/logrus"
    21  )
    22  
    23  var (
    24  	enforceRepoClass bool
    25  )
    26  
    27  func main() {
    28  	var (
    29  		issuer = &TokenIssuer{}
    30  		pkFile string
    31  		addr   string
    32  		debug  bool
    33  		err    error
    34  
    35  		passwdFile string
    36  		realm      string
    37  
    38  		cert    string
    39  		certKey string
    40  	)
    41  
    42  	flag.StringVar(&issuer.Issuer, "issuer", "distribution-token-server", "Issuer string for token")
    43  	flag.StringVar(&pkFile, "key", "", "Private key file")
    44  	flag.StringVar(&addr, "addr", "localhost:8080", "Address to listen on")
    45  	flag.BoolVar(&debug, "debug", false, "Debug mode")
    46  
    47  	flag.StringVar(&passwdFile, "passwd", ".htpasswd", "Passwd file")
    48  	flag.StringVar(&realm, "realm", "", "Authentication realm")
    49  
    50  	flag.StringVar(&cert, "tlscert", "", "Certificate file for TLS")
    51  	flag.StringVar(&certKey, "tlskey", "", "Certificate key for TLS")
    52  
    53  	flag.BoolVar(&enforceRepoClass, "enforce-class", false, "Enforce policy for single repository class")
    54  
    55  	flag.Parse()
    56  
    57  	if debug {
    58  		logrus.SetLevel(logrus.DebugLevel)
    59  	}
    60  
    61  	if pkFile == "" {
    62  		issuer.SigningKey, err = libtrust.GenerateECP256PrivateKey()
    63  		if err != nil {
    64  			logrus.Fatalf("Error generating private key: %v", err)
    65  		}
    66  		logrus.Debugf("Using newly generated key with id %s", issuer.SigningKey.KeyID())
    67  	} else {
    68  		issuer.SigningKey, err = libtrust.LoadKeyFile(pkFile)
    69  		if err != nil {
    70  			logrus.Fatalf("Error loading key file %s: %v", pkFile, err)
    71  		}
    72  		logrus.Debugf("Loaded private key with id %s", issuer.SigningKey.KeyID())
    73  	}
    74  
    75  	if realm == "" {
    76  		logrus.Fatalf("Must provide realm")
    77  	}
    78  
    79  	ac, err := auth.GetAccessController("htpasswd", map[string]interface{}{
    80  		"realm": realm,
    81  		"path":  passwdFile,
    82  	})
    83  	if err != nil {
    84  		logrus.Fatalf("Error initializing access controller: %v", err)
    85  	}
    86  
    87  	// TODO: Make configurable
    88  	issuer.Expiration = 15 * time.Minute
    89  
    90  	ctx := dcontext.Background()
    91  
    92  	ts := &tokenServer{
    93  		issuer:           issuer,
    94  		accessController: ac,
    95  		refreshCache:     map[string]refreshToken{},
    96  	}
    97  
    98  	router := mux.NewRouter()
    99  	router.Path("/token/").Methods("GET").Handler(handlerWithContext(ctx, ts.getToken))
   100  	router.Path("/token/").Methods("POST").Handler(handlerWithContext(ctx, ts.postToken))
   101  
   102  	if cert == "" {
   103  		err = http.ListenAndServe(addr, router)
   104  	} else if certKey == "" {
   105  		logrus.Fatalf("Must provide certficate (-tlscert) and key (-tlskey)")
   106  	} else {
   107  		err = http.ListenAndServeTLS(addr, cert, certKey, router)
   108  	}
   109  
   110  	if err != nil {
   111  		logrus.Infof("Error serving: %v", err)
   112  	}
   113  
   114  }
   115  
   116  // handlerWithContext wraps the given context-aware handler by setting up the
   117  // request context from a base context.
   118  func handlerWithContext(ctx context.Context, handler func(context.Context, http.ResponseWriter, *http.Request)) http.Handler {
   119  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   120  		ctx := dcontext.WithRequest(ctx, r)
   121  		logger := dcontext.GetRequestLogger(ctx)
   122  		ctx = dcontext.WithLogger(ctx, logger)
   123  
   124  		handler(ctx, w, r)
   125  	})
   126  }
   127  
   128  func handleError(ctx context.Context, err error, w http.ResponseWriter) {
   129  	ctx, w = dcontext.WithResponseWriter(ctx, w)
   130  
   131  	if serveErr := errcode.ServeJSON(w, err); serveErr != nil {
   132  		dcontext.GetResponseLogger(ctx).Errorf("error sending error response: %v", serveErr)
   133  		return
   134  	}
   135  
   136  	dcontext.GetResponseLogger(ctx).Info("application error")
   137  }
   138  
   139  var refreshCharacters = []rune("0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
   140  
   141  const refreshTokenLength = 15
   142  
   143  func newRefreshToken() string {
   144  	s := make([]rune, refreshTokenLength)
   145  	max := int64(len(refreshCharacters))
   146  	for i := range s {
   147  		randInt, err := rand.Int(rand.Reader, big.NewInt(max))
   148  		// let '0' serves the failure case
   149  		if err != nil {
   150  			logrus.Infof("Error on making refersh token: %v", err)
   151  			randInt = big.NewInt(0)
   152  		}
   153  		s[i] = refreshCharacters[randInt.Int64()]
   154  	}
   155  	return string(s)
   156  }
   157  
   158  type refreshToken struct {
   159  	subject string
   160  	service string
   161  }
   162  
   163  type tokenServer struct {
   164  	issuer           *TokenIssuer
   165  	accessController auth.AccessController
   166  	refreshCache     map[string]refreshToken
   167  }
   168  
   169  type tokenResponse struct {
   170  	Token        string `json:"access_token"`
   171  	RefreshToken string `json:"refresh_token,omitempty"`
   172  	ExpiresIn    int    `json:"expires_in,omitempty"`
   173  }
   174  
   175  var repositoryClassCache = map[string]string{}
   176  
   177  func filterAccessList(ctx context.Context, scope string, requestedAccessList []auth.Access) []auth.Access {
   178  	if !strings.HasSuffix(scope, "/") {
   179  		scope = scope + "/"
   180  	}
   181  	grantedAccessList := make([]auth.Access, 0, len(requestedAccessList))
   182  	for _, access := range requestedAccessList {
   183  		if access.Type == "repository" {
   184  			if !strings.HasPrefix(access.Name, scope) {
   185  				dcontext.GetLogger(ctx).Debugf("Resource scope not allowed: %s", access.Name)
   186  				continue
   187  			}
   188  			if enforceRepoClass {
   189  				if class, ok := repositoryClassCache[access.Name]; ok {
   190  					if class != access.Class {
   191  						dcontext.GetLogger(ctx).Debugf("Different repository class: %q, previously %q", access.Class, class)
   192  						continue
   193  					}
   194  				} else if strings.EqualFold(access.Action, "push") {
   195  					repositoryClassCache[access.Name] = access.Class
   196  				}
   197  			}
   198  		} else if access.Type == "registry" {
   199  			if access.Name != "catalog" {
   200  				dcontext.GetLogger(ctx).Debugf("Unknown registry resource: %s", access.Name)
   201  				continue
   202  			}
   203  			// TODO: Limit some actions to "admin" users
   204  		} else {
   205  			dcontext.GetLogger(ctx).Debugf("Skipping unsupported resource type: %s", access.Type)
   206  			continue
   207  		}
   208  		grantedAccessList = append(grantedAccessList, access)
   209  	}
   210  	return grantedAccessList
   211  }
   212  
   213  type acctSubject struct{}
   214  
   215  func (acctSubject) String() string { return "acctSubject" }
   216  
   217  type requestedAccess struct{}
   218  
   219  func (requestedAccess) String() string { return "requestedAccess" }
   220  
   221  type grantedAccess struct{}
   222  
   223  func (grantedAccess) String() string { return "grantedAccess" }
   224  
   225  // getToken handles authenticating the request and authorizing access to the
   226  // requested scopes.
   227  func (ts *tokenServer) getToken(ctx context.Context, w http.ResponseWriter, r *http.Request) {
   228  	dcontext.GetLogger(ctx).Info("getToken")
   229  
   230  	params := r.URL.Query()
   231  	service := params.Get("service")
   232  	scopeSpecifiers := params["scope"]
   233  	var offline bool
   234  	if offlineStr := params.Get("offline_token"); offlineStr != "" {
   235  		var err error
   236  		offline, err = strconv.ParseBool(offlineStr)
   237  		if err != nil {
   238  			handleError(ctx, ErrorBadTokenOption.WithDetail(err), w)
   239  			return
   240  		}
   241  	}
   242  
   243  	requestedAccessList := ResolveScopeSpecifiers(ctx, scopeSpecifiers)
   244  
   245  	authorizedCtx, err := ts.accessController.Authorized(ctx, requestedAccessList...)
   246  	if err != nil {
   247  		challenge, ok := err.(auth.Challenge)
   248  		if !ok {
   249  			handleError(ctx, err, w)
   250  			return
   251  		}
   252  
   253  		// Get response context.
   254  		ctx, w = dcontext.WithResponseWriter(ctx, w)
   255  
   256  		challenge.SetHeaders(r, w)
   257  		handleError(ctx, errcode.ErrorCodeUnauthorized.WithDetail(challenge.Error()), w)
   258  
   259  		dcontext.GetResponseLogger(ctx).Info("get token authentication challenge")
   260  
   261  		return
   262  	}
   263  	ctx = authorizedCtx
   264  
   265  	username := dcontext.GetStringValue(ctx, "auth.user.name")
   266  
   267  	ctx = context.WithValue(ctx, acctSubject{}, username)
   268  	ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, acctSubject{}))
   269  
   270  	dcontext.GetLogger(ctx).Info("authenticated client")
   271  
   272  	ctx = context.WithValue(ctx, requestedAccess{}, requestedAccessList)
   273  	ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, requestedAccess{}))
   274  
   275  	grantedAccessList := filterAccessList(ctx, username, requestedAccessList)
   276  	ctx = context.WithValue(ctx, grantedAccess{}, grantedAccessList)
   277  	ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, grantedAccess{}))
   278  
   279  	token, err := ts.issuer.CreateJWT(username, service, grantedAccessList)
   280  	if err != nil {
   281  		handleError(ctx, err, w)
   282  		return
   283  	}
   284  
   285  	dcontext.GetLogger(ctx).Info("authorized client")
   286  
   287  	response := tokenResponse{
   288  		Token:     token,
   289  		ExpiresIn: int(ts.issuer.Expiration.Seconds()),
   290  	}
   291  
   292  	if offline {
   293  		response.RefreshToken = newRefreshToken()
   294  		ts.refreshCache[response.RefreshToken] = refreshToken{
   295  			subject: username,
   296  			service: service,
   297  		}
   298  	}
   299  
   300  	ctx, w = dcontext.WithResponseWriter(ctx, w)
   301  
   302  	w.Header().Set("Content-Type", "application/json")
   303  	json.NewEncoder(w).Encode(response)
   304  
   305  	dcontext.GetResponseLogger(ctx).Info("get token complete")
   306  }
   307  
   308  type postTokenResponse struct {
   309  	Token        string `json:"access_token"`
   310  	Scope        string `json:"scope,omitempty"`
   311  	ExpiresIn    int    `json:"expires_in,omitempty"`
   312  	IssuedAt     string `json:"issued_at,omitempty"`
   313  	RefreshToken string `json:"refresh_token,omitempty"`
   314  }
   315  
   316  // postToken handles authenticating the request and authorizing access to the
   317  // requested scopes.
   318  func (ts *tokenServer) postToken(ctx context.Context, w http.ResponseWriter, r *http.Request) {
   319  	grantType := r.PostFormValue("grant_type")
   320  	if grantType == "" {
   321  		handleError(ctx, ErrorMissingRequiredField.WithDetail("missing grant_type value"), w)
   322  		return
   323  	}
   324  
   325  	service := r.PostFormValue("service")
   326  	if service == "" {
   327  		handleError(ctx, ErrorMissingRequiredField.WithDetail("missing service value"), w)
   328  		return
   329  	}
   330  
   331  	clientID := r.PostFormValue("client_id")
   332  	if clientID == "" {
   333  		handleError(ctx, ErrorMissingRequiredField.WithDetail("missing client_id value"), w)
   334  		return
   335  	}
   336  
   337  	var offline bool
   338  	switch r.PostFormValue("access_type") {
   339  	case "", "online":
   340  	case "offline":
   341  		offline = true
   342  	default:
   343  		handleError(ctx, ErrorUnsupportedValue.WithDetail("unknown access_type value"), w)
   344  		return
   345  	}
   346  
   347  	requestedAccessList := ResolveScopeList(ctx, r.PostFormValue("scope"))
   348  
   349  	var subject string
   350  	var rToken string
   351  	switch grantType {
   352  	case "refresh_token":
   353  		rToken = r.PostFormValue("refresh_token")
   354  		if rToken == "" {
   355  			handleError(ctx, ErrorUnsupportedValue.WithDetail("missing refresh_token value"), w)
   356  			return
   357  		}
   358  		rt, ok := ts.refreshCache[rToken]
   359  		if !ok || rt.service != service {
   360  			handleError(ctx, errcode.ErrorCodeUnauthorized.WithDetail("invalid refresh token"), w)
   361  			return
   362  		}
   363  		subject = rt.subject
   364  	case "password":
   365  		ca, ok := ts.accessController.(auth.CredentialAuthenticator)
   366  		if !ok {
   367  			handleError(ctx, ErrorUnsupportedValue.WithDetail("password grant type not supported"), w)
   368  			return
   369  		}
   370  		subject = r.PostFormValue("username")
   371  		if subject == "" {
   372  			handleError(ctx, ErrorUnsupportedValue.WithDetail("missing username value"), w)
   373  			return
   374  		}
   375  		password := r.PostFormValue("password")
   376  		if password == "" {
   377  			handleError(ctx, ErrorUnsupportedValue.WithDetail("missing password value"), w)
   378  			return
   379  		}
   380  		if err := ca.AuthenticateUser(subject, password); err != nil {
   381  			handleError(ctx, errcode.ErrorCodeUnauthorized.WithDetail("invalid credentials"), w)
   382  			return
   383  		}
   384  	default:
   385  		handleError(ctx, ErrorUnsupportedValue.WithDetail("unknown grant_type value"), w)
   386  		return
   387  	}
   388  
   389  	ctx = context.WithValue(ctx, acctSubject{}, subject)
   390  	ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, acctSubject{}))
   391  
   392  	dcontext.GetLogger(ctx).Info("authenticated client")
   393  
   394  	ctx = context.WithValue(ctx, requestedAccess{}, requestedAccessList)
   395  	ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, requestedAccess{}))
   396  
   397  	grantedAccessList := filterAccessList(ctx, subject, requestedAccessList)
   398  	ctx = context.WithValue(ctx, grantedAccess{}, grantedAccessList)
   399  	ctx = dcontext.WithLogger(ctx, dcontext.GetLogger(ctx, grantedAccess{}))
   400  
   401  	token, err := ts.issuer.CreateJWT(subject, service, grantedAccessList)
   402  	if err != nil {
   403  		handleError(ctx, err, w)
   404  		return
   405  	}
   406  
   407  	dcontext.GetLogger(ctx).Info("authorized client")
   408  
   409  	response := postTokenResponse{
   410  		Token:     token,
   411  		ExpiresIn: int(ts.issuer.Expiration.Seconds()),
   412  		IssuedAt:  time.Now().UTC().Format(time.RFC3339),
   413  		Scope:     ToScopeList(grantedAccessList),
   414  	}
   415  
   416  	if offline {
   417  		rToken = newRefreshToken()
   418  		ts.refreshCache[rToken] = refreshToken{
   419  			subject: subject,
   420  			service: service,
   421  		}
   422  	}
   423  
   424  	if rToken != "" {
   425  		response.RefreshToken = rToken
   426  	}
   427  
   428  	ctx, w = dcontext.WithResponseWriter(ctx, w)
   429  
   430  	w.Header().Set("Content-Type", "application/json")
   431  	json.NewEncoder(w).Encode(response)
   432  
   433  	dcontext.GetResponseLogger(ctx).Info("post token complete")
   434  }
   435  

View as plain text