1 package database
2
3 import (
4 "bytes"
5 "context"
6 "crypto/rand"
7 "encoding/base32"
8 "encoding/gob"
9 "errors"
10 "io"
11 "net/http"
12 "strings"
13 "time"
14
15 "edge-infra.dev/pkg/edge/iam/config"
16 "edge-infra.dev/pkg/edge/iam/crypto"
17
18 "github.com/go-redis/redis"
19 "github.com/gorilla/sessions"
20 )
21
22
23 type SessionStore struct {
24
25 client redis.UniversalClient
26
27 options sessions.Options
28
29 keyPrefix string
30
31 keyGen KeyGenFunc
32
33 serializer SessionSerializer
34 }
35
36
37 type KeyGenFunc func() (string, error)
38
39
40 func NewRedisSessionStore(_ context.Context, client redis.UniversalClient) (*SessionStore, error) {
41 rs := &SessionStore{
42 options: sessions.Options{
43 Path: "/",
44 MaxAge: config.SessionCookieMaxAge(),
45 HttpOnly: true,
46
47
48 SameSite: http.SameSiteLaxMode,
49 },
50 client: client,
51 keyPrefix: "session:",
52 keyGen: generateRandomKey,
53 serializer: GobSerializer{},
54 }
55
56 return rs, rs.client.Ping().Err()
57 }
58
59
60 func (s *SessionStore) Get(r *http.Request, name string) (*sessions.Session, error) {
61 return s.New(r, name)
62 }
63
64
65 func (s *SessionStore) New(r *http.Request, name string) (*sessions.Session, error) {
66 session := sessions.NewSession(s, name)
67 opts := s.options
68 session.Options = &opts
69 session.IsNew = true
70
71 c, err := r.Cookie(name)
72 if err != nil {
73 return session, nil
74 }
75 session.ID = c.Value
76
77 err = s.load(session)
78 if err == nil {
79 session.IsNew = false
80 } else if err == redis.Nil {
81 err = nil
82 }
83 return session, err
84 }
85
86
87
88
89
90
91
92 func (s *SessionStore) Save(_ *http.Request, w http.ResponseWriter, session *sessions.Session) error {
93
94 if session.Options.MaxAge <= 0 {
95 if err := s.delete(session); err != nil {
96 return err
97 }
98 http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
99 return nil
100 }
101
102 if session.ID == "" {
103 id, err := s.keyGen()
104 if err != nil {
105 return errors.New("SessionStore: failed to generate session id")
106 }
107 session.ID = id
108 }
109 if err := s.save(session); err != nil {
110 return err
111 }
112
113 http.SetCookie(w, sessions.NewCookie(session.Name(), session.ID, session.Options))
114 return nil
115 }
116
117
118 func (s *SessionStore) Options(opts sessions.Options) {
119 s.options = opts
120 }
121
122
123 func (s *SessionStore) KeyPrefix(keyPrefix string) {
124 s.keyPrefix = keyPrefix
125 }
126
127
128 func (s *SessionStore) KeyGen(f KeyGenFunc) {
129 s.keyGen = f
130 }
131
132
133 func (s *SessionStore) Serializer(ss SessionSerializer) {
134 s.serializer = ss
135 }
136
137
138 func (s *SessionStore) Close() error {
139 return s.client.Close()
140 }
141
142
143 func (s *SessionStore) save(session *sessions.Session) error {
144 b, err := s.serializer.Serialize(session)
145 if err != nil {
146 return err
147 }
148
149 if config.EncryptionEnabled() {
150 encryptedValue, err := crypto.EncryptRedis(b, config.EncryptionKey())
151 if err != nil {
152 return err
153 }
154 return s.client.Set(s.keyPrefix+session.ID, encryptedValue, time.Duration(session.Options.MaxAge)*time.Second).Err()
155 }
156
157 return s.client.Set(s.keyPrefix+session.ID, b, time.Duration(session.Options.MaxAge)*time.Second).Err()
158 }
159
160
161 func (s *SessionStore) load(session *sessions.Session) error {
162 cmd := s.client.Get(s.keyPrefix + session.ID)
163 if cmd.Err() != nil {
164 return cmd.Err()
165 }
166
167 b, err := cmd.Bytes()
168 if err != nil {
169 return err
170 }
171
172 if config.EncryptionEnabled() {
173
174 if !isRedisDataEncrypted(b) {
175 ttl := s.client.TTL(s.keyPrefix + session.ID)
176 encryptedVal, err := crypto.EncryptRedis(b, config.EncryptionKey())
177 if err != nil {
178 return err
179 }
180
181 err = s.client.Set(s.keyPrefix+session.ID, encryptedVal, ttl.Val()).Err()
182 if err != nil {
183 return err
184 }
185
186 return s.serializer.Deserialize(b, session)
187 }
188
189 decryptedValue, err := crypto.DecryptRedis(string(b), config.EncryptionKey())
190 if err != nil {
191 return err
192 }
193 return s.serializer.Deserialize(decryptedValue, session)
194 }
195 return s.serializer.Deserialize(b, session)
196 }
197
198
199 func (s *SessionStore) delete(session *sessions.Session) error {
200 return s.client.Del(s.keyPrefix + session.ID).Err()
201 }
202
203
204 type SessionSerializer interface {
205 Serialize(s *sessions.Session) ([]byte, error)
206 Deserialize(b []byte, s *sessions.Session) error
207 }
208
209
210 type GobSerializer struct{}
211
212 func (gs GobSerializer) Serialize(s *sessions.Session) ([]byte, error) {
213 buf := new(bytes.Buffer)
214 enc := gob.NewEncoder(buf)
215 err := enc.Encode(s.Values)
216 if err == nil {
217 return buf.Bytes(), nil
218 }
219 return nil, err
220 }
221
222 func (gs GobSerializer) Deserialize(d []byte, s *sessions.Session) error {
223 dec := gob.NewDecoder(bytes.NewBuffer(d))
224 return dec.Decode(&s.Values)
225 }
226
227
228 func generateRandomKey() (string, error) {
229 k := make([]byte, 64)
230 if _, err := io.ReadFull(rand.Reader, k); err != nil {
231 return "", err
232 }
233 return strings.TrimRight(base32.StdEncoding.EncodeToString(k), "="), nil
234 }
235
236
237 func isRedisDataEncrypted(data []byte) bool {
238 searchStr := []byte("EncryptedData:")
239 return bytes.Contains(data, searchStr)
240 }
241
View as plain text