...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package oauth2
16
17 import (
18 "crypto/rand"
19 "crypto/subtle"
20 "encoding/hex"
21 "net/http"
22
23 "github.com/gorilla/sessions"
24 "github.com/pkg/errors"
25 )
26
27 var (
28 DefaultSessionName = "oauth2"
29 sessionStateKey = "state"
30 )
31
32 type SessionStateStore struct {
33 Sessions sessions.Store
34 }
35
36 func (s *SessionStateStore) GenerateState(w http.ResponseWriter, r *http.Request) (string, error) {
37
38 sess, _ := s.Sessions.Get(r, DefaultSessionName)
39
40 b := make([]byte, 20)
41 if _, err := rand.Read(b); err != nil {
42 return "", errors.Wrap(err, "failed to generate state value")
43 }
44
45 state := hex.EncodeToString(b)
46 sess.Values[sessionStateKey] = state
47 return state, sess.Save(r, w)
48 }
49
50 func (s *SessionStateStore) VerifyState(r *http.Request, expected string) (bool, error) {
51 sess, err := s.Sessions.Get(r, DefaultSessionName)
52 if err != nil {
53 return false, err
54 }
55 st, ok := sess.Values[sessionStateKey]
56 if !ok {
57 return false, errors.New("no state value found in the session")
58 }
59
60 state, ok := st.(string)
61 if !ok {
62 return false, errors.New("session state value was an incorrect type")
63 }
64 return subtle.ConstantTimeCompare([]byte(expected), []byte(state)) == 1, nil
65 }
66
View as plain text