1 package store
2
3 import (
4 "database/sql"
5 "encoding/base32"
6 "fmt"
7 "net/http"
8 "strings"
9 "time"
10
11 "errors"
12
13 "github.com/go-logr/logr"
14 "github.com/gorilla/securecookie"
15 "github.com/gorilla/sessions"
16
17
18 _ "github.com/lib/pq"
19
20 "edge-infra.dev/pkg/edge/audit"
21 )
22
23
24 type PGStore struct {
25 Codecs []securecookie.Codec
26 Options *sessions.Options
27 Path string
28 DbPool *sql.DB
29 log logr.Logger
30 auditLog *audit.Sink
31 }
32
33
34 type PGSession struct {
35 ID int64
36 Key string
37 Data string
38 CreatedOn time.Time
39 ModifiedOn time.Time
40 ExpiresOn time.Time
41 }
42
43
44
45 func NewPGStore(dbURL string, logger logr.Logger, auditLog *audit.Sink, keyPairs ...[]byte) (*PGStore, error) {
46 db, err := sql.Open("postgres", dbURL)
47 if err != nil {
48
49 return nil, err
50 }
51 return NewPGStoreFromPool(db, logger, auditLog, keyPairs...)
52 }
53
54
55
56
57 func NewPGStoreFromPool(db *sql.DB, logger logr.Logger, auditLog *audit.Sink, keyPairs ...[]byte) (*PGStore, error) {
58 dbStore := &PGStore{
59 Codecs: securecookie.CodecsFromPairs(keyPairs...),
60 Options: &sessions.Options{
61 Path: "/",
62 MaxAge: 86400 * 30,
63 },
64 DbPool: db,
65 log: logger,
66 auditLog: auditLog,
67 }
68
69
70 err := dbStore.createSessionsTable()
71 if err != nil {
72 return nil, err
73 }
74
75 return dbStore, nil
76 }
77
78
79 func (db *PGStore) Close() {
80 db.DbPool.Close()
81 }
82
83
84
85 func (db *PGStore) Get(r *http.Request, name string) (*sessions.Session, error) {
86 return sessions.GetRegistry(r).Get(db, name)
87 }
88
89
90 func (db *PGStore) New(r *http.Request, name string) (*sessions.Session, error) {
91 session := sessions.NewSession(db, name)
92 if session == nil {
93 return nil, nil
94 }
95
96 opts := *db.Options
97 session.Options = &(opts)
98 session.IsNew = true
99
100 var err error
101 if c, errCookie := r.Cookie(name); errCookie == nil {
102 err = securecookie.DecodeMulti(name, c.Value, &session.ID, db.Codecs...)
103 if err == nil {
104 err = db.load(session)
105 if err == nil {
106 session.IsNew = false
107 } else if errors.Is(err, sql.ErrNoRows) {
108 err = nil
109 }
110 }
111 }
112
113 db.MaxAge(db.Options.MaxAge)
114
115 return session, err
116 }
117
118
119 func (db *PGStore) Save(_ *http.Request, w http.ResponseWriter, session *sessions.Session) error {
120
121 if session.Options.MaxAge < 0 {
122 if err := db.destroy(session); err != nil {
123 return err
124 }
125 http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
126 return nil
127 }
128
129 if session.ID == "" {
130
131 session.ID = strings.TrimRight(
132 base32.StdEncoding.EncodeToString(
133 securecookie.GenerateRandomKey(32),
134 ), "=")
135 }
136
137 if err := db.save(session); err != nil {
138 return err
139 }
140
141
142 encoded, err := securecookie.EncodeMulti(session.Name(), session.ID, db.Codecs...)
143 if err != nil {
144 return err
145 }
146
147 http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
148 return nil
149 }
150
151
152
153
154
155 func (db *PGStore) MaxLength(l int) {
156 for _, c := range db.Codecs {
157 if codec, ok := c.(*securecookie.SecureCookie); ok {
158 codec.MaxLength(l)
159 }
160 }
161 }
162
163
164
165
166 func (db *PGStore) MaxAge(age int) {
167 db.Options.MaxAge = age
168
169
170 for _, codec := range db.Codecs {
171 if sc, ok := codec.(*securecookie.SecureCookie); ok {
172 sc.MaxAge(age)
173 }
174 }
175 }
176
177
178
179 func (db *PGStore) load(session *sessions.Session) error {
180 var s PGSession
181 err := db.selectOne(&s, session.ID)
182 if err != nil {
183 return err
184 }
185 return securecookie.DecodeMulti(session.Name(), s.Data, &session.Values, db.Codecs...)
186 }
187
188
189
190 func (db *PGStore) save(session *sessions.Session) error {
191 status := "Success"
192 defer func() {
193 db.audit(SaveSessionOp, status, session)
194 }()
195 encoded, err := securecookie.EncodeMulti(session.Name(), session.Values, db.Codecs...)
196 if err != nil {
197 status = Failure
198 return err
199 }
200
201 crOn := session.Values["created_on"]
202 exOn := session.Values["expires_on"]
203
204 var expiresOn time.Time
205
206 createdOn, ok := crOn.(time.Time)
207 if !ok {
208 createdOn = time.Now()
209 }
210
211 if exOn == nil {
212 expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge))
213 } else {
214 expiresOn = exOn.(time.Time)
215 if expiresOn.Sub(time.Now().Add(time.Second*time.Duration(session.Options.MaxAge))) < 0 {
216 expiresOn = time.Now().Add(time.Second * time.Duration(session.Options.MaxAge))
217 }
218 }
219
220 s := PGSession{
221 Key: session.ID,
222 Data: encoded,
223 CreatedOn: createdOn,
224 ExpiresOn: expiresOn,
225 ModifiedOn: time.Now(),
226 }
227
228 if session.IsNew {
229 return db.insert(&s)
230 }
231
232 return db.update(&s)
233 }
234
235
236 func (db *PGStore) destroy(session *sessions.Session) error {
237
238
239
240
241
242
243
244
245
246
247 status := Success
248 defer func() {
249 db.audit(DestroySessionOp, status, session)
250 }()
251 _, err := db.DbPool.Exec("DELETE FROM http_sessions WHERE key = $1", session.ID)
252 if err != nil {
253 if errors.Is(err, sql.ErrNoRows) {
254 return nil
255 }
256
257 db.log.Error(err, "an error occurred deleting session", enrichLogWithSession(DestroySessionOp, session)...)
258 return err
259 }
260 return nil
261 }
262
263 func (db *PGStore) createSessionsTable() error {
264 stmt := `DO $$
265 BEGIN
266 CREATE TABLE IF NOT EXISTS http_sessions (
267 id BIGSERIAL PRIMARY KEY,
268 key BYTEA,
269 data BYTEA,
270 created_on TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP,
271 modified_on TIMESTAMPTZ,
272 expires_on TIMESTAMPTZ);
273 CREATE INDEX IF NOT EXISTS http_sessions_expiry_idx ON http_sessions (expires_on);
274 CREATE INDEX IF NOT EXISTS http_sessions_key_idx ON http_sessions (key);
275 EXCEPTION WHEN insufficient_privilege THEN
276 IF NOT EXISTS (SELECT FROM pg_catalog.pg_tables WHERE schemaname = current_schema() AND tablename = 'http_sessions') THEN
277 RAISE;
278 END IF;
279 WHEN others THEN RAISE;
280 END;
281 $$;`
282
283 _, err := db.DbPool.Exec(stmt)
284 if err != nil {
285 return fmt.Errorf("unable to create http_sessions table in the database, err: %w", err)
286 }
287
288 return nil
289 }
290
291 func (db *PGStore) selectOne(s *PGSession, key string) error {
292 stmt := "SELECT id, key, data, created_on, modified_on, expires_on FROM http_sessions WHERE key = $1"
293 err := db.DbPool.QueryRow(stmt, key).Scan(&s.ID, &s.Key, &s.Data, &s.CreatedOn, &s.ModifiedOn, &s.ExpiresOn)
294 if err != nil {
295 if !errors.Is(err, sql.ErrNoRows) {
296 return fmt.Errorf("unable to find session in the database, err: %w", err)
297 }
298 }
299 return nil
300 }
301
302 func (db *PGStore) insert(s *PGSession) error {
303 stmt := `INSERT INTO http_sessions (key, data, created_on, modified_on, expires_on)
304 VALUES ($1, $2, $3, $4, $5)`
305 _, err := db.DbPool.Exec(stmt, s.Key, s.Data, s.CreatedOn, s.ModifiedOn, s.ExpiresOn)
306
307 return err
308 }
309
310 func (db *PGStore) update(s *PGSession) error {
311 stmt := `UPDATE http_sessions SET data=$1, modified_on=$2, expires_on=$3 WHERE key=$4`
312 _, err := db.DbPool.Exec(stmt, s.Data, s.ModifiedOn, s.ExpiresOn, s.Key)
313
314 return err
315 }
316
317
318
319
320
321
322
View as plain text