...

Source file src/github.com/palantir/go-githubapp/oauth2/handler.go

Documentation: github.com/palantir/go-githubapp/oauth2

     1  // Copyright 2018 Palantir Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package oauth2
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"net/http"
    21  
    22  	"golang.org/x/oauth2"
    23  )
    24  
    25  const (
    26  	queryCode  = "code"
    27  	queryError = "error"
    28  	queryState = "state"
    29  )
    30  
    31  var (
    32  	ErrInvalidState = errors.New("oauth2: invalid state value")
    33  )
    34  
    35  // Login contains information about the result of a successful auth flow.
    36  type Login struct {
    37  	Token  *oauth2.Token
    38  	Client *http.Client
    39  }
    40  
    41  // LoginError is an error returned as a parameter by the OAuth provider.
    42  type LoginError string
    43  
    44  func (err LoginError) Error() string {
    45  	return string(err)
    46  }
    47  
    48  type Param func(*handler)
    49  
    50  type ErrorCallback func(w http.ResponseWriter, r *http.Request, err error)
    51  type LoginCallback func(w http.ResponseWriter, r *http.Request, login *Login)
    52  
    53  type handler struct {
    54  	config *oauth2.Config
    55  
    56  	onError ErrorCallback
    57  	onLogin LoginCallback
    58  
    59  	forceTLS bool
    60  	store    StateStore
    61  }
    62  
    63  // NewHandler returns an http.Hander that implements the 3-leg OAuth2 flow on a
    64  // single endpoint. It accepts callbacks for both error and success conditions
    65  // so that clients can take action after the auth flow is complete.
    66  func NewHandler(c *oauth2.Config, params ...Param) http.Handler {
    67  	h := &handler{
    68  		config:  c,
    69  		onError: DefaultErrorCallback,
    70  		onLogin: DefaultLoginCallback,
    71  		store:   insecureStateStore{},
    72  	}
    73  
    74  	for _, p := range params {
    75  		p(h)
    76  	}
    77  
    78  	return h
    79  }
    80  
    81  func DefaultErrorCallback(w http.ResponseWriter, r *http.Request, err error) {
    82  	if err == ErrInvalidState {
    83  		http.Error(w, "invalid state parameter", http.StatusBadRequest)
    84  		return
    85  	}
    86  	if _, ok := err.(LoginError); ok {
    87  		http.Error(w, fmt.Sprintf("oauth2 error: %v", err.Error()), http.StatusBadRequest)
    88  		return
    89  	}
    90  	http.Error(w, err.Error(), http.StatusInternalServerError)
    91  }
    92  
    93  func DefaultLoginCallback(w http.ResponseWriter, r *http.Request, login *Login) {
    94  	w.WriteHeader(http.StatusOK)
    95  }
    96  
    97  // ForceTLS determines if generated URLs always use HTTPS. By default, the
    98  // protocol of the request is used.
    99  func ForceTLS(forceTLS bool) Param {
   100  	return func(h *handler) {
   101  		h.forceTLS = forceTLS
   102  	}
   103  }
   104  
   105  // WithStore sets the StateStore used to create and verify OAuth2 states. The
   106  // default state store uses a static value, is insecure, and is not suitable
   107  // for production use.
   108  func WithStore(ss StateStore) Param {
   109  	return func(h *handler) {
   110  		h.store = ss
   111  	}
   112  }
   113  
   114  // OnError sets the error callback.
   115  func OnError(c ErrorCallback) Param {
   116  	return func(h *handler) {
   117  		h.onError = c
   118  	}
   119  }
   120  
   121  // OnLogin sets the login callback.
   122  func OnLogin(c LoginCallback) Param {
   123  	return func(h *handler) {
   124  		h.onLogin = c
   125  	}
   126  }
   127  
   128  // WithRedirectURL sets a static redirect URL. By default, the redirect URL is
   129  // generated using the request path, the Host header, and the ForceTLS option.
   130  func WithRedirectURL(uri string) Param {
   131  	return func(h *handler) {
   132  		h.config.RedirectURL = uri
   133  	}
   134  }
   135  
   136  func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   137  	// copy config for modification
   138  	conf := *h.config
   139  	if conf.RedirectURL == "" {
   140  		conf.RedirectURL = redirectURL(r, h.forceTLS)
   141  	}
   142  
   143  	// if the provider returned an error, abort the processes
   144  	if r.FormValue(queryError) != "" {
   145  		h.onError(w, r, LoginError(r.FormValue(queryError)))
   146  		return
   147  	}
   148  
   149  	// if this is an initial request, redirect to the provider
   150  	if isInitial(r) {
   151  		state, err := h.store.GenerateState(w, r)
   152  		if err != nil {
   153  			h.onError(w, r, err)
   154  			return
   155  		}
   156  
   157  		url := conf.AuthCodeURL(state, oauth2.AccessTypeOnline)
   158  		http.Redirect(w, r, url, http.StatusFound)
   159  		return
   160  	}
   161  
   162  	// otherwise, verify the state and complete the flow
   163  	isValid, err := h.store.VerifyState(r, r.FormValue(queryState))
   164  	if err != nil {
   165  		h.onError(w, r, err)
   166  		return
   167  	}
   168  
   169  	if !isValid {
   170  		h.onError(w, r, ErrInvalidState)
   171  		return
   172  	}
   173  
   174  	tok, err := conf.Exchange(r.Context(), r.FormValue(queryCode))
   175  	if err != nil {
   176  		h.onError(w, r, err)
   177  		return
   178  	}
   179  
   180  	h.onLogin(w, r, &Login{
   181  		Token:  tok,
   182  		Client: conf.Client(r.Context(), tok),
   183  	})
   184  }
   185  
   186  func isInitial(r *http.Request) bool {
   187  	return r.FormValue(queryCode) == ""
   188  }
   189  
   190  func redirectURL(r *http.Request, forceTLS bool) string {
   191  	u := *r.URL
   192  	u.Host = r.Host
   193  
   194  	if forceTLS || r.TLS != nil {
   195  		u.Scheme = "https"
   196  	} else {
   197  		u.Scheme = "http"
   198  	}
   199  
   200  	q := u.Query()
   201  	q.Del(queryCode)
   202  	q.Del(queryState)
   203  	u.RawQuery = q.Encode()
   204  
   205  	return u.String()
   206  }
   207  

View as plain text