1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package oauth2
16
17 import (
18 "errors"
19 "fmt"
20 "net/http"
21
22 "golang.org/x/oauth2"
23 )
24
25 const (
26 queryCode = "code"
27 queryError = "error"
28 queryState = "state"
29 )
30
31 var (
32 ErrInvalidState = errors.New("oauth2: invalid state value")
33 )
34
35
36 type Login struct {
37 Token *oauth2.Token
38 Client *http.Client
39 }
40
41
42 type LoginError string
43
44 func (err LoginError) Error() string {
45 return string(err)
46 }
47
48 type Param func(*handler)
49
50 type ErrorCallback func(w http.ResponseWriter, r *http.Request, err error)
51 type LoginCallback func(w http.ResponseWriter, r *http.Request, login *Login)
52
53 type handler struct {
54 config *oauth2.Config
55
56 onError ErrorCallback
57 onLogin LoginCallback
58
59 forceTLS bool
60 store StateStore
61 }
62
63
64
65
66 func NewHandler(c *oauth2.Config, params ...Param) http.Handler {
67 h := &handler{
68 config: c,
69 onError: DefaultErrorCallback,
70 onLogin: DefaultLoginCallback,
71 store: insecureStateStore{},
72 }
73
74 for _, p := range params {
75 p(h)
76 }
77
78 return h
79 }
80
81 func DefaultErrorCallback(w http.ResponseWriter, r *http.Request, err error) {
82 if err == ErrInvalidState {
83 http.Error(w, "invalid state parameter", http.StatusBadRequest)
84 return
85 }
86 if _, ok := err.(LoginError); ok {
87 http.Error(w, fmt.Sprintf("oauth2 error: %v", err.Error()), http.StatusBadRequest)
88 return
89 }
90 http.Error(w, err.Error(), http.StatusInternalServerError)
91 }
92
93 func DefaultLoginCallback(w http.ResponseWriter, r *http.Request, login *Login) {
94 w.WriteHeader(http.StatusOK)
95 }
96
97
98
99 func ForceTLS(forceTLS bool) Param {
100 return func(h *handler) {
101 h.forceTLS = forceTLS
102 }
103 }
104
105
106
107
108 func WithStore(ss StateStore) Param {
109 return func(h *handler) {
110 h.store = ss
111 }
112 }
113
114
115 func OnError(c ErrorCallback) Param {
116 return func(h *handler) {
117 h.onError = c
118 }
119 }
120
121
122 func OnLogin(c LoginCallback) Param {
123 return func(h *handler) {
124 h.onLogin = c
125 }
126 }
127
128
129
130 func WithRedirectURL(uri string) Param {
131 return func(h *handler) {
132 h.config.RedirectURL = uri
133 }
134 }
135
136 func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
137
138 conf := *h.config
139 if conf.RedirectURL == "" {
140 conf.RedirectURL = redirectURL(r, h.forceTLS)
141 }
142
143
144 if r.FormValue(queryError) != "" {
145 h.onError(w, r, LoginError(r.FormValue(queryError)))
146 return
147 }
148
149
150 if isInitial(r) {
151 state, err := h.store.GenerateState(w, r)
152 if err != nil {
153 h.onError(w, r, err)
154 return
155 }
156
157 url := conf.AuthCodeURL(state, oauth2.AccessTypeOnline)
158 http.Redirect(w, r, url, http.StatusFound)
159 return
160 }
161
162
163 isValid, err := h.store.VerifyState(r, r.FormValue(queryState))
164 if err != nil {
165 h.onError(w, r, err)
166 return
167 }
168
169 if !isValid {
170 h.onError(w, r, ErrInvalidState)
171 return
172 }
173
174 tok, err := conf.Exchange(r.Context(), r.FormValue(queryCode))
175 if err != nil {
176 h.onError(w, r, err)
177 return
178 }
179
180 h.onLogin(w, r, &Login{
181 Token: tok,
182 Client: conf.Client(r.Context(), tok),
183 })
184 }
185
186 func isInitial(r *http.Request) bool {
187 return r.FormValue(queryCode) == ""
188 }
189
190 func redirectURL(r *http.Request, forceTLS bool) string {
191 u := *r.URL
192 u.Host = r.Host
193
194 if forceTLS || r.TLS != nil {
195 u.Scheme = "https"
196 } else {
197 u.Scheme = "http"
198 }
199
200 q := u.Query()
201 q.Del(queryCode)
202 q.Del(queryState)
203 u.RawQuery = q.Encode()
204
205 return u.String()
206 }
207
View as plain text